Skip to content

Add native FP8 model support. #448

Open
nyo16 wants to merge 4 commits intoelixir-nx:mainfrom
nyo16:feat/fp8-native-support
Open

Add native FP8 model support. #448
nyo16 wants to merge 4 commits intoelixir-nx:mainfrom
nyo16:feat/fp8-native-support

Conversation

@nyo16
Copy link
Contributor

@nyo16 nyo16 commented Jan 28, 2026

Summary

This PR adds support for loading and running FP8 (8-bit floating point) quantized models natively in Bumblebee. FP8 models use approximately half the memory
of BF16 models while maintaining good inference quality.

Changes

Core FP8 Support

  • Add preserve_source_types option to Bumblebee.load_model/2 to keep FP8 weights in their native format
  • Add dequantize_kernel/3 function in Bumblebee.Layers for runtime FP8 → F32 conversion using scale_inv tensors
  • Update FP8 type detection to use the new nx/safetensors format {:f8_e4m3fn, 8}

Qwen3 FP8 Integration

  • Add params_mapping for FP8 weight scales (weight_scale_inv) in Qwen3 architecture
  • Update transformer layers to apply dequantization when FP8 weights are detected

Dependencies

  • Update to git versions of nx, exla, torchx, and safetensors for FP8 type support
  • Note: These can be switched back to hex versions once released

Documentation

  • Add FP8 quantization section to the Qwen3 notebook with usage examples

Usage

Loading an FP8 Model

repo = {:hf, "Qwen/Qwen3-4B-Instruct-2507-FP8"}

{:ok, model_info} = Bumblebee.load_model(repo,
  backend: EXLA.Backend,
  preserve_source_types: true
)
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
{:ok, generation_config} = Bumblebee.load_generation_config(repo)

Running Inference

generation_config =
  Bumblebee.configure(generation_config,
    max_new_tokens: 256,
    temperature: 0.7,
    strategy: %{type: :multinomial_sampling, top_p: 0.8, top_k: 20}
  )

serving =
  Bumblebee.Text.generation(model_info, tokenizer, generation_config,
    compile: [batch_size: 1, sequence_length: 1024],
    defn_options: [compiler: EXLA]
  )

# Using chat template format
prompt = """
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
What are the benefits of quantized models?<|im_end|>
<|im_start|>assistant
"""

Nx.Serving.run(serving, prompt)

Supported FP8 Models

  • Qwen/Qwen3-4B-Instruct-2507-FP8
  • Other FP8 variants following the same weight format

Notes

  • FP8 support requires a GPU backend (EXLA with CUDA) for optimal performance
  • The preserve_source_types: true option is required to keep weights in FP8 format
  • Dequantization happens automatically during inference using the weight_scale_inv tensors

nyo16 and others added 3 commits January 8, 2026 11:30
Add comprehensive FP8 quantized model support for models like Qwen3-FP8.
This enables loading and running FP8 models with per-block scale factors.

Changes:

bumblebee.ex:
- Add :preserve_source_types option to load_model/2 to keep FP8 types

pytorch_params.ex:
- Pass preserve_source_types through param loading pipeline
- Modify ensure_type/3 to preserve FP8 types when option is set

layers.ex:
- Add fp8_aware_dense/3 layer that handles FP8 quantized weights
- Implements block-wise dequantization using scale_inv parameter
- Automatically falls back to identity scaling for non-FP8 models

layers/transformer.ex:
- Add :attention_dense option to blocks/2, block/2, multi_head_attention/4
- Allows custom dense function for Q, K, V, and output projections

text/qwen3.ex:
- Update decoder to use fp8_aware_dense for attention via attention_dense
- Update gated_ffn to use fp8_aware_dense for FFN layers
- Add scale_inv to params_mapping for all attention and FFN layers

The implementation supports both:
- Pre-dequantization: Convert FP8->F32 before loading
- Native FP8: Load FP8 weights directly, apply scale_inv at runtime

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Update dependencies to use git versions of nx and safetensors which
have the new FP8 type representation: {:f8_e4m3fn, 8} instead of
{:f, 8, :e4m3fn}.

