Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
be694c3
add megatron training args
fishcrap Dec 2, 2025
398b4e0
fix for dsv3
fishcrap Dec 3, 2025
d160507
fp8 align 16 for training input
fishcrap Dec 3, 2025
c9fa040
add fp8 update weight which needs quantize to fp8 first
fishcrap Dec 4, 2025
4bb22b9
add sglang online quant
fishcrap Dec 4, 2025
24fda85
Merge remote-tracking branch 'github/main' into sxj/fp8_train
fishcrap Dec 5, 2025
42c5844
add online dequant and quant in megatron save load
fishcrap Dec 8, 2025
e949d0c
fix shape
fishcrap Dec 8, 2025
731b11d
convert pytorch fp8 to transformer_engine fp8
fishcrap Dec 10, 2025
2e3ac85
fix load
fishcrap Dec 10, 2025
cc2bf1e
fix fp8 scale_inv and weight not in same bin
fishcrap Dec 11, 2025
ed48b0e
fix fp8 load
fishcrap Dec 17, 2025
ce8e6e0
fix hf save
fishcrap Dec 17, 2025
8e203be
fix fp8 save
fishcrap Dec 17, 2025
e968de7
add fp8 tests
fishcrap Dec 17, 2025
55e36a3
fix fp8_param weight update
fishcrap Dec 17, 2025
dc5b71d
fix hf_load
fishcrap Dec 22, 2025
384cbaf
add fp8_recipe in optimizer
fishcrap Dec 22, 2025
176bd26
default scale_inv dtype bfloat16
fishcrap Dec 22, 2025
0edd0a4
fix megatron distributed
fishcrap Dec 23, 2025
b683eb0
Merge remote-tracking branch 'origin/main' into sxj/fp8_train
fishcrap Dec 23, 2025
1b81d61
refactor fp8 tests
fishcrap Dec 24, 2025
31df0ef
fix test names
fishcrap Dec 24, 2025
ca7c973
use refactered forward in tests
fishcrap Dec 24, 2025
18ddcbb
use refactered train in tests
fishcrap Dec 24, 2025
5ae8bd6
fix and refactor fp8 tests
fishcrap Dec 24, 2025
25650c1
fix
fishcrap Dec 24, 2025
07074fd
fix
fishcrap Dec 24, 2025
ad45cb3
fix
fishcrap Dec 24, 2025
65ee2e6
fix tests
fishcrap Dec 24, 2025
eba74ea
fix megatron engine
fishcrap Dec 24, 2025
803981f
fix test fp8 conversion
fishcrap Dec 24, 2025
0ce7194
fix test fp8 conversion
fishcrap Dec 24, 2025
1af4220
add explanation for fixing distributed optimizer
fishcrap Dec 24, 2025
7a92a3f
fix import and comments
fishcrap Dec 24, 2025
058fd75
del useless comments
fishcrap Dec 24, 2025
f189bb9
fix inference ep for megatron engine
fishcrap Dec 24, 2025
da9108f
del comment
fishcrap Dec 24, 2025
174bcad
use engine fixture
fishcrap Dec 24, 2025
c692f69
del __init__.py
fishcrap Dec 24, 2025
89d3f03
del pytorch fp8 to te fp8
fishcrap Dec 24, 2025
437ae40
add comments
fishcrap Dec 24, 2025
500c010
add fp8 consistency check
fishcrap Dec 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,12 @@ class DistributedDataParallelConfig:
bucket_size: int | None = None
average_in_collective: bool = False
fp8_param_gather: bool = False
data_parallel_sharding_strategy: str = field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it for FSDP or DDP? Does no_shard means no sharding for optimizer states or parameters?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete this field

default="no_shard",
metadata={
"help": "Sharding strategy for FSDP. Valid values are 'no_shard', 'optim', 'optim_grads', 'optim_grads_params'."
},
)


@dataclass
Expand Down Expand Up @@ -446,6 +452,115 @@ class MegatronEngineConfig:
distribute_saved_activations: bool | None = None
recompute_modules: list[str] | None = None

# MoE
moe_router_dtype: str | None = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default to float32?

