Skip to content

Conversation

@xiaoxi-wangfj
Copy link
Collaborator

@xiaoxi-wangfj xiaoxi-wangfj commented Dec 26, 2025

This PR introduces blockwise, scaling-aware FP8 transpose optimizations for FP8 MoE that enable a casting-free, FP8-centric MoE dataflow in TransformerEngine by eliminating unnecessary cast and re-quantization steps, while maintaining numerical stability in existing FP8 training workflows.

This PR is designed to be used in conjunction with PR 021ai/Megatron-LM#1

Description

The design and theoretical background of this PR are described in the paper:
FP8-Flow-MoE: A Casting-Free FP8 Recipe without Double Quantization Error

The follow figure illustrates the optimized MoE dataflow and highlights the key optimization points (marked as ①–⑤).

FP8FLOW-MoE

FP8 Quantization Before Dispatch (DeepEP → GroupedLinear)

Quantization is performed before DeepEP dispatch, and row-wise FP8 tensors are directly fed into GroupedLinear.

  • Keeps dispatch → permute → expert computation entirely in FP8
  • Float8BlockwiseQTensor is propagated with a COMPACT layout (for _rowwise_scale_inv) along the dispatch → permute → GroupedLinear path, avoiding layout-induced .T.contiguous() calls and reducing unnecessary memory copies.

(Shown as marker ① in the figure)

Scaling-Aware FP8 Transpose for Wgrad

GroupedLinear requires:

  • row-wise FP8 for Fprop/Dgrad
  • column-wise FP8 for Wgrad

To avoid dequantize → transpose → requantize , this PR introduces scaling_aware_fp8_transpose, which:

  • Converts row-wise FP8 to column-wise FP8 via exponent manipulation only
  • Preserves scale consistency across layouts

(Shown as marker ④ in the figure)

Fused Permute + Padding / Unpermute + Unpadding

We fuse two memory movement operators along the MoE path:

  • permute + pad in the forward pass
  • unpermute + unpad in the backward pass

For details of this optimization, please refer to PR NVIDIA#1921

(Shown as marker ② in the figure)

Fused Activation + Quantization

Activation and FP8 quantization are fused into a single kernel, Produces FP8 outputs directly, while enabling FP8 persistence

(Shown as marker ③ in the figure)

Add fine-grained recompute moe_expert

Because the entire dispatch → permute → GroupedLinear path stays in FP8, we enable fine-grained recomputation at the moe_expert level:

  • Saves ~50% peak activation memory and avoids recomputation of the router compared to recomputing the full module moe level

(Shown as marker ⑤ in the figure)

Performance Results

We evaluate FP8-Flow-MoE on DeepSeek-V3 (671B) to validate scalability and robustness under realistic large-scale training conditions.

Throughput

Measured throughput (TGS, tokens/GPU/s) under different expert parallelism (EP) on DeepSeek-V3 (671B) :

  • vs. BF16
    +6% (EP8), +8% (EP16), +16% (EP32)

  • vs. TransformerEngine blockwise FP8 recipe
    +3% (EP8), +8% (EP16), up to +21% (EP32)

Memory Efficiency

With AC = selective checkpointing and recompute-modules = moe_expert:

  • At EP8:
    • ~8 GB lower peak memory vs. BF16
    • ~16.5 GB lower peak memory vs. blockwise FP8

Numerical Accuracy

We trained for >200B tokens. The loss deviation of FP8-Flow-MoE stays within 0.19% compared to both BF16 baselines, with no observed instability or divergence.

Limitations

  • Currently validated on NVIDIA Hopper architecture with blockwise FP8 recipe

Type of Change

  • New feature (non-breaking change which adds functionality)
  • Bug fix
  • Breaking change
  • Infra/Build change
  • Code refactoring

Summary of Code Changes

Megatron-LM

  • Added fused FP8 kernels for activation + quantization in fused_bias_swiglu.py and fused_weighted_swiglu_quant.py
  • Integrated FP8 dispatch and expert recomputation support in Megatron-LM fused_a2a.py

TransformerEngine

  • Added support for Float8BlockwiseQTensor inputs in grouped_linear.py
  • Added scaling_aware_fp8_transpose operator in triton/blockwise_scaling_aware_fp8_transpose.py

…ization

Signed-off-by: xiaoxi-wangfj <690912414@qq.com>
Signed-off-by: xiaoxi-wangfj <690912414@qq.com>
1. add fp8 rowwise scaling-aware transpose op for wgrad columwise.
2. support Float8BlockwiseQTensor input in grouped_linear.
3. _rowwise_scale_inv is propagated with a COMPACT layout along the `dispatch → permute → GroupedLinear` path.

Signed-off-by: xiaoxi-wangfj <690912414@qq.com>
Co-authored-by: dantesuu@gmail.com
Co-authored-by: xzhu@zhejianglab.org
Co-authored-by: 123sssmmm@gmail.com
@xiaoxi-wangfj xiaoxi-wangfj changed the title Fp8 flow moe [PyTorch]Add Casting-Free FP8-Flow-MoE Blockwise Optimizations Dec 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants