Open
Conversation
Add comprehensive FP8 quantized model support for models like Qwen3-FP8. This enables loading and running FP8 models with per-block scale factors. Changes: bumblebee.ex: - Add :preserve_source_types option to load_model/2 to keep FP8 types pytorch_params.ex: - Pass preserve_source_types through param loading pipeline - Modify ensure_type/3 to preserve FP8 types when option is set layers.ex: - Add fp8_aware_dense/3 layer that handles FP8 quantized weights - Implements block-wise dequantization using scale_inv parameter - Automatically falls back to identity scaling for non-FP8 models layers/transformer.ex: - Add :attention_dense option to blocks/2, block/2, multi_head_attention/4 - Allows custom dense function for Q, K, V, and output projections text/qwen3.ex: - Update decoder to use fp8_aware_dense for attention via attention_dense - Update gated_ffn to use fp8_aware_dense for FFN layers - Add scale_inv to params_mapping for all attention and FFN layers The implementation supports both: - Pre-dequantization: Convert FP8->F32 before loading - Native FP8: Load FP8 weights directly, apply scale_inv at runtime Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Update dependencies to use git versions of nx and safetensors which
have the new FP8 type representation: {:f8_e4m3fn, 8} instead of
{:f, 8, :e4m3fn}.
Changes:
- Update mix.exs to use git deps for nx, exla, torchx, and safetensors
- Update FP8 type detection pattern in pytorch_params.ex
- Add TODO comments noting deps should be switched back to hex when released
Tested with Qwen/Qwen3-4B-Instruct-2507-FP8 model - loads and generates
correctly with preserve_source_types: true.
Add a new section demonstrating how to load and use FP8 quantized Qwen3 models with preserve_source_types: true option. Updated introduction and summary to reflect the new capability.
e8e7f67 to
d6d5f62
Compare
Contributor
Author
|
To generate the fp8 tiny model Generate a tiny FP8 Qwen3 model for testing Bumblebee's FP8 support.
This creates a minimal model with:
- FP8 E4M3FN weights for linear layers
- Corresponding weight_scale_inv tensors (128x128 block scaling)
- Saved in safetensors format
Usage:
python generate_fp8_qwen3.py
# Then upload to HuggingFace: huggingface-cli upload roulis/tiny-fp8-qwen3 ./tiny-fp8-qwen3
"""
import torch
import json
import os
from safetensors.torch import save_file
# Tiny model config matching existing tiny-random-Qwen3ForCausalLM
CONFIG = {
"architectures": ["Qwen3ForCausalLM"],
"hidden_size": 32,
"intermediate_size": 64,
"num_attention_heads": 4,
"num_hidden_layers": 2,
"num_key_value_heads": 2,
"vocab_size": 1024,
"head_dim": 8, # hidden_size / num_attention_heads
"rms_norm_eps": 1e-6,
"rope_theta": 1000000.0,
"max_position_embeddings": 512,
"torch_dtype": "float8_e4m3fn",
"model_type": "qwen3",
"use_qk_norm": True,
"tie_word_embeddings": True,
"quantization_config": {
"quant_method": "fp8",
"weight_block_size": [128, 128]
}
}
BLOCK_SIZE = 128
def create_fp8_weight(shape, seed=42):
"""Create a random FP8 E4M3FN weight tensor."""
torch.manual_seed(seed)
# Create random values in valid FP8 E4M3FN range (-448 to 448)
weight_f32 = torch.randn(shape) * 0.1
weight_fp8 = weight_f32.to(torch.float8_e4m3fn)
return weight_fp8
def create_scale_inv(weight_shape):
"""Create scale_inv tensor for block-wise dequantization.
Shape: [ceil(out_features/128), ceil(in_features/128)]
For testing, use scale of 1.0 (identity) so dequantized = original.
"""
out_features, in_features = weight_shape
out_blocks = (out_features + BLOCK_SIZE - 1) // BLOCK_SIZE
in_blocks = (in_features + BLOCK_SIZE - 1) // BLOCK_SIZE
# Use 1.0 for identity scaling (easier to verify in tests)
return torch.ones(out_blocks, in_blocks, dtype=torch.float32)
def generate_model():
hidden_size = CONFIG["hidden_size"]
intermediate_size = CONFIG["intermediate_size"]
num_heads = CONFIG["num_attention_heads"]
num_kv_heads = CONFIG["num_key_value_heads"]
head_dim = CONFIG["head_dim"]
vocab_size = CONFIG["vocab_size"]
num_layers = CONFIG["num_hidden_layers"]
tensors = {}
seed = 0
# Embedding (not quantized)
tensors["model.embed_tokens.weight"] = torch.randn(vocab_size, hidden_size)
for layer_idx in range(num_layers):
prefix = f"model.layers.{layer_idx}"
# Self-attention projections (FP8 quantized)
q_size = num_heads * head_dim
kv_size = num_kv_heads * head_dim
# Q projection
tensors[f"{prefix}.self_attn.q_proj.weight"] = create_fp8_weight((q_size, hidden_size), seed)
seed += 1
tensors[f"{prefix}.self_attn.q_proj.weight_scale_inv"] = create_scale_inv((q_size, hidden_size))
# K projection
tensors[f"{prefix}.self_attn.k_proj.weight"] = create_fp8_weight((kv_size, hidden_size), seed)
seed += 1
tensors[f"{prefix}.self_attn.k_proj.weight_scale_inv"] = create_scale_inv((kv_size, hidden_size))
# V projection
tensors[f"{prefix}.self_attn.v_proj.weight"] = create_fp8_weight((kv_size, hidden_size), seed)
seed += 1
tensors[f"{prefix}.self_attn.v_proj.weight_scale_inv"] = create_scale_inv((kv_size, hidden_size))
# O projection
tensors[f"{prefix}.self_attn.o_proj.weight"] = create_fp8_weight((hidden_size, q_size), seed)
seed += 1
tensors[f"{prefix}.self_attn.o_proj.weight_scale_inv"] = create_scale_inv((hidden_size, q_size))
# QK norms (not quantized)
tensors[f"{prefix}.self_attn.q_norm.weight"] = torch.ones(head_dim)
tensors[f"{prefix}.self_attn.k_norm.weight"] = torch.ones(head_dim)
# MLP (FP8 quantized)
tensors[f"{prefix}.mlp.gate_proj.weight"] = create_fp8_weight((intermediate_size, hidden_size), seed)
seed += 1
tensors[f"{prefix}.mlp.gate_proj.weight_scale_inv"] = create_scale_inv((intermediate_size, hidden_size))
tensors[f"{prefix}.mlp.up_proj.weight"] = create_fp8_weight((intermediate_size, hidden_size), seed)
seed += 1
tensors[f"{prefix}.mlp.up_proj.weight_scale_inv"] = create_scale_inv((intermediate_size, hidden_size))
tensors[f"{prefix}.mlp.down_proj.weight"] = create_fp8_weight((hidden_size, intermediate_size), seed)
seed += 1
tensors[f"{prefix}.mlp.down_proj.weight_scale_inv"] = create_scale_inv((hidden_size, intermediate_size))
# Layer norms (not quantized)
tensors[f"{prefix}.input_layernorm.weight"] = torch.ones(hidden_size)
tensors[f"{prefix}.post_attention_layernorm.weight"] = torch.ones(hidden_size)
# Final norm (not quantized)
tensors["model.norm.weight"] = torch.ones(hidden_size)
# LM head (can be tied to embeddings, but we include it for completeness)
# Not quantized since it shares with embeddings
return tensors
def main():
output_dir = "tiny-fp8-qwen3"
os.makedirs(output_dir, exist_ok=True)
# Generate model tensors
tensors = generate_model()
# Save as safetensors
save_file(tensors, os.path.join(output_dir, "model.safetensors"))
# Save config
with open(os.path.join(output_dir, "config.json"), "w") as f:
json.dump(CONFIG, f, indent=2)
print(f"Model saved to {output_dir}/")
print(f"Total tensors: {len(tensors)}")
print("\nTo upload to HuggingFace:")
print(f" huggingface-cli upload roulis/tiny-fp8-qwen3 {output_dir}")
if __name__ == "__main__":
main() |
- Add fp8_aware_dense layer unit tests - Add FP8 Qwen3 model loading test using roulis/tiny-fp8-qwen3 - Include Python script to generate tiny FP8 test models
d6d5f62 to
6893058
Compare
Comment on lines
+535
to
+536
| # Preserve FP8 E4M3FN types when preserve_source_types is enabled | ||
| {_expected, {:f8_e4m3fn, 8}, true} -> tensor |
Member
There was a problem hiding this comment.
We likely don't want to do this here, because Axon may cast and it can lead to inconsistent behaviour (see #311). Ideally we want to apply an Axon.MixedPrecision policy, but we cannot determine it upfront. Also Axon policies apply per layer, but in this case we may have a layer where each param has different type. I need to think about the best way to address it and the loading API we should have.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds support for loading and running FP8 (8-bit floating point) quantized models natively in Bumblebee. FP8 models use approximately half the memory
of BF16 models while maintaining good inference quality.
Changes
Core FP8 Support
preserve_source_typesoption toBumblebee.load_model/2to keep FP8 weights in their native formatdequantize_kernel/3function inBumblebee.Layersfor runtime FP8 → F32 conversion using scale_inv tensors{:f8_e4m3fn, 8}Qwen3 FP8 Integration
params_mappingfor FP8 weight scales (weight_scale_inv) in Qwen3 architectureDependencies
nx,exla,torchx, andsafetensorsfor FP8 type supportDocumentation
Usage
Loading an FP8 Model
Supported FP8 Models
Notes