moe_shared_expert_overlap: bool = field(
default=False,
metadata={
"help": "Enable overlapping between shared expert computations and dispatcher communications. "
"Without this, the shared epxerts execute after the routed experts."
},
)
moe_enable_deepep: bool = False
moe_token_dispatcher_type: str = field(
default="alltoall",
metadata={
"help": "Type of token dispatcher. Options: 'allgather','alltoall' and 'flex'."
},
)
moe_permute_fusion: bool = field(
default=False,
metadata={"help": "Fuse token rearrangement ops during token dispatching."},
)

# FP8 Training Configuration
fp8: str | None = field(
default=None,
metadata={
"help": "Enable FP8 precision training. Options: "
"'e4m3' (uniform e4m3), "
"'hybrid' (e4m3 for activations/weights, e5m2 for output activation gradients)."
},
)
Comment on lines +476 to +484
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we provide an example yaml config for fp8 qwen3 training? We'd better provide a learning curve with the config (fp8 vs bf16 training curve).


fp8_recipe: str = field(
default="delayed",
metadata={
"help": "FP8 scaling recipe. Options: 'tensorwise', 'delayed', 'mxfp8' (Blackwell only), 'blockwise'."
},
)

fp8_param: bool = field(
default=False,
metadata={
"help": "Keep parameters in FP8 precision to save memory. "
"Must be used together with fp8 mode. "
"Not all parameters will be converted to fp8; for example, biases will remain unchanged."
},
)

fp8_margin: int = field(
default=0,
metadata={"help": "Margin for FP8 scaling factor computation."},
)

fp8_amax_history_len: int = field(
default=1,
metadata={
"help": "Length of amax history window for scaling factor computation."
},
)

fp8_amax_compute_algo: str = field(
default="most_recent",
metadata={
"help": "Algorithm for choosing amax value. Options: 'max' (largest in history window), 'most_recent'."
},
)

fp8_wgrad: bool = field(
default=True,
metadata={
"help": "When False, override FP8 config and compute weight gradients in higher precision."
},
)

fp8_dot_product_attention: bool = field(
default=False,
metadata={"help": "Use FP8 implementation of Dot Product Attention."},
)

fp8_multi_head_attention: bool = field(
default=False,
metadata={"help": "Use FP8 implementation of Multi Head Attention."},
)

tp_only_amax_red: bool = field(
default=False,
metadata={"help": "Reduce FP8 AMAX only in TP or TP-CP domain."},
)

first_last_layers_bf16: bool = field(
default=False,
metadata={
"help": "Retain first and last N TransformerBlocks in BF16 instead of FP8."
},
)

num_layers_at_start_in_bf16: int = field(
default=1,
metadata={
"help": "Number of layers at start to keep in BF16 when first_last_layers_bf16 is True."
},
)

num_layers_at_end_in_bf16: int = field(
default=1,
metadata={
"help": "Number of layers at end to keep in BF16 when first_last_layers_bf16 is True."
},
)


@dataclass
class SchedulingStrategy:
Expand Down Expand Up @@ -959,6 +1074,7 @@ class SGLangConfig:
# and passed as `model_loader_extra_config` to SGLang.
enable_multithread_load: bool = False
enable_fast_load: bool = False
quantization: str | None = None

# Use staticmethod to make OmegaConf happy.
@staticmethod
Expand Down
144 changes: 139 additions & 5 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses
import functools
import gc
import math
import os
from collections.abc import Callable, Iterator
from concurrent.futures import Future
Expand All @@ -15,6 +16,7 @@
from megatron.core import tensor_parallel
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.distributed import finalize_model_grads
from megatron.core.fp8_utils import is_float8tensor
from megatron.core.optimizer import OptimizerConfig as MCoreOptimizerConfig
from megatron.core.optimizer import get_megatron_optimizer
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler
Expand All @@ -26,7 +28,11 @@
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PretrainedConfig

from areal.api.alloc_mode import MegatronParallelStrategy, ParallelStrategy
from areal.api.alloc_mode import (
AllocationMode,
MegatronParallelStrategy,
ParallelStrategy,
)
from areal.api.cli_args import MicroBatchSpec, TrainEngineConfig
from areal.api.engine_api import InferenceEngine, TrainEngine
from areal.api.io_struct import FinetuneSpec, ParamSpec, SaveLoadMeta, WeightUpdateMeta
Expand Down Expand Up @@ -124,6 +130,9 @@ def __init__(self, config: TrainEngineConfig):
self.seed: int = 0
self.own_global_group: bool = False
self.is_offload: bool = False
self.enable_fp8: bool = self.config.megatron.fp8 is not None
self.fp8_align_size: int = 16
self.quantization_config: dict[str, int | str | list[str]] | None = None

