diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index e24b20d66..7e164e619 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -393,6 +393,12 @@ class DistributedDataParallelConfig: bucket_size: int | None = None average_in_collective: bool = False fp8_param_gather: bool = False + data_parallel_sharding_strategy: str = field( + default="no_shard", + metadata={ + "help": "Sharding strategy for FSDP. Valid values are 'no_shard', 'optim', 'optim_grads', 'optim_grads_params'." + }, + ) @dataclass @@ -446,6 +452,115 @@ class MegatronEngineConfig: distribute_saved_activations: bool | None = None recompute_modules: list[str] | None = None + # MoE + moe_router_dtype: str | None = None + moe_shared_expert_overlap: bool = field( + default=False, + metadata={ + "help": "Enable overlapping between shared expert computations and dispatcher communications. " + "Without this, the shared epxerts execute after the routed experts." + }, + ) + moe_enable_deepep: bool = False + moe_token_dispatcher_type: str = field( + default="alltoall", + metadata={ + "help": "Type of token dispatcher. Options: 'allgather','alltoall' and 'flex'." + }, + ) + moe_permute_fusion: bool = field( + default=False, + metadata={"help": "Fuse token rearrangement ops during token dispatching."}, + ) + + # FP8 Training Configuration + fp8: str | None = field( + default=None, + metadata={ + "help": "Enable FP8 precision training. Options: " + "'e4m3' (uniform e4m3), " + "'hybrid' (e4m3 for activations/weights, e5m2 for output activation gradients)." + }, + ) + + fp8_recipe: str = field( + default="delayed", + metadata={ + "help": "FP8 scaling recipe. Options: 'tensorwise', 'delayed', 'mxfp8' (Blackwell only), 'blockwise'." + }, + ) + + fp8_param: bool = field( + default=False, + metadata={ + "help": "Keep parameters in FP8 precision to save memory. " + "Must be used together with fp8 mode. " + "Not all parameters will be converted to fp8; for example, biases will remain unchanged." + }, + ) + + fp8_margin: int = field( + default=0, + metadata={"help": "Margin for FP8 scaling factor computation."}, + ) + + fp8_amax_history_len: int = field( + default=1, + metadata={ + "help": "Length of amax history window for scaling factor computation." + }, + ) + + fp8_amax_compute_algo: str = field( + default="most_recent", + metadata={ + "help": "Algorithm for choosing amax value. Options: 'max' (largest in history window), 'most_recent'." + }, + ) + + fp8_wgrad: bool = field( + default=True, + metadata={ + "help": "When False, override FP8 config and compute weight gradients in higher precision." + }, + ) + + fp8_dot_product_attention: bool = field( + default=False, + metadata={"help": "Use FP8 implementation of Dot Product Attention."}, + ) + + fp8_multi_head_attention: bool = field( + default=False, + metadata={"help": "Use FP8 implementation of Multi Head Attention."}, + ) + + tp_only_amax_red: bool = field( + default=False, + metadata={"help": "Reduce FP8 AMAX only in TP or TP-CP domain."}, + ) + + first_last_layers_bf16: bool = field( + default=False, + metadata={ + "help": "Retain first and last N TransformerBlocks in BF16 instead of FP8." + }, + ) + + num_layers_at_start_in_bf16: int = field( + default=1, + metadata={ + "help": "Number of layers at start to keep in BF16 when first_last_layers_bf16 is True." + }, + ) + + num_layers_at_end_in_bf16: int = field( + default=1, + metadata={ + "help": "Number of layers at end to keep in BF16 when first_last_layers_bf16 is True." + }, + ) + @dataclass class SchedulingStrategy: @@ -959,6 +1074,7 @@ class SGLangConfig: # and passed as `model_loader_extra_config` to SGLang. enable_multithread_load: bool = False enable_fast_load: bool = False + quantization: str | None = None # Use staticmethod to make OmegaConf happy. @staticmethod diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 2c1c3a94d..9c460356f 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1,6 +1,7 @@ import dataclasses import functools import gc +import math import os from collections.abc import Callable, Iterator from concurrent.futures import Future @@ -15,6 +16,7 @@ from megatron.core import tensor_parallel from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.distributed import finalize_model_grads +from megatron.core.fp8_utils import is_float8tensor from megatron.core.optimizer import OptimizerConfig as MCoreOptimizerConfig from megatron.core.optimizer import get_megatron_optimizer from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler @@ -26,7 +28,11 @@ from torchdata.stateful_dataloader import StatefulDataLoader from transformers import PretrainedConfig -from areal.api.alloc_mode import MegatronParallelStrategy, ParallelStrategy +from areal.api.alloc_mode import ( + AllocationMode, + MegatronParallelStrategy, + ParallelStrategy, +) from areal.api.cli_args import MicroBatchSpec, TrainEngineConfig from areal.api.engine_api import InferenceEngine, TrainEngine from areal.api.io_struct import FinetuneSpec, ParamSpec, SaveLoadMeta, WeightUpdateMeta @@ -124,6 +130,9 @@ def __init__(self, config: TrainEngineConfig): self.seed: int = 0 self.own_global_group: bool = False self.is_offload: bool = False + self.enable_fp8: bool = self.config.megatron.fp8 is not None + self.fp8_align_size: int = 16 + self.quantization_config: dict[str, int | str | list[str]] | None = None def create_process_group(self, parallel_strategy: ParallelStrategy | None = None): if parallel_strategy is None: @@ -189,6 +198,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): f"update_weight_group_{mpu.get_pipeline_model_parallel_rank()}" ) self.engine_lock = DistributedLock("train_engine_lock") + self.alloc_mode: AllocationMode | None = kwargs.get("alloc_mode", None) self.tokenizer = load_hf_tokenizer(self.config.path) self.bridge = mbridge.AutoBridge.from_pretrained(self.config.path) @@ -214,6 +224,12 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): self.parallel_strategy, self.hf_config, self.tf_config ) + # Get quantization_config from hf_config if available (for FP8 weight updates) + self.quantization_config = getattr(self.hf_config, "quantization_config", None) + + self._check_and_apply_fp8_config() + self._validate_fp8_consistency() + # initialize mcore (DDP Wrapped) GPTModel with self.device: models = make_mcore_model( @@ -229,6 +245,18 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): with self.device: self._load_model_from_hf(self.config.path) + # NOTE: When using distributed optimizer, megatron will use the + # high precision init val to initialize the main parameters for optimizer. + # However, the high precision init val does not exist for FP8 models. + # (The high precision init val is random initialization for FP8 models.) + # So we need to clear the high precision init val here. + for model in self.model: + for _, param in model.named_parameters(): + if hasattr(param, "get_high_precision_init_val"): + param.clear_high_precision_init_val() + delattr(param, "get_high_precision_init_val") + delattr(param, "clear_high_precision_init_val") + assert self.model, "Megatron models failed to initialize." modules = [m.module if isinstance(m, DDP) else m for m in self.model] total_params = sum( @@ -687,6 +715,60 @@ def onload(self) -> None: def clear_batches(self, *args): """Placeholder method of single-controller API.""" + def _check_and_apply_fp8_config(self): + if self.mcore_config.fp8 is not None: + self.tf_config.fp8 = self.mcore_config.fp8 + self.tf_config.fp8_recipe = self.mcore_config.fp8_recipe + self.tf_config.fp8_param = self.mcore_config.fp8_param + self.tf_config.fp8_margin = self.mcore_config.fp8_margin + self.tf_config.fp8_amax_history_len = self.mcore_config.fp8_amax_history_len + self.tf_config.fp8_amax_compute_algo = ( + self.mcore_config.fp8_amax_compute_algo + ) + self.tf_config.fp8_wgrad = self.mcore_config.fp8_wgrad + self.tf_config.fp8_dot_product_attention = ( + self.mcore_config.fp8_dot_product_attention + ) + self.tf_config.fp8_multi_head_attention = ( + self.mcore_config.fp8_multi_head_attention + ) + self.tf_config.tp_only_amax_red = self.mcore_config.tp_only_amax_red + self.tf_config.first_last_layers_bf16 = ( + self.mcore_config.first_last_layers_bf16 + ) + self.tf_config.num_layers_at_start_in_bf16 = ( + self.mcore_config.num_layers_at_start_in_bf16 + ) + self.tf_config.num_layers_at_end_in_bf16 = ( + self.mcore_config.num_layers_at_end_in_bf16 + ) + self.logger.info( + f"FP8 training enabled: fp8={self.mcore_config.fp8}, " + f"fp8_recipe={self.mcore_config.fp8_recipe}, " + f"fp8_param={self.mcore_config.fp8_param}" + ) + # fp8_param_gather is passed from make_mcore_model() + + def _validate_fp8_consistency(self): + """Validate that training and inference precision are consistent. + + If FP8 training is enabled, inference must also use FP8. + If FP8 training is disabled, inference must not use FP8. + """ + train_fp8 = self.mcore_config.fp8 is not None + inference_fp8 = ( + self.quantization_config is not None + and self.quantization_config.get("quant_method", None) == "fp8" + ) + + if not train_fp8 and inference_fp8 or train_fp8 and not inference_fp8: + raise RuntimeError( + "Inconsistent FP8 configuration: " + "Training and inference must both use FP8 or both not use FP8. " + f"Training fp8={train_fp8}, " + f"Inference fp8={inference_fp8}" + ) + def _make_parallel_strategy( self, parallel_strategy: ParallelStrategy ) -> MegatronParallelStrategy: @@ -750,6 +832,7 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: use_distributed_optimizer=self.mcore_config.ddp.use_distributed_optimizer, params_dtype=self.dtype, clip_grad=self.optimizer_config.gradient_clipping, + fp8_recipe=self.mcore_config.fp8_recipe, ) mcore_opt_config.overlap_param_gather_with_optimizer_step = ( self.mcore_config.overlap_param_gather_with_optimizer_step @@ -812,6 +895,18 @@ def _check_rollout_engine_connected(self) -> None: " before using rollout/update_weight methods." ) + def _get_inference_ep_config(self) -> dict[str, bool]: + inference_enable_ep_moe = False + + if self.alloc_mode is not None: + gen_parallel = self.alloc_mode.gen + if gen_parallel is not None: + inference_enable_ep_moe = gen_parallel.ep_size > 1 + + return { + "inference_enable_ep_moe": inference_enable_ep_moe, + } + def _ensure_ready(self) -> None: if self.is_offload: self.onload() @@ -869,16 +964,33 @@ def _impl_update_weight_from_distributed( param = all_gather_param(name, param) param = remove_padding(name, param, self.hf_config.vocab_size) + if is_float8tensor(param): + # FP8 is stored as uint8, so element_size is 1 byte + param_size = param.numel() * 1 + # Convert TE FP8 to bf16 before convert_to_hf (which will convert to PyTorch FP8) + param = param.dequantize(dtype=self.dtype) + else: + param_size = param.numel() * param.element_size() + if not self.is_pipeline_parallel_head(): return buffer_size - param_size = param.numel() * param.element_size() if buffer_size + param_size > weight_chunked_mem_size: self._update_bucket_weights_from_distributed(meta, converted_named_tensors) buffer_size = 0 + # Get inference EP configuration + inference_ep_config = self._get_inference_ep_config() + converted_named_tensors.extend( - convert_to_hf(self.tf_config, self.hf_config.model_type, name, param) + convert_to_hf( + self.tf_config, + self.hf_config.model_type, + name, + param, + quantization_config=self.quantization_config, + **inference_ep_config, + ) ) buffer_size += param_size return buffer_size @@ -940,10 +1052,20 @@ def _update_bucket_expert_weights_from_distributed( gathered_params = sum(gathered_params, []) + # Get inference EP configuration + inference_ep_config = self._get_inference_ep_config() + converted_hf_tensors = [] for name, param in gathered_params: converted_hf_tensors.extend( - convert_to_hf(self.tf_config, self.hf_config.model_type, name, param) + convert_to_hf( + self.tf_config, + self.hf_config.model_type, + name, + param, + quantization_config=self.quantization_config, + **inference_ep_config, + ) ) self._update_bucket_weights_from_distributed(meta, converted_hf_tensors) @@ -960,7 +1082,14 @@ def _impl_update_expert_weight_from_distributed( param = all_gather_param(name, param) param = remove_padding(name, param, self.hf_config.vocab_size) - param_size = param.numel() * param.element_size() + if is_float8tensor(param): + # FP8 is stored as uint8, so element_size is 1 byte + param_size = param.numel() * 1 + # Convert TE FP8 to bf16 (will be converted to PyTorch FP8 later in convert_to_hf) + param = param.dequantize(dtype=self.dtype) + else: + param_size = param.numel() * param.element_size() + if ( buffer_size + param_size ) * mpu.get_expert_model_parallel_world_size() > weight_chunked_mem_size: @@ -1155,6 +1284,11 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList: # 2. Align sequence lengths to integer multiples of `align_to_multiple_of=tp_size*cp_size*2` # to satisfy the requirement of Megatron parallelism. align_to_multiple_of = tp_size * cp_size * 2 if cp_size > 1 else tp_size + align_to_multiple_of = ( + math.lcm(align_to_multiple_of, self.fp8_align_size) + if self.enable_fp8 + else align_to_multiple_of + ) mb_list = pad_mb_list( mb_list, pad_value=0.0, diff --git a/areal/models/mcore/hf_load.py b/areal/models/mcore/hf_load.py index 85f6d3cf9..b7c67da28 100644 --- a/areal/models/mcore/hf_load.py +++ b/areal/models/mcore/hf_load.py @@ -12,7 +12,9 @@ from safetensors import safe_open from areal.models.mcore.registry import unwrap_to_gpt_model +from areal.platforms import current_platform from areal.utils import logging +from areal.utils.fp8_utils import dequantize_params logger = logging.getLogger("HF WeightsLoader") @@ -24,6 +26,15 @@ def _get_tp_slice(shape, dim, tp_rank, tp_size) -> tuple: return tuple(res) +def _get_shape(obj) -> list: + """Get shape from either a tensor or PySafeSlice object.""" + if isinstance(obj, torch.Tensor): + return list(obj.shape) + else: + # PySafeSlice object + return obj.get_shape() + + def _weight_to_mcore_tp( hf_config, mcore_weights_name: str, @@ -32,6 +43,7 @@ def _weight_to_mcore_tp( tp_rank: int, tp_size: int, dtype: torch.dtype | None = None, + weight_block_size: list[int, int] | None = None, ) -> torch.Tensor: if ( "self_attention.linear_qkv." in mcore_weights_name @@ -46,7 +58,7 @@ def _weight_to_mcore_tp( group_dim = head_dim * num_attention_heads // num_key_value_heads q, k, v = hf_weights_safe_slice # q k v might be tp split - real_num_key_value_heads = q.get_shape()[0] // group_dim + real_num_key_value_heads = _get_shape(q)[0] // group_dim s = _get_tp_slice((real_num_key_value_heads * group_dim,), 0, tp_rank, tp_size) q = q[s].reshape( real_num_key_value_heads // tp_size, @@ -67,32 +79,36 @@ def _weight_to_mcore_tp( gate, up = hf_weights_safe_slice # chunk 0 for TP split gate = gate[ - _get_tp_slice(gate.get_shape(), dim=0, tp_rank=tp_rank, tp_size=tp_size) + _get_tp_slice(_get_shape(gate), dim=0, tp_rank=tp_rank, tp_size=tp_size) ] - up = up[_get_tp_slice(up.get_shape(), dim=0, tp_rank=tp_rank, tp_size=tp_size)] + up = up[_get_tp_slice(_get_shape(up), dim=0, tp_rank=tp_rank, tp_size=tp_size)] res = torch.cat([gate, up], dim=0) elif "mlp.experts.linear_fc2.weight" in mcore_weights_name: # moe assert len(hf_weights_safe_slice) == 1 x = hf_weights_safe_slice[0] - shape = x.get_shape() + shape = _get_shape(x) # dim 1 chunk - res = x[_get_tp_slice(shape, dim=1, tp_rank=tp_rank, tp_size=tp_size)] + partition_dim = 1 + res = x[ + _get_tp_slice(shape, dim=partition_dim, tp_rank=tp_rank, tp_size=tp_size) + ] else: assert len(hf_weights_safe_slice) == 1 x = hf_weights_safe_slice[0] - if mcore_param_shape == x.get_shape(): - res = x[:] + x_shape = _get_shape(x) + partition_dim = None + if mcore_param_shape == x_shape: + res = x[:] if not isinstance(x, torch.Tensor) else x else: - assert len(x.get_shape()) == len(mcore_param_shape) - for partition_dim, (s1, s2) in enumerate( - zip(x.get_shape(), mcore_param_shape) - ): + assert len(x_shape) == len(mcore_param_shape) + for dim, (s1, s2) in enumerate(zip(x_shape, mcore_param_shape)): if s1 != s2: + partition_dim = dim break # chunk on `partition_dim` res = x[ _get_tp_slice( - x.get_shape(), dim=partition_dim, tp_rank=tp_rank, tp_size=tp_size + x_shape, dim=partition_dim, tp_rank=tp_rank, tp_size=tp_size ) ] if dtype is not None: @@ -115,27 +131,71 @@ def _load_weight_with_bridge_worker( for name in f.keys(): all_slices[name] = f.get_slice(name) + quantization_config = getattr(bridge.hf_config, "quantization_config", None) + for local_name in local_names: hf_names = local_to_hf_map[local_name] param = state_dict[local_name] - if "experts" in local_name: + if "experts" in local_name and "shared_experts" not in local_name: tp_size = mpu.get_expert_tensor_parallel_world_size() tp_rank = mpu.get_expert_tensor_parallel_rank() else: tp_size = mpu.get_tensor_model_parallel_world_size() tp_rank = mpu.get_tensor_model_parallel_rank() + # Get weight_block_size from quantization_config + weight_block_size = None + if quantization_config is not None: + weight_block_size = quantization_config.get("weight_block_size", None) + assert ( + isinstance(weight_block_size, (list, tuple)) + and len(weight_block_size) == 2 + ) + + # Check if any HF weight is FP8 (has _scale_inv suffix) + # If fp8 mode is not enabled in megatron, + # we need to dequantize FP8 weights before converting to mcore format + # Now only support FP8 dequantization + hf_weights_safe_slice = [] + + for hf_name in hf_names: + if "_scale_inv" in hf_name: + continue + hf_slice = all_slices[hf_name] + scale_inv_name = f"{hf_name}_scale_inv" + if scale_inv_name in all_slices: + # HF weight is FP8, dequantize to higher precision (bf16) + # TODO: convert pytorch fp8 to te fp8 directly + device = torch.device(current_platform.device_type) + weight = hf_slice[:].to(device) + scale_inv = all_slices[scale_inv_name][:].to(device) + dequantized_weight = dequantize_params( + weight, + scale_inv, + dst_dtype=bridge.dtype, + quantization_config=quantization_config, + ) + dequantized_weight = dequantized_weight.cpu() + hf_weights_safe_slice.append(dequantized_weight) + else: + hf_weights_safe_slice.append(hf_slice) + + # TODO: check fp type is matched between pytorch and te + param_to_load = _weight_to_mcore_tp( hf_config=bridge.hf_config, mcore_weights_name=local_name, mcore_param_shape=list(param.shape), - hf_weights_safe_slice=[all_slices[hf_name] for hf_name in hf_names], + hf_weights_safe_slice=hf_weights_safe_slice, tp_rank=tp_rank, tp_size=tp_size, dtype=bridge.dtype, + weight_block_size=weight_block_size, ) - # load + + # Load the parameter + # NOTE: for megatron FP8 param, `param.copy_` will do quantization internally param.copy_(param_to_load, non_blocking=True) @@ -272,9 +332,17 @@ def load_weights_from_hf_with_mbridge_fast( if is_critic and "output_layer" in local_name: continue for name in hf_names: + if "_scale_inv" in name: + continue filename = index[name] if filename not in local_to_file_map[local_name]: local_to_file_map[local_name].append(filename) + # Also include the scale_inv file if it exists + scale_inv_name = f"{name}_scale_inv" + if scale_inv_name in index: + scale_inv_filename = index[scale_inv_name] + if scale_inv_filename not in local_to_file_map[local_name]: + local_to_file_map[local_name].append(scale_inv_filename) grouped_local_names, grouped_filenames = make_filename_bins(local_to_file_map) diff --git a/areal/models/mcore/hf_save.py b/areal/models/mcore/hf_save.py index 357dc6d2c..dfc3fd84d 100644 --- a/areal/models/mcore/hf_save.py +++ b/areal/models/mcore/hf_save.py @@ -10,12 +10,14 @@ from mbridge.core import Bridge from mbridge.core.util import unwrap_model from megatron.core import parallel_state as mpu +from megatron.core.fp8_utils import is_float8tensor from safetensors.torch import save_file from torch.distributed._functional_collectives import all_gather_into_tensor_coalesced from areal.models.mcore.registry import unwrap_to_gpt_model from areal.platforms import current_platform from areal.utils import logging +from areal.utils.fp8_utils import quantize_params logger = logging.getLogger("HF WeightsSaver") @@ -250,6 +252,7 @@ def save_weights_to_hf_with_mbridge_fast( non_expert_sd = {} _all_gather_specs = [] all_gather_outputs = {} + quantization_config = getattr(bridge.hf_config, "quantization_config", None) for s in non_expert_specs: if s.tensor_model_parallel and mpu.get_tensor_model_parallel_world_size() > 1: _all_gather_specs.append(s) @@ -262,6 +265,7 @@ def save_weights_to_hf_with_mbridge_fast( all_gather_outputs[s.global_name] = gathered_param for s in non_expert_specs: param = s.param + if s.tensor_model_parallel: # allocate a new tensor with proper size if mpu.get_tensor_model_parallel_world_size() <= 1: @@ -270,14 +274,28 @@ def save_weights_to_hf_with_mbridge_fast( infer_params = all_gather_outputs[s.global_name].chunk( mpu.get_tensor_model_parallel_world_size(), dim=0 ) + infer_params = [ + p.dequantize(dtype=bridge.dtype) if is_float8tensor(p) else p + for p in infer_params + ] infer_params = bridge._weight_merge_across_tp( s.global_name, infer_params, param ) else: infer_params = param + if is_float8tensor(infer_params): + infer_params = infer_params.dequantize(dtype=bridge.dtype) converted_names, converted_params = bridge._weight_to_hf_format( s.global_name, infer_params ) + # Apply quantization if quantization_config is present + if quantization_config is not None: + converted_named_params = list(zip(converted_names, converted_params)) + quantized_named_params = quantize_params( + s.global_name, converted_named_params, quantization_config + ) + converted_names = [name for name, _ in quantized_named_params] + converted_params = [param for _, param in quantized_named_params] for n, p in zip(converted_names, converted_params): assert n not in non_expert_sd, n non_expert_sd[n] = p @@ -370,10 +388,23 @@ def _save_one_shard(x): params = all_gather_outputs[s.global_name].chunk(etp_size, dim=0) else: params = [param] + + params = [ + p.dequantize(dtype=bridge.dtype) if is_float8tensor(p) else p + for p in params + ] merge_params = bridge._weight_merge_across_tp(s.global_name, params, param) converted_names, converted_params = bridge._weight_to_hf_format( s.global_name, merge_params ) + # Apply quantization if quantization_config is present + if quantization_config is not None: + converted_named_params = list(zip(converted_names, converted_params)) + quantized_named_params = quantize_params( + s.global_name, converted_named_params, quantization_config + ) + converted_names = [name for name, _ in quantized_named_params] + converted_params = [param for _, param in quantized_named_params] for n, p in zip(converted_names, converted_params): assert n not in expert_sd, n expert_sd[n] = p diff --git a/areal/models/mcore/registry.py b/areal/models/mcore/registry.py index 59b1e7c2c..bfabf241e 100644 --- a/areal/models/mcore/registry.py +++ b/areal/models/mcore/registry.py @@ -1,6 +1,7 @@ import dataclasses import torch +from mbridge.core.bridge import Bridge from megatron.core import tensor_parallel from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.distributed import DistributedDataParallelConfig as MCoreDDPConfig @@ -131,7 +132,7 @@ def make_mcore_model( hf_config: PretrainedConfig, tf_config: TransformerConfig, mcore_config: MegatronEngineConfig | None = None, - bridge=None, + bridge: Bridge | None = None, is_critic: bool = False, ) -> list[GPTModel | DDP]: if bridge is not None: diff --git a/areal/tests/fp8/comparison_utils.py b/areal/tests/fp8/comparison_utils.py new file mode 100644 index 000000000..09aab231e --- /dev/null +++ b/areal/tests/fp8/comparison_utils.py @@ -0,0 +1,295 @@ +"""Comparison utilities for FP8/BF16 comparison tests. + +This module contains reusable functions for comparing tensors, activations, +gradients, and other model outputs between FP8 and BF16 models. +""" + +from collections import defaultdict +from typing import Any + +import torch +import torch.nn.functional as F + +from areal.tests.fp8.model_hooks import categorize_op_name +from areal.utils import logging + +logger = logging.getLogger("FP8 BF16 Comparison Utils") + + +def compare_tensors( + tensor_bf16: torch.Tensor, + tensor_fp8: torch.Tensor, + name: str = "tensor", + check_nan_inf: bool = False, + check_zero_norm: bool = False, +) -> dict[str, Any]: + """Compare two tensors and return statistics. + + Args: + tensor_bf16: BF16 tensor + tensor_fp8: FP8 tensor + name: Name identifier for logging + check_nan_inf: Whether to check for NaN/Inf values + check_zero_norm: Whether to check for zero norm + + Returns: + Dictionary with comparison statistics: + - max_diff: Maximum absolute difference + - mean_diff: Mean absolute difference + - cos_sim: Cosine similarity + - bf16_norm: Norm of BF16 tensor + - fp8_norm: Norm of FP8 tensor + - has_nan: Whether any tensor has NaN + - has_inf: Whether any tensor has Inf + - zero_norm: Whether any tensor has zero norm + """ + result = { + "name": name, + "shape_match": tensor_bf16.shape == tensor_fp8.shape, + } + + if not result["shape_match"]: + logger.warning( + f"{name} shapes don't match: BF16={tensor_bf16.shape}, FP8={tensor_fp8.shape}" + ) + return result + + # Calculate differences + diff = (tensor_bf16 - tensor_fp8).abs() + result["max_diff"] = diff.max().item() + result["mean_diff"] = diff.mean().item() + + # Calculate norms + bf16_norm = tensor_bf16.norm().item() + fp8_norm = tensor_fp8.norm().item() + result["bf16_norm"] = bf16_norm + result["fp8_norm"] = fp8_norm + + # Check for NaN/Inf + if check_nan_inf: + bf16_has_nan = torch.isnan(tensor_bf16).any().item() + bf16_has_inf = torch.isinf(tensor_bf16).any().item() + fp8_has_nan = torch.isnan(tensor_fp8).any().item() + fp8_has_inf = torch.isinf(tensor_fp8).any().item() + result["has_nan"] = bf16_has_nan or fp8_has_nan + result["has_inf"] = bf16_has_inf or fp8_has_inf + + if result["has_nan"] or result["has_inf"]: + logger.warning( + f"{name} has NaN/Inf: " + f"BF16 NaN={bf16_has_nan}, Inf={bf16_has_inf}, " + f"FP8 NaN={fp8_has_nan}, Inf={fp8_has_inf}" + ) + + # Check for zero norm + if check_zero_norm: + result["zero_norm"] = bf16_norm == 0.0 or fp8_norm == 0.0 + if result["zero_norm"]: + logger.warning( + f"{name} has zero norm: BF16 norm={bf16_norm:.6e}, FP8 norm={fp8_norm:.6e}" + ) + + # Calculate cosine similarity + if check_zero_norm and result.get("zero_norm", False): + result["cos_sim"] = 0.0 + else: + tensor_bf16_flat = tensor_bf16.flatten() + tensor_fp8_flat = tensor_fp8.flatten() + cos_sim = F.cosine_similarity( + tensor_bf16_flat.unsqueeze(0), tensor_fp8_flat.unsqueeze(0), dim=1 + ).item() + + if torch.isnan(torch.tensor(cos_sim)): + logger.warning(f"{name} cosine similarity is NaN, setting to 0.0") + cos_sim = 0.0 + + result["cos_sim"] = cos_sim + + return result + + +def compare_tensors_dict( + dict_bf16: dict[str, torch.Tensor], + dict_fp8: dict[str, torch.Tensor], + title: str = "Comparison", + check_nan_inf: bool = False, + check_zero_norm: bool = False, + group_by_op_type: bool = True, + name_width: int = 50, +) -> dict[str, Any]: + """Compare two dictionaries of tensors and return statistics grouped by operation type. + + Args: + dict_bf16: Dictionary of BF16 tensors + dict_fp8: Dictionary of FP8 tensors + title: Title for logging + check_nan_inf: Whether to check for NaN/Inf values + check_zero_norm: Whether to check for zero norm + group_by_op_type: Whether to group statistics by operation type + name_width: Width for name formatting in logs + + Returns: + Dictionary with comparison statistics: + - stats_by_type: Statistics grouped by operation type + - individual_stats: Individual tensor statistics + """ + logger.info("\n" + "=" * 80) + logger.info(f"{title} by Operation Type") + logger.info("=" * 80) + + stats_by_type = defaultdict( + lambda: {"max_diffs": [], "mean_diffs": [], "cos_sims": [], "names": []} + ) + individual_stats = {} + + common_names = set(dict_bf16.keys()) & set(dict_fp8.keys()) + for name in sorted(common_names): + tensor_bf16 = dict_bf16[name] + tensor_fp8 = dict_fp8[name] + + # Skip None values + if tensor_bf16 is None or tensor_fp8 is None: + continue + + # Compare tensors + comparison = compare_tensors( + tensor_bf16, + tensor_fp8, + name=name, + check_nan_inf=check_nan_inf, + check_zero_norm=check_zero_norm, + ) + + if not comparison["shape_match"]: + continue + + individual_stats[name] = comparison + + # Group by operation type if requested + if group_by_op_type: + op_type = categorize_op_name(name) + stats_by_type[op_type]["max_diffs"].append(comparison["max_diff"]) + stats_by_type[op_type]["mean_diffs"].append(comparison["mean_diff"]) + stats_by_type[op_type]["cos_sims"].append(comparison["cos_sim"]) + stats_by_type[op_type]["names"].append(name) + + # Format with fixed width for alignment + name_str = f"{name} ({op_type})" + logger.info( + f"{name_str:<{name_width}} " + f"max_diff={comparison['max_diff']:>12.6f}, " + f"mean_diff={comparison['mean_diff']:>12.6f}, " + f"cos_sim={comparison['cos_sim']:>10.6f}" + ) + else: + logger.info( + f"{name:<{name_width}} " + f"max_diff={comparison['max_diff']:>12.6f}, " + f"mean_diff={comparison['mean_diff']:>12.6f}, " + f"cos_sim={comparison['cos_sim']:>10.6f}" + ) + + # Summary by op type + if group_by_op_type and stats_by_type: + logger.info("\n" + "-" * 80) + logger.info(f"{title} Summary by Operation Type") + logger.info("-" * 80) + for op_type in sorted(stats_by_type.keys()): + stats = stats_by_type[op_type] + if stats["max_diffs"]: + max_diff_val = max(stats["max_diffs"]) + mean_diff_val = sum(stats["mean_diffs"]) / len(stats["mean_diffs"]) + cos_sim_val = sum(stats["cos_sims"]) / len(stats["cos_sims"]) + logger.info( + f"{op_type:<50} " + f"max_diff={max_diff_val:>12.6f}, " + f"mean_diff={mean_diff_val:>12.6f}, " + f"cos_sim={cos_sim_val:>10.6f}, " + f"n_ops={len(stats['names']):>4}" + ) + + return { + "stats_by_type": dict(stats_by_type), + "individual_stats": individual_stats, + } + + +def compare_logits( + logits_bf16: torch.Tensor, + logits_fp8: torch.Tensor, +) -> dict[str, Any]: + """Compare logits between BF16 and FP8 models. + + Args: + logits_bf16: BF16 logits tensor + logits_fp8: FP8 logits tensor + + Returns: + Dictionary with comparison statistics + """ + logger.info("\n" + "=" * 80) + logger.info("Logits Comparison") + logger.info("=" * 80) + + comparison = compare_tensors(logits_bf16, logits_fp8, name="logits") + + if comparison["shape_match"]: + logger.info(f"Logits max diff: {comparison['max_diff']:.6f}") + logger.info(f"Logits mean diff: {comparison['mean_diff']:.6f}") + logger.info(f"Logits cosine similarity: {comparison['cos_sim']:.6f}") + else: + logger.warning( + f"Logits shapes don't match: BF16={logits_bf16.shape}, FP8={logits_fp8.shape}" + ) + + return comparison + + +def find_problematic_operations( + stats_by_type: dict[str, dict[str, list]], + threshold: float = 0.95, +) -> list[tuple[str, str, float, float]]: + """Find operations with cosine similarity below threshold. + + Args: + stats_by_type: Statistics grouped by operation type + threshold: Cosine similarity threshold + + Returns: + List of tuples: (op_type, name, cos_sim, max_diff) + """ + problematic = [] + for op_type, stats in stats_by_type.items(): + for i, (name, cos_sim) in enumerate(zip(stats["names"], stats["cos_sims"])): + if cos_sim < threshold: + problematic.append((op_type, name, cos_sim, stats["max_diffs"][i])) + return sorted(problematic, key=lambda x: x[2]) # Sort by cos_sim + + +def log_problematic_operations( + stats_by_type: dict[str, dict[str, list]], + threshold: float = 0.95, + title: str = "Problematic Operations", +): + """Log operations with cosine similarity below threshold. + + Args: + stats_by_type: Statistics grouped by operation type + threshold: Cosine similarity threshold + title: Title for logging + """ + problematic = find_problematic_operations(stats_by_type, threshold) + + logger.info("\n" + "=" * 80) + logger.info(f"{title} (low cosine similarity, threshold={threshold})") + logger.info("=" * 80) + + if problematic: + for op_type, name, cos_sim, max_diff in problematic: + logger.info( + f" {name} ({op_type}): cos_sim={cos_sim:.6f}, max_diff={max_diff:.6f}" + ) + else: + logger.info(f"No problematic operations found (all cos_sim >= {threshold})") + + logger.info("=" * 80) diff --git a/areal/tests/fp8/engine_utils.py b/areal/tests/fp8/engine_utils.py new file mode 100644 index 000000000..3566d33f0 --- /dev/null +++ b/areal/tests/fp8/engine_utils.py @@ -0,0 +1,459 @@ +"""Shared utilities for FP8/BF16 comparison tests. + +This module contains common helper functions, fixtures, and constants +used across multiple FP8/BF16 comparison test files. +""" + +import os +from collections import defaultdict +from typing import Any + +import torch +import torch.nn.functional as F +from megatron.core import parallel_state as mpu + +from areal.api.alloc_mode import AllocationMode +from areal.api.cli_args import ( + MegatronEngineConfig, + OptimizerConfig, + TrainEngineConfig, +) +from areal.api.io_struct import FinetuneSpec +from areal.engine.core.train_engine import reorder_and_pad_outputs +from areal.engine.megatron_engine import MegatronEngine +from areal.utils import logging +from areal.utils.data import ( + broadcast_tensor, + pack_tensor_dict, +) +from areal.utils.functional import gather_logprobs + +logger = logging.getLogger("FP8 BF16 Comparison Utils") + + +def extract_gemm_kernels(profiler, phase: str = "forward"): + """Extract and summarize GEMM-related kernels from profiler output. + + Args: + profiler: torch.profiler.profile instance + phase: Phase name ("forward" or "backward") + + Returns: + Dictionary with gemm kernel statistics + """ + gemm_keywords = ["gemm", "cublas", "cutlass", "matmul", "mm", "bmm"] + + gemm_events = [] + + # Get all events from profiler - iterate through all events to find CUDA kernels + try: + # Try to get events() which gives us raw events + all_events = list(profiler.events()) + except Exception: + # Fallback to key_averages() if events() is not available + all_events = list(profiler.key_averages()) + + for event in all_events: + # Get event name - try different attributes + event_name = None + if hasattr(event, "key"): + event_name = event.key + elif hasattr(event, "name"): + event_name = event.name + elif hasattr(event, "__str__"): + event_name = str(event) + else: + continue + + # Check if this is a CUDA kernel event + # CUDA kernels typically have specific attributes + is_cuda_kernel = False + if hasattr(event, "is_cuda") and event.is_cuda: + is_cuda_kernel = True + elif ( + hasattr(event, "device_type") and event.device_type == 1 + ): # CUDA device type + is_cuda_kernel = True + elif "cuda" in str(type(event)).lower() or "kernel" in event_name.lower(): + is_cuda_kernel = True + + # Check if this is a gemm-related kernel + event_name_lower = event_name.lower() + if is_cuda_kernel and any( + keyword.lower() in event_name_lower for keyword in gemm_keywords + ): + # Extract kernel information + kernel_info = { + "name": event_name, + "duration_us": 0.0, + "count": 1, + } + + # Try to get CUDA time (in microseconds) + if hasattr(event, "cuda_time_total"): + kernel_info["duration_us"] = event.cuda_time_total / 1000.0 + elif hasattr(event, "cuda_time"): + kernel_info["duration_us"] = event.cuda_time / 1000.0 + elif hasattr(event, "self_cuda_time_total"): + kernel_info["duration_us"] = event.self_cuda_time_total / 1000.0 + elif hasattr(event, "self_cuda_time"): + kernel_info["duration_us"] = event.self_cuda_time / 1000.0 + + # Try to get count + if hasattr(event, "count"): + kernel_info["count"] = event.count + + # Try to get input shapes if available + if hasattr(event, "input_shapes") and event.input_shapes: + kernel_info["input_shapes"] = event.input_shapes + elif hasattr(event, "shapes") and event.shapes: + kernel_info["input_shapes"] = event.shapes + + gemm_events.append(kernel_info) + + # Also check key_averages for aggregated view + try: + key_avgs = profiler.key_averages() + for event in key_avgs: + event_name = None + if hasattr(event, "key"): + event_name = event.key + elif hasattr(event, "name"): + event_name = event.name + else: + continue + + event_name_lower = event_name.lower() + # Check if this is a gemm-related operation (may be at higher level) + if any(keyword.lower() in event_name_lower for keyword in gemm_keywords): + # Check if we already have this in gemm_events + if not any(e["name"] == event_name for e in gemm_events): + kernel_info = { + "name": event_name, + "duration_us": 0.0, + "count": 1, + } + + if hasattr(event, "cuda_time_total"): + kernel_info["duration_us"] = event.cuda_time_total / 1000.0 + elif hasattr(event, "self_cuda_time_total"): + kernel_info["duration_us"] = event.self_cuda_time_total / 1000.0 + + if hasattr(event, "count"): + kernel_info["count"] = event.count + + if hasattr(event, "input_shapes") and event.input_shapes: + kernel_info["input_shapes"] = event.input_shapes + + gemm_events.append(kernel_info) + except Exception: + pass + + # Group by kernel name + kernel_stats = defaultdict( + lambda: {"count": 0, "total_time_us": 0.0, "input_shapes": []} + ) + + for event in gemm_events: + name = event["name"] + kernel_stats[name]["count"] += event["count"] + kernel_stats[name]["total_time_us"] += event["duration_us"] + if "input_shapes" in event and event["input_shapes"]: + kernel_stats[name]["input_shapes"].extend(event["input_shapes"]) + + # Calculate averages + result = { + "phase": phase, + "total_gemm_kernels": len(gemm_events), + "unique_kernel_names": len(kernel_stats), + "kernels": {}, + } + + for name, stats in kernel_stats.items(): + result["kernels"][name] = { + "count": stats["count"], + "total_time_us": stats["total_time_us"], + "avg_time_us": stats["total_time_us"] / stats["count"] + if stats["count"] > 0 + else 0, + "input_shapes": list(set(str(s) for s in stats["input_shapes"][:5])) + if stats["input_shapes"] + else [], + } + + return result + + +def print_gemm_profile(profile_result: dict): + """Print gemm profiling results in a readable format.""" + logger.info("=" * 80) + logger.info(f"GEMM Kernel Profile - {profile_result['phase'].upper()}") + logger.info("=" * 80) + logger.info(f"Total GEMM kernels found: {profile_result['total_gemm_kernels']}") + logger.info(f"Unique kernel names: {profile_result['unique_kernel_names']}") + logger.info("") + + if not profile_result["kernels"]: + logger.info("No GEMM kernels found in this phase.") + return + + # Sort by total time + sorted_kernels = sorted( + profile_result["kernels"].items(), + key=lambda x: x[1]["total_time_us"], + reverse=True, + ) + + logger.info("GEMM Kernels (sorted by total time):") + logger.info("-" * 80) + for i, (name, stats) in enumerate(sorted_kernels, 1): + logger.info(f"{i}. {name}") + logger.info(f" Count: {stats['count']}") + logger.info( + f" Total time: {stats['total_time_us']:.2f} us ({stats['total_time_us'] / 1000:.2f} ms)" + ) + logger.info(f" Avg time: {stats['avg_time_us']:.2f} us") + if stats["input_shapes"]: + logger.info(f" Sample shapes: {', '.join(stats['input_shapes'])}") + logger.info("") + + total_time = sum(s["total_time_us"] for s in profile_result["kernels"].values()) + logger.info(f"Total GEMM time: {total_time:.2f} us ({total_time / 1000:.2f} ms)") + logger.info("=" * 80) + + +def create_engine( + model_path: str, + fp8_enabled: bool = False, + fp8_param: bool = False, + port: int = 7777, +) -> MegatronEngine: + """Create and initialize a MegatronEngine.""" + os.environ.update( + { + "WORLD_SIZE": "1", + "RANK": "0", + "LOCAL_RANK": "0", + "MASTER_ADDR": "localhost", + "MASTER_PORT": str(port), + } + ) + + megatron_config = MegatronEngineConfig() + if fp8_enabled: + megatron_config.fp8 = "e4m3" + megatron_config.fp8_param = fp8_param + megatron_config.fp8_recipe = "blockwise" + megatron_config.ddp.fp8_param_gather = True + + config = TrainEngineConfig( + experiment_name="test", + trial_name="test", + path=model_path, + optimizer=OptimizerConfig(), + megatron=megatron_config, + ) + alloc_mode = AllocationMode.from_str("d1p1t1") + ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=128, train_batch_size=8) + engine = MegatronEngine(config) + engine.create_process_group(alloc_mode.train) + engine.initialize(addr=None, ft_spec=ft_spec) + return engine + + +@torch.no_grad() +def forward_with_logits_and_logprobs( + engine: MegatronEngine, input_: dict[str, Any], profile_gemm: bool = False +) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass that returns both logits and logprobs. + + Args: + engine: MegatronEngine instance + input_: Input dictionary + profile_gemm: If True, profile GEMM kernels during forward pass + + Returns: + tuple: (logits, logprobs) both with shape [batch, seq_len, ...] + """ + engine.eval() + engine._ensure_ready() + + # Prepare input similar to forward_batch method + cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"] + output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist() + + # Prepare micro-batches + mb_list = engine._prepare_mb_list(input_).to(engine.device) + + # Collect logits and logprobs from forward pass + logits_list: list[torch.Tensor] = [] + logprobs_list: list[torch.Tensor] = [] + + def process_output(output: torch.Tensor, inputs: dict[str, Any]) -> None: + """Process output to extract logits and logprobs.""" + # output is already unpad_logits'd by forward_backward_batch + labels = torch.roll(inputs["input_ids"], shifts=-1, dims=-1) + logprobs = gather_logprobs( + output, + labels, + temperature=engine.config.temperature, + tp_group=mpu.get_tensor_model_parallel_group() + if mpu.get_tensor_model_parallel_world_size() > 1 + else None, + ) + logits_list.append(output) + logprobs_list.append(logprobs) + return None + + # Profile GEMM kernels if requested + if profile_gemm: + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CUDA, + torch.profiler.ProfilerActivity.CPU, + ], + record_shapes=True, + with_stack=False, + profile_memory=False, + ) as prof: + engine.forward_backward_batch(mb_list, process_output, forward_only=True) + torch.cuda.synchronize() + + # Extract and print GEMM kernels + gemm_profile = extract_gemm_kernels(prof, phase="forward") + print_gemm_profile(gemm_profile) + else: + engine.forward_backward_batch(mb_list, process_output, forward_only=True) + + # Aggregate, reorder, and pad outputs + logits = None + logprobs = None + if mpu.is_pipeline_last_stage(): + if logits_list: + logits = reorder_and_pad_outputs( + logits_list, output_seqlens, mb_list, aggregate_fn=torch.cat + ) + logprobs = reorder_and_pad_outputs( + logprobs_list, output_seqlens, mb_list, aggregate_fn=torch.cat + ) + + # Broadcast results + logits = broadcast_tensor( + logits, + src_rank=mpu.get_pipeline_model_parallel_last_rank(), + group=mpu.get_pipeline_model_parallel_group(), + ) + logprobs = broadcast_tensor( + logprobs, + src_rank=mpu.get_pipeline_model_parallel_last_rank(), + group=mpu.get_pipeline_model_parallel_group(), + ) + + return logits, logprobs + + +def decode_with_megatron_forward( + engine: MegatronEngine, + prompt: str, + max_new_tokens: int = 50, + temperature: float = 1.0, + top_k: int | None = None, + top_p: float | None = None, +) -> str: + """Decode using Megatron forward pass for autoregressive generation. + + Args: + engine: MegatronEngine instance + prompt: Input prompt text + max_new_tokens: Maximum number of tokens to generate + temperature: Sampling temperature + top_k: Top-k sampling (None for no limit) + top_p: Top-p (nucleus) sampling (None for no limit) + + Returns: + Generated text (prompt + generated tokens) + """ + engine.eval() + if engine.is_offload: + engine.onload() + + assert engine.model is not None, "Model is not initialized." + assert engine.tokenizer is not None, "Tokenizer is not initialized." + + # Encode prompt + encoded = engine.tokenizer(prompt, return_tensors="pt") + input_ids = encoded["input_ids"].to(engine.device) + generated_ids = input_ids.clone() + + # Generate tokens autoregressively + for step in range(max_new_tokens): + # Prepare input dict + batch_size = generated_ids.shape[0] + seq_len = generated_ids.shape[1] + attention_mask = torch.ones( + (batch_size, seq_len), dtype=torch.bool, device=engine.device + ) + + input_dict = { + "input_ids": generated_ids, + "attention_mask": attention_mask, + } + + # Forward pass to get logits + logits, _ = forward_with_logits_and_logprobs(engine, input_dict) + + # Get logits for the last token position + # logits shape: [batch, seq_len, vocab_size] + next_token_logits = logits[:, -1, :] # [batch, vocab_size] + + # Apply temperature + if temperature != 1.0: + next_token_logits = next_token_logits / temperature + + # Apply top-k filtering + if top_k is not None and top_k > 0: + indices_to_remove = ( + next_token_logits + < torch.topk(next_token_logits, top_k)[0][..., -1, None] + ) + next_token_logits[indices_to_remove] = float("-inf") + + # Apply top-p (nucleus) filtering + if top_p is not None and top_p < 1.0: + sorted_logits, sorted_indices = torch.sort( + next_token_logits, descending=True + ) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ + ..., :-1 + ].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + next_token_logits[indices_to_remove] = float("-inf") + + # Sample next token + probs = F.softmax(next_token_logits, dim=-1) + next_token_id = torch.multinomial(probs, num_samples=1) # [batch, 1] + + # Append to generated sequence + generated_ids = torch.cat([generated_ids, next_token_id], dim=1) + + # Check for EOS token + eos_token_id = getattr(engine.tokenizer, "eos_token_id", None) + if eos_token_id is not None and next_token_id[0, 0].item() == eos_token_id: + logger.info("EOS token generated, stopping.") + break + + # Decode full sequence + generated_text = engine.tokenizer.decode( + generated_ids[0], skip_special_tokens=False + ) + + return generated_text diff --git a/areal/tests/fp8/model_hooks.py b/areal/tests/fp8/model_hooks.py new file mode 100644 index 000000000..36a56c595 --- /dev/null +++ b/areal/tests/fp8/model_hooks.py @@ -0,0 +1,593 @@ +"""Model manipulation utilities for FP8/BF16 comparison tests. + +This module contains functions for extracting layers, reducing models, +and collecting activations/gradients using hooks. +""" + +from typing import Any + +import torch +from megatron.core import parallel_state as mpu + +from areal.engine.core.train_engine import compute_total_loss_weight +from areal.engine.megatron_engine import MegatronEngine +from areal.tests.fp8.engine_utils import ( + extract_gemm_kernels, + print_gemm_profile, +) +from areal.utils import logging +from areal.utils.megatron import get_named_parameters + +logger = logging.getLogger("FP8 BF16 Model Utils") + + +def get_model_from_engine(engine: MegatronEngine): + """Get the actual model module from engine, unwrapping DDP and Float16Module.""" + assert engine.model is not None, "Model is not initialized." + model = engine.model[0] + if hasattr(model, "module"): + model = model.module + # Handle Float16Module wrapper + if hasattr(model, "module"): + model = model.module + return model + + +def reduce_model_to_layers(engine: MegatronEngine, layer_indices: list[int] | int): + """Reduce the model to specified transformer layers while keeping full structure. + + This function modifies the model in-place by replacing decoder.layers (ModuleList) + with a new ModuleList containing only the specified layers. This allows the model + to maintain its full structure (embedding, rotary_pos_emb, final_layernorm, output_layer) + so that forward pass and loss computation work correctly. + + Args: + engine: MegatronEngine instance + layer_indices: Index or list of indices of layers to keep (0-based). + If int, keeps only that layer. If list, keeps multiple layers. + + Returns: + The original number of layers (for potential restoration) + """ + model = get_model_from_engine(engine) + + # Get decoder + decoder = None + if hasattr(model, "decoder"): + decoder = model.decoder + elif hasattr(model, "module") and hasattr(model.module, "decoder"): + decoder = model.module.decoder + + if decoder is None or not hasattr(decoder, "layers"): + raise ValueError("Cannot find decoder.layers") + + original_layers = decoder.layers + original_num_layers = len(original_layers) + + # Convert single int to list + if isinstance(layer_indices, int): + layer_indices = [layer_indices] + + # Validate layer indices + for layer_idx in layer_indices: + if layer_idx >= original_num_layers: + raise ValueError( + f"Layer index {layer_idx} out of range. Model has {original_num_layers} layers." + ) + + # Remove duplicates and sort to maintain order + layer_indices = sorted(list(set(layer_indices))) + + # Create new ModuleList with only the specified layers + selected_layers = [original_layers[idx] for idx in layer_indices] + new_layers = torch.nn.ModuleList(selected_layers) + + # Replace the layers ModuleList + decoder.layers = new_layers + + if len(layer_indices) == 1: + logger.info( + f"Reduced model from {original_num_layers} layers to 1 layer (keeping layer {layer_indices[0]})" + ) + else: + logger.info( + f"Reduced model from {original_num_layers} layers to {len(layer_indices)} layers (keeping layers {layer_indices})" + ) + + return original_num_layers + + +def collect_gradients_after_train_batch( + engine: MegatronEngine, input_: dict[str, Any], profile_gemm: bool = False +) -> dict[str, torch.Tensor]: + """Execute train_batch but collect gradients before optimizer.step(). + + This function replicates the train_batch logic but stops before optimizer.step() + to collect gradients for comparison. + + Args: + engine: MegatronEngine instance + input_: Input dictionary + profile_gemm: If True, profile GEMM kernels during forward and backward pass + + Returns: + Dictionary mapping parameter names to their gradients. + """ + engine._ensure_ready() + assert engine.optimizer is not None, "Optimizer is not initialized." + engine.optimizer.zero_grad() + for model in engine.model: + model.zero_grad_buffer() + + # Step 1: Prepare micro-batches + mb_list = engine._prepare_mb_list(input_).to(engine.device) + + # Step 2: Define loss functions + def sft_loss_fn(logprobs, entropy, input_): + """SFT loss function based on compute_packed_sft_loss.""" + del entropy # SFT does not use entropy + + # Get loss_mask from input + loss_mask = input_["loss_mask"].bool() + + # Shift loss_mask to align with next-token prediction + loss_mask = torch.roll(loss_mask, shifts=-1, dims=-1) + + # Apply loss_mask to logprobs + logprobs = torch.where(loss_mask, logprobs, 0) + + # Compute loss: negative log likelihood averaged over valid tokens + device = logprobs.device + num_valid = loss_mask.count_nonzero() + if num_valid == 0: + return torch.tensor(0.0, device=device, requires_grad=True) + + loss = -logprobs.sum() / num_valid + return loss + + def loss_weight_fn(mb): + """Loss weight function based on number of valid tokens.""" + return mb["loss_mask"].count_nonzero() + + # Step 3: Compute total loss weight + total_loss_weight = compute_total_loss_weight( + mb_list, loss_weight_fn, mpu.get_data_parallel_group() + ) + + # Step 4: Forward-backward using Megatron's pipeline function + loss_multiplier = ( + mpu.get_data_parallel_world_size() * engine.optimizer.get_loss_scale().item() + ) + + def process_output(output: torch.Tensor, inputs: dict[str, Any]) -> torch.Tensor: + return engine._compute_logprobs_and_loss( + output, + inputs, + loss_fn=sft_loss_fn, + loss_weight_fn=loss_weight_fn, + total_loss_weight=total_loss_weight, + loss_multiplier=loss_multiplier, + ) + + # Profile GEMM kernels if requested + if profile_gemm: + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CUDA, + torch.profiler.ProfilerActivity.CPU, + ], + record_shapes=True, + with_stack=False, + profile_memory=False, + ) as prof: + engine.forward_backward_batch(mb_list, process_output, forward_only=False) + torch.cuda.synchronize() + + # Extract and print GEMM kernels + gemm_profile = extract_gemm_kernels(prof, phase="backward") + print_gemm_profile(gemm_profile) + else: + engine.forward_backward_batch(mb_list, process_output, forward_only=False) + + # Step 5: Collect gradients before optimizer.step() + gradients = {} + for name, param in get_named_parameters(engine.model, num_experts=None): + if param.requires_grad: + # Try to get gradient from param.grad or param.main_grad + grad = None + if hasattr(param, "main_grad") and param.main_grad is not None: + grad = param.main_grad.clone() + elif hasattr(param, "grad") and param.grad is not None: + grad = param.grad.clone() + else: + raise ValueError(f"No gradient found for {name}") + + if grad is not None: + if mpu.get_tensor_model_parallel_world_size() > 1: + raise NotImplementedError("TP gradients are not supported yet") + gradients[name] = grad + + return gradients + + +def categorize_op_name(name: str) -> str: + """Categorize operation name into op type. + + Args: + name: Parameter or activation name + + Returns: + Op type category: 'attention', 'mlp', 'layernorm', 'embedding', 'other' + """ + name_lower = name.lower() + if "attn" in name_lower or "attention" in name_lower: + if ( + "qkv" in name_lower + or "q_proj" in name_lower + or "k_proj" in name_lower + or "v_proj" in name_lower + ): + return "attention_proj" + elif ( + "linear_proj" in name_lower + or "o_proj" in name_lower + or "out_proj" in name_lower + ): + return "attention_out" + elif "core_attention" in name_lower: + return "attention_core" + else: + return "attention" + elif "mlp" in name_lower or "feedforward" in name_lower or "ffn" in name_lower: + if "activation" in name_lower: + return "mlp_activation" + elif "fc1" in name_lower or "gate" in name_lower or "up" in name_lower: + return "mlp_gate_up" + elif "fc2" in name_lower or "down" in name_lower: + return "mlp_down" + else: + return "mlp" + elif "rotary" in name_lower or "rope" in name_lower: + return "rope" + elif "layernorm" in name_lower or "norm" in name_lower: + # Distinguish Q/K layernorms from regular layernorms + if "q_layernorm" in name_lower or "k_layernorm" in name_lower: + return "qk_layernorm" + return "layernorm" + elif "embedding" in name_lower or "embed" in name_lower: + return "embedding" + else: + return "other" + + +def forward_backward_model_with_hooks( + engine: MegatronEngine, + input_: dict[str, Any], + layer_indices: list[int] | int = 0, +) -> tuple[ + torch.Tensor, + dict[str, torch.Tensor], + dict[str, torch.Tensor], + dict[str, torch.Tensor], +]: + """Perform forward and backward pass on model with specified layers and activation hooks. + + This function reduces the model to specified layers, then performs forward and backward + using the full model structure (embedding -> layers -> final_layernorm -> output_layer), + allowing for real loss computation. + + Args: + engine: MegatronEngine instance + input_: Input dictionary with 'input_ids', 'attention_mask', 'loss_mask' + layer_indices: Index or list of indices of layers to keep (0-based). + If int, keeps only that layer. If list, keeps multiple layers. + + Returns: + tuple: (logits, activations_dict, gradients_dict, output_gradients_dict) + - logits: Output logits from the model + - activations_dict: Dictionary mapping op names to their output activations + - gradients_dict: Dictionary mapping parameter names to their gradients + - output_gradients_dict: Dictionary mapping op names to their output gradients + """ + # Convert single int to list for consistency + if isinstance(layer_indices, int): + layer_indices = [layer_indices] + + # Reduce model to specified layers + _ = reduce_model_to_layers(engine, layer_indices) + + activations = {} + gradients = {} + output_gradients = {} # Gradients flowing back to module outputs + hooks = [] + + def make_activation_hook(name): + def hook(module, input, output): + try: + if isinstance(output, tuple): + activations[name] = ( + output[0].clone().detach() if len(output) > 0 else None + ) + else: + activations[name] = output.clone().detach() + logger.info( + f"Captured activation for {name}: {activations[name].dtype}" + ) + except Exception as e: + logger.warning(f"Failed to capture activation for {name}: {e}") + + return hook + + def make_input_pre_hook(activation_key: str, log_name: str): + """Create a pre-hook to capture module input.""" + + def input_hook(module, input): + try: + if isinstance(input, tuple): + activations[activation_key] = ( + input[0].clone().detach() if len(input) > 0 else None + ) + else: + activations[activation_key] = input.clone().detach() + logger.info( + f"Captured {log_name} input: {activations[activation_key].shape}" + ) + except Exception as e: + logger.warning(f"Failed to capture {log_name} input: {e}") + + return input_hook + + def make_output_grad_hook(grad_key: str, log_name: str): + """Create a backward hook to capture module output gradient.""" + + def backward_hook(module, grad_input, grad_output): + try: + if grad_output is not None and len(grad_output) > 0: + if grad_output[0] is not None: + output_gradients[grad_key] = grad_output[0].clone().detach() + logger.info( + f"Captured {log_name} output grad: {output_gradients[grad_key].shape}" + ) + except Exception as e: + logger.warning(f"Failed to capture {log_name} output grad: {e}") + + return backward_hook + + def register_layernorm_hooks( + module, layer_prefix: str, layernorm_name: str + ) -> list: + """Register input pre-hook and backward hook for a layernorm module.""" + registered_hooks = [] + activation_key = f"{layer_prefix}.self_attention.{layernorm_name}.input" + grad_key = f"{layer_prefix}.self_attention.{layernorm_name}.output_grad" + + pre_hook = module.register_forward_pre_hook( + make_input_pre_hook(activation_key, layernorm_name) + ) + registered_hooks.append(pre_hook) + + backward_hook = module.register_full_backward_hook( + make_output_grad_hook(grad_key, layernorm_name) + ) + registered_hooks.append(backward_hook) + + return registered_hooks + + # Get model and register hooks + model = get_model_from_engine(engine) + + # Register hooks for components + hook_names = [] + + # Embedding + if hasattr(model, "embedding"): + hook_names.append(("embedding", model.embedding)) + if hasattr(model.embedding, "word_embeddings"): + hook_names.append( + ("embedding.word_embeddings", model.embedding.word_embeddings) + ) + + # Rotary position embedding + if hasattr(model, "rotary_pos_emb"): + hook_names.append(("rotary_pos_emb", model.rotary_pos_emb)) + + # Decoder and layers + if hasattr(model, "decoder"): + decoder = model.decoder + hook_names.append(("decoder", decoder)) + + # Selected layers (after reduction) + if hasattr(decoder, "layers") and len(decoder.layers) > 0: + # Register hooks for each layer + for layer_idx_in_reduced, layer in enumerate(decoder.layers): + layer_prefix = f"layer_{layer_idx_in_reduced}" + + hook_names.append((f"{layer_prefix}", layer)) + + # Input layernorm + if hasattr(layer, "input_layernorm"): + hook_names.append( + (f"{layer_prefix}.input_layernorm", layer.input_layernorm) + ) + + # Self attention + if hasattr(layer, "self_attention"): + hook_names.append( + (f"{layer_prefix}.self_attention", layer.self_attention) + ) + if hasattr(layer.self_attention, "linear_qkv"): + hook_names.append( + ( + f"{layer_prefix}.self_attention.linear_qkv", + layer.self_attention.linear_qkv, + ) + ) + if hasattr(layer.self_attention, "linear_proj"): + hook_names.append( + ( + f"{layer_prefix}.self_attention.linear_proj", + layer.self_attention.linear_proj, + ) + ) + if hasattr(layer.self_attention, "core_attention"): + hook_names.append( + ( + f"{layer_prefix}.self_attention.core_attention", + layer.self_attention.core_attention, + ) + ) + # Register hooks for q_layernorm and k_layernorm + for layernorm_name in ["q_layernorm", "k_layernorm"]: + if hasattr(layer.self_attention, layernorm_name): + layernorm_module = getattr( + layer.self_attention, layernorm_name + ) + hook_names.append( + ( + f"{layer_prefix}.self_attention.{layernorm_name}", + layernorm_module, + ) + ) + hooks.extend( + register_layernorm_hooks( + layernorm_module, layer_prefix, layernorm_name + ) + ) + + # Post attention layernorm + if hasattr(layer, "post_attention_layernorm"): + hook_names.append( + ( + f"{layer_prefix}.post_attention_layernorm", + layer.post_attention_layernorm, + ) + ) + elif hasattr(layer, "pre_mlp_layernorm"): + hook_names.append( + (f"{layer_prefix}.pre_mlp_layernorm", layer.pre_mlp_layernorm) + ) + + # MLP + if hasattr(layer, "mlp"): + hook_names.append((f"{layer_prefix}.mlp", layer.mlp)) + if hasattr(layer.mlp, "linear_fc1"): + hook_names.append( + (f"{layer_prefix}.mlp.linear_fc1", layer.mlp.linear_fc1) + ) + if hasattr(layer.mlp, "linear_fc2"): + hook_names.append( + (f"{layer_prefix}.mlp.linear_fc2", layer.mlp.linear_fc2) + ) + + # Add pre-hook to capture activation output + if hasattr(layer.mlp, "linear_fc2"): + activation_key = f"{layer_prefix}.mlp.activation_output" + activation_hook = ( + layer.mlp.linear_fc2.register_forward_pre_hook( + make_input_pre_hook(activation_key, "MLP activation") + ) + ) + hooks.append(activation_hook) + + # Final layernorm + if hasattr(decoder, "final_layernorm"): + hook_names.append(("decoder.final_layernorm", decoder.final_layernorm)) + + # Output layer + if hasattr(model, "output_layer"): + hook_names.append(("output_layer", model.output_layer)) + + # Register forward hooks and backward hooks for all modules + for name, module in hook_names: + try: + # Register forward hook + hook = module.register_forward_hook(make_activation_hook(name)) + hooks.append(hook) + + # Register backward hook to capture output gradients + def make_backward_hook(hook_name): + def backward_hook(module, grad_input, grad_output): + try: + if grad_output is not None and len(grad_output) > 0: + if grad_output[0] is not None: + output_gradients[f"{hook_name}.output_grad"] = ( + grad_output[0].clone().detach() + ) + logger.debug( + f"Captured output grad for {hook_name}: {output_gradients[f'{hook_name}.output_grad'].shape}" + ) + except Exception as e: + logger.warning( + f"Failed to capture output grad for {hook_name}: {e}" + ) + + return backward_hook + + backward_hook = module.register_full_backward_hook(make_backward_hook(name)) + hooks.append(backward_hook) + except Exception as e: + logger.warning(f"Failed to register hook for {name}: {e}") + + # Forward and backward using engine's train_batch method + engine.train() + + # Prepare loss function + def sft_loss_fn(logprobs, entropy, input_): + del entropy + loss_mask = input_["loss_mask"].bool() + loss_mask = torch.roll(loss_mask, shifts=-1, dims=-1) + logprobs = torch.where(loss_mask, logprobs, 0) + device = logprobs.device + num_valid = loss_mask.count_nonzero() + if num_valid == 0: + return torch.tensor(0.0, device=device, requires_grad=True) + loss = -logprobs.sum() / num_valid + return loss + + def loss_weight_fn(mb): + return mb["loss_mask"].count_nonzero() + + # Use engine's train_batch but collect gradients before optimizer step + engine.optimizer.zero_grad() + for model_chunk in engine.model: + model_chunk.zero_grad_buffer() + + # Forward and backward + engine.train_batch(input_, sft_loss_fn, loss_weight_fn) + + # Collect gradients from all components (focusing on the selected layers) + model = get_model_from_engine(engine) + + # Collect gradients from all selected layers + if ( + hasattr(model, "decoder") + and hasattr(model.decoder, "layers") + and len(model.decoder.layers) > 0 + ): + for layer_idx_in_reduced, layer in enumerate(model.decoder.layers): + layer_prefix = f"layer_{layer_idx_in_reduced}" + for name, param in layer.named_parameters(): + if param.requires_grad: + grad = None + if hasattr(param, "main_grad") and param.main_grad is not None: + grad = param.main_grad.clone().detach() + elif hasattr(param, "grad") and param.grad is not None: + grad = param.grad.clone().detach() + else: + raise ValueError(f"No gradient found for {layer_prefix}.{name}") + + if grad is not None: + # Use layer_X. prefix to match activation naming + gradients[f"{layer_prefix}.{name}"] = grad + else: + logger.warning(f"No gradient found for {layer_prefix}.{name}") + + # Get logits by doing a forward pass + engine.eval() + logits = engine.forward(input_) + + # Remove hooks + for hook in hooks: + hook.remove() + + return logits, activations, gradients, output_gradients diff --git a/areal/tests/test_fp8_bf16_comparison.py b/areal/tests/test_fp8_bf16_comparison.py new file mode 100644 index 000000000..477da979a --- /dev/null +++ b/areal/tests/test_fp8_bf16_comparison.py @@ -0,0 +1,706 @@ +"""Test comparison between FP8 and BF16 models using Megatron Engine. + +This test verifies: +1. Load FP8 model with fp8_param enabled and BF16 model using Megatron Engine +2. Compare logprobs from forward pass +3. Compare logits from forward pass +""" + +import re +from datetime import datetime +from pathlib import Path +from typing import Any + +import pytest +import torch +import torch.distributed as dist +import torch.nn.functional as F +from transformers import AutoTokenizer + +from areal.platforms import current_platform +from areal.tests.fp8.comparison_utils import ( + compare_logits, + compare_tensors_dict, + log_problematic_operations, +) +from areal.tests.fp8.engine_utils import ( + create_engine, + decode_with_megatron_forward, + forward_with_logits_and_logprobs, +) +from areal.tests.fp8.model_hooks import ( + collect_gradients_after_train_batch, + forward_backward_model_with_hooks, +) +from areal.tests.utils import get_model_path +from areal.utils import logging + +MODEL_PATH_BF16 = get_model_path( + "/storage/openpsi/models/Qwen__Qwen3-0.6B/", "Qwen/Qwen3-0.6B" +) +MODEL_PATH_FP8 = get_model_path( + "/storage/openpsi/models/Qwen__Qwen3-0.6B-FP8/", "Qwen/Qwen3-0.6B-FP8" +) + +logger = logging.getLogger("FP8 BF16 Comparison Test") + + +@pytest.fixture(scope="module") +def fixed_input( + questions: list[str] | None = None, + answers: list[str] | None = None, + model_path: str = MODEL_PATH_BF16, + device=current_platform.device_type, +) -> dict[str, Any]: + """Create fixed input data for SFT training with question and answer. + + Args: + questions: List of question strings. If None, uses default questions. + answers: List of answer strings. If None, uses default answers. + model_path: Path to the model for loading tokenizer. + device: Device to place tensors on. + + Returns: + Dictionary with 'input_ids', 'attention_mask', and 'loss_mask' tensors. + loss_mask: 0 for prompt tokens, 1 for answer tokens (including EOS). + """ + if questions is None: + questions = [ + "What is 2+2?", + "Count from 1 to 5:", + ] + if answers is None: + answers = [ + " 2+2 equals 4.", + " 1, 2, 3, 4, 5", + ] + + assert len(questions) == len(answers), ( + "Questions and answers must have the same length" + ) + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + eos_token = tokenizer.eos_token if tokenizer.eos_token else "" + pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 + + input_ids_list = [] + loss_mask_list = [] + + for question, answer in zip(questions, answers): + # Encode full sequence: question + answer + eos_token + full_text = question + answer + eos_token + seq_token = tokenizer.encode(full_text, add_special_tokens=False) + + # Encode prompt (question) to determine where loss_mask should start + prompt_token = tokenizer.encode(question, add_special_tokens=False) + + # Create loss_mask: 0 for prompt, 1 for answer (including EOS) + loss_mask = [0] * len(prompt_token) + [1] * (len(seq_token) - len(prompt_token)) + + input_ids_list.append(torch.tensor(seq_token, dtype=torch.long, device=device)) + loss_mask_list.append(torch.tensor(loss_mask, dtype=torch.long, device=device)) + + # Pad to same length + max_length = max(ids.shape[0] for ids in input_ids_list) + + padded_input_ids = [] + padded_loss_masks = [] + attention_masks = [] + + for input_ids, loss_mask in zip(input_ids_list, loss_mask_list): + seq_len = input_ids.shape[0] + padding_length = max_length - seq_len + + padded_ids = F.pad(input_ids, (0, padding_length), value=pad_token_id) + padded_input_ids.append(padded_ids) + + padded_loss_mask = F.pad(loss_mask, (0, padding_length), value=0) + padded_loss_masks.append(padded_loss_mask) + + attention_mask = F.pad( + torch.ones(seq_len, dtype=torch.bool, device=device), + (0, padding_length), + value=0, + ) + attention_masks.append(attention_mask) + + # Stack into batch + input_ids = torch.stack(padded_input_ids).to(device) + loss_mask = torch.stack(padded_loss_masks).to(device) + attention_mask = torch.stack(attention_masks).to(device) + + logger.info("Using fixed input:") + for i, (q, a) in enumerate(zip(questions, answers)): + logger.info(f" Sample {i}:") + logger.info(f" Question: {q}") + logger.info(f" Answer: {a}") + logger.info(f" Input IDs shape: {input_ids[i].shape}") + logger.info(f" Loss mask sum: {loss_mask[i].sum().item()}") + + return dict( + input_ids=input_ids, + attention_mask=attention_mask, + loss_mask=loss_mask, + ) + + +@pytest.fixture(scope="module") +def engine_bf16(): + engine = create_engine( + MODEL_PATH_BF16, fp8_enabled=False, fp8_param=False, port=7777 + ) + try: + yield engine + finally: + engine.destroy() + if dist.is_initialized(): + dist.destroy_process_group() + + +@pytest.fixture(scope="module") +def engine_fp8(): + engine = create_engine(MODEL_PATH_FP8, fp8_enabled=True, fp8_param=True, port=7778) + try: + yield engine + finally: + engine.destroy() + if dist.is_initialized(): + dist.destroy_process_group() + + +def test_megatron_decode_output(engine_bf16, engine_fp8): + """Test decode using Megatron forward pass and print model output.""" + # Test prompts + test_prompts = [ + "What is 2+2?", + "The capital of France is", + "Once upon a time", + ] + + top_k = None + temperature = 0.7 + max_new_tokens = 100 + + logger.info("=" * 80) + logger.info("Testing decode with BF16 model") + logger.info("=" * 80) + + for prompt in test_prompts: + logger.info(f"{'=' * 80}") + logger.info(f"Prompt: {prompt}") + logger.info(f"{'=' * 80}") + generated = decode_with_megatron_forward( + engine_bf16, + prompt, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k, + ) + logger.info(f"BF16 Final output: {generated}\n") + + logger.info("=" * 80) + logger.info("Testing decode with FP8 model") + logger.info("=" * 80) + + for prompt in test_prompts: + logger.info(f"{'=' * 80}") + logger.info(f"Prompt: {prompt}") + logger.info(f"{'=' * 80}") + generated = decode_with_megatron_forward( + engine_fp8, + prompt, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k, + ) + logger.info(f"FP8 Final output: {generated}\n") + + +def test_fp8_bf16_logits_logprobs_comparison(fixed_input, engine_bf16, engine_fp8): + """Compare both logits and logprobs between FP8 and BF16 models.""" + logits_bf16, logprobs_bf16 = forward_with_logits_and_logprobs( + engine_bf16, fixed_input + ) + + logits_fp8, logprobs_fp8 = forward_with_logits_and_logprobs(engine_fp8, fixed_input) + + # Get attention mask to filter out padding positions + attention_mask = fixed_input["attention_mask"] # [batch, seq_len] + + # Compare logprobs first + assert logprobs_bf16.shape == logprobs_fp8.shape, "Logprob shapes don't match" + # Only compute differences for non-padding positions + valid_logprobs_mask = attention_mask # [batch, seq_len] + logprob_diff = (logprobs_bf16 - logprobs_fp8).abs() + logprob_max_diff = (logprob_diff * valid_logprobs_mask).max().item() + logprob_mean_diff = ( + logprob_diff * valid_logprobs_mask + ).sum().item() / valid_logprobs_mask.sum().item() + logger.info( + f"Logprob comparison (non-padding only): max_diff={logprob_max_diff:.6f}, mean_diff={logprob_mean_diff:.6f}" + ) + + # Compare logits + assert logits_bf16.shape == logits_fp8.shape, "Logits shapes don't match" + # Only compute differences for non-padding positions + valid_logits_mask = attention_mask.unsqueeze(-1) # [batch, seq_len, 1] + logits_diff = (logits_bf16 - logits_fp8).abs() + logits_max_diff = (logits_diff * valid_logits_mask).max().item() + logits_mean_diff = ( + logits_diff * valid_logits_mask + ).sum().item() / valid_logits_mask.sum().item() + logger.info( + f"Logits comparison (non-padding only): max_diff={logits_max_diff:.6f}, mean_diff={logits_mean_diff:.6f}" + ) + + # Sequence-level and token-level cosine similarity (only use valid tokens) + batch_size, seq_len = attention_mask.shape + + # Collect data for both sequence-level and token-level cosine similarity + cos_sim_logprobs_seq_list = [] + cos_sim_logits_seq_list = [] + logprobs_bf16_valid_list = [] + logprobs_fp8_valid_list = [] + logits_bf16_valid_list = [] + logits_fp8_valid_list = [] + + for i in range(batch_size): + valid_mask_i = attention_mask[i] # [seq_len] + valid_indices = valid_mask_i.nonzero(as_tuple=False).squeeze(-1) # [num_valid] + + # Extract valid token positions: [num_valid, vocab_size] + logprobs_bf16_valid = logprobs_bf16[i, valid_indices] # [num_valid, vocab_size] + logprobs_fp8_valid = logprobs_fp8[i, valid_indices] # [num_valid, vocab_size] + logits_bf16_valid = logits_bf16[i, valid_indices] # [num_valid, vocab_size] + logits_fp8_valid = logits_fp8[i, valid_indices] # [num_valid, vocab_size] + + # For sequence-level: flatten valid tokens: [num_valid * vocab_size] + logprobs_bf16_flat = logprobs_bf16_valid.flatten() + logprobs_fp8_flat = logprobs_fp8_valid.flatten() + logits_bf16_flat = logits_bf16_valid.flatten() + logits_fp8_flat = logits_fp8_valid.flatten() + + # Compute sequence-level cosine similarity for this sample + cos_sim_logprobs_i = F.cosine_similarity( + logprobs_bf16_flat.unsqueeze(0), logprobs_fp8_flat.unsqueeze(0), dim=1 + ).item() + cos_sim_logits_i = F.cosine_similarity( + logits_bf16_flat.unsqueeze(0), logits_fp8_flat.unsqueeze(0), dim=1 + ).item() + + cos_sim_logprobs_seq_list.append(cos_sim_logprobs_i) + cos_sim_logits_seq_list.append(cos_sim_logits_i) + + # For token-level: collect individual valid tokens + logprobs_bf16_valid_list.append(logprobs_bf16_valid) # [num_valid, vocab_size] + logprobs_fp8_valid_list.append(logprobs_fp8_valid) # [num_valid, vocab_size] + logits_bf16_valid_list.append(logits_bf16_valid) # [num_valid, vocab_size] + logits_fp8_valid_list.append(logits_fp8_valid) # [num_valid, vocab_size] + + # Sequence-level statistics + cos_sim_logprobs_seq_mean = sum(cos_sim_logprobs_seq_list) / len( + cos_sim_logprobs_seq_list + ) + cos_sim_logits_seq_mean = sum(cos_sim_logits_seq_list) / len( + cos_sim_logits_seq_list + ) + + logger.info( + f"Seq cosine similarity of logprobs (valid tokens only): {cos_sim_logprobs_seq_mean}" + ) + logger.info( + f"Seq cosine similarity of logits (valid tokens only): {cos_sim_logits_seq_mean}" + ) + + # Stack token-level tensors: [num_valid_tokens, vocab_size] + logprobs_bf16_valid = torch.cat(logprobs_bf16_valid_list, dim=0) + logprobs_fp8_valid = torch.cat(logprobs_fp8_valid_list, dim=0) + logits_bf16_valid = torch.cat(logits_bf16_valid_list, dim=0) + logits_fp8_valid = torch.cat(logits_fp8_valid_list, dim=0) + + # Compute cosine similarity only for valid tokens + cos_sim_logprobs_valid = F.cosine_similarity( + logprobs_bf16_valid, logprobs_fp8_valid, dim=-1 + ) # [num_valid_tokens] + cos_sim_logits_valid = F.cosine_similarity( + logits_bf16_valid, logits_fp8_valid, dim=-1 + ) # [num_valid_tokens] + + cos_sim_logprobs_mean = cos_sim_logprobs_valid.mean().item() + cos_sim_logprobs_min = cos_sim_logprobs_valid.min().item() + cos_sim_logprobs_max = cos_sim_logprobs_valid.max().item() + + cos_sim_logits_mean = cos_sim_logits_valid.mean().item() + cos_sim_logits_min = cos_sim_logits_valid.min().item() + cos_sim_logits_max = cos_sim_logits_valid.max().item() + + logger.info( + f"Token cosine similarity of logprobs (valid tokens only): {cos_sim_logprobs_mean}" + ) + logger.info( + f"Token cosine similarity of logits (valid tokens only): {cos_sim_logits_mean}" + ) + logger.info( + f"Token cosine similarity of logprobs - min: {cos_sim_logprobs_min:.6f}, max: {cos_sim_logprobs_max:.6f}" + ) + logger.info( + f"Token cosine similarity of logits - min: {cos_sim_logits_min:.6f}, max: {cos_sim_logits_max:.6f}" + ) + + if cos_sim_logprobs_mean < 0.99: + raise AssertionError( + f"Token cosine similarity of logprobs is less than 0.99: {cos_sim_logprobs_mean}" + ) + if cos_sim_logits_mean < 0.99: + raise AssertionError( + f"Token cosine similarity of logits is less than 0.99: {cos_sim_logits_mean}" + ) + + +def test_fp8_bf16_gradient_comparison(fixed_input, engine_bf16, engine_fp8): + """Compare gradients between FP8 and BF16 models after train_batch. + + This test verifies that gradients computed from FP8 and BF16 models + are consistent across all layers. + """ + engine_bf16.train() + gradients_bf16 = collect_gradients_after_train_batch(engine_bf16, fixed_input) + logger.info(f"BF16 model: collected {len(gradients_bf16)} parameter gradients") + + engine_fp8.train() + gradients_fp8 = collect_gradients_after_train_batch(engine_fp8, fixed_input) + logger.info(f"FP8 model: collected {len(gradients_fp8)} parameter gradients") + + # Compare gradients + assert len(gradients_bf16) == len(gradients_fp8), ( + f"Number of parameters with gradients don't match: " + f"BF16={len(gradients_bf16)}, FP8={len(gradients_fp8)}" + ) + + # Get common parameter names + common_names = set(gradients_bf16.keys()) & set(gradients_fp8.keys()) + logger.info(f"Comparing {len(common_names)} common parameters") + + # Statistics for all layers + all_max_diffs = [] + all_mean_diffs = [] + all_cos_sims = [] + layer_stats = [] + + for name in sorted(common_names): + grad_bf16 = gradients_bf16[name] + grad_fp8 = gradients_fp8[name] + + # Check shapes match + assert grad_bf16.shape == grad_fp8.shape, ( + f"Gradient shapes don't match for {name}: " + f"BF16={grad_bf16.shape}, FP8={grad_fp8.shape}" + ) + + # Compute differences + grad_diff = (grad_bf16 - grad_fp8).abs() + max_diff = grad_diff.max().item() + mean_diff = grad_diff.mean().item() + + # Compute cosine similarity + grad_bf16_flat = grad_bf16.flatten() + grad_fp8_flat = grad_fp8.flatten() + cos_sim = F.cosine_similarity( + grad_bf16_flat.unsqueeze(0), grad_fp8_flat.unsqueeze(0), dim=1 + ).item() + + all_max_diffs.append(max_diff) + all_mean_diffs.append(mean_diff) + all_cos_sims.append(cos_sim) + + # Extract layer index from parameter name for grouping + layer_match = re.search(r"layers\.(\d+)", name) + layer_idx = int(layer_match.group(1)) if layer_match else -1 + + layer_stats.append( + { + "name": name, + "layer_idx": layer_idx, + "max_diff": max_diff, + "mean_diff": mean_diff, + "cos_sim": cos_sim, + "shape": grad_bf16.shape, + } + ) + + # Log statistics by layer + layer_indices = sorted( + set(s["layer_idx"] for s in layer_stats if s["layer_idx"] >= 0) + ) + for layer_idx in layer_indices: + layer_grads = [s for s in layer_stats if s["layer_idx"] == layer_idx] + layer_max_diffs = [s["max_diff"] for s in layer_grads] + layer_mean_diffs = [s["mean_diff"] for s in layer_grads] + layer_cos_sims = [s["cos_sim"] for s in layer_grads] + + logger.info( + f"Layer {layer_idx}: " + f"max_diff={max(layer_max_diffs):.6f}, " + f"mean_diff={sum(layer_mean_diffs) / len(layer_mean_diffs):.6f}, " + f"cos_sim={sum(layer_cos_sims) / len(layer_cos_sims):.6f}, " + f"n_params={len(layer_grads)}, " + f"names={','.join([s['name'] for s in layer_grads])}" + ) + # log lay_idx < 0 + layer_stats_less_than_0 = [s for s in layer_stats if s["layer_idx"] < 0] + logger.info(f"Do not have layer indices: {len(layer_stats_less_than_0)} params") + for stat in layer_stats_less_than_0: + name_str = f"Layer {stat['name']}" + logger.info( + f"{name_str:<70} " + f"max_diff={stat['max_diff']:>12.6f}, " + f"mean_diff={stat['mean_diff']:>12.6f}, " + f"cos_sim={stat['cos_sim']:>10.6f}" + ) + + # Overall statistics + overall_max_diff = max(all_max_diffs) + overall_mean_diff = sum(all_mean_diffs) / len(all_mean_diffs) + overall_cos_sim = sum(all_cos_sims) / len(all_cos_sims) + overall_min_cos_sim = min(all_cos_sims) + + logger.info("=" * 80) + logger.info("Overall gradient comparison statistics:") + logger.info(f" Max difference: {overall_max_diff:.6f}") + logger.info(f" Mean difference: {overall_mean_diff:.6f}") + logger.info(f" Mean cosine similarity: {overall_cos_sim:.6f}") + logger.info(f" Min cosine similarity: {overall_min_cos_sim:.6f}") + logger.info("=" * 80) + + # Log parameters with lowest cosine similarity + layer_stats_sorted = sorted(layer_stats, key=lambda x: x["cos_sim"], reverse=False) + logger.info("Top 10 parameters with lowest gradient cosine similarity:") + for i, stat in enumerate(layer_stats_sorted[:10]): + logger.info( + f" {i + 1}. {stat['name']}: " + f"cos_sim={stat['cos_sim']:.6f}, " + f"max_diff={stat['max_diff']:.6f}, " + f"mean_diff={stat['mean_diff']:.6f}" + ) + + # Assertions - allow some tolerance for FP8 quantization + assert overall_cos_sim > 0.94, ( + f"Overall cosine similarity too low: {overall_cos_sim:.6f}. " + f"This suggests gradients are not consistent between BF16 and FP8 models." + ) + assert overall_min_cos_sim > 0.60, ( + f"Minimum cosine similarity too low: {overall_min_cos_sim:.6f}. " + f"Some parameters have very different gradients." + ) + + +@pytest.mark.skip(reason="This test is only for debugging") +def test_profile_gemm_kernels(fixed_input, engine_bf16, engine_fp8): + """Profile and print GEMM kernels used in forward and backward pass. + + This test profiles the GEMM kernels (matrix multiplication operations) used + during forward and backward passes to understand which operators are being used. + """ + logger.info("=" * 80) + logger.info("Profiling GEMM kernels - BF16 Model") + logger.info("=" * 80) + + # Profile forward pass + logger.info("\n>>> Profiling FORWARD pass...") + logits_bf16, logprobs_bf16 = forward_with_logits_and_logprobs( + engine_bf16, fixed_input, profile_gemm=True + ) + + # Profile backward pass + logger.info("\n>>> Profiling BACKWARD pass...") + engine_bf16.train() + gradients_bf16 = collect_gradients_after_train_batch( + engine_bf16, fixed_input, profile_gemm=True + ) + logger.info(f"Collected {len(gradients_bf16)} parameter gradients") + + logger.info("\n" + "=" * 80) + logger.info("Profiling GEMM kernels - FP8 Model") + logger.info("=" * 80) + + # Profile forward pass + logger.info("\n>>> Profiling FORWARD pass...") + logits_fp8, logprobs_fp8 = forward_with_logits_and_logprobs( + engine_fp8, fixed_input, profile_gemm=True + ) + + # Profile backward pass + logger.info("\n>>> Profiling BACKWARD pass...") + engine_fp8.train() + gradients_fp8 = collect_gradients_after_train_batch( + engine_fp8, fixed_input, profile_gemm=True + ) + logger.info(f"Collected {len(gradients_fp8)} parameter gradients") + + +def test_fp8_bf16_partial_layers_comparison( + fixed_input, engine_bf16, engine_fp8, save_data: bool = False +): + """Compare FP8 and BF16 on a model reduced to specified layers. + + This test reduces the model to specified transformer layers while keeping the full + structure (embedding, rotary_pos_emb, final_layernorm, output_layer), performs + forward and backward with real loss computation, and compares activations and + gradients between FP8 and BF16 models to identify which operations have precision issues. + """ + # Test specific layers - can be a single layer index or a list of indices + layer_indices = list( + range(2) + ) # Test the first layer, or use [0, 1] to test first two layers + + logger.info("=" * 80) + logger.info(f"Testing model with layers {layer_indices} - BF16 Model") + logger.info("=" * 80) + + # Forward and backward on model with specified layers + logits_bf16, activations_bf16, gradients_bf16, output_gradients_bf16 = ( + forward_backward_model_with_hooks( + engine_bf16, + fixed_input, + layer_indices=layer_indices, + ) + ) + + logger.info(f"BF16 - Logits shape: {logits_bf16.shape}") + logger.info(f"BF16 - Collected {len(activations_bf16)} activations") + logger.info(f"BF16 - Collected {len(gradients_bf16)} gradients") + logger.info(f"BF16 - Collected {len(output_gradients_bf16)} output gradients") + + logger.info("\n" + "=" * 80) + logger.info(f"Testing model with layers {layer_indices} - FP8 Model") + logger.info("=" * 80) + + # Forward and backward on model with specified layers + logits_fp8, activations_fp8, gradients_fp8, output_gradients_fp8 = ( + forward_backward_model_with_hooks( + engine_fp8, + fixed_input, + layer_indices=layer_indices, + ) + ) + + logger.info(f"FP8 - Logits shape: {logits_fp8.shape}") + logger.info(f"FP8 - Collected {len(activations_fp8)} activations") + logger.info(f"FP8 - Collected {len(gradients_fp8)} gradients") + logger.info(f"FP8 - Collected {len(output_gradients_fp8)} output gradients") + + # Compare logits + compare_logits(logits_bf16, logits_fp8) + + # Compare activations by op type + activation_comparison = compare_tensors_dict( + activations_bf16, + activations_fp8, + title="Activation Comparison", + check_nan_inf=False, + check_zero_norm=False, + group_by_op_type=True, + name_width=50, + ) + + # Compare gradients by op type + gradient_comparison = compare_tensors_dict( + gradients_bf16, + gradients_fp8, + title="Gradient Comparison", + check_nan_inf=True, + check_zero_norm=True, + group_by_op_type=True, + name_width=80, + ) + + # Compare output gradients by op type + output_gradient_comparison = compare_tensors_dict( + output_gradients_bf16, + output_gradients_fp8, + title="Output Gradient Comparison", + check_nan_inf=False, + check_zero_norm=False, + group_by_op_type=True, + name_width=80, + ) + + # Log problematic operations + log_problematic_operations( + activation_comparison["stats_by_type"], + threshold=0.95, + title="Problematic Activations", + ) + log_problematic_operations( + gradient_comparison["stats_by_type"], + threshold=0.95, + title="Problematic Gradients", + ) + log_problematic_operations( + output_gradient_comparison["stats_by_type"], + threshold=0.95, + title="Problematic Output Gradients", + ) + + if save_data: + # Save q_layernorm and k_layernorm inputs and output gradients for separate testing + layernorm_inputs_bf16 = {} + layernorm_inputs_fp8 = {} + layernorm_output_grads_bf16 = {} + layernorm_output_grads_fp8 = {} + for name in activations_bf16.keys(): + if name.endswith(".q_layernorm.input") or name.endswith( + ".k_layernorm.input" + ): + layernorm_inputs_bf16[name] = activations_bf16[name] + for name in activations_fp8.keys(): + if name.endswith(".q_layernorm.input") or name.endswith( + ".k_layernorm.input" + ): + layernorm_inputs_fp8[name] = activations_fp8[name] + for name in output_gradients_bf16.keys(): + if name.endswith(".q_layernorm.output_grad") or name.endswith( + ".k_layernorm.output_grad" + ): + layernorm_output_grads_bf16[name] = output_gradients_bf16[name] + for name in output_gradients_fp8.keys(): + if name.endswith(".q_layernorm.output_grad") or name.endswith( + ".k_layernorm.output_grad" + ): + layernorm_output_grads_fp8[name] = output_gradients_fp8[name] + + if layernorm_inputs_bf16 or layernorm_inputs_fp8: + logger.info("\n" + "=" * 80) + logger.info("Found layernorm inputs for separate testing") + logger.info(f"BF16 layernorm inputs: {list(layernorm_inputs_bf16.keys())}") + logger.info(f"FP8 layernorm inputs: {list(layernorm_inputs_fp8.keys())}") + logger.info( + f"BF16 layernorm output grads: {list(layernorm_output_grads_bf16.keys())}" + ) + logger.info( + f"FP8 layernorm output grads: {list(layernorm_output_grads_fp8.keys())}" + ) + logger.info("=" * 80) + + # Save activation inputs to files + save_dir = Path("activation_inputs") + save_dir.mkdir(exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + # Also save a combined file with metadata + combined_data = { + "bf16_inputs": layernorm_inputs_bf16, + "fp8_inputs": layernorm_inputs_fp8, + "bf16_output_grads": layernorm_output_grads_bf16, + "fp8_output_grads": layernorm_output_grads_fp8, + "timestamp": timestamp, + "layer_indices": layer_indices, + } + combined_save_path = save_dir / f"layernorm_inputs_combined_{timestamp}.pt" + torch.save(combined_data, combined_save_path) + logger.info(f"Saved combined layernorm inputs to: {combined_save_path}") + logger.info( + f" Total size: {combined_save_path.stat().st_size / 1024 / 1024:.2f} MB" + ) + + logger.info("=" * 80) diff --git a/areal/tests/test_fp8_conversion.py b/areal/tests/test_fp8_conversion.py new file mode 100644 index 000000000..9537e2be1 --- /dev/null +++ b/areal/tests/test_fp8_conversion.py @@ -0,0 +1,298 @@ +"""Test FP8 conversion and matrix multiplication correctness. + +This test verifies: +1. BF16 matrix multiplication baseline +2. BF16 -> TE Blockwise FP8 -> FP8 GEMM -> BF16 comparison +""" + +import pytest +import torch + +from areal.platforms import current_platform +from areal.utils import logging + +logger = logging.getLogger("Test FP8 Conversion") + +try: + import transformer_engine.pytorch as te + import transformer_engine_torch as tex + from transformer_engine.common import recipe + from transformer_engine.pytorch.cpp_extensions import general_gemm + from transformer_engine.pytorch.tensor import ( + Float8BlockQuantizer, + Float8BlockwiseQTensor, + ) +except ImportError as e: + logger.warning( + f"transformer_engine not available: {e}. " + "Skipping all FP8 conversion tests. " + "To run FP8 tests, please install transformer_engine.", + ) + pytestmark = pytest.mark.skip( + reason="transformer_engine is required for FP8 tests. " + "Please install transformer_engine to run these tests." + ) + # Set dummy values to avoid NameError + te = None + tex = None + recipe = None + general_gemm = None + Float8BlockQuantizer = None + Float8BlockwiseQTensor = None + + +def high_precision_to_te_blockwise_fp8( + tensor: torch.Tensor, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + *, + rowwise: bool = True, + columnwise: bool = False, + block_scaling_dim: int = 2, + amax_epsilon: float = 0.0, + force_pow_2_scales: bool = True, +) -> Float8BlockwiseQTensor: + """ + Quantize high precision tensor to TE Blockwise FP8 tensor. + + Args: + tensor: High precision tensor (float32, float16, bfloat16, etc.) + fp8_dtype: TE FP8 format + rowwise: Whether to use rowwise data layout + columnwise: Whether to use columnwise data layout + block_scaling_dim: Block scaling dimension (1 or 2) + amax_epsilon: Epsilon for amax computation + force_pow_2_scales: Whether to force power-of-2 scales + + Returns: + Float8BlockwiseQTensor: TE Blockwise FP8 tensor + """ + # Create Float8BlockQuantizer + # Note: Always set both rowwise and columnwise to True to allow GEMM to choose the best layout + # This matches the test pattern in TransformerEngine tests + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, # Always enable rowwise + columnwise=True, # Always enable columnwise for flexibility + amax_epsilon=amax_epsilon, + force_pow_2_scales=force_pow_2_scales, + block_scaling_dim=block_scaling_dim, + ) + + # Check if tensor can be quantized (needs to satisfy block size requirements) + if not quantizer.is_quantizable(tensor): + raise ValueError( + f"Tensor shape {tensor.shape} cannot be quantized with block size {quantizer.block_len}. " + f"Both dimensions must be multiples of {quantizer.block_len}." + ) + + # Quantize tensor + te_blockwise_fp8_tensor = quantizer(tensor) + + return te_blockwise_fp8_tensor + + +def _log_tensor_comparison( + tensor1: torch.Tensor, + tensor2: torch.Tensor, + name: str, + max_threshold: float | None = None, + mean_threshold: float | None = None, +) -> tuple[float, float]: + """Compare two tensors and log the differences. + + Args: + tensor1: First tensor + tensor2: Second tensor + name: Name for logging + max_threshold: Optional threshold for max difference + mean_threshold: Optional threshold for mean difference + + Returns: + Tuple of (max_diff, mean_diff) + """ + max_diff = (tensor1 - tensor2).abs().max().item() + mean_diff = (tensor1 - tensor2).abs().mean().item() + logger.info(f"{name}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + + if max_threshold is not None: + assert max_diff < max_threshold, f"{name} max difference too large: {max_diff}" + + if mean_threshold is not None: + assert mean_diff < mean_threshold, ( + f"{name} mean difference too large: {mean_diff}" + ) + + return max_diff, mean_diff + + +def _create_test_tensors( + device: torch.device, M: int = 256, K: int = 512, N: int = 128 +): + """Create test tensors for matrix multiplication. + + Args: + device: Device to create tensors on + M: First dimension of A + K: Shared dimension + N: Second dimension of B + + Returns: + Tuple of (a_bf16, b_bf16) where A is [M, K] and B is [K, N] + """ + torch.manual_seed(42) + a_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16) + b_bf16 = torch.randn(K, N, device=device, dtype=torch.bfloat16) + return a_bf16, b_bf16 + + +def _perform_fp8_gemm( + weight_fp8: Float8BlockwiseQTensor, + input_fp8: Float8BlockwiseQTensor, + output_shape: tuple[int, ...], + device: torch.device, + layout: str = "TN", +) -> torch.Tensor: + """Perform FP8 GEMM using general_gemm. + + Args: + weight_fp8: Weight tensor in FP8 format [N, K] or [K, N] + input_fp8: Input tensor in FP8 format [M, K] or [K, M] + output_shape: Output shape [M, N] + device: Device to perform computation on + layout: GEMM layout ("TN", "NN", etc.) + + Returns: + Result tensor in BF16 format [M, N] + """ + result = torch.empty(output_shape, device=device, dtype=torch.bfloat16) + workspace = torch.empty(32 * 1024 * 1024 + 1024, dtype=torch.uint8, device=device) + + result, *_ = general_gemm( + weight_fp8, + input_fp8, + workspace, + out_dtype=torch.bfloat16, + layout=layout, + out=result, + use_split_accumulator=False, + ) + + return result + + +@pytest.fixture +def device(): + return torch.device(current_platform.device_type) + + +@pytest.fixture +def test_tensors(device): + """Fixture for test tensors.""" + return _create_test_tensors(device) + + +def test_te_fp8_gemm_vs_bf16(test_tensors, device): + """Test BF16 -> TE Blockwise FP8 -> FP8 GEMM -> BF16 comparison.""" + a_bf16, b_bf16 = test_tensors + M, _, N = a_bf16.shape[0], a_bf16.shape[1], b_bf16.shape[1] + + # BF16 baseline + result_bf16 = torch.matmul(a_bf16, b_bf16) + + # Convert A to TE Blockwise FP8 with 1D scaling (input pattern) + a_te_fp8 = high_precision_to_te_blockwise_fp8( + a_bf16, + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + block_scaling_dim=1, # 1D scaling for input + ) + + # Transpose B to match Linear layer weight format [N, K] + b_bf16_t = b_bf16.t().contiguous() + b_te_fp8 = high_precision_to_te_blockwise_fp8( + b_bf16_t, + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + block_scaling_dim=2, # 2D scaling for weight + ) + + # Perform FP8 GEMM + result_fp8 = _perform_fp8_gemm(b_te_fp8, a_te_fp8, (M, N), device, layout="TN") + + # Compare with baseline (allowing for quantization error) + _log_tensor_comparison( + result_bf16, + result_fp8, + "TE FP8 GEMM vs BF16 baseline", + max_threshold=10.0, + mean_threshold=1.0, + ) + + +def test_te_linear_autocast_vs_bf16(test_tensors): + """Test TransformerEngine Linear with autocast FP8 vs BF16.""" + a_bf16, b_bf16 = test_tensors + K, N = a_bf16.shape[1], b_bf16.shape[1] + + # Transpose B to match Linear layer weight format [N, K] + b_bf16_t = b_bf16.t().contiguous() + + # Create Linear layer + my_linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16) + my_linear.weight.data.copy_(b_bf16_t) + + # BF16 forward + out_bf16 = my_linear(a_bf16) + + # FP8 autocast forward + fp8_recipe = recipe.Float8BlockScaling(fp8_format=recipe.Format.E4M3) + with te.autocast(enabled=True, recipe=fp8_recipe): + auto_out_bf16 = my_linear(a_bf16) + + # Compare autocast FP8 vs BF16 + _log_tensor_comparison( + out_bf16, + auto_out_bf16, + "TE Linear autocast FP8 vs BF16", + ) + + +def test_te_linear_autocast_vs_gemm(test_tensors, device): + """Test TransformerEngine Linear autocast FP8 vs manual FP8 GEMM.""" + a_bf16, b_bf16 = test_tensors + K, N = a_bf16.shape[1], b_bf16.shape[1] + M = a_bf16.shape[0] + + # Transpose B to match Linear layer weight format [N, K] + b_bf16_t = b_bf16.t().contiguous() + + # Create Linear layer + my_linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16) + my_linear.weight.data.copy_(b_bf16_t) + + # FP8 autocast forward + fp8_recipe = recipe.Float8BlockScaling(fp8_format=recipe.Format.E4M3) + with te.autocast(enabled=True, recipe=fp8_recipe): + auto_out_bf16 = my_linear(a_bf16) + + # Manual FP8 GEMM + a_te_fp8 = high_precision_to_te_blockwise_fp8( + a_bf16, + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + block_scaling_dim=1, + ) + b_te_fp8 = high_precision_to_te_blockwise_fp8( + b_bf16_t, + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + block_scaling_dim=2, + ) + result_gemm = _perform_fp8_gemm(b_te_fp8, a_te_fp8, (M, N), device, layout="TN") + + # Compare + _log_tensor_comparison( + auto_out_bf16, + result_gemm, + "TE Linear autocast vs manual GEMM", + ) diff --git a/areal/tests/test_fp8_rmsnorm.py b/areal/tests/test_fp8_rmsnorm.py new file mode 100644 index 000000000..d0fd8008e --- /dev/null +++ b/areal/tests/test_fp8_rmsnorm.py @@ -0,0 +1,797 @@ +"""RMSNorm testing utilities for FP8/BF16 comparison tests. + +This module contains RMSNorm-related classes, functions, and tests. +""" + +import re +from datetime import datetime +from pathlib import Path +from typing import Any + +import pytest +import torch +import torch.distributed as dist +import torch.nn.functional as F +from megatron.core.fp8_utils import get_fp8_context, is_float8tensor +from megatron.core.utils import get_model_config +from torch import nn +from torch.autograd import Function +from transformers import PretrainedConfig + +from areal.engine.megatron_engine import MegatronEngine +from areal.tests.fp8.engine_utils import create_engine +from areal.tests.fp8.model_hooks import get_model_from_engine +from areal.tests.utils import get_model_path +from areal.utils import logging + +logger = logging.getLogger("FP8 BF16 RMSNorm Test") + +MODEL_PATH_BF16 = get_model_path( + "/storage/openpsi/models/Qwen__Qwen3-0.6B/", "Qwen/Qwen3-0.6B" +) +MODEL_PATH_FP8 = get_model_path( + "/storage/openpsi/models/Qwen__Qwen3-0.6B-FP8/", "Qwen/Qwen3-0.6B-FP8" +) + + +@pytest.fixture(scope="module") +def engine_bf16(): + engine = create_engine( + MODEL_PATH_BF16, fp8_enabled=False, fp8_param=False, port=7777 + ) + try: + yield engine + finally: + engine.destroy() + if dist.is_initialized(): + dist.destroy_process_group() + + +@pytest.fixture(scope="module") +def engine_fp8(): + engine = create_engine(MODEL_PATH_FP8, fp8_enabled=True, fp8_param=True, port=7778) + try: + yield engine + finally: + engine.destroy() + if dist.is_initialized(): + dist.destroy_process_group() + + +def dequantize_fp8_param(tensor: torch.Tensor) -> torch.Tensor: + """Dequantize FP8 tensor to bfloat16.""" + if is_float8tensor(tensor): + return tensor.dequantize(dtype=torch.bfloat16) + else: + logger.info("Not a quantized tensor, converting to bfloat16") + return tensor.to(torch.bfloat16) + + +class Qwen3RMSNormFunction(Function): + """Custom autograd Function for Qwen3RMSNorm backward.""" + + @staticmethod + def forward(ctx, hidden_states, weight, variance_epsilon): + """ + Forward pass for RMSNorm. + + Args: + hidden_states: Input tensor of shape [..., hidden_size] + weight: Weight parameter of shape [hidden_size] + variance_epsilon: Epsilon value for numerical stability + + Returns: + Normalized and weighted output tensor + """ + input_dtype = hidden_states.dtype + hidden_states_fp32 = hidden_states.to(torch.float32) + + # Compute variance: mean(x^2) along last dimension + variance = hidden_states_fp32.pow(2).mean(-1, keepdim=True) + + # Compute normalized: x / sqrt(variance + eps) + inv_std = torch.rsqrt(variance + variance_epsilon) + normalized = hidden_states_fp32 * inv_std + + # Apply weight and convert back to input dtype + output = (weight * normalized).to(input_dtype) + + # Save tensors for backward + ctx.save_for_backward(hidden_states_fp32, weight, inv_std, normalized) + ctx.variance_epsilon = variance_epsilon + ctx.input_dtype = input_dtype + + return output + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass for RMSNorm. + + Args: + grad_output: Gradient w.r.t. output, shape [..., hidden_size] + + Returns: + grad_input: Gradient w.r.t. input + grad_weight: Gradient w.r.t. weight + grad_eps: None (variance_epsilon is not a tensor) + """ + hidden_states, weight, inv_std, normalized = ctx.saved_tensors + input_dtype = ctx.input_dtype + + # Convert grad_output to float32 for computation + grad_output_fp32 = grad_output.to(torch.float32) + + # Gradient w.r.t. weight: sum over all dimensions except last + grad_weight = (grad_output_fp32 * normalized).sum( + dim=tuple(range(grad_output_fp32.dim() - 1)) + ) + + # Gradient w.r.t. normalized: weight * grad_output + grad_normalized = grad_output_fp32 * weight.unsqueeze(0) + + # Gradient w.r.t. variance + inv_std_pow3 = inv_std.pow(3) + grad_variance = (grad_normalized * hidden_states * -0.5 * inv_std_pow3).sum( + -1, keepdim=True + ) + + # Gradient w.r.t. hidden_states + hidden_size = hidden_states.shape[-1] + grad_input_from_variance = grad_variance * 2.0 * hidden_states / hidden_size + + # d(normalized)/d(hidden_states) = inv_std (direct contribution) + grad_input_from_normalized = grad_normalized * inv_std + + # Total gradient w.r.t. input + grad_input = grad_input_from_normalized + grad_input_from_variance + + # Convert back to input dtype + grad_input = grad_input.to(input_dtype) + grad_weight = grad_weight.to(input_dtype) + + return grad_input, grad_weight, None + + +class Qwen3RMSNorm(nn.Module): + """Qwen3RMSNorm is equivalent to T5LayerNorm.""" + + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return Qwen3RMSNormFunction.apply( + hidden_states, self.weight, self.variance_epsilon + ) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +def forward_backward_rmsnorm_module( + layernorm_module: torch.nn.Module, + input_activation: torch.Tensor, + dtype: torch.dtype = torch.bfloat16, + name: str = "rmsnorm", + collect_gradients: bool = True, + output_grad: torch.Tensor | None = None, +) -> dict[str, Any]: + """Forward and backward a single RMSNorm module with given input activation. + + This function tests a RMSNorm module in isolation by: + 1. Setting the module to train mode (for gradients) + 2. Converting input to the specified dtype + 3. Running forward pass + 4. Running backward pass with a dummy loss + 5. Collecting output statistics and gradients + + Args: + layernorm_module: The RMSNorm module to test + input_activation: Input activation tensor + dtype: Data type to use (torch.bfloat16 or torch.float16) + name: Name identifier for logging + collect_gradients: Whether to collect gradients (requires backward pass) + output_grad: Optional gradient from downstream layers for backward pass + + Returns: + Dictionary with output tensor, statistics, and gradients + """ + layernorm_module.train() # Set to train mode for gradients + + # Convert input to specified dtype and ensure it requires grad + input_activation = input_activation.to(dtype=dtype) + if collect_gradients: + input_activation = input_activation.clone().detach().requires_grad_(True) + + # Forward pass + output = layernorm_module(input_activation) + + # Calculate statistics + output_norm = output.norm().item() + output_max = output.abs().max().item() + output_mean = output.mean().item() + output_std = output.std().item() + + gradients = {} + if collect_gradients: + # Zero gradients first + layernorm_module.zero_grad() + if input_activation.grad is not None: + input_activation.grad.zero_() + + # Use provided output gradient if available, otherwise use dummy loss + if output_grad is not None: + # Use the real gradient from downstream layers + output_grad = output_grad.to(dtype=dtype, device=output.device) + output.backward(output_grad) + else: + # Create a dummy loss (sum of output) + loss = output.sum() + # Backward pass + loss.backward() + + # Collect gradients from module parameters + for param_name, param in layernorm_module.named_parameters(): + if param.requires_grad: + grad = None + # Check different gradient storage locations + if hasattr(param, "main_grad") and param.main_grad is not None: + grad = param.main_grad.clone().detach() + elif hasattr(param, "grad") and param.grad is not None: + grad = param.grad.clone().detach() + else: + raise ValueError(f"No gradient found for {param_name}") + if grad is not None: + gradients[param_name + "_grad"] = grad + logger.debug( + f"{name} gradient {param_name}: " + f"shape={grad.shape}, norm={grad.norm().item():.6f}, " + f"min={grad.min().item():.6f}, max={grad.max().item():.6f}" + ) + + gradients["input"] = input_activation.clone().detach() + gradients["output"] = output.clone().detach() + + if output_grad is not None: + gradients["output_grad"] = output_grad.clone().detach() + + logger.info( + f"{name} ({dtype}): " + f"input_shape={input_activation.shape}, output_shape={output.shape}, " + f"output_norm={output_norm:.6f}, output_max={output_max:.6f}, " + f"output_mean={output_mean:.6f}, output_std={output_std:.6f}, " + f"n_gradients={len(gradients)}" + ) + + return { + "output": output, + "output_norm": output_norm, + "output_max": output_max, + "output_mean": output_mean, + "output_std": output_std, + "input_shape": input_activation.shape, + "output_shape": output.shape, + "gradients": gradients, + } + + +def load_layernorm_inputs_from_file(file_path: str | Path) -> dict[str, Any]: + """Load layernorm activation inputs from saved file. + + Args: + file_path: Path to the saved .pt file (can be combined file or individual file) + + Returns: + Dictionary with 'bf16_inputs', 'fp8_inputs', 'timestamp', 'layer_indices' + """ + file_path = Path(file_path) + if not file_path.exists(): + raise FileNotFoundError(f"File not found: {file_path}") + + data = torch.load(file_path, map_location="cpu") + + # Check if it's a combined file or individual file + if isinstance(data, dict) and "bf16_inputs" in data and "fp8_inputs" in data: + # Combined file + return data + elif isinstance(data, dict): + # Individual file - determine if BF16 or FP8 based on keys or filename + if "bf16" in file_path.name.lower(): + return { + "bf16_inputs": data, + "fp8_inputs": {}, + "timestamp": file_path.stem.split("_")[-1] + if "_" in file_path.stem + else "", + "layer_indices": [], + } + elif "fp8" in file_path.name.lower(): + return { + "bf16_inputs": {}, + "fp8_inputs": data, + "timestamp": file_path.stem.split("_")[-1] + if "_" in file_path.stem + else "", + "layer_indices": [], + } + else: + # Assume it's BF16 if can't determine + return { + "bf16_inputs": data, + "fp8_inputs": {}, + "timestamp": file_path.stem.split("_")[-1] + if "_" in file_path.stem + else "", + "layer_indices": [], + } + else: + raise ValueError(f"Unexpected file format in {file_path}") + + +def get_custom_rmsnorm( + layernorm_module: torch.nn.Module, + hf_config: PretrainedConfig, + device: torch.device, + dtype: torch.dtype = torch.bfloat16, + weight: torch.Tensor | None = None, +) -> torch.nn.Module: + """Create a custom RMSNorm module with dequantized FP8 params.""" + # Extract weight parameter + if hasattr(layernorm_module, "weight"): + weight_param = layernorm_module.weight + else: + # Try to find weight in named_parameters + weight_param = None + for name, param in layernorm_module.named_parameters(): + if "weight" in name.lower(): + weight_param = param + break + + if weight_param is None: + raise ValueError(f"Cannot find weight parameter in {layernorm_module}") + + # Dequantize if FP8, or convert to bfloat16 + dequantized_weight_data = dequantize_fp8_param(weight_param.data) + + # Get hidden_size from weight shape + hidden_size = hf_config.head_dim + eps = hf_config.rms_norm_eps + + # Create custom RMSNorm module + custom_rmsnorm = Qwen3RMSNorm(hidden_size, eps=eps) + if weight is not None: + custom_rmsnorm.weight.data = ( + weight.clone().detach().to(device=device, dtype=dtype) + ) + else: + custom_rmsnorm.weight.data = dequantized_weight_data.clone().detach() + custom_rmsnorm = custom_rmsnorm.to(device=device, dtype=dtype) + + logger.info( + f"Using custom Qwen3RMSNorm for to replace {layernorm_module} with dtype {dtype}" + ) + + return custom_rmsnorm + + +def compare_rmsnorm_bf16_fp8( + engine_bf16: MegatronEngine, + engine_fp8: MegatronEngine, + q_layernorm_input_bf16: torch.Tensor, + q_layernorm_input_fp8: torch.Tensor, + layer_path: str, + output_grad_bf16: torch.Tensor | None = None, + output_grad_fp8: torch.Tensor | None = None, + use_custom_rmsnorm: bool = False, + save_data: bool = False, +) -> dict[str, Any]: + """Compare RMSNorm module outputs between BF16 and FP8 engines. + + This function extracts the q_layernorm module from both engines and compares + their outputs when given the respective input activations. + + Args: + engine_bf16: BF16 MegatronEngine + engine_fp8: FP8 MegatronEngine + q_layernorm_input_bf16: Input activation from BF16 model + q_layernorm_input_fp8: Input activation from FP8 model + layer_path: Path to identify the layer (e.g., "layer_0.self_attention.q_layernorm") + output_grad_bf16: Optional output gradient for BF16 + output_grad_fp8: Optional output gradient for FP8 + use_custom_rmsnorm: Whether to use custom RMSNorm + save_data: Whether to save data to file + + Returns: + Dictionary with comparison results + """ + logger.info("=" * 80) + logger.info(f"Testing RMSNorm module: {layer_path}") + logger.info("=" * 80) + + # Extract q_layernorm module from both engines + model_bf16 = get_model_from_engine(engine_bf16) + model_fp8 = get_model_from_engine(engine_fp8) + + # Parse layer path + matches = re.match( + r"layer_(\d+)\.self_attention\.(q_layernorm|k_layernorm)", layer_path + ) + if not matches: + raise ValueError( + f"Invalid layer path: {layer_path}. Expected format: layer_X.self_attention.(q_layernorm|k_layernorm)" + ) + layer_idx = int(matches.group(1)) + layernorm_type = matches.group(2) + + fp8_context = get_fp8_context(get_model_config(model_fp8), layer_no=layer_idx) + + # Get decoder and layer + decoder_bf16 = model_bf16.decoder if hasattr(model_bf16, "decoder") else None + decoder_fp8 = model_fp8.decoder if hasattr(model_fp8, "decoder") else None + + if decoder_bf16 is None or decoder_fp8 is None: + raise ValueError("Cannot find decoder in model") + + if layer_idx >= len(decoder_bf16.layers) or layer_idx >= len(decoder_fp8.layers): + raise ValueError(f"Layer index {layer_idx} out of range") + + layer_bf16 = decoder_bf16.layers[layer_idx] + layer_fp8 = decoder_fp8.layers[layer_idx] + + if not hasattr(layer_bf16.self_attention, layernorm_type) or not hasattr( + layer_fp8.self_attention, layernorm_type + ): + raise ValueError(f"Layer {layer_idx} does not have {layernorm_type}") + + layernorm_bf16 = getattr(layer_bf16.self_attention, layernorm_type) + layernorm_fp8 = getattr(layer_fp8.self_attention, layernorm_type) + + # Test BF16 + logger.info("Testing BF16 RMSNorm...") + if use_custom_rmsnorm: + layernorm_bf16 = get_custom_rmsnorm( + layernorm_bf16, engine_bf16.hf_config, engine_bf16.device, torch.bfloat16 + ) + result_bf16 = forward_backward_rmsnorm_module( + layernorm_bf16, + q_layernorm_input_bf16, + output_grad=output_grad_bf16, + dtype=torch.bfloat16, + name=f"{layer_path} (BF16)", + collect_gradients=True, + ) + + # Test FP8 + logger.info("Testing FP8 RMSNorm...") + if use_custom_rmsnorm: + # For custom RMSNorm, we dequantize params first, so no need for FP8 context + layernorm_fp8 = get_custom_rmsnorm( + layernorm_fp8, engine_fp8.hf_config, engine_fp8.device, torch.bfloat16 + ) + result_fp8 = forward_backward_rmsnorm_module( + layernorm_fp8, + q_layernorm_input_fp8, + output_grad=output_grad_fp8, + dtype=torch.bfloat16, # Will use dequantized params + name=f"{layer_path} (FP8, dequantized)", + collect_gradients=True, + ) + else: + # Use original FP8 module with FP8 context + with fp8_context: + result_fp8 = forward_backward_rmsnorm_module( + layernorm_fp8, + q_layernorm_input_fp8, + output_grad=output_grad_fp8, + dtype=torch.bfloat16, # Input will be converted, but module may use FP8 internally + name=f"{layer_path} (FP8)", + collect_gradients=True, + ) + + if save_data: + # save input, weight, output_grad for both BF16 and FP8 + save_dir = Path("layernorm_inputs") + save_dir.mkdir(exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + save_path = save_dir / f"layernorm_inputs_{layer_path}_{timestamp}.pt" + torch.save( + { + "bf16": { + "input": q_layernorm_input_bf16, + "weight": layernorm_bf16.weight.data.clone().detach(), + "output_grad": output_grad_bf16.clone().detach() + if output_grad_bf16 is not None + else None, + }, + "fp8": { + "input": q_layernorm_input_fp8, + "weight": layernorm_fp8.weight.data.clone().detach(), + "output_grad": output_grad_fp8.clone().detach() + if output_grad_fp8 is not None + else None, + }, + }, + save_path, + ) + logger.info(f"Saved layernorm inputs to: {save_path}") + + # Compare outputs + output_bf16 = result_bf16["output"] + output_fp8 = result_fp8["output"] + + if output_bf16.shape != output_fp8.shape: + logger.warning( + f"Output shapes don't match: BF16={output_bf16.shape}, FP8={output_fp8.shape}" + ) + return { + "layer_path": layer_path, + "shape_mismatch": True, + "bf16_shape": output_bf16.shape, + "fp8_shape": output_fp8.shape, + } + + # Calculate differences + output_diff = (output_bf16 - output_fp8).abs() + max_diff = output_diff.max().item() + mean_diff = output_diff.mean().item() + + # Cosine similarity + output_bf16_flat = output_bf16.flatten() + output_fp8_flat = output_fp8.flatten() + cos_sim = F.cosine_similarity( + output_bf16_flat.unsqueeze(0), output_fp8_flat.unsqueeze(0), dim=1 + ).item() + + logger.info("=" * 80) + logger.info(f"RMSNorm Comparison Results for {layer_path}") + logger.info("=" * 80) + logger.info( + f"Output - max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, cos_sim={cos_sim:.6f}" + ) + logger.info( + f"BF16 output_norm={result_bf16['output_norm']:.6f}, FP8 output_norm={result_fp8['output_norm']:.6f}" + ) + logger.info( + f"BF16 output_max={result_bf16['output_max']:.6f}, FP8 output_max={result_fp8['output_max']:.6f}" + ) + + # Compare gradients + gradients_bf16 = result_bf16.get("gradients", {}) + gradients_fp8 = result_fp8.get("gradients", {}) + + gradient_comparison = {} + common_gradient_names = set(gradients_bf16.keys()) & set(gradients_fp8.keys()) + + if common_gradient_names: + logger.info("\n" + "-" * 80) + logger.info("Gradient Comparison") + logger.info("-" * 80) + + for grad_name in sorted(common_gradient_names): + grad_bf16 = gradients_bf16[grad_name] + grad_fp8 = gradients_fp8[grad_name] + + if grad_bf16.shape != grad_fp8.shape: + logger.warning( + f"Gradient {grad_name} shapes don't match: " + f"BF16={grad_bf16.shape}, FP8={grad_fp8.shape}" + ) + continue + + # Calculate differences + grad_diff = (grad_bf16 - grad_fp8).abs() + grad_max_diff = grad_diff.max().item() + grad_mean_diff = grad_diff.mean().item() + + # Cosine similarity + grad_bf16_flat = grad_bf16.flatten() + grad_fp8_flat = grad_fp8.flatten() + grad_cos_sim = F.cosine_similarity( + grad_bf16_flat.unsqueeze(0), grad_fp8_flat.unsqueeze(0), dim=1 + ).item() + + # Norms + grad_bf16_norm = grad_bf16.norm().item() + grad_fp8_norm = grad_fp8.norm().item() + + gradient_comparison[grad_name] = { + "max_diff": grad_max_diff, + "mean_diff": grad_mean_diff, + "cos_sim": grad_cos_sim, + "bf16_norm": grad_bf16_norm, + "fp8_norm": grad_fp8_norm, + } + + logger.info( + f"{layer_path + '.' + grad_name:<80} " + f"max_diff={grad_max_diff:>12.6f}, " + f"mean_diff={grad_mean_diff:>12.6f}, " + f"cos_sim={grad_cos_sim:>10.6f}, " + f"BF16_norm={grad_bf16_norm:>12.6f}, FP8_norm={grad_fp8_norm:>12.6f}" + ) + + # Summary + if gradient_comparison: + avg_cos_sim = sum(g["cos_sim"] for g in gradient_comparison.values()) / len( + gradient_comparison + ) + max_grad_diff = max(g["max_diff"] for g in gradient_comparison.values()) + logger.info("-" * 80) + logger.info( + f"Gradient Summary: " + f"avg_cos_sim={avg_cos_sim:.6f}, " + f"max_diff={max_grad_diff:.6f}, " + f"n_gradients={len(gradient_comparison)}" + ) + else: + logger.warning("No common gradients found for comparison") + + logger.info("=" * 80) + + return { + "layer_path": layer_path, + "output_max_diff": max_diff, + "output_mean_diff": mean_diff, + "output_cos_sim": cos_sim, + "bf16_output_norm": result_bf16["output_norm"], + "fp8_output_norm": result_fp8["output_norm"], + "bf16_output_max": result_bf16["output_max"], + "fp8_output_max": result_fp8["output_max"], + "output_bf16": output_bf16, + "output_fp8": output_fp8, + "gradient_comparison": gradient_comparison, + } + + +@pytest.mark.skip(reason="This test is only for debugging") +@pytest.mark.parametrize("use_custom_rmsnorm", [True, False]) +@pytest.mark.parametrize( + "activation_inputs_file", + [ + "activation_inputs/layernorm_inputs_combined_20251216_170822.pt", + ], +) +def test_rmsnorm_from_file( + use_custom_rmsnorm: bool, + activation_inputs_file: str | Path | None, + engine_bf16, + engine_fp8, + layer_path: str | None = None, + save_data: bool = False, +): + """Test RMSNorm modules using activation inputs loaded from file. + + This test loads previously saved activation inputs from file and tests + RMSNorm modules (q_layernorm and k_layernorm) in isolation. + + Args: + activation_inputs_file: Path to the saved activation inputs file. + layer_path: Specific layer path to test (e.g., "layer_0.self_attention.q_layernorm"). + If None, will test all available layers. + use_custom_rmsnorm: If True, use custom Qwen3RMSNorm with dequantized FP8 params. + For FP8, params will be dequantized to bfloat16 before RMSNorm. + """ + + # Load activation inputs + logger.info("=" * 80) + logger.info(f"Loading activation inputs from: {activation_inputs_file}") + logger.info("=" * 80) + + data = load_layernorm_inputs_from_file(activation_inputs_file) + bf16_inputs = data.get("bf16_inputs", {}) + fp8_inputs = data.get("fp8_inputs", {}) + bf16_output_grads = data.get("bf16_output_grads", {}) + fp8_output_grads = data.get("fp8_output_grads", {}) + + logger.info(f"Loaded BF16 inputs: {list(bf16_inputs.keys())}") + logger.info(f"Loaded FP8 inputs: {list(fp8_inputs.keys())}") + + # Find matching layer paths + common_keys = set(bf16_inputs.keys()) & set(fp8_inputs.keys()) + if not common_keys: + logger.warning("No common layer paths found between BF16 and FP8 inputs") + return + + # Filter by layer_path if specified + if layer_path: + # Convert layer_path to input key format + if layer_path.endswith(".q_layernorm"): + input_key = layer_path.replace(".q_layernorm", ".q_layernorm.input") + elif layer_path.endswith(".k_layernorm"): + input_key = layer_path.replace(".k_layernorm", ".k_layernorm.input") + else: + input_key = f"{layer_path}.input" + + if input_key not in common_keys: + logger.warning(f"Layer path {layer_path} not found in loaded inputs") + logger.info(f"Available keys: {sorted(common_keys)}") + return + + common_keys = {input_key} + + # Only test q_layernorm + common_keys = {k for k in common_keys if k.endswith(".q_layernorm.input")} + + # Test each matching layer + results = [] + for input_key in sorted(common_keys): + # Extract layer path from input key + if input_key.endswith(".q_layernorm.input"): + test_layer_path = input_key.replace(".input", "") + layernorm_type = "q_layernorm" + elif input_key.endswith(".k_layernorm.input"): + test_layer_path = input_key.replace(".input", "") + layernorm_type = "k_layernorm" + else: + logger.warning(f"Unexpected input key format: {input_key}") + continue + + logger.info("\n" + "=" * 80) + logger.info(f"Testing {layernorm_type} for {test_layer_path}") + logger.info("=" * 80) + + # Get input activations + q_layernorm_input_bf16 = bf16_inputs[input_key] + q_layernorm_input_fp8 = fp8_inputs[input_key] + + # Get output gradients (from downstream layers) + output_grad_key = input_key.replace(".input", ".output_grad") + output_grad_bf16 = bf16_output_grads.get(output_grad_key, None) + output_grad_fp8 = fp8_output_grads.get(output_grad_key, None) + + q_layernorm_input_bf16 = q_layernorm_input_bf16.to(engine_bf16.device) + q_layernorm_input_fp8 = q_layernorm_input_fp8.to(engine_fp8.device) + if output_grad_bf16 is not None: + output_grad_bf16 = output_grad_bf16.to(engine_bf16.device) + if output_grad_fp8 is not None: + output_grad_fp8 = output_grad_fp8.to(engine_fp8.device) + + # Compare RMSNorm + result = compare_rmsnorm_bf16_fp8( + engine_bf16, + engine_fp8, + q_layernorm_input_bf16, + q_layernorm_input_fp8, + test_layer_path, + output_grad_bf16=output_grad_bf16, + output_grad_fp8=output_grad_fp8, + use_custom_rmsnorm=use_custom_rmsnorm, + save_data=save_data, + ) + results.append(result) + + # Summary + logger.info("\n" + "=" * 80) + logger.info("RMSNorm Test Summary") + logger.info("=" * 80) + for result in results: + if "shape_mismatch" in result and result["shape_mismatch"]: + logger.warning( + f"{result['layer_path']}: Shape mismatch - " + f"BF16={result['bf16_shape']}, FP8={result['fp8_shape']}" + ) + else: + logger.info( + f"{result['layer_path']}: " + f"output_max_diff={result['output_max_diff']:.6f}, " + f"output_mean_diff={result['output_mean_diff']:.6f}, " + f"output_cos_sim={result['output_cos_sim']:.6f}" + ) + + # Gradient summary + if "gradient_comparison" in result and result["gradient_comparison"]: + grad_comp = result["gradient_comparison"] + avg_grad_cos_sim = sum(g["cos_sim"] for g in grad_comp.values()) / len( + grad_comp + ) + max_grad_diff = max(g["max_diff"] for g in grad_comp.values()) + logger.info( + f" Gradients: " + f"avg_cos_sim={avg_grad_cos_sim:.6f}, " + f"max_diff={max_grad_diff:.6f}, " + f"n_gradients={len(grad_comp)}" + ) + + logger.info("=" * 80) diff --git a/areal/utils/fp8_kernels.py b/areal/utils/fp8_kernels.py new file mode 100644 index 000000000..3b7d93290 --- /dev/null +++ b/areal/utils/fp8_kernels.py @@ -0,0 +1,145 @@ +# Adapted from slime +import torch +import triton +import triton.language as tl + + +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +@triton.jit +def _blockwise_cast_to_fp8_triton( + X, + Y, + S, + stride_xm, + stride_xn, + stride_ym, + stride_yn, + stride_sm, + stride_sn, + M, + N, + eps, + fp8_min, + fp8_max, + BLOCK_M: tl.constexpr = 32, + BLOCK_N: tl.constexpr = 128, +): + pid_m = tl.cast(tl.program_id(axis=0), tl.int64) + pid_n = tl.cast(tl.program_id(axis=1), tl.int64) + off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = off_m < M + mask_n = off_n < N + mask = mask_m[:, None] & mask_n[None, :] + + x = tl.load( + X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn, + mask=mask, + other=0.0, + ).to(tl.float32) + _absmax = tl.maximum(tl.max(tl.abs(x)), eps) + x_s = _absmax / fp8_max + s_inv = 1.0 / x_s + y_q = tl.clamp(x * s_inv, fp8_min, fp8_max).to(Y.dtype.element_ty) + + tl.store( + Y + off_m[:, None] * stride_ym + off_n[None, :] * stride_yn, y_q, mask=mask + ) + tl.store(S + pid_m * stride_sm + pid_n * stride_sn, x_s) + + +def blockwise_cast_to_fp8_triton( + x: torch.Tensor, block_size=None +) -> tuple[torch.Tensor, torch.Tensor]: + BLOCK_M, BLOCK_N = 128, 128 + if block_size: + BLOCK_M, BLOCK_N = block_size[0], block_size[1] + M, N = x.shape + fp8_dtype = torch.float8_e4m3fn + fp8_max = torch.finfo(fp8_dtype).max + fp8_min = -fp8_max + y = torch.empty(M, N, device=x.device, dtype=fp8_dtype) + s = torch.empty( + ceil_div(M, BLOCK_M), ceil_div(N, BLOCK_N), dtype=torch.float32, device=x.device + ) + + def grid(meta): + return (triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(N, meta["BLOCK_N"])) + + if x.is_contiguous(): + kwargs = { + "BLOCK_M": BLOCK_M, + "BLOCK_N": BLOCK_N, + "num_warps": 8, + "num_stages": 2, + } + else: + kwargs = { + "BLOCK_M": BLOCK_M, + "BLOCK_N": BLOCK_N, + "num_warps": 1, + "num_stages": 4, + } + _blockwise_cast_to_fp8_triton[grid]( + x, + y, + s, + *x.stride(), + *y.stride(), + *s.stride(), + M, + N, + 1e-10, + fp8_min, + fp8_max, + **kwargs, + ) + return y, s + + +# Adapted from https://github.com/alibaba/Pai-Megatron-Patch/blob/2b201af08336dea0403df7c6b497c964cf5a2e75/toolkits/model_checkpoints_convertor/deepseek/fp8_cast_bf16.py +@triton.jit +def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.load(s_ptr + pid_m * n + pid_n) + y = x * s + tl.store(y_ptr + offs, y, mask=mask) + + +def weight_dequant( + x: torch.Tensor, + s: torch.Tensor, + block_size: int = 128, + dst_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """Dequantize FP8 weights to the given dtype. + + Args: + x: FP8 weight tensor (2D) + s: Scale inverse tensor (2D, shape matches block structure) + block_size: Block size used for quantization + dst_dtype: Destination dtype to dequantize to + + Returns: + Dequantized weight tensor in the destination dtype + """ + assert x.is_contiguous() and s.is_contiguous() + assert x.dim() == 2 and s.dim() == 2 + M, N = x.size() + y = torch.empty_like(x, dtype=dst_dtype) + + def grid(meta): + return (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"])) + + weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + return y diff --git a/areal/utils/fp8_utils.py b/areal/utils/fp8_utils.py new file mode 100644 index 000000000..e52aa1aa3 --- /dev/null +++ b/areal/utils/fp8_utils.py @@ -0,0 +1,192 @@ +import re + +import torch + +try: + from sglang.srt.layers.quantization.fp8_utils import ( + quant_weight_ue8m0, + transform_scale_ue8m0, + ) + from sglang.srt.model_loader.utils import should_deepgemm_weight_requant_ue8m0 +except ImportError: + should_deepgemm_weight_requant_ue8m0 = None + quant_weight_ue8m0 = None + transform_scale_ue8m0 = None + +from areal.utils.fp8_kernels import blockwise_cast_to_fp8_triton, weight_dequant + + +# Adapted from slime +def _quantize_param( + name: str, + weight: torch.Tensor, + weight_block_size: tuple[int, int] | list[int] | None = None, + scale_dtype: torch.dtype = torch.bfloat16, +) -> list[tuple[str, torch.Tensor]]: + """Quantize a single weight parameter to FP8 format. + + Args: + name: Parameter name (must end with ".weight") + weight: Weight tensor to quantize + weight_block_size: Optional block size for blockwise quantization [block_m, block_n] + + Returns: + List of (name, tensor) tuples: [(weight_name, quantized_weight), (scale_name, scale)] + """ + assert name.endswith(".weight"), f"Expected weight parameter, got {name}" + FP8_MIN = torch.finfo(torch.float8_e4m3fn).min + FP8_MAX = torch.finfo(torch.float8_e4m3fn).max + + if weight_block_size is not None: + # Blockwise quantization + if ( + should_deepgemm_weight_requant_ue8m0 is not None + and should_deepgemm_weight_requant_ue8m0( + weight_block_size=weight_block_size + ) + ): + # Use sglang's quantization + qweight, scale = quant_weight_ue8m0( + weight, weight_block_size=weight_block_size + ) + scale = transform_scale_ue8m0(scale, mn=qweight.shape[-2]) + else: + # Use triton-based blockwise quantization + qweight, scale = blockwise_cast_to_fp8_triton(weight, weight_block_size) + scale_name = name.replace(".weight", ".weight_scale_inv") + else: + # Per-tensor quantization + scale = weight.abs().max().clamp(min=1e-12).to(torch.float32) / FP8_MAX + qweight = ( + (weight / scale).clamp(min=FP8_MIN, max=FP8_MAX).to(torch.float8_e4m3fn) + ) + scale = scale.view(1) + scale_name = name.replace(".weight", ".weight_scale") + + scale = scale.to(scale_dtype) + return [(name, qweight), (scale_name, scale)] + + +# Adapted from slime +def quantize_params( + megatron_name: str, + converted_named_params: list[tuple[str, torch.Tensor]], + quantization_config: dict[str, int | str | list[str]] | None, + scale_dtype: torch.dtype = torch.bfloat16, +) -> list[tuple[str, torch.Tensor]]: + """Apply FP8 quantization to converted HuggingFace parameters.""" + if quantization_config is None: + return converted_named_params + + assert quantization_config["quant_method"] == "fp8" + assert quantization_config["fmt"] == "e4m3" + assert quantization_config["activation_scheme"] == "dynamic" + weight_block_size = quantization_config.get("weight_block_size", None) + + # handle both with and without "module.module." prefix + if not megatron_name.startswith("module.module."): + # Add prefix if missing for pattern matching + if megatron_name.startswith("decoder."): + megatron_name = "module.module." + megatron_name + elif megatron_name.startswith("mtp."): + megatron_name = "module.module." + megatron_name + + decoder_layers_pattern = r"module\.module\.decoder\.layers\.(\d+)\.(.+)" + match = re.match(decoder_layers_pattern, megatron_name) + + if not match: + # Check mtp layers + mtp_layer_pattern = r"module\.module\.mtp\.layers\.(\d+)\.(.+)" + match = re.match(mtp_layer_pattern, megatron_name) + if not match: + return converted_named_params + layer_idx, rest = match.groups() + rest = rest.replace("transformer_layer.", "") + else: + layer_idx, rest = match.groups() + + # Experts + expert_pattern = r"mlp.experts\.(.+)\.weight(\d+)" + match = re.match(expert_pattern, rest) + if match: + rest, expert_idx = match.groups() + if rest in ["linear_fc1", "linear_fc2"]: + quantize_named_params = [] + for converted_name, param in converted_named_params: + # Skip bf16 weight_scale and input_scale + # TODO: find a clearer way. + if converted_name.endswith("_scale"): + continue + quantize_named_params.extend( + _quantize_param( + converted_name, param, weight_block_size, scale_dtype + ) + ) + return quantize_named_params + + # Shared expert + shared_expert_pattern = r"mlp.shared_experts\.(.+)" + match = re.match(shared_expert_pattern, rest) + if match: + rest = match.groups()[0] + if rest in ["linear_fc1.weight", "linear_fc2.weight"]: + quantize_named_params = [] + for converted_name, param in converted_named_params: + quantize_named_params.extend( + _quantize_param( + converted_name, param, weight_block_size, scale_dtype + ) + ) + return quantize_named_params + + # Regular attention and MLP layers + if rest in [ + "self_attention.linear_proj.weight", + "self_attention.linear_qkv.weight", + "mlp.linear_fc1.weight", + "mlp.linear_fc2.weight", + # mla + "self_attention.linear_q_proj.weight", + "self_attention.linear_q_down_proj.weight", + "self_attention.linear_q_up_proj.weight", + "self_attention.linear_kv_down_proj.weight", + "self_attention.linear_kv_up_proj.weight", + ]: + quantize_named_params = [] + for converted_name, param in converted_named_params: + quantize_named_params.extend( + _quantize_param(converted_name, param, weight_block_size, scale_dtype) + ) + return quantize_named_params + + # For other parameters, return original converted_named_params + return converted_named_params + + +def dequantize_params( + weight: torch.Tensor, + scale_inv: torch.Tensor, + dst_dtype: torch.dtype = torch.bfloat16, + quantization_config: dict[str, int | str | list[str]] | None = None, +) -> torch.Tensor: + """Dequantize FP8 weights to the given dtype.""" + if not weight.is_contiguous(): + weight = weight.contiguous() + if not scale_inv.is_contiguous(): + scale_inv = scale_inv.contiguous() + + if quantization_config is None: + block_size = 128 + else: + weight_block_size = quantization_config.get("weight_block_size", None) + # TODO: consider (M, N) block size, now only support square block size + if weight_block_size is not None: + assert ( + len(weight_block_size) == 2 + and weight_block_size[0] == weight_block_size[1] + ) + block_size = weight_block_size[0] + else: + block_size = 128 + + return weight_dequant(weight, scale_inv, block_size, dst_dtype) diff --git a/areal/utils/mcore/pipeline_parallel.py b/areal/utils/mcore/pipeline_parallel.py index 232ca8da1..fd11cf49a 100644 --- a/areal/utils/mcore/pipeline_parallel.py +++ b/areal/utils/mcore/pipeline_parallel.py @@ -266,7 +266,6 @@ def mlp_params(intermediate: int | None) -> float: moe_layer_indices: set[int] = set() freq = getattr(tf_conf, "moe_layer_freq", 1) - assert freq > 0 if isinstance(freq, int): step = abs(freq) assert step >= 1 diff --git a/areal/utils/megatron.py b/areal/utils/megatron.py index 007ae0d99..870b8e593 100644 --- a/areal/utils/megatron.py +++ b/areal/utils/megatron.py @@ -8,6 +8,8 @@ from torch import Tensor from torch.nn.parameter import Parameter +from areal.utils.fp8_utils import quantize_params + # Adapted from slime def all_gather_param(name: str, param: Parameter | Tensor): @@ -60,7 +62,7 @@ def remove_padding(name: str, param: Parameter | Tensor, vocab_size: int): # Adapted from slime def convert_qwen3moe_to_hf( - tf_config: TransformerConfig, name: str, param: Parameter | Tensor + tf_config: TransformerConfig, name: str, param: Parameter | Tensor, **kwargs ): if name == "module.module.embedding.word_embeddings.weight": return [("model.embed_tokens.weight", param)] @@ -215,7 +217,7 @@ def convert_qwen3moe_to_hf( # Adapted from slime def convert_qwen2_to_hf( - tf_config: TransformerConfig, name: str, param: Parameter | Tensor + tf_config: TransformerConfig, name: str, param: Parameter | Tensor, **kwargs ): if name == "module.module.embedding.word_embeddings.weight": return [("model.embed_tokens.weight", param)] @@ -301,21 +303,212 @@ def convert_qwen2_to_hf( raise ValueError(f"Unknown parameter name: {name}") +# Adapted from slime +def convert_deepseekv3_to_hf( + tf_config: TransformerConfig, name: str, param: Parameter | Tensor, **kwargs +): + if name == "module.module.embedding.word_embeddings.weight": + return [("model.embed_tokens.weight", param)] + if name == "module.module.output_layer.weight": + return [("lm_head.weight", param)] + if name == "module.module.decoder.final_layernorm.weight": + return [("model.norm.weight", param)] + + try: + head_dim = ( + tf_config.kv_channels + if tf_config.kv_channels is not None + else tf_config.hidden_size // tf_config.num_attention_heads + ) + except (AttributeError, TypeError): + head_dim = tf_config.hidden_size // tf_config.num_attention_heads + value_num_per_group = tf_config.num_attention_heads // tf_config.num_query_groups + + decoder_layers_pattern = r"module\.module\.decoder\.layers\.(\d+)\.(.+)" + match = re.match(decoder_layers_pattern, name) + if match: + layer_idx, rest = match.groups() + + # experts + expert_pattern = r"mlp.experts\.(.+)\.weight(\d+)" + match = re.match(expert_pattern, rest) + if match: + rest, expert_idx = match.groups() + if rest == "linear_fc1": + gate_weight, up_weight = param.chunk(2, dim=0) + outputs = [ + ( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight", + gate_weight, + ), + ( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight", + up_weight, + ), + ] + return outputs + elif rest == "linear_fc2": + outputs = [ + ( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight", + param, + ), + ] + if kwargs.get("inference_enable_ep_moe", False): + outputs += [ + ( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.input_scale", + torch.tensor(1.0, dtype=torch.float32, device=param.device), + ), + ( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight_scale", + torch.tensor(1.0, dtype=torch.float32, device=param.device), + ), + ] + return outputs + else: + raise ValueError(f"Unknown expert parameter name: {name}") + + # shared expert + shared_expert_pattern = r"mlp.shared_experts\.(.+)" + match = re.match(shared_expert_pattern, rest) + if match: + rest = match.groups()[0] + if rest == "linear_fc1.weight": + gate_weight, up_weight = param.chunk(2, dim=0) + return [ + ( + f"model.layers.{layer_idx}.mlp.shared_experts.gate_proj.weight", + gate_weight, + ), + ( + f"model.layers.{layer_idx}.mlp.shared_experts.up_proj.weight", + up_weight, + ), + ] + elif rest == "linear_fc2.weight": + return [ + ( + f"model.layers.{layer_idx}.mlp.shared_experts.down_proj.weight", + param, + ) + ] + else: + raise ValueError(f"Unknown shared expert parameter name: {name}") + + if rest == "self_attention.linear_proj.weight": + return [(f"model.layers.{layer_idx}.self_attn.o_proj.weight", param)] + elif rest == "self_attention.linear_q_proj.weight": + return [(f"model.layers.{layer_idx}.self_attn.q_proj.weight", param)] + elif rest == "self_attention.linear_q_down_proj.weight": + return [(f"model.layers.{layer_idx}.self_attn.q_a_proj.weight", param)] + elif rest == "self_attention.linear_q_up_proj.layer_norm_weight": + return [(f"model.layers.{layer_idx}.self_attn.q_a_layernorm.weight", param)] + elif rest == "self_attention.linear_q_up_proj.weight": + return [(f"model.layers.{layer_idx}.self_attn.q_b_proj.weight", param)] + elif rest == "self_attention.linear_qkv.bias": + param = param.view(tf_config.num_query_groups, -1) + q_bias, k_bias, v_bias = torch.split( + param, + split_size_or_sections=[ + value_num_per_group * head_dim, + head_dim, + head_dim, + ], + dim=1, + ) + q_bias = q_bias.contiguous().flatten() + k_bias = k_bias.contiguous().flatten() + v_bias = v_bias.contiguous().flatten() + return [ + (f"model.layers.{layer_idx}.self_attn.q_proj.bias", q_bias), + (f"model.layers.{layer_idx}.self_attn.k_proj.bias", k_bias), + (f"model.layers.{layer_idx}.self_attn.v_proj.bias", v_bias), + ] + elif rest == "mlp.linear_fc1.weight": + gate_weight, up_weight = param.chunk(2, dim=0) + return [ + (f"model.layers.{layer_idx}.mlp.gate_proj.weight", gate_weight), + (f"model.layers.{layer_idx}.mlp.up_proj.weight", up_weight), + ] + elif rest == "mlp.linear_fc2.weight": + return [(f"model.layers.{layer_idx}.mlp.down_proj.weight", param)] + elif ( + rest == "self_attention.linear_qkv.layer_norm_weight" + or rest == "input_layernorm.weight" + ): + return [(f"model.layers.{layer_idx}.input_layernorm.weight", param)] + elif rest == "mlp.linear_fc1.layer_norm_weight": + return [ + (f"model.layers.{layer_idx}.post_attention_layernorm.weight", param) + ] + elif rest == "self_attention.linear_kv_down_proj.weight": + return [ + (f"model.layers.{layer_idx}.self_attn.kv_a_proj_with_mqa.weight", param) + ] + elif rest == "self_attention.linear_kv_up_proj.layer_norm_weight": + return [ + (f"model.layers.{layer_idx}.self_attn.kv_a_layernorm.weight", param) + ] + elif rest == "self_attention.linear_kv_up_proj.weight": + return [(f"model.layers.{layer_idx}.self_attn.kv_b_proj.weight", param)] + elif rest == "pre_mlp_layernorm.weight": + return [ + (f"model.layers.{layer_idx}.post_attention_layernorm.weight", param) + ] + elif rest == "mlp.router.weight": + return [(f"model.layers.{layer_idx}.mlp.gate.weight", param)] + elif rest == "mlp.router.expert_bias": + return [ + (f"model.layers.{layer_idx}.mlp.gate.e_score_correction_bias", param) + ] + + raise ValueError(f"Unknown parameter name: {name}") + + # Adapted from slime # A registry for conversion functions is more extensible. _CONVERSION_FN_REGISTRY = { "qwen3_moe": convert_qwen3moe_to_hf, "qwen2": convert_qwen2_to_hf, "qwen3": convert_qwen2_to_hf, + "deepseekv3": convert_deepseekv3_to_hf, } def convert_to_hf( - tf_config: TransformerConfig, model_name: str, name: str, param: Parameter | Tensor + tf_config: TransformerConfig, + model_name: str, + name: str, + param: Parameter | Tensor, + quantization_config: dict[str, int | str | list[str]] | None = None, + **kwargs, ): + """Convert Megatron parameter to HuggingFace format, optionally with FP8 quantization. + + Args: + tf_config: Transformer configuration + model_name: Model name (e.g., "qwen2", "qwen3_moe") + name: Parameter name in Megatron format + param: Parameter tensor + quantization_config: Optional quantization config dict with keys: + - quant_method: "fp8" + - fmt: "e4m3" + - activation_scheme: "dynamic" + - weight_block_size: Optional tuple/list of [block_m, block_n] for blockwise quantization + + Returns: + List of (name, tensor) tuples in HuggingFace format. For FP8 quantization, + returns both quantized weight and scale tensors. + """ for key, conversion_fn in _CONVERSION_FN_REGISTRY.items(): if key in model_name: - return conversion_fn(tf_config, name, param) + converted_named_tensors = conversion_fn(tf_config, name, param, **kwargs) + if quantization_config: + return quantize_params( + name, converted_named_tensors, quantization_config + ) + return converted_named_tensors raise ValueError(f"Unsupported model for HF conversion: {model_name}") diff --git a/docs/cli_reference.md b/docs/cli_reference.md index 71d789179..d4fd9a63c 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -566,6 +566,7 @@ https://github.com/sgl-project/sglang for detailed documentation. | `decode_log_interval` | integer | `1` | - | | `enable_multithread_load` | boolean | `False` | - | | `enable_fast_load` | boolean | `False` | - | +| `quantization` | string \| None | `None` | - | (section-v-llm)= @@ -809,17 +810,18 @@ Configuration for Megatron's DistributedDataParallel. Refer to Megatron-LM documentation for details. -| Parameter | Type | Default | Description | -| --------------------------- | --------------- | ------- | ----------- | -| `grad_reduce_in_fp32` | boolean | `True` | - | -| `overlap_grad_reduce` | boolean | `False` | - | -| `overlap_param_gather` | boolean | `False` | - | -| `align_param_gather` | boolean | `False` | - | -| `use_distributed_optimizer` | boolean | `True` | - | -| `check_for_nan_in_grad` | boolean | `False` | - | -| `bucket_size` | integer \| None | `None` | - | -| `average_in_collective` | boolean | `False` | - | -| `fp8_param_gather` | boolean | `False` | - | +| Parameter | Type | Default | Description | +| --------------------------------- | --------------- | ------------ | ------------------------------------------------------------------------------------------------------ | +| `grad_reduce_in_fp32` | boolean | `True` | - | +| `overlap_grad_reduce` | boolean | `False` | - | +| `overlap_param_gather` | boolean | `False` | - | +| `align_param_gather` | boolean | `False` | - | +| `use_distributed_optimizer` | boolean | `True` | - | +| `check_for_nan_in_grad` | boolean | `False` | - | +| `bucket_size` | integer \| None | `None` | - | +| `average_in_collective` | boolean | `False` | - | +| `fp8_param_gather` | boolean | `False` | - | +| `data_parallel_sharding_strategy` | string | `"no_shard"` | Sharding strategy for FSDP. Valid values are 'no_shard', 'optim', 'optim_grads', 'optim_grads_params'. | (section-megatron-engine)= @@ -829,27 +831,45 @@ Configuration for Megatron-LM training framework. Refer to Megatron-LM documentation for implementation details. -| Parameter | Type | Default | Description | -| ------------------------------------------ | -------------------------------------------------------------------- | ------------ | ------------------------------------------------------------------------------------------------------------------- | -| `wrap_with_ddp` | boolean | `True` | - | -| `use_torch_fsdp2` | boolean | `False` | - | -| `use_custom_fsdp` | boolean | `False` | - | -| `ddp` | [`DistributedDataParallelConfig`](section-distributed-data-parallel) | **Required** | - | -| `virtual_pipeline_parallel_size` | integer | `1` | Virtual pipeline parallel size for Megatron interleaved schedule. Set to >1 to enable VPP. Default is 1 (disabled). | -| `overlap_param_gather_with_optimizer_step` | boolean | `False` | - | -| `use_precision_aware_optimizer` | boolean | `False` | - | -| `main_grads_dtype` | string | `"float32"` | - | -| `main_params_dtype` | string | `"float32"` | - | -| `exp_avg_dtype` | string | `"float32"` | - | -| `exp_avg_sq_dtype` | string | `"float32"` | - | -| `async_save` | boolean | `False` | - | -| `use_checkpoint_opt_param_scheduler` | boolean | `True` | - | -| `use_deterministic_algorithms` | boolean | `False` | - | -| `recompute_granularity` | string \| None | `"full"` | - | -| `recompute_method` | string \| None | `"uniform"` | - | -| `recompute_num_layers` | integer \| None | `1` | - | -| `distribute_saved_activations` | boolean \| None | `None` | - | -| `recompute_modules` | list of string \| None | `None` | - | +| Parameter | Type | Default | Description | +| ------------------------------------------ | -------------------------------------------------------------------- | --------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `wrap_with_ddp` | boolean | `True` | - | +| `use_torch_fsdp2` | boolean | `False` | - | +| `use_custom_fsdp` | boolean | `False` | - | +| `ddp` | [`DistributedDataParallelConfig`](section-distributed-data-parallel) | **Required** | - | +| `virtual_pipeline_parallel_size` | integer | `1` | Virtual pipeline parallel size for Megatron interleaved schedule. Set to >1 to enable VPP. Default is 1 (disabled). | +| `overlap_param_gather_with_optimizer_step` | boolean | `False` | - | +| `use_precision_aware_optimizer` | boolean | `False` | - | +| `main_grads_dtype` | string | `"float32"` | - | +| `main_params_dtype` | string | `"float32"` | - | +| `exp_avg_dtype` | string | `"float32"` | - | +| `exp_avg_sq_dtype` | string | `"float32"` | - | +| `async_save` | boolean | `False` | - | +| `use_checkpoint_opt_param_scheduler` | boolean | `True` | - | +| `use_deterministic_algorithms` | boolean | `False` | - | +| `recompute_granularity` | string \| None | `"full"` | - | +| `recompute_method` | string \| None | `"uniform"` | - | +| `recompute_num_layers` | integer \| None | `1` | - | +| `distribute_saved_activations` | boolean \| None | `None` | - | +| `recompute_modules` | list of string \| None | `None` | - | +| `moe_router_dtype` | string \| None | `None` | - | +| `moe_shared_expert_overlap` | boolean | `False` | Enable overlapping between shared expert computations and dispatcher communications. Without this, the shared epxerts execute after the routed experts. | +| `moe_enable_deepep` | boolean | `False` | - | +| `moe_token_dispatcher_type` | string | `"alltoall"` | Type of token dispatcher. Options: 'allgather','alltoall' and 'flex'. | +| `moe_permute_fusion` | boolean | `False` | Fuse token rearrangement ops during token dispatching. | +| `fp8` | string \| None | `None` | Enable FP8 precision training. Options: 'e4m3' (uniform e4m3), 'hybrid' (e4m3 for activations/weights, e5m2 for output activation gradients). | +| `fp8_recipe` | string | `"delayed"` | FP8 scaling recipe. Options: 'tensorwise', 'delayed', 'mxfp8' (Blackwell only), 'blockwise'. | +| `fp8_param` | boolean | `False` | Keep parameters in FP8 precision to save memory. Must be used together with fp8 mode. Not all parameters will be converted to fp8; for example, biases will remain unchanged. | +| `fp8_margin` | integer | `0` | Margin for FP8 scaling factor computation. | +| `fp8_amax_history_len` | integer | `1` | Length of amax history window for scaling factor computation. | +| `fp8_amax_compute_algo` | string | `"most_recent"` | Algorithm for choosing amax value. Options: 'max' (largest in history window), 'most_recent'. | +| `fp8_wgrad` | boolean | `True` | When False, override FP8 config and compute weight gradients in higher precision. | +| `fp8_dot_product_attention` | boolean | `False` | Use FP8 implementation of Dot Product Attention. | +| `fp8_multi_head_attention` | boolean | `False` | Use FP8 implementation of Multi Head Attention. | +| `tp_only_amax_red` | boolean | `False` | Reduce FP8 AMAX only in TP or TP-CP domain. | +| `first_last_layers_bf16` | boolean | `False` | Retain first and last N TransformerBlocks in BF16 instead of FP8. | +| `num_layers_at_start_in_bf16` | integer | `1` | Number of layers at start to keep in BF16 when first_last_layers_bf16 is True. | +| `num_layers_at_end_in_bf16` | integer | `1` | Number of layers at end to keep in BF16 when first_last_layers_bf16 is True. | (section-perf-tracer)=