-
Notifications
You must be signed in to change notification settings - Fork 257
[Feat] Add FP8 training support #758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
be694c3
398b4e0
d160507
c9fa040
4bb22b9
24fda85
42c5844
e949d0c
731b11d
2e3ac85
cc2bf1e
ed48b0e
ce8e6e0
8e203be
e968de7
55e36a3
dc5b71d
384cbaf
176bd26
0edd0a4
b683eb0
1b81d61
31df0ef
ca7c973
18ddcbb
5ae8bd6
25650c1
07074fd
ad45cb3
65ee2e6
eba74ea
803981f
0ce7194
1af4220
7a92a3f
058fd75
f189bb9
da9108f
174bcad
c692f69
89d3f03
437ae40
500c010
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -393,6 +393,12 @@ class DistributedDataParallelConfig: | |
| bucket_size: int | None = None | ||
| average_in_collective: bool = False | ||
| fp8_param_gather: bool = False | ||
| data_parallel_sharding_strategy: str = field( | ||
| default="no_shard", | ||
| metadata={ | ||
| "help": "Sharding strategy for FSDP. Valid values are 'no_shard', 'optim', 'optim_grads', 'optim_grads_params'." | ||
| }, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -446,6 +452,115 @@ class MegatronEngineConfig: | |
| distribute_saved_activations: bool | None = None | ||
| recompute_modules: list[str] | None = None | ||
|
|
||
| # MoE | ||
| moe_router_dtype: str | None = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. default to |
||
| moe_shared_expert_overlap: bool = field( | ||
| default=False, | ||
| metadata={ | ||
| "help": "Enable overlapping between shared expert computations and dispatcher communications. " | ||
| "Without this, the shared epxerts execute after the routed experts." | ||
| }, | ||
| ) | ||
| moe_enable_deepep: bool = False | ||
| moe_token_dispatcher_type: str = field( | ||
| default="alltoall", | ||
| metadata={ | ||
| "help": "Type of token dispatcher. Options: 'allgather','alltoall' and 'flex'." | ||
| }, | ||
| ) | ||
| moe_permute_fusion: bool = field( | ||
| default=False, | ||
| metadata={"help": "Fuse token rearrangement ops during token dispatching."}, | ||
| ) | ||
|
|
||
| # FP8 Training Configuration | ||
| fp8: str | None = field( | ||
| default=None, | ||
| metadata={ | ||
| "help": "Enable FP8 precision training. Options: " | ||
| "'e4m3' (uniform e4m3), " | ||
| "'hybrid' (e4m3 for activations/weights, e5m2 for output activation gradients)." | ||
| }, | ||
| ) | ||
|
Comment on lines
+476
to
+484
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we provide an example yaml config for fp8 qwen3 training? We'd better provide a learning curve with the config (fp8 vs bf16 training curve). |
||
|
|
||
| fp8_recipe: str = field( | ||
| default="delayed", | ||
| metadata={ | ||
| "help": "FP8 scaling recipe. Options: 'tensorwise', 'delayed', 'mxfp8' (Blackwell only), 'blockwise'." | ||
| }, | ||
| ) | ||
|
|
||
| fp8_param: bool = field( | ||
| default=False, | ||
| metadata={ | ||
| "help": "Keep parameters in FP8 precision to save memory. " | ||
| "Must be used together with fp8 mode. " | ||
| "Not all parameters will be converted to fp8; for example, biases will remain unchanged." | ||
| }, | ||
| ) | ||
|
|
||
| fp8_margin: int = field( | ||
| default=0, | ||
| metadata={"help": "Margin for FP8 scaling factor computation."}, | ||
| ) | ||
|
|
||
| fp8_amax_history_len: int = field( | ||
| default=1, | ||
| metadata={ | ||
| "help": "Length of amax history window for scaling factor computation." | ||
| }, | ||
| ) | ||
|
|
||
| fp8_amax_compute_algo: str = field( | ||
| default="most_recent", | ||
| metadata={ | ||
| "help": "Algorithm for choosing amax value. Options: 'max' (largest in history window), 'most_recent'." | ||
| }, | ||
| ) | ||
|
|
||
| fp8_wgrad: bool = field( | ||
| default=True, | ||
| metadata={ | ||
| "help": "When False, override FP8 config and compute weight gradients in higher precision." | ||
| }, | ||
| ) | ||
|
|
||
| fp8_dot_product_attention: bool = field( | ||
| default=False, | ||
| metadata={"help": "Use FP8 implementation of Dot Product Attention."}, | ||
| ) | ||
|
|
||
| fp8_multi_head_attention: bool = field( | ||
| default=False, | ||
| metadata={"help": "Use FP8 implementation of Multi Head Attention."}, | ||
| ) | ||
|
|
||
| tp_only_amax_red: bool = field( | ||
| default=False, | ||
| metadata={"help": "Reduce FP8 AMAX only in TP or TP-CP domain."}, | ||
| ) | ||
|
|
||
| first_last_layers_bf16: bool = field( | ||
| default=False, | ||
| metadata={ | ||
| "help": "Retain first and last N TransformerBlocks in BF16 instead of FP8." | ||
| }, | ||
| ) | ||
|
|
||
| num_layers_at_start_in_bf16: int = field( | ||
| default=1, | ||
| metadata={ | ||
| "help": "Number of layers at start to keep in BF16 when first_last_layers_bf16 is True." | ||
| }, | ||
| ) | ||
|
|
||
| num_layers_at_end_in_bf16: int = field( | ||
| default=1, | ||
| metadata={ | ||
| "help": "Number of layers at end to keep in BF16 when first_last_layers_bf16 is True." | ||
| }, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class SchedulingStrategy: | ||
|
|
@@ -959,6 +1074,7 @@ class SGLangConfig: | |
| # and passed as `model_loader_extra_config` to SGLang. | ||
| enable_multithread_load: bool = False | ||
| enable_fast_load: bool = False | ||
| quantization: str | None = None | ||
|
|
||
| # Use staticmethod to make OmegaConf happy. | ||
| @staticmethod | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| import dataclasses | ||
| import functools | ||
| import gc | ||
| import math | ||
| import os | ||
| from collections.abc import Callable, Iterator | ||
| from concurrent.futures import Future | ||
|
|
@@ -15,6 +16,7 @@ | |
| from megatron.core import tensor_parallel | ||
| from megatron.core.distributed import DistributedDataParallel as DDP | ||
| from megatron.core.distributed import finalize_model_grads | ||
| from megatron.core.fp8_utils import is_float8tensor | ||
| from megatron.core.optimizer import OptimizerConfig as MCoreOptimizerConfig | ||
| from megatron.core.optimizer import get_megatron_optimizer | ||
| from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler | ||
|
|
@@ -26,7 +28,11 @@ | |
| from torchdata.stateful_dataloader import StatefulDataLoader | ||
| from transformers import PretrainedConfig | ||
|
|
||
| from areal.api.alloc_mode import MegatronParallelStrategy, ParallelStrategy | ||
| from areal.api.alloc_mode import ( | ||
| AllocationMode, | ||
| MegatronParallelStrategy, | ||
| ParallelStrategy, | ||
| ) | ||
| from areal.api.cli_args import MicroBatchSpec, TrainEngineConfig | ||
| from areal.api.engine_api import InferenceEngine, TrainEngine | ||
| from areal.api.io_struct import FinetuneSpec, ParamSpec, SaveLoadMeta, WeightUpdateMeta | ||
|
|
@@ -124,6 +130,9 @@ def __init__(self, config: TrainEngineConfig): | |
| self.seed: int = 0 | ||
| self.own_global_group: bool = False | ||
| self.is_offload: bool = False | ||
| self.enable_fp8: bool = self.config.megatron.fp8 is not None | ||
| self.fp8_align_size: int = 16 | ||
| self.quantization_config: dict[str, int | str | list[str]] | None = None | ||
|
|
||
| def create_process_group(self, parallel_strategy: ParallelStrategy | None = None): | ||
| if parallel_strategy is None: | ||
|
|
@@ -189,6 +198,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): | |
| f"update_weight_group_{mpu.get_pipeline_model_parallel_rank()}" | ||
| ) | ||
| self.engine_lock = DistributedLock("train_engine_lock") | ||
| self.alloc_mode: AllocationMode | None = kwargs.get("alloc_mode", None) | ||
|
|
||
| self.tokenizer = load_hf_tokenizer(self.config.path) | ||
| self.bridge = mbridge.AutoBridge.from_pretrained(self.config.path) | ||
|
|
@@ -214,6 +224,12 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): | |
| self.parallel_strategy, self.hf_config, self.tf_config | ||
| ) | ||
|
|
||
| # Get quantization_config from hf_config if available (for FP8 weight updates) | ||
| self.quantization_config = getattr(self.hf_config, "quantization_config", None) | ||
fishcrap marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| self._check_and_apply_fp8_config() | ||
| self._validate_fp8_consistency() | ||
|
|
||
| # initialize mcore (DDP Wrapped) GPTModel | ||
| with self.device: | ||
| models = make_mcore_model( | ||
|
|
@@ -229,6 +245,18 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): | |
| with self.device: | ||
| self._load_model_from_hf(self.config.path) | ||
|
|
||
| # NOTE: When using distributed optimizer, megatron will use the | ||
| # high precision init val to initialize the main parameters for optimizer. | ||
| # However, the high precision init val does not exist for FP8 models. | ||
| # (The high precision init val is random initialization for FP8 models.) | ||
| # So we need to clear the high precision init val here. | ||
| for model in self.model: | ||
| for _, param in model.named_parameters(): | ||
| if hasattr(param, "get_high_precision_init_val"): | ||
| param.clear_high_precision_init_val() | ||
| delattr(param, "get_high_precision_init_val") | ||
| delattr(param, "clear_high_precision_init_val") | ||
|
|
||
| assert self.model, "Megatron models failed to initialize." | ||
| modules = [m.module if isinstance(m, DDP) else m for m in self.model] | ||
| total_params = sum( | ||
|
|
@@ -687,6 +715,60 @@ def onload(self) -> None: | |
| def clear_batches(self, *args): | ||
| """Placeholder method of single-controller API.""" | ||
|
|
||
| def _check_and_apply_fp8_config(self): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should also check
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should also revert the above change |
||
| if self.mcore_config.fp8 is not None: | ||
| self.tf_config.fp8 = self.mcore_config.fp8 | ||
| self.tf_config.fp8_recipe = self.mcore_config.fp8_recipe | ||
| self.tf_config.fp8_param = self.mcore_config.fp8_param | ||
| self.tf_config.fp8_margin = self.mcore_config.fp8_margin | ||
| self.tf_config.fp8_amax_history_len = self.mcore_config.fp8_amax_history_len | ||
| self.tf_config.fp8_amax_compute_algo = ( | ||
| self.mcore_config.fp8_amax_compute_algo | ||
| ) | ||
| self.tf_config.fp8_wgrad = self.mcore_config.fp8_wgrad | ||
| self.tf_config.fp8_dot_product_attention = ( | ||
| self.mcore_config.fp8_dot_product_attention | ||
| ) | ||
| self.tf_config.fp8_multi_head_attention = ( | ||
| self.mcore_config.fp8_multi_head_attention | ||
| ) | ||
| self.tf_config.tp_only_amax_red = self.mcore_config.tp_only_amax_red | ||
| self.tf_config.first_last_layers_bf16 = ( | ||
| self.mcore_config.first_last_layers_bf16 | ||
| ) | ||
| self.tf_config.num_layers_at_start_in_bf16 = ( | ||
| self.mcore_config.num_layers_at_start_in_bf16 | ||
| ) | ||
| self.tf_config.num_layers_at_end_in_bf16 = ( | ||
| self.mcore_config.num_layers_at_end_in_bf16 | ||
| ) | ||
| self.logger.info( | ||
| f"FP8 training enabled: fp8={self.mcore_config.fp8}, " | ||
| f"fp8_recipe={self.mcore_config.fp8_recipe}, " | ||
| f"fp8_param={self.mcore_config.fp8_param}" | ||
| ) | ||
| # fp8_param_gather is passed from make_mcore_model() | ||
|
|
||
| def _validate_fp8_consistency(self): | ||
| """Validate that training and inference precision are consistent. | ||
|
|
||
| If FP8 training is enabled, inference must also use FP8. | ||
| If FP8 training is disabled, inference must not use FP8. | ||
| """ | ||
| train_fp8 = self.mcore_config.fp8 is not None | ||
| inference_fp8 = ( | ||
| self.quantization_config is not None | ||
| and self.quantization_config.get("quant_method", None) == "fp8" | ||
| ) | ||
|
|
||
| if not train_fp8 and inference_fp8 or train_fp8 and not inference_fp8: | ||
| raise RuntimeError( | ||
| "Inconsistent FP8 configuration: " | ||
| "Training and inference must both use FP8 or both not use FP8. " | ||
| f"Training fp8={train_fp8}, " | ||
| f"Inference fp8={inference_fp8}" | ||
| ) | ||
|
|
||
| def _make_parallel_strategy( | ||
| self, parallel_strategy: ParallelStrategy | ||
| ) -> MegatronParallelStrategy: | ||
|
|
@@ -750,6 +832,7 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: | |
| use_distributed_optimizer=self.mcore_config.ddp.use_distributed_optimizer, | ||
| params_dtype=self.dtype, | ||
| clip_grad=self.optimizer_config.gradient_clipping, | ||
| fp8_recipe=self.mcore_config.fp8_recipe, | ||
| ) | ||
| mcore_opt_config.overlap_param_gather_with_optimizer_step = ( | ||
| self.mcore_config.overlap_param_gather_with_optimizer_step | ||
|
|
@@ -812,6 +895,18 @@ def _check_rollout_engine_connected(self) -> None: | |
| " before using rollout/update_weight methods." | ||
| ) | ||
|
|
||
| def _get_inference_ep_config(self) -> dict[str, bool]: | ||
| inference_enable_ep_moe = False | ||
|
|
||
| if self.alloc_mode is not None: | ||
| gen_parallel = self.alloc_mode.gen | ||
| if gen_parallel is not None: | ||
| inference_enable_ep_moe = gen_parallel.ep_size > 1 | ||
|
|
||
| return { | ||
| "inference_enable_ep_moe": inference_enable_ep_moe, | ||
| } | ||
|
|
||
| def _ensure_ready(self) -> None: | ||
| if self.is_offload: | ||
| self.onload() | ||
|
|
@@ -869,16 +964,33 @@ def _impl_update_weight_from_distributed( | |
| param = all_gather_param(name, param) | ||
| param = remove_padding(name, param, self.hf_config.vocab_size) | ||
|
|
||
| if is_float8tensor(param): | ||
| # FP8 is stored as uint8, so element_size is 1 byte | ||
| param_size = param.numel() * 1 | ||
| # Convert TE FP8 to bf16 before convert_to_hf (which will convert to PyTorch FP8) | ||
| param = param.dequantize(dtype=self.dtype) | ||
| else: | ||
| param_size = param.numel() * param.element_size() | ||
|
|
||
| if not self.is_pipeline_parallel_head(): | ||
| return buffer_size | ||
|
|
||
| param_size = param.numel() * param.element_size() | ||
| if buffer_size + param_size > weight_chunked_mem_size: | ||
| self._update_bucket_weights_from_distributed(meta, converted_named_tensors) | ||
| buffer_size = 0 | ||
|
|
||
| # Get inference EP configuration | ||
| inference_ep_config = self._get_inference_ep_config() | ||
|
|
||
| converted_named_tensors.extend( | ||
| convert_to_hf(self.tf_config, self.hf_config.model_type, name, param) | ||
| convert_to_hf( | ||
| self.tf_config, | ||
| self.hf_config.model_type, | ||
| name, | ||
| param, | ||
| quantization_config=self.quantization_config, | ||
| **inference_ep_config, | ||
| ) | ||
| ) | ||
| buffer_size += param_size | ||
| return buffer_size | ||
|
|
@@ -940,10 +1052,20 @@ def _update_bucket_expert_weights_from_distributed( | |
|
|
||
| gathered_params = sum(gathered_params, []) | ||
|
|
||
| # Get inference EP configuration | ||
| inference_ep_config = self._get_inference_ep_config() | ||
|
|
||
| converted_hf_tensors = [] | ||
| for name, param in gathered_params: | ||
| converted_hf_tensors.extend( | ||
| convert_to_hf(self.tf_config, self.hf_config.model_type, name, param) | ||
| convert_to_hf( | ||
| self.tf_config, | ||
| self.hf_config.model_type, | ||
| name, | ||
| param, | ||
| quantization_config=self.quantization_config, | ||
| **inference_ep_config, | ||
| ) | ||
| ) | ||
|
|
||
| self._update_bucket_weights_from_distributed(meta, converted_hf_tensors) | ||
|
|
@@ -960,7 +1082,14 @@ def _impl_update_expert_weight_from_distributed( | |
| param = all_gather_param(name, param) | ||
| param = remove_padding(name, param, self.hf_config.vocab_size) | ||
|
|
||
| param_size = param.numel() * param.element_size() | ||
| if is_float8tensor(param): | ||
| # FP8 is stored as uint8, so element_size is 1 byte | ||
| param_size = param.numel() * 1 | ||
| # Convert TE FP8 to bf16 (will be converted to PyTorch FP8 later in convert_to_hf) | ||
| param = param.dequantize(dtype=self.dtype) | ||
| else: | ||
| param_size = param.numel() * param.element_size() | ||
|
|
||
| if ( | ||
| buffer_size + param_size | ||
| ) * mpu.get_expert_model_parallel_world_size() > weight_chunked_mem_size: | ||
|
|
@@ -1155,6 +1284,11 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList: | |
| # 2. Align sequence lengths to integer multiples of `align_to_multiple_of=tp_size*cp_size*2` | ||
| # to satisfy the requirement of Megatron parallelism. | ||
| align_to_multiple_of = tp_size * cp_size * 2 if cp_size > 1 else tp_size | ||
| align_to_multiple_of = ( | ||
| math.lcm(align_to_multiple_of, self.fp8_align_size) | ||
| if self.enable_fp8 | ||
| else align_to_multiple_of | ||
| ) | ||
| mb_list = pad_mb_list( | ||
| mb_list, | ||
| pad_value=0.0, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it for FSDP or DDP? Does
no_shardmeans no sharding for optimizer states or parameters?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete this field