Changes:
- Update mix.exs to use git deps for nx, exla, torchx, and safetensors
- Update FP8 type detection pattern in pytorch_params.ex
- Add TODO comments noting deps should be switched back to hex when released

Tested with Qwen/Qwen3-4B-Instruct-2507-FP8 model - loads and generates
correctly with preserve_source_types: true.
Add a new section demonstrating how to load and use FP8 quantized
Qwen3 models with preserve_source_types: true option. Updated
introduction and summary to reflect the new capability.
@nyo16 nyo16 force-pushed the feat/fp8-native-support branch from e8e7f67 to d6d5f62 Compare January 28, 2026 19:03
@nyo16
Copy link
Contributor Author

nyo16 commented Jan 28, 2026

To generate the fp8 tiny model

Generate a tiny FP8 Qwen3 model for testing Bumblebee's FP8 support.

  This creates a minimal model with:
  - FP8 E4M3FN weights for linear layers
  - Corresponding weight_scale_inv tensors (128x128 block scaling)
  - Saved in safetensors format

  Usage:
    python generate_fp8_qwen3.py
    # Then upload to HuggingFace: huggingface-cli upload roulis/tiny-fp8-qwen3 ./tiny-fp8-qwen3
  """

  import torch
  import json
  import os
  from safetensors.torch import save_file

  # Tiny model config matching existing tiny-random-Qwen3ForCausalLM
  CONFIG = {
      "architectures": ["Qwen3ForCausalLM"],
      "hidden_size": 32,
      "intermediate_size": 64,
      "num_attention_heads": 4,
      "num_hidden_layers": 2,
      "num_key_value_heads": 2,
      "vocab_size": 1024,
      "head_dim": 8,  # hidden_size / num_attention_heads
      "rms_norm_eps": 1e-6,
      "rope_theta": 1000000.0,
      "max_position_embeddings": 512,
      "torch_dtype": "float8_e4m3fn",
      "model_type": "qwen3",
      "use_qk_norm": True,
      "tie_word_embeddings": True,
      "quantization_config": {
          "quant_method": "fp8",
          "weight_block_size": [128, 128]
      }
  }

  BLOCK_SIZE = 128


  def create_fp8_weight(shape, seed=42):
      """Create a random FP8 E4M3FN weight tensor."""
      torch.manual_seed(seed)
      # Create random values in valid FP8 E4M3FN range (-448 to 448)
      weight_f32 = torch.randn(shape) * 0.1
      weight_fp8 = weight_f32.to(torch.float8_e4m3fn)
      return weight_fp8


  def create_scale_inv(weight_shape):
      """Create scale_inv tensor for block-wise dequantization.

      Shape: [ceil(out_features/128), ceil(in_features/128)]
      For testing, use scale of 1.0 (identity) so dequantized = original.
      """
      out_features, in_features = weight_shape
      out_blocks = (out_features + BLOCK_SIZE - 1) // BLOCK_SIZE
      in_blocks = (in_features + BLOCK_SIZE - 1) // BLOCK_SIZE
      # Use 1.0 for identity scaling (easier to verify in tests)
      return torch.ones(out_blocks, in_blocks, dtype=torch.float32)


  def generate_model():
      hidden_size = CONFIG["hidden_size"]
      intermediate_size = CONFIG["intermediate_size"]
      num_heads = CONFIG["num_attention_heads"]
      num_kv_heads = CONFIG["num_key_value_heads"]
      head_dim = CONFIG["head_dim"]
      vocab_size = CONFIG["vocab_size"]
      num_layers = CONFIG["num_hidden_layers"]

      tensors = {}
      seed = 0

      # Embedding (not quantized)
      tensors["model.embed_tokens.weight"] = torch.randn(vocab_size, hidden_size)

      for layer_idx in range(num_layers):
          prefix = f"model.layers.{layer_idx}"

          # Self-attention projections (FP8 quantized)
          q_size = num_heads * head_dim
          kv_size = num_kv_heads * head_dim

          # Q projection
          tensors[f"{prefix}.self_attn.q_proj.weight"] = create_fp8_weight((q_size, hidden_size), seed)
          seed += 1
          tensors[f"{prefix}.self_attn.q_proj.weight_scale_inv"] = create_scale_inv((q_size, hidden_size))

          # K projection
          tensors[f"{prefix}.self_attn.k_proj.weight"] = create_fp8_weight((kv_size, hidden_size), seed)
          seed += 1
          tensors[f"{prefix}.self_attn.k_proj.weight_scale_inv"] = create_scale_inv((kv_size, hidden_size))

          # V projection
          tensors[f"{prefix}.self_attn.v_proj.weight"] = create_fp8_weight((kv_size, hidden_size), seed)
          seed += 1
          tensors[f"{prefix}.self_attn.v_proj.weight_scale_inv"] = create_scale_inv((kv_size, hidden_size))

          # O projection
          tensors[f"{prefix}.self_attn.o_proj.weight"] = create_fp8_weight((hidden_size, q_size), seed)
          seed += 1
          tensors[f"{prefix}.self_attn.o_proj.weight_scale_inv"] = create_scale_inv((hidden_size, q_size))

          # QK norms (not quantized)
          tensors[f"{prefix}.self_attn.q_norm.weight"] = torch.ones(head_dim)
          tensors[f"{prefix}.self_attn.k_norm.weight"] = torch.ones(head_dim)

          # MLP (FP8 quantized)
          tensors[f"{prefix}.mlp.gate_proj.weight"] = create_fp8_weight((intermediate_size, hidden_size), seed)
          seed += 1
          tensors[f"{prefix}.mlp.gate_proj.weight_scale_inv"] = create_scale_inv((intermediate_size, hidden_size))

          tensors[f"{prefix}.mlp.up_proj.weight"] = create_fp8_weight((intermediate_size, hidden_size), seed)
          seed += 1
          tensors[f"{prefix}.mlp.up_proj.weight_scale_inv"] = create_scale_inv((intermediate_size, hidden_size))

          tensors[f"{prefix}.mlp.down_proj.weight"] = create_fp8_weight((hidden_size, intermediate_size), seed)
          seed += 1
          tensors[f"{prefix}.mlp.down_proj.weight_scale_inv"] = create_scale_inv((hidden_size, intermediate_size))

          # Layer norms (not quantized)
          tensors[f"{prefix}.input_layernorm.weight"] = torch.ones(hidden_size)
          tensors[f"{prefix}.post_attention_layernorm.weight"] = torch.ones(hidden_size)

      # Final norm (not quantized)
      tensors["model.norm.weight"] = torch.ones(hidden_size)

      # LM head (can be tied to embeddings, but we include it for completeness)
      # Not quantized since it shares with embeddings

      return tensors


  def main():
      output_dir = "tiny-fp8-qwen3"
      os.makedirs(output_dir, exist_ok=True)

      # Generate model tensors
      tensors = generate_model()

      # Save as safetensors
      save_file(tensors, os.path.join(output_dir, "model.safetensors"))

      # Save config
      with open(os.path.join(output_dir, "config.json"), "w") as f:
          json.dump(CONFIG, f, indent=2)

      print(f"Model saved to {output_dir}/")
      print(f"Total tensors: {len(tensors)}")
      print("\nTo upload to HuggingFace:")
      print(f"  huggingface-cli upload roulis/tiny-fp8-qwen3 {output_dir}")


  if __name__ == "__main__":
      main()

@nyo16 nyo16 marked this pull request as ready for review January 28, 2026 19:05
- Add fp8_aware_dense layer unit tests
- Add FP8 Qwen3 model loading test using roulis/tiny-fp8-qwen3
- Include Python script to generate tiny FP8 test models
@nyo16 nyo16 force-pushed the feat/fp8-native-support branch from d6d5f62 to 6893058 Compare January 28, 2026 19:08
Comment on lines +535 to +536
# Preserve FP8 E4M3FN types when preserve_source_types is enabled
{_expected, {:f8_e4m3fn, 8}, true} -> tensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We likely don't want to do this here, because Axon may cast and it can lead to inconsistent behaviour (see #311). Ideally we want to apply an Axon.MixedPrecision policy, but we cannot determine it upfront. Also Axon policies apply per layer, but in this case we may have a layer where each param has different type. I need to think about the best way to address it and the loading API we should have.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants