Llama-3.1-8B-FlashNorm

FlashNorm-prepared compatibility checkpoint of meta-llama/Llama-3.1-8B, derived from Meta's original weights obtained via the unsloth/Llama-3.1-8B ungated mirror (bit-identical to the upstream release).

The FlashNorm transformation is mathematically exact. This checkpoint loads in stock transformers and vLLM without any code changes.

What is FlashNorm?

An exact reformulation of RMSNorm → Linear that (i) folds the per-channel normalization weights into the following linear layer (W* = W · diag(g)) and (ii) defers the scalar 1/RMS(x) normalization to after the matmul. On hardware with distinct vector and matrix units, the matrix multiplication and the RMS reduction can execute in parallel.

See the paper and the transformer-tricks repo for details.

What’s different from the source checkpoint?

Tensor Source This checkpoint
model.layers.*.input_layernorm.weight learned per-channel g all ones
model.layers.*.self_attn.{q,k,v}_proj.weight W W · diag(g_input_layernorm)
model.layers.*.post_attention_layernorm.weight learned per-channel g all ones
model.layers.*.mlp.{gate,up}_proj.weight W W · diag(g_post_attention_layernorm)
model.norm.weight learned per-channel g all ones
lm_head.weight W W · diag(g_model_norm)

Llama-3.1-8B has untied embeddings, so model.norm is also folded into lm_head. All tensors are stored in the source dtype (bfloat16); merged products are computed in float32 internally before casting back.

Usage

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

tok = AutoTokenizer.from_pretrained('open-machine/Llama-3.1-8B-FlashNorm')
model = AutoModelForCausalLM.from_pretrained(
    'open-machine/Llama-3.1-8B-FlashNorm',
    dtype=torch.float16,
).cuda().eval()

ids = tok('Once upon a time there was', return_tensors='pt').input_ids.cuda()
out = model.generate(ids, max_new_tokens=50, do_sample=False)
print(tok.decode(out[0], skip_special_tokens=True))

With vLLM:

vllm serve open-machine/Llama-3.1-8B-FlashNorm

Framework behavior

The FlashNorm transformation is mathematically exact.

  • HuggingFace Transformers at fp32: greedy generation is bit-identical to the source.
  • HuggingFace Transformers at fp16 and vLLM (any precision): a one-token argmax flip is possible at tight decision points; downstream greedy decoding then amplifies this. Reason: precomputed merged weights interact differently with lossy inference kernels than runtime x·g·W would.

This is a general property of precomputing weight-folded tensors for lossy-inference kernels, not specific to FlashNorm. A native fused RMSNorm + QKV kernel (deferring g to runtime) eliminates the framework dependency and is in progress for vLLM / FlashInfer.

License

Llama 3.1 Community License, inherited from the source model.

Downloads last month
624
Safetensors
Model size
8B params
Tensor type
F32
·
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for open-machine/Llama-3.1-8B-FlashNorm

Finetuned
(1767)
this model
Quantizations
2 models