Llama-3.1-8B-FlashNorm
FlashNorm-prepared compatibility checkpoint of meta-llama/Llama-3.1-8B, derived from Meta's original weights obtained via the unsloth/Llama-3.1-8B ungated mirror (bit-identical to the upstream release).
The FlashNorm transformation is mathematically exact. This checkpoint loads in stock transformers and vLLM without any code changes.
What is FlashNorm?
An exact reformulation of RMSNorm → Linear that (i) folds the per-channel normalization weights into the following linear layer (W* = W · diag(g)) and (ii) defers the scalar 1/RMS(x) normalization to after the matmul. On hardware with distinct vector and matrix units, the matrix multiplication and the RMS reduction can execute in parallel.
See the paper and the transformer-tricks repo for details.
What’s different from the source checkpoint?
| Tensor | Source | This checkpoint |
|---|---|---|
model.layers.*.input_layernorm.weight |
learned per-channel g |
all ones |
model.layers.*.self_attn.{q,k,v}_proj.weight |
W |
W · diag(g_input_layernorm) |
model.layers.*.post_attention_layernorm.weight |
learned per-channel g |
all ones |
model.layers.*.mlp.{gate,up}_proj.weight |
W |
W · diag(g_post_attention_layernorm) |
model.norm.weight |
learned per-channel g |
all ones |
lm_head.weight |
W |
W · diag(g_model_norm) |
Llama-3.1-8B has untied embeddings, so model.norm is also folded into lm_head. All tensors are stored in the source dtype (bfloat16); merged products are computed in float32 internally before casting back.
Usage
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
tok = AutoTokenizer.from_pretrained('open-machine/Llama-3.1-8B-FlashNorm')
model = AutoModelForCausalLM.from_pretrained(
'open-machine/Llama-3.1-8B-FlashNorm',
dtype=torch.float16,
).cuda().eval()
ids = tok('Once upon a time there was', return_tensors='pt').input_ids.cuda()
out = model.generate(ids, max_new_tokens=50, do_sample=False)
print(tok.decode(out[0], skip_special_tokens=True))
With vLLM:
vllm serve open-machine/Llama-3.1-8B-FlashNorm
Framework behavior
The FlashNorm transformation is mathematically exact.
- HuggingFace Transformers at fp32: greedy generation is bit-identical to the source.
- HuggingFace Transformers at fp16 and vLLM (any precision): a one-token argmax flip is possible at tight decision points; downstream greedy decoding then amplifies this. Reason: precomputed merged weights interact differently with lossy inference kernels than runtime
x·g·Wwould.
This is a general property of precomputing weight-folded tensors for lossy-inference kernels, not specific to FlashNorm. A native fused RMSNorm + QKV kernel (deferring g to runtime) eliminates the framework dependency and is in progress for vLLM / FlashInfer.
License
Llama 3.1 Community License, inherited from the source model.
- Downloads last month
- 624