def create_process_group(self, parallel_strategy: ParallelStrategy | None = None):
if parallel_strategy is None:
Expand Down Expand Up @@ -189,6 +198,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
f"update_weight_group_{mpu.get_pipeline_model_parallel_rank()}"
)
self.engine_lock = DistributedLock("train_engine_lock")
self.alloc_mode: AllocationMode | None = kwargs.get("alloc_mode", None)

self.tokenizer = load_hf_tokenizer(self.config.path)
self.bridge = mbridge.AutoBridge.from_pretrained(self.config.path)
Expand All @@ -214,6 +224,12 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
self.parallel_strategy, self.hf_config, self.tf_config
)

# Get quantization_config from hf_config if available (for FP8 weight updates)
self.quantization_config = getattr(self.hf_config, "quantization_config", None)

self._check_and_apply_fp8_config()
self._validate_fp8_consistency()

# initialize mcore (DDP Wrapped) GPTModel
with self.device:
models = make_mcore_model(
Expand All @@ -229,6 +245,18 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
with self.device:
self._load_model_from_hf(self.config.path)

# NOTE: When using distributed optimizer, megatron will use the
# high precision init val to initialize the main parameters for optimizer.
# However, the high precision init val does not exist for FP8 models.
# (The high precision init val is random initialization for FP8 models.)
# So we need to clear the high precision init val here.
for model in self.model:
for _, param in model.named_parameters():
if hasattr(param, "get_high_precision_init_val"):
param.clear_high_precision_init_val()
delattr(param, "get_high_precision_init_val")
delattr(param, "clear_high_precision_init_val")

assert self.model, "Megatron models failed to initialize."
modules = [m.module if isinstance(m, DDP) else m for m in self.model]
total_params = sum(
Expand Down Expand Up @@ -687,6 +715,60 @@ def onload(self) -> None:
def clear_batches(self, *args):
"""Placeholder method of single-controller API."""

def _check_and_apply_fp8_config(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also check transformer_engine installation here. If transformer_engine is not installed, e.g., in a uv pip install environment, a runtime error should be raised

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should also revert the above change

if self.mcore_config.fp8 is not None:
self.tf_config.fp8 = self.mcore_config.fp8
self.tf_config.fp8_recipe = self.mcore_config.fp8_recipe
self.tf_config.fp8_param = self.mcore_config.fp8_param
self.tf_config.fp8_margin = self.mcore_config.fp8_margin
self.tf_config.fp8_amax_history_len = self.mcore_config.fp8_amax_history_len
self.tf_config.fp8_amax_compute_algo = (
self.mcore_config.fp8_amax_compute_algo
)
self.tf_config.fp8_wgrad = self.mcore_config.fp8_wgrad
self.tf_config.fp8_dot_product_attention = (
self.mcore_config.fp8_dot_product_attention
)
self.tf_config.fp8_multi_head_attention = (
self.mcore_config.fp8_multi_head_attention
)
self.tf_config.tp_only_amax_red = self.mcore_config.tp_only_amax_red
self.tf_config.first_last_layers_bf16 = (
self.mcore_config.first_last_layers_bf16
)
self.tf_config.num_layers_at_start_in_bf16 = (
self.mcore_config.num_layers_at_start_in_bf16
)
self.tf_config.num_layers_at_end_in_bf16 = (
self.mcore_config.num_layers_at_end_in_bf16
)
self.logger.info(
f"FP8 training enabled: fp8={self.mcore_config.fp8}, "
f"fp8_recipe={self.mcore_config.fp8_recipe}, "
f"fp8_param={self.mcore_config.fp8_param}"
)
# fp8_param_gather is passed from make_mcore_model()

def _validate_fp8_consistency(self):
"""Validate that training and inference precision are consistent.

If FP8 training is enabled, inference must also use FP8.
If FP8 training is disabled, inference must not use FP8.
"""
train_fp8 = self.mcore_config.fp8 is not None
inference_fp8 = (
self.quantization_config is not None
and self.quantization_config.get("quant_method", None) == "fp8"
)

if not train_fp8 and inference_fp8 or train_fp8 and not inference_fp8:
raise RuntimeError(
"Inconsistent FP8 configuration: "
"Training and inference must both use FP8 or both not use FP8. "
f"Training fp8={train_fp8}, "
f"Inference fp8={inference_fp8}"
)

def _make_parallel_strategy(
self, parallel_strategy: ParallelStrategy
) -> MegatronParallelStrategy:
Expand Down Expand Up @@ -750,6 +832,7 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None:
use_distributed_optimizer=self.mcore_config.ddp.use_distributed_optimizer,
params_dtype=self.dtype,
clip_grad=self.optimizer_config.gradient_clipping,
fp8_recipe=self.mcore_config.fp8_recipe,
)
mcore_opt_config.overlap_param_gather_with_optimizer_step = (
self.mcore_config.overlap_param_gather_with_optimizer_step
Expand Down Expand Up @@ -812,6 +895,18 @@ def _check_rollout_engine_connected(self) -> None:
" before using rollout/update_weight methods."
)

def _get_inference_ep_config(self) -> dict[str, bool]:
inference_enable_ep_moe = False

if self.alloc_mode is not None:
gen_parallel = self.alloc_mode.gen
if gen_parallel is not None:
inference_enable_ep_moe = gen_parallel.ep_size > 1

return {
"inference_enable_ep_moe": inference_enable_ep_moe,
}

def _ensure_ready(self) -> None:
if self.is_offload:
self.onload()
Expand Down Expand Up @@ -869,16 +964,33 @@ def _impl_update_weight_from_distributed(
param = all_gather_param(name, param)
param = remove_padding(name, param, self.hf_config.vocab_size)

if is_float8tensor(param):
# FP8 is stored as uint8, so element_size is 1 byte
param_size = param.numel() * 1
# Convert TE FP8 to bf16 before convert_to_hf (which will convert to PyTorch FP8)
param = param.dequantize(dtype=self.dtype)
else:
param_size = param.numel() * param.element_size()

if not self.is_pipeline_parallel_head():
return buffer_size

param_size = param.numel() * param.element_size()
if buffer_size + param_size > weight_chunked_mem_size:
self._update_bucket_weights_from_distributed(meta, converted_named_tensors)
buffer_size = 0

# Get inference EP configuration
inference_ep_config = self._get_inference_ep_config()

converted_named_tensors.extend(
convert_to_hf(self.tf_config, self.hf_config.model_type, name, param)
convert_to_hf(
self.tf_config,
self.hf_config.model_type,
name,
param,
quantization_config=self.quantization_config,
**inference_ep_config,
)
)
buffer_size += param_size
return buffer_size
Expand Down Expand Up @@ -940,10 +1052,20 @@ def _update_bucket_expert_weights_from_distributed(

gathered_params = sum(gathered_params, [])

# Get inference EP configuration
inference_ep_config = self._get_inference_ep_config()

converted_hf_tensors = []
for name, param in gathered_params:
converted_hf_tensors.extend(
convert_to_hf(self.tf_config, self.hf_config.model_type, name, param)
convert_to_hf(
self.tf_config,
self.hf_config.model_type,
name,
param,
quantization_config=self.quantization_config,
**inference_ep_config,
)
)

self._update_bucket_weights_from_distributed(meta, converted_hf_tensors)
Expand All @@ -960,7 +1082,14 @@ def _impl_update_expert_weight_from_distributed(
param = all_gather_param(name, param)
param = remove_padding(name, param, self.hf_config.vocab_size)

param_size = param.numel() * param.element_size()
if is_float8tensor(param):
# FP8 is stored as uint8, so element_size is 1 byte
param_size = param.numel() * 1
# Convert TE FP8 to bf16 (will be converted to PyTorch FP8 later in convert_to_hf)
param = param.dequantize(dtype=self.dtype)
else:
param_size = param.numel() * param.element_size()

if (
buffer_size + param_size
) * mpu.get_expert_model_parallel_world_size() > weight_chunked_mem_size:
Expand Down Expand Up @@ -1155,6 +1284,11 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList:
# 2. Align sequence lengths to integer multiples of `align_to_multiple_of=tp_size*cp_size*2`
# to satisfy the requirement of Megatron parallelism.
align_to_multiple_of = tp_size * cp_size * 2 if cp_size > 1 else tp_size
align_to_multiple_of = (
math.lcm(align_to_multiple_of, self.fp8_align_size)
if self.enable_fp8
else align_to_multiple_of
)
mb_list = pad_mb_list(
mb_list,
pad_value=0.0,
Expand Down
Loading