From 9bda954fb48381bc48e852bc177350e56697cad7 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Mon, 12 Jan 2026 18:51:33 -0800 Subject: [PATCH 01/11] support latent moe import and fix local experts sync Signed-off-by: jenchen13 --- .../torch/export/plugins/mcore_nemotron.py | 8 +++++++- modelopt/torch/quantization/model_calib.py | 18 ++++++++++++------ .../torch/quantization/plugins/megatron.py | 1 - 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 5fdb8ba1b..53fd0d232 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -81,8 +81,11 @@ "shared_experts.linear_fc2": NameRemapping( "backbone.layers.{}.mixer.shared_experts.down_proj.", ROW_TP ), -} + # Latent MoE + "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj.", REPLICATE), + "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj.", REPLICATE), +} nemotron_h_causal_lm_export: dict[str, CustomModuleMapping] = { "word_embeddings": NameRemapping("backbone.embeddings."), @@ -115,4 +118,7 @@ "shared_experts.linear_fc2": NameRemapping( "backbone.layers.{}.mixer.shared_experts.down_proj." ), + # Latent MoE + "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj."), + "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj."), } diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 3184f2a78..9774e71f2 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -95,13 +95,22 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) # TODO: create sync_bias_across_distributed_group - # Step 1:Sync amax across data parallelism + # Step 1: Sync amax across local experts in a SequentialMLP + for name, module in model.named_modules(): + if hasattr(module, "sync_moe_local_experts_amax"): + module.sync_moe_local_experts_amax() + + # TODO just for testing + if "experts" in name and "weight_quantizer" in name: + assert child.amax is not None + + # Step 2:Sync amax across data parallelism for name, module in model.named_modules(): if isinstance(module, QuantModule): for child in module.children(): if isinstance(child, (TensorQuantizer, SequentialQuantizer)): sync_quantizer_amax_across_dp_ep(child, module.parallel_state) - # TP sync: + # Step 3: TP sync # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same # ColumnParallel: X @ [A_1, A_2] (weights split along Cout) @@ -182,10 +191,7 @@ def sync_quantizer_amax_across_tp( parallel_state=module.parallel_state, ) - # MOE Quantization - if hasattr(module, "sync_moe_local_experts_amax"): - module.sync_moe_local_experts_amax() - + # KV Cache Quantization if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): # We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 803c9747f..c44aa86b3 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -581,7 +581,6 @@ def sync_moe_local_experts_amax(self): This function is called to synchronize the amax values across local experts s.t. all localexperts will share the same amax. """ - torch.distributed.barrier() # Collect amax from all local experts amax_dict = {} for expert in self.local_experts: From 5da17f4237c0d75c07bdc4c0a2ab770de3f0513e Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Mon, 12 Jan 2026 21:06:15 -0800 Subject: [PATCH 02/11] patch TransformerLayer forward Signed-off-by: jenchen13 --- .../torch/export/plugins/mcore_nemotron.py | 2 ++ .../torch/quantization/plugins/megatron.py | 27 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 53fd0d232..a61fc367e 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -87,6 +87,8 @@ } +# TODO later support MTP import/export + nemotron_h_causal_lm_export: dict[str, CustomModuleMapping] = { "word_embeddings": NameRemapping("backbone.embeddings."), "final_norm": NameRemapping("backbone.norm_f."), diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index c44aa86b3..64b8c51b4 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -23,6 +23,7 @@ import megatron.core.parallel_state as mcore_parallel import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp +import megatron.core.transformer.transformer_layer as megatron_transformer_layer import megatron.core.transformer.moe.experts as megatron_moe import megatron.core.transformer.moe.moe_layer as megatron_moe_layer import torch @@ -40,6 +41,7 @@ register_modelopt_extra_state_callbacks, ) from modelopt.torch.utils.distributed import ParallelState +import torch.distributed as dist from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear @@ -593,12 +595,18 @@ def sync_moe_local_experts_amax(self): if stored_amax is None else torch.maximum(stored_amax, amax_tensor) ) + #if isinstance(module, TensorQuantizer) and module.amax is None: + # print(f"MISSING AMAX BEFORE SYNC in expert rank {dist.get_rank()}: {name}", flush=True) + + # Apply synchronized amax values back to all local experts for expert in self.local_experts: for name, module in expert.named_modules(): if isinstance(module, TensorQuantizer) and module.amax is not None: module.amax = amax_dict[name].detach().clone().to(module.amax.device) + #if isinstance(module, TensorQuantizer) and module.amax is None: + # print(f"MISSING AMAX AFTER SYNC in expert rank {dist.get_rank()}: {name}", flush=True) def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): """Override the default to enable singleton_local_shards. @@ -758,3 +766,22 @@ def forward(self, hidden_states): super().forward(hidden_states) self.router.topk = original_top_k return super().forward(hidden_states) + +# TODO double check if MOE forward will be implemented in MoELayer or TransformerLayer +# We do not need both layers to be patched + +@QuantModuleRegistry.register({megatron_transformer_layer.TransformerLayer: "megatron_transformer_layer_TransformerLayer"}) +class _QuantTransformerLayer(QuantModule): + def _setup(self): + pass + + def _forward_mlp_moe_preprocess(self, hidden_states): + #print(f"FORWARD in TransformerLayer rank {dist.get_rank()}", flush=True) + if any(getattr(m, "_if_calib", False) for m in self.mlp.experts.modules()): + print(f"Forcing top_k to num_experts in TransformerLayer rank {dist.get_rank()}", flush=True) + original_top_k = self.mlp.router.topk + self.mlp.router.topk = self.mlp.router.num_experts + super()._forward_mlp_moe_preprocess(hidden_states) + self.mlp.router.topk = original_top_k + + return super()._forward_mlp_moe_preprocess(hidden_states) \ No newline at end of file From 4471c03ee61957bd12283a84af3070a05dc20cf3 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Mon, 12 Jan 2026 21:37:46 -0800 Subject: [PATCH 03/11] fix bug of duplicate forward Signed-off-by: jenchen13 --- modelopt/torch/quantization/plugins/megatron.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 64b8c51b4..62b542b66 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -763,8 +763,9 @@ def forward(self, hidden_states): if any(getattr(m, "_if_calib", False) for m in self.experts.modules()): original_top_k = self.router.topk self.router.topk = self.router.num_experts - super().forward(hidden_states) + output = super().forward(hidden_states) self.router.topk = original_top_k + return output return super().forward(hidden_states) # TODO double check if MOE forward will be implemented in MoELayer or TransformerLayer @@ -776,12 +777,11 @@ def _setup(self): pass def _forward_mlp_moe_preprocess(self, hidden_states): - #print(f"FORWARD in TransformerLayer rank {dist.get_rank()}", flush=True) if any(getattr(m, "_if_calib", False) for m in self.mlp.experts.modules()): - print(f"Forcing top_k to num_experts in TransformerLayer rank {dist.get_rank()}", flush=True) original_top_k = self.mlp.router.topk self.mlp.router.topk = self.mlp.router.num_experts - super()._forward_mlp_moe_preprocess(hidden_states) + output = super()._forward_mlp_moe_preprocess(hidden_states) self.mlp.router.topk = original_top_k + return output return super()._forward_mlp_moe_preprocess(hidden_states) \ No newline at end of file From db1892f50aa4a5961d4b474c82e084b77d2e6243 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Tue, 20 Jan 2026 12:18:13 -0800 Subject: [PATCH 04/11] fix kv bmm export Signed-off-by: jenchen13 --- modelopt/torch/export/plugins/mcore_custom.py | 9 +++ .../torch/export/plugins/mcore_nemotron.py | 12 ++- modelopt/torch/export/quant_utils.py | 31 +++----- .../torch/export/unified_export_megatron.py | 73 ++++++++----------- modelopt/torch/quantization/model_calib.py | 7 -- .../torch/quantization/plugins/megatron.py | 5 -- 6 files changed, 61 insertions(+), 76 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_custom.py b/modelopt/torch/export/plugins/mcore_custom.py index 23804b322..25a2cd0cb 100644 --- a/modelopt/torch/export/plugins/mcore_custom.py +++ b/modelopt/torch/export/plugins/mcore_custom.py @@ -126,6 +126,15 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] func_kwargs=func_kwargs, ) +class SelfAttentionScaling(CustomModuleMapping): + """A custom module mapping that scales self attention.""" + def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}): + """Create a custom module mapping that scales self attention.""" + super().__init__( + func_name="self_attention_scaling", + target_name_or_prefix=target_name_or_prefix, + func_kwargs=func_kwargs, + ) class GatedMLPSlicing(CustomModuleMapping): """A custom module mapping that slices gate_proj and up_proj.""" diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index a61fc367e..f857230ae 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -26,6 +26,7 @@ NameRemapping, QKVMerging, QKVSlicing, + SelfAttentionScaling, ) # Example on adding a new CausalLM. @@ -84,10 +85,18 @@ # Latent MoE "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj.", REPLICATE), "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj.", REPLICATE), + # MTP + #"enorm": NameRemapping("mtp.layers.{}.enorm.", REPLICATE), + #"hnorm": NameRemapping("mtp.layers.{}.hnorm.", REPLICATE), + #"eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", REPLICATE), + #"layer_norm": NameRemapping("mtp.layers.{}.final_layernorm.", REPLICATE), + #"norm": NameRemapping("mtp.layers.{}.norm", REPLICATE) + # "transformer_layer": NameRemapping("mtp.layers.{}.mixer", REPLICATE), + } -# TODO later support MTP import/export +# TODO ADD MTP export nemotron_h_causal_lm_export: dict[str, CustomModuleMapping] = { "word_embeddings": NameRemapping("backbone.embeddings."), @@ -106,6 +115,7 @@ "input_layernorm": NameRemapping("backbone.layers.{}.norm."), "linear_qkv": QKVSlicing("backbone.layers.{}.mixer."), "linear_proj": NameRemapping("backbone.layers.{}.mixer.o_proj."), + "core_attention": SelfAttentionScaling("backbone.layers.{}.mixer."), # MLP "pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm."), "linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj."), diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index eee13dc51..dbf78db0f 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -332,7 +332,7 @@ def get_prequant_scaling_factor(module: nn.Module) -> torch.Tensor: if prequant_scaling_factor is not None: assert torch.all(prequant_scaling_factor > 0), ( f"prequant scaling factor {prequant_scaling_factor} not positive." - ) + ) return prequant_scaling_factor @@ -344,32 +344,22 @@ def get_kv_cache_bias(kv_module: nn.Module) -> list[torch.Tensor]: kv_bias.append(getattr(quantizer_module, "_bias_value", None)) return kv_bias - -def get_kv_cache_scaling_factor(kv_module: nn.Module) -> list[torch.Tensor]: - """Returns the kv_cache scaling factor if output quantizer is set. Else returns None by default.""" - if not hasattr(kv_module, "k_bmm_quantizer") or not hasattr(kv_module, "v_bmm_quantizer"): +def get_kv_cache_scaling_factor(self_attention_module: nn.Module) -> torch.Tensor: + """ + Returns the k and v BMM scaling factors if BMM quantizers are set in the self attention module. + Else returns None by default. + """ + if not hasattr(self_attention_module, "k_bmm_quantizer") or not hasattr(self_attention_module, "v_bmm_quantizer"): return [None, None] scaling_factors = [ - get_scaling_factor(getattr(kv_module, quantizer)) + get_scaling_factor(getattr(self_attention_module, quantizer)) for quantizer in ("k_bmm_quantizer", "v_bmm_quantizer") ] - - # For FP8, we recommend default kv cache scaling factor to be 1. - if get_kv_cache_dtype(kv_module) == KV_CACHE_FP8: - for i, factor in enumerate(scaling_factors): - if factor.item() > 0.5: - warn( - f"Warning: Large KV activation detected: {factor.item()}, " - "Quantized KV cache may lead to higher accuracy drop." - ) - scaling_factors[i] = torch.max( - factor, torch.tensor([1.0], dtype=torch.float, device=factor.device) - ) - return scaling_factors + def get_kv_cache_dtype(modules: list[nn.Module] | nn.Module) -> str | None: """Returns the kv_cache dtype. @@ -390,8 +380,7 @@ def get_kv_cache_dtype(modules: list[nn.Module] | nn.Module) -> str | None: for module in modules: # Case where the module has both k_bmm_quantizer and v_bmm_quantizer - # Still check for output quantizer for the unified_megatron_export path - for quantizer in ("k_bmm_quantizer", "v_bmm_quantizer", "output_quantizer"): + for quantizer in ("k_bmm_quantizer", "v_bmm_quantizer"): quantizer_attr = getattr(module, quantizer, None) if quantizer_attr and quantizer_attr.is_enabled: num_bits_list.append(quantizer_attr.num_bits) diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index f1bd67327..c17c2c0bb 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -52,6 +52,8 @@ from .quant_utils import ( get_activation_scaling_factor, get_kv_cache_dtype, + get_kv_cache_scaling_factor, + get_quant_config, get_quantization_format, get_scaling_factor, get_weight_block_size, @@ -86,33 +88,6 @@ ] -# This path uses output_quantizer for KV cache quantization. -# The function below is the old version of get_kv_cache_scaling_factor which is now refactored to handle bmm_quantizer. -def get_kv_cache_scaling_factor(kv_module: nn.Module) -> torch.Tensor: - """Returns the kv_cache scaling factor if output quantizer is set. Else returns None by default.""" - scaling_factor = ( - get_scaling_factor(kv_module.output_quantizer) - if hasattr(kv_module, "output_quantizer") - else None - ) - - if not scaling_factor: - return None - - # For FP8, we recommend default kv cache scaling factor to be 1. - if get_kv_cache_dtype(kv_module) == KV_CACHE_FP8: - if scaling_factor.item() > 0.5: - warn( - f"!!!!Large KV activations detected: {scaling_factor.item()}, " - "Quantized KV cache may lead to higher accuracy drop.\n!!!!" - ) - scaling_factor = torch.max( - scaling_factor, - torch.tensor([1.0], dtype=torch.float, device=scaling_factor.device), - ) - return scaling_factor - - class GPTModelExporter: """Megatron Core GPTModel Exporter. @@ -283,6 +258,7 @@ def save_pretrained( kv_cache_quantization = None kv_cache_dtype = get_kv_cache_dtype(self.model) + print("kv_cache_dtype: ", kv_cache_dtype) if kv_cache_dtype in (KV_CACHE_FP8, KV_CACHE_NVFP4): # FP8 KV Cache is supported in VLLM; NVFP4 supported in TRTLLM kv_cache_quantization = kv_cache_dtype @@ -320,7 +296,9 @@ def save_pretrained( pass if is_last_stage_main_rank and quantization is not None: - hf_quant_config = { + # TODO refactor to use mte.quant_utils.get_quant_config + # except layer names are different in MCore and HF + hf_quant_config = { "producer": { "name": "modelopt", "version": __version__, @@ -328,9 +306,11 @@ def save_pretrained( "quantization": { "quant_algo": quantization, "kv_cache_quant_algo": kv_cache_quantization, - "exclude_modules": ["lm_head"], + "exclude_modules": ["lm_head"], # TODO update this dynamically }, } + if quantization == "NVFP4": + hf_quant_config["quantization"]["group_size"] = 16 with open(save_directory + "/hf_quant_config.json", "w") as f: json.dump(hf_quant_config, f, indent=4) @@ -473,6 +453,7 @@ def _custom_mapping_to_lambda(mapping): method_map = { "name_remapping": self._name_remapping, "qkv_slicing": self._qkv_slicing, + "self_attention_scaling": self._self_attention_scaling, "gated_mlp_slicing": self._gated_mlp_slicing, "pack_name_remapping": self._pack_name_remapping, "pack_name_remapping_gpt_oss": self._pack_name_remapping_gpt_oss, @@ -541,12 +522,8 @@ def _get_quantized_state( # TODO (chenhany): support AWQ with pre_quant_scale if hasattr(module.input_quantizer, "_pre_quant_scale"): raise ValueError("Detect pre_quant_scale! SmoothQuant/AWQ are not yet supported!") - - if hasattr(module, "output_quantizer"): - output_scale = get_kv_cache_scaling_factor(module) - if output_scale is not None: - name_to_value["output_scale"] = output_scale - + + return name_to_value, qformat, block_size def _get_quantization_format(self, module: torch.nn.Module): @@ -674,9 +651,7 @@ def _qkv_slicing( q_proj_name="q_proj", k_proj_name="k_proj", v_proj_name="v_proj", - k_scale_name="k_scale", - v_scale_name="v_scale", - ): + ): name_to_value, qformat, block_size = self._get_quantized_state(module, self.dtype) q_proj_prefix = prefix + q_proj_name + "." @@ -774,10 +749,7 @@ def _qkv_slicing( q_proj_key = q_proj_prefix + key k_proj_key = k_proj_prefix + key v_proj_key = v_proj_prefix + key - if key == "output_scale": - self._state_dict[prefix + k_scale_name] = val.detach().clone() - self._state_dict[prefix + v_scale_name] = val.detach().clone() - elif key == "bias": + if key == "bias": # Slice bias similar to weight bias = val.detach().clone() bias = bias.reshape([qkv_total_dim, head_size]) @@ -790,6 +762,17 @@ def _qkv_slicing( self._state_dict[k_proj_key] = val.detach().clone() self._state_dict[v_proj_key] = val.detach().clone() + def _self_attention_scaling(self, module, prefix, k_scale_name="k_scale", v_scale_name="v_scale"): + """KV cache scaling for self attention module.""" + k_scale_key = prefix + k_scale_name + v_scale_key = prefix + v_scale_name + if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): + kv_scales = get_kv_cache_scaling_factor(module) + if all(s is not None for s in kv_scales): + self._state_dict[k_scale_key] = kv_scales[0] + self._state_dict[v_scale_key] = kv_scales[1] + + def _pack_name_remapping(self, module, prefix, layer_type=None): """Pack name remapping into one tensor.""" weight_list = [] @@ -1149,6 +1132,8 @@ def _get_state_dict(self): self.rules["q_layernorm"](layer.self_attention.q_layernorm, layer_id) self.rules["k_layernorm"](layer.self_attention.k_layernorm, layer_id) self.rules["linear_qkv"](layer.self_attention.linear_qkv, layer_id) + if hasattr(layer.self_attention, "core_attention"): + self.rules["core_attention"](layer.self_attention.core_attention, layer_id) self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) if ( getattr(layer.self_attention.core_attention, "softmax_offset", None) @@ -1166,6 +1151,10 @@ def _get_state_dict(self): self.rules["router"]( layer.mlp.router, layer_id, dtype=self.moe_router_dtype ) + if hasattr(layer.mlp, "fc1_latent_proj") and layer.mlp.fc1_latent_proj is not None: + self.rules["fc1_latent_proj"](layer.mlp.fc1_latent_proj, layer_id) + if hasattr(layer.mlp, "fc2_latent_proj") and layer.mlp.fc2_latent_proj is not None: + self.rules["fc2_latent_proj"](layer.mlp.fc2_latent_proj, layer_id) if ( hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 9774e71f2..73293d4b1 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -100,10 +100,6 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): if hasattr(module, "sync_moe_local_experts_amax"): module.sync_moe_local_experts_amax() - # TODO just for testing - if "experts" in name and "weight_quantizer" in name: - assert child.amax is not None - # Step 2:Sync amax across data parallelism for name, module in model.named_modules(): if isinstance(module, QuantModule): @@ -165,7 +161,6 @@ def sync_quantizer_amax_across_tp( axes_for_sync=[None, -1], parallel_state=module.parallel_state, ) - sync_quantizer_amax_across_tp( module.weight_quantizer, name, @@ -284,8 +279,6 @@ def quant_func(x, amax, quantizer=module): # Step 4: Compute optimal amax and load it finish_stats_collection(model, method="mse") - # TODO: Sync amax across distributed processes - def enable_stats_collection(model: nn.Module): """Enable stats collection for all quantizers in the model.""" diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 62b542b66..4414ec5e2 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -595,8 +595,6 @@ def sync_moe_local_experts_amax(self): if stored_amax is None else torch.maximum(stored_amax, amax_tensor) ) - #if isinstance(module, TensorQuantizer) and module.amax is None: - # print(f"MISSING AMAX BEFORE SYNC in expert rank {dist.get_rank()}: {name}", flush=True) @@ -765,7 +763,6 @@ def forward(self, hidden_states): self.router.topk = self.router.num_experts output = super().forward(hidden_states) self.router.topk = original_top_k - return output return super().forward(hidden_states) # TODO double check if MOE forward will be implemented in MoELayer or TransformerLayer @@ -782,6 +779,4 @@ def _forward_mlp_moe_preprocess(self, hidden_states): self.mlp.router.topk = self.mlp.router.num_experts output = super()._forward_mlp_moe_preprocess(hidden_states) self.mlp.router.topk = original_top_k - return output - return super()._forward_mlp_moe_preprocess(hidden_states) \ No newline at end of file From f26bf3cc5e1267f15196a9c29def072d04668d6e Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Tue, 20 Jan 2026 12:26:17 -0800 Subject: [PATCH 05/11] small fixes Signed-off-by: jenchen13 --- modelopt/torch/export/unified_export_megatron.py | 6 +++--- modelopt/torch/quantization/model_calib.py | 11 ++++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index c17c2c0bb..feddd7488 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -298,7 +298,7 @@ def save_pretrained( if is_last_stage_main_rank and quantization is not None: # TODO refactor to use mte.quant_utils.get_quant_config # except layer names are different in MCore and HF - hf_quant_config = { + hf_quant_config = { "producer": { "name": "modelopt", "version": __version__, @@ -309,7 +309,7 @@ def save_pretrained( "exclude_modules": ["lm_head"], # TODO update this dynamically }, } - if quantization == "NVFP4": + if quantization == "NVFP4": # update block size hf_quant_config["quantization"]["group_size"] = 16 with open(save_directory + "/hf_quant_config.json", "w") as f: json.dump(hf_quant_config, f, indent=4) @@ -763,7 +763,7 @@ def _qkv_slicing( self._state_dict[v_proj_key] = val.detach().clone() def _self_attention_scaling(self, module, prefix, k_scale_name="k_scale", v_scale_name="v_scale"): - """KV cache scaling for self attention module.""" + """KV cache scaling for CoreAttention module.""" k_scale_key = prefix + k_scale_name v_scale_key = prefix + v_scale_name if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 73293d4b1..595034a8a 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -81,6 +81,11 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis forward_loop(model) finish_stats_collection(model) + # Step 1: Sync amax across local experts in a SequentialMLP + for name, module in model.named_modules(): + if hasattr(module, "sync_moe_local_experts_amax"): + module.sync_moe_local_experts_amax() + if not distributed_sync: return @@ -95,11 +100,7 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) # TODO: create sync_bias_across_distributed_group - # Step 1: Sync amax across local experts in a SequentialMLP - for name, module in model.named_modules(): - if hasattr(module, "sync_moe_local_experts_amax"): - module.sync_moe_local_experts_amax() - + # Step 2:Sync amax across data parallelism for name, module in model.named_modules(): if isinstance(module, QuantModule): From 2cc4acc26b74a5fa6dee0daef23e5019f7a07bf8 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Thu, 22 Jan 2026 09:39:40 -0800 Subject: [PATCH 06/11] mtp import fixes Signed-off-by: jenchen13 --- .../torch/export/plugins/mcore_nemotron.py | 10 +- .../torch/export/plugins/megatron_importer.py | 333 ++++++++++-------- modelopt/torch/export/quant_utils.py | 31 +- .../torch/export/unified_export_megatron.py | 17 +- 4 files changed, 217 insertions(+), 174 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index f857230ae..07f4656bf 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -86,12 +86,10 @@ "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj.", REPLICATE), "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj.", REPLICATE), # MTP - #"enorm": NameRemapping("mtp.layers.{}.enorm.", REPLICATE), - #"hnorm": NameRemapping("mtp.layers.{}.hnorm.", REPLICATE), - #"eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", REPLICATE), - #"layer_norm": NameRemapping("mtp.layers.{}.final_layernorm.", REPLICATE), - #"norm": NameRemapping("mtp.layers.{}.norm", REPLICATE) - # "transformer_layer": NameRemapping("mtp.layers.{}.mixer", REPLICATE), + "mtp.enorm": NameRemapping("mtp.layers.{}.enorm.", REPLICATE), + "mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm.", REPLICATE), + "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", REPLICATE), + "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", REPLICATE), } diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index 0af79eb36..a58612b82 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -150,7 +150,14 @@ def _name_remapping( mapping={}, parallel_config: ParallelConfig | None = None, dtype: torch.dtype | None = None, + is_mtp: bool = False, ): + if is_mtp: + if "backbone" in prefix: + prefix = prefix.replace("backbone", "mtp") + else: + prefix = prefix.replace("model", "mtp") + print(f"name_remapping: {prefix}, mapping: {mapping}") if dtype is None: dtype = self.dtype if isinstance(module, torch.Tensor): @@ -262,7 +269,13 @@ def _qkv_merging( k_proj_name="k_proj", v_proj_name="v_proj", parallel_config: ParallelConfig | None = None, + is_mtp: bool = False, ): + if is_mtp: + if "backbone" in prefix: + prefix = prefix.replace("backbone", "mtp") + else: + prefix = prefix.replace("model", "mtp") config = module.config hidden_size = config.hidden_size num_query_groups = config.num_query_groups @@ -469,9 +482,111 @@ def _unpack_name_remapping_gpt_oss( linear_module.load_state_dict(state_dict) + def _import_mamba_layer(self, layer, layer_id, layer_pbar): + layer_pbar.set_description("Importing Mamba layer") + if not isinstance(layer.norm, IdentityOp): + self.rules["norm"](layer.norm, layer_id) + + self.rules["mixer_norm"](layer.mixer.norm, layer_id) + self.rules["A_log"](layer.mixer.A_log, layer_id) + self.rules["D"](layer.mixer.D, layer_id) + self.rules["dt_bias"](layer.mixer.dt_bias, layer_id) + self.rules["conv1d"](layer.mixer.conv1d, layer_id) + self.rules["in_proj"](layer.mixer.in_proj, layer_id) + self.rules["out_proj"](layer.mixer.out_proj, layer_id) + + def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp=False): + if not isinstance(layer.input_layernorm, IdentityOp): + self.rules["input_layernorm"](layer.input_layernorm, layer_id, is_mtp=is_mtp) + + attention = layer.self_attention + if not isinstance(attention, IdentityOp): + if "MLASelfAttention" in str(type(attention)): + if hasattr(attention, "linear_q_proj"): + layer_pbar.set_description("Importing MLA (without q LoRA)") + self.rules["linear_q_proj"](attention.linear_q_proj, layer_id, is_mtp=is_mtp) + else: + layer_pbar.set_description("Importing MLA (with q LoRA)") + self.rules["linear_q_down_proj"](attention.linear_q_down_proj, layer_id, is_mtp=is_mtp) + self.rules["linear_q_layernorm"](attention.q_layernorm, layer_id, is_mtp=is_mtp) + self.rules["linear_q_up_proj"](attention.linear_q_up_proj, layer_id, is_mtp=is_mtp) + self.rules["linear_kv_down_proj"](attention.linear_kv_down_proj, layer_id, is_mtp=is_mtp) + self.rules["linear_kv_layernorm"](attention.kv_layernorm, layer_id, is_mtp=is_mtp) + self.rules["linear_kv_up_proj"](attention.linear_kv_up_proj, layer_id, is_mtp=is_mtp) + self.rules["linear_proj"](attention.linear_proj, layer_id, is_mtp=is_mtp) + else: + layer_pbar.set_description("Importing GQA/MHA") + if attention.q_layernorm is not None and not isinstance( + attention.q_layernorm, (IdentityOp, L2Norm) + ): + self.rules["q_layernorm"](attention.q_layernorm, layer_id, is_mtp=is_mtp) + self.rules["k_layernorm"](attention.k_layernorm, layer_id, is_mtp=is_mtp) + self.rules["linear_qkv"](attention.linear_qkv, layer_id, is_mtp=is_mtp) + self.rules["linear_proj"](attention.linear_proj, layer_id, is_mtp=is_mtp) + if getattr(attention.core_attention, "softmax_offset", None) is not None: + self.rules["softmax_offset"]( + attention.core_attention.softmax_offset, layer_id, is_mtp=is_mtp + ) + + if not isinstance(layer.pre_mlp_layernorm, IdentityOp): + self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id, is_mtp=is_mtp) + + if not isinstance(layer.mlp, IdentityOp): + if "MoE" in str(type(layer.mlp)): + layer_pbar.set_description("Importing MoE") + self.rules["router"]( + layer.mlp.router, layer_id, dtype=self.moe_router_dtype, is_mtp=is_mtp + ) + if ( + hasattr(layer.mlp, "shared_experts") + and layer.mlp.shared_experts is not None + ): + layer_pbar.set_description("Importing MoE shared experts") + fc1 = layer.mlp.shared_experts.linear_fc1 + fc2 = layer.mlp.shared_experts.linear_fc2 + self.rules["shared_experts.linear_fc1"](fc1, layer_id, is_mtp=is_mtp) + self.rules["shared_experts.linear_fc2"](fc2, layer_id, is_mtp=is_mtp) + if not self.rules.get("use_packed_local_experts", False): + for local_expert_id, expert in tqdm( + enumerate(layer.mlp.experts.local_experts), + desc="Importing MoE local experts", + leave=False, + disable=self.disable_tqdm, + ): + expert_id = layer.mlp.local_expert_indices[local_expert_id] + fc1 = expert.linear_fc1 + fc2 = expert.linear_fc2 + self.rules["local_experts.linear_fc1"](fc1, layer_id, expert_id, is_mtp=is_mtp) + self.rules["local_experts.linear_fc2"](fc2, layer_id, expert_id, is_mtp=is_mtp) + # We only support either EP or ETP for now + elif get_expert_tensor_parallel_world_size() > 1: + # ETP supports for packed MoE + # ETP is not supported for gpt-oss model + if self.arch in ["GptOssForCausalLM"]: + raise ValueError("ETP is not supported for gpt-oss model") + self.rules["local_experts.linear_fc1_etp"]( + layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + ) + self.rules["local_experts.linear_fc2_etp"]( + layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + ) + else: + # EP supports for packed MoE + self.rules["local_experts.linear_fc1_ep"]( + layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + ) + self.rules["local_experts.linear_fc2_ep"]( + layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + ) + else: + layer_pbar.set_description("Importing MLP") + self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id, is_mtp=is_mtp) + self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id, is_mtp=is_mtp) + + def _import_state_dict(self): model = self.model - + print(model, flush=True) layer_pbar = tqdm(model.decoder.layers, disable=self.disable_tqdm) # Embedding @@ -481,108 +596,13 @@ def _import_state_dict(self): # Decoder layers for layer in layer_pbar: + print(f"Importing layer {layer.layer_number}", flush=True) layer_id = layer.layer_number - 1 if isinstance(layer, MambaLayer): - if not isinstance(layer.norm, IdentityOp): - self.rules["norm"](layer.norm, layer_id) - - self.rules["mixer_norm"](layer.mixer.norm, layer_id) - self.rules["A_log"](layer.mixer.A_log, layer_id) - self.rules["D"](layer.mixer.D, layer_id) - self.rules["dt_bias"](layer.mixer.dt_bias, layer_id) - - self.rules["conv1d"](layer.mixer.conv1d, layer_id) - self.rules["in_proj"](layer.mixer.in_proj, layer_id) - self.rules["out_proj"](layer.mixer.out_proj, layer_id) - + self._import_mamba_layer(layer, layer_id, layer_pbar) elif isinstance(layer, TransformerLayer): - if not isinstance(layer.input_layernorm, IdentityOp): - self.rules["input_layernorm"](layer.input_layernorm, layer_id) - - attention = layer.self_attention - if not isinstance(attention, IdentityOp): - if "MLASelfAttention" in str(type(attention)): - if hasattr(attention, "linear_q_proj"): - layer_pbar.set_description("Importing MLA (without q LoRA)") - self.rules["linear_q_proj"](attention.linear_q_proj, layer_id) - else: - layer_pbar.set_description("Importing MLA (with q LoRA)") - self.rules["linear_q_down_proj"](attention.linear_q_down_proj, layer_id) - self.rules["linear_q_layernorm"](attention.q_layernorm, layer_id) - self.rules["linear_q_up_proj"](attention.linear_q_up_proj, layer_id) - self.rules["linear_kv_down_proj"](attention.linear_kv_down_proj, layer_id) - self.rules["linear_kv_layernorm"](attention.kv_layernorm, layer_id) - self.rules["linear_kv_up_proj"](attention.linear_kv_up_proj, layer_id) - self.rules["linear_proj"](attention.linear_proj, layer_id) - else: - layer_pbar.set_description("Importing GQA/MHA") - if attention.q_layernorm is not None and not isinstance( - attention.q_layernorm, (IdentityOp, L2Norm) - ): - self.rules["q_layernorm"](attention.q_layernorm, layer_id) - self.rules["k_layernorm"](attention.k_layernorm, layer_id) - self.rules["linear_qkv"](attention.linear_qkv, layer_id) - self.rules["linear_proj"](attention.linear_proj, layer_id) - if getattr(attention.core_attention, "softmax_offset", None) is not None: - self.rules["softmax_offset"]( - attention.core_attention.softmax_offset, layer_id - ) - - if not isinstance(layer.pre_mlp_layernorm, IdentityOp): - self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) - - if not isinstance(layer.mlp, IdentityOp): - if "MoE" in str(type(layer.mlp)): - layer_pbar.set_description("Importing MoE") - self.rules["router"]( - layer.mlp.router, layer_id, dtype=self.moe_router_dtype - ) - if ( - hasattr(layer.mlp, "shared_experts") - and layer.mlp.shared_experts is not None - ): - layer_pbar.set_description("Importing MoE shared experts") - fc1 = layer.mlp.shared_experts.linear_fc1 - fc2 = layer.mlp.shared_experts.linear_fc2 - self.rules["shared_experts.linear_fc1"](fc1, layer_id) - self.rules["shared_experts.linear_fc2"](fc2, layer_id) - if not self.rules.get("use_packed_local_experts", False): - for local_expert_id, expert in tqdm( - enumerate(layer.mlp.experts.local_experts), - desc="Importing MoE local experts", - leave=False, - disable=self.disable_tqdm, - ): - expert_id = layer.mlp.local_expert_indices[local_expert_id] - fc1 = expert.linear_fc1 - fc2 = expert.linear_fc2 - self.rules["local_experts.linear_fc1"](fc1, layer_id, expert_id) - self.rules["local_experts.linear_fc2"](fc2, layer_id, expert_id) - # We only support either EP or ETP for now - elif get_expert_tensor_parallel_world_size() > 1: - # ETP supports for packed MoE - # ETP is not supported for gpt-oss model - if self.arch in ["GptOssForCausalLM"]: - raise ValueError("ETP is not supported for gpt-oss model") - self.rules["local_experts.linear_fc1_etp"]( - layer.mlp.experts.local_experts, layer_id - ) - self.rules["local_experts.linear_fc2_etp"]( - layer.mlp.experts.local_experts, layer_id - ) - else: - # EP supports for packed MoE - self.rules["local_experts.linear_fc1_ep"]( - layer.mlp.experts.local_experts, layer_id - ) - self.rules["local_experts.linear_fc2_ep"]( - layer.mlp.experts.local_experts, layer_id - ) - else: - layer_pbar.set_description("Importing MLP") - self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) - self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) + self._import_transformer_layer(layer, layer_id, layer_pbar) if self.verbose: print( @@ -591,71 +611,92 @@ def _import_state_dict(self): ), flush=True, ) + break # TODO: remove this # Final layernorm if hasattr(model.decoder, "final_layernorm") and model.decoder.final_layernorm: self.rules["final_layernorm"](model.decoder.final_layernorm) - if hasattr(model.decoder, "final_norm") and model.decoder.final_norm: self.rules["final_norm"](model.decoder.final_norm) # Output layer if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights: self.rules["output_layer"](model.output_layer) + # MTP if hasattr(model, "mtp"): + print("Importing MTP", flush=True) # MTP is the last layer in DeepSeek V3/R1 - layer_id += 1 - for mtp in model.mtp: - self.rules["mtp.fc"](mtp.fc, layer_id) + if len(model.mtp.layers) == 1: # Repeated MTP + layer_id = 0 # reset layer_id for repeated MTP + mtp = model.mtp.layers[0] + + self.rules["mtp.eh_proj"](mtp.eh_proj, layer_id) self.rules["mtp.enorm"](mtp.enorm, layer_id) self.rules["mtp.hnorm"](mtp.hnorm, layer_id) - self.rules["mtp.input_layernorm"](mtp.decoder.layers[0].input_layernorm, layer_id) - if hasattr(mtp.decoder.layers[0].self_attention, "linear_q_proj"): - self.rules["mtp.linear_q_proj"]( - mtp.decoder.layers[0].self_attention.linear_q_proj, layer_id + + mtp_model_layers = mtp.mtp_model_layer.layers + for mtp_model_layer in mtp_model_layers: + if isinstance(mtp_model_layer, MambaLayer): + self._import_mamba_layer(mtp_model_layer, layer_id, layer_pbar) + elif isinstance(mtp_model_layer, TransformerLayer): + self._import_transformer_layer(mtp_model_layer, layer_id, layer_pbar, is_mtp=True) + else: + raise ValueError(f"Unsupported layer type during MTP import: {type(mtp_model_layer)}") + + layer_id += 1 + else: # non-repeated MTP + + for mtp in model.mtp.layers: + self.rules["mtp.eh_proj"](mtp.eh_proj, layer_id) + self.rules["mtp.enorm"](mtp.enorm, layer_id) + self.rules["mtp.hnorm"](mtp.hnorm, layer_id) + self.rules["mtp.input_layernorm"](mtp.decoder.layers[0].input_layernorm, layer_id) + if hasattr(mtp.decoder.layers[0].self_attention, "linear_q_proj"): + self.rules["mtp.linear_q_proj"]( + mtp.decoder.layers[0].self_attention.linear_q_proj, layer_id + ) + else: + self.rules["mtp.linear_q_down_proj"]( + mtp.decoder.layers[0].self_attention.linear_q_down_proj, layer_id + ) + self.rules["mtp.linear_q_layernorm"]( + mtp.decoder.layers[0].self_attention.q_layernorm, layer_id + ) + self.rules["mtp.linear_q_up_proj"]( + mtp.decoder.layers[0].self_attention.linear_q_up_proj, layer_id + ) + self.rules["mtp.linear_kv_down_proj"]( + mtp.decoder.layers[0].self_attention.linear_kv_down_proj, layer_id ) - else: - self.rules["mtp.linear_q_down_proj"]( - mtp.decoder.layers[0].self_attention.linear_q_down_proj, layer_id + self.rules["mtp.linear_kv_layernorm"]( + mtp.decoder.layers[0].self_attention.kv_layernorm, layer_id ) - self.rules["mtp.linear_q_layernorm"]( - mtp.decoder.layers[0].self_attention.q_layernorm, layer_id + self.rules["mtp.linear_kv_up_proj"]( + mtp.decoder.layers[0].self_attention.linear_kv_up_proj, layer_id ) - self.rules["mtp.linear_q_up_proj"]( - mtp.decoder.layers[0].self_attention.linear_q_up_proj, layer_id + self.rules["mtp.linear_proj"]( + mtp.decoder.layers[0].self_attention.linear_proj, layer_id ) - self.rules["mtp.linear_kv_down_proj"]( - mtp.decoder.layers[0].self_attention.linear_kv_down_proj, layer_id - ) - self.rules["mtp.linear_kv_layernorm"]( - mtp.decoder.layers[0].self_attention.kv_layernorm, layer_id - ) - self.rules["mtp.linear_kv_up_proj"]( - mtp.decoder.layers[0].self_attention.linear_kv_up_proj, layer_id - ) - self.rules["mtp.linear_proj"]( - mtp.decoder.layers[0].self_attention.linear_proj, layer_id - ) - self.rules["mtp.pre_mlp_layernorm"]( - mtp.decoder.layers[0].pre_mlp_layernorm, layer_id - ) - self.rules["mtp.router"](mtp.decoder.layers[0].mlp.router, layer_id) - self.rules["mtp.shared_experts.linear_fc1"]( - mtp.decoder.layers[0].mlp.shared_experts.linear_fc1, layer_id - ) - self.rules["mtp.shared_experts.linear_fc2"]( - mtp.decoder.layers[0].mlp.shared_experts.linear_fc2, layer_id - ) - for expert_id, expert in tqdm( - enumerate(mtp.decoder.layers[0].mlp.experts.local_experts), - desc="Importing MoE local experts", - leave=False, - disable=self.disable_tqdm, - ): - self.rules["mtp.local_experts.linear_fc1"]( - expert.linear_fc1, layer_id, expert_id + self.rules["mtp.pre_mlp_layernorm"]( + mtp.decoder.layers[0].pre_mlp_layernorm, layer_id ) - self.rules["mtp.local_experts.linear_fc2"]( - expert.linear_fc2, layer_id, expert_id + self.rules["mtp.router"](mtp.decoder.layers[0].mlp.router, layer_id) + self.rules["mtp.shared_experts.linear_fc1"]( + mtp.decoder.layers[0].mlp.shared_experts.linear_fc1, layer_id ) + self.rules["mtp.shared_experts.linear_fc2"]( + mtp.decoder.layers[0].mlp.shared_experts.linear_fc2, layer_id + ) + for expert_id, expert in tqdm( + enumerate(mtp.decoder.layers[0].mlp.experts.local_experts), + desc="Importing MoE local experts", + leave=False, + disable=self.disable_tqdm, + ): + self.rules["mtp.local_experts.linear_fc1"]( + expert.linear_fc1, layer_id, expert_id + ) + self.rules["mtp.local_experts.linear_fc2"]( + expert.linear_fc2, layer_id, expert_id + ) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index dbf78db0f..8f43f052c 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -380,12 +380,29 @@ def get_kv_cache_dtype(modules: list[nn.Module] | nn.Module) -> str | None: for module in modules: # Case where the module has both k_bmm_quantizer and v_bmm_quantizer - for quantizer in ("k_bmm_quantizer", "v_bmm_quantizer"): + # Still check for output quantizer for the unified_megatron_export path + for quantizer in ("k_bmm_quantizer", "v_bmm_quantizer", "output_quantizer"): quantizer_attr = getattr(module, quantizer, None) if quantizer_attr and quantizer_attr.is_enabled: num_bits_list.append(quantizer_attr.num_bits) is_affine &= hasattr(quantizer_attr, "_bias_value") + return _compute_kv_cache_dtype(num_bits_list) + +def _compute_kv_cache_dtype(num_bits_list: list[int]) -> str | None: + """Returns the kv_cache dtype. + + If num_bits of output_quantizer is (4, 3) then returns FP8; if it is 8, returns int8, + otherwise returns None. + + Args: + modules: The module or list of modules to inspect. + + Returns: + The kv_cache dtype. + """ + is_affine = True + if (4, 3) in num_bits_list: return KV_CACHE_FP8 elif 8 in num_bits_list: @@ -909,18 +926,6 @@ def postprocess_state_dict( value = value.float() / maxbound - # Warn if scale exceeds threshold - if quantization == KV_CACHE_FP8 and value.item() > 0.5: - logger.warning( - "Large KV activations detected. Quantized KV cache may lead to higher accuracy drop. " - "Setting KV cache scaling factor to at least 1." - ) - - # Ensure scale is at least 1 for KV_CACHE_FP8 - # We export real value for KV_CACHE_NVFP4 - if quantization == KV_CACHE_FP8: - value.clamp_(min=1.0) - post_state_dict[prefix + new_suffix] = value break diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index feddd7488..44cd3e955 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -51,8 +51,8 @@ from .plugins.megatron_importer import GPTModelImporter from .quant_utils import ( get_activation_scaling_factor, - get_kv_cache_dtype, get_kv_cache_scaling_factor, + get_kv_cache_dtype, get_quant_config, get_quantization_format, get_scaling_factor, @@ -256,12 +256,6 @@ def save_pretrained( elif quantization_format == QUANTIZATION_NVFP4: quantization = "NVFP4" - kv_cache_quantization = None - kv_cache_dtype = get_kv_cache_dtype(self.model) - print("kv_cache_dtype: ", kv_cache_dtype) - if kv_cache_dtype in (KV_CACHE_FP8, KV_CACHE_NVFP4): - # FP8 KV Cache is supported in VLLM; NVFP4 supported in TRTLLM - kv_cache_quantization = kv_cache_dtype # We use the last PP rank and the 1st EP rank to write the config because # medusa_heads and eagle_module only exist in the last stage. if is_last_stage_main_rank: @@ -305,12 +299,13 @@ def save_pretrained( }, "quantization": { "quant_algo": quantization, - "kv_cache_quant_algo": kv_cache_quantization, "exclude_modules": ["lm_head"], # TODO update this dynamically }, } if quantization == "NVFP4": # update block size hf_quant_config["quantization"]["group_size"] = 16 + if hasattr(self, "kv_cache_dtype"): + hf_quant_config["quantization"]["kv_cache_quant_algo"] = self.kv_cache_dtype with open(save_directory + "/hf_quant_config.json", "w") as f: json.dump(hf_quant_config, f, indent=4) @@ -731,7 +726,7 @@ def _qkv_slicing( quantized_weight = to_quantized_weight( weight, scale, - qformat, + qformat, weight_scale_2, block_size, ) @@ -772,6 +767,10 @@ def _self_attention_scaling(self, module, prefix, k_scale_name="k_scale", v_scal self._state_dict[k_scale_key] = kv_scales[0] self._state_dict[v_scale_key] = kv_scales[1] + kv_cache_dtype = get_kv_cache_dtype(module) + if kv_cache_dtype in (KV_CACHE_FP8, KV_CACHE_NVFP4): + # FP8 KV Cache is supported in VLLM; NVFP4 supported in TRTLLM + self.kv_cache_dtype = kv_cache_dtype def _pack_name_remapping(self, module, prefix, layer_type=None): """Pack name remapping into one tensor.""" From 3d0a31a97b4e117fea8a52ee04dc980b5256e7fd Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Thu, 22 Jan 2026 13:29:08 -0800 Subject: [PATCH 07/11] enable TELinear quant Signed-off-by: jenchen13 --- modelopt/torch/export/plugins/mcore_nemotron.py | 6 ++++++ modelopt/torch/quantization/plugins/megatron.py | 7 +++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 07f4656bf..d2cb2858a 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -131,4 +131,10 @@ # Latent MoE "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj."), "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj."), + # MTP + "mtp.enorm": NameRemapping("mtp.layers.{}.enorm."), + "mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm."), + "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj."), + "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm."), + } diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 4414ec5e2..521b73ea1 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -50,6 +50,7 @@ try: from megatron.core.extensions.transformer_engine import ( + TELinear, TEColumnParallelGroupedLinear, TEColumnParallelLinear, TEDotProductAttention, @@ -603,8 +604,6 @@ def sync_moe_local_experts_amax(self): for name, module in expert.named_modules(): if isinstance(module, TensorQuantizer) and module.amax is not None: module.amax = amax_dict[name].detach().clone().to(module.amax.device) - #if isinstance(module, TensorQuantizer) and module.amax is None: - # print(f"MISSING AMAX AFTER SYNC in expert rank {dist.get_rank()}: {name}", flush=True) def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): """Override the default to enable singleton_local_shards. @@ -624,6 +623,10 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): if HAS_TE: + @QuantModuleRegistry.register({TELinear: "te_mcore_Linear"}) + class _QuantTEMCoreLinear(_QuantTELinear): + pass + @QuantModuleRegistry.register({TERowParallelLinear: "te_mcore_RowParallelLinear"}) class _QuantTEMCoreRowParallelLinear(_QuantTELinear, _MegatronRowParallelLinear): pass From 334b0b9417792f74591dd97118fb90e4d026fa90 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Fri, 23 Jan 2026 15:09:05 -0800 Subject: [PATCH 08/11] import grouped mlp in mtp Signed-off-by: jenchen13 --- modelopt/torch/export/plugins/mcore_custom.py | 9 ++ .../torch/export/plugins/mcore_nemotron.py | 15 +- .../torch/export/plugins/megatron_importer.py | 149 ++++++++++++------ .../torch/export/unified_export_megatron.py | 2 +- modelopt/torch/quantization/model_calib.py | 3 +- 5 files changed, 125 insertions(+), 53 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_custom.py b/modelopt/torch/export/plugins/mcore_custom.py index 25a2cd0cb..77a3208bb 100644 --- a/modelopt/torch/export/plugins/mcore_custom.py +++ b/modelopt/torch/export/plugins/mcore_custom.py @@ -102,7 +102,16 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] func_kwargs=func_kwargs, ) +class GroupedMLPMerging(CustomModuleMapping): + """A custom module mapping that merges up_proj and down_proj for Grouped MLP.""" + def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}): + """Create a custom module mapping that merges up_proj and down_proj for Grouped MLP.""" + super().__init__( + func_name="grouped_mlp_merging", + target_name_or_prefix=target_name_or_prefix, + func_kwargs=func_kwargs, + ) class GatedMLPMerging(CustomModuleMapping): """A custom module mapping that merges gate_proj and up_proj.""" diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index d2cb2858a..385418a37 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -26,6 +26,7 @@ NameRemapping, QKVMerging, QKVSlicing, + GroupedMLPMerging, SelfAttentionScaling, ) @@ -85,12 +86,14 @@ # Latent MoE "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj.", REPLICATE), "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj.", REPLICATE), - # MTP - "mtp.enorm": NameRemapping("mtp.layers.{}.enorm.", REPLICATE), - "mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm.", REPLICATE), - "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", REPLICATE), - "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", REPLICATE), - + # Repeated MTP module + "mtp.enorm": NameRemapping("mtp.layers.{}.enorm.", {"is_mtp": True}), + "mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm.", {"is_mtp": True}), + "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", {"is_mtp": True}), + "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", {"is_mtp": True}), + # Grouped local experts in MTP + "experts.linear_fc1": GroupedMLPMerging("mtp.layers.{}.experts.{{}}.up_proj", COL_ETP | {"is_mtp": True}), + "experts.linear_fc2": GroupedMLPMerging("mtp.layers.{}.experts.{{}}.down_proj", ROW_ETP | {"is_mtp": True}), } diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index a58612b82..924001db9 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -19,7 +19,7 @@ from pathlib import Path import torch -import torch.distributed +import torch.distributed as dist from huggingface_hub import snapshot_download from tqdm import tqdm @@ -40,6 +40,7 @@ with import_plugin("megatron"): from megatron.core.parallel_state import ( get_expert_tensor_parallel_world_size, + get_expert_tensor_parallel_rank, get_tensor_model_parallel_world_size, ) from megatron.core.ssm.mamba_layer import MambaLayer @@ -94,12 +95,12 @@ def __init__( if workspace_dir is None: workspace_dir = tempfile.gettempdir() pretrained_model_path = workspace_dir + "/" + pretrained_model_name_or_path - if torch.distributed.get_rank() == 0: + if dist.get_rank() == 0: snapshot_download( repo_id=pretrained_model_name_or_path, local_dir=pretrained_model_path, ) - torch.distributed.barrier() + dist.barrier() self.arch = self._hf_config.architectures[0] self.all_rules = self._populate_rule_book() self.rules = self.all_rules[self.arch] @@ -108,7 +109,7 @@ def __init__( self.dtype = dtype self.dequantize = dequantize self.verbose = verbose - self.disable_tqdm = torch.distributed.get_rank() > 0 or verbose + self.disable_tqdm = dist.get_rank() > 0 or verbose def _populate_rule_book(self): """The rule book maps each state_dict key to a Callable.""" @@ -119,6 +120,7 @@ def _custom_mapping_to_lambda(mapping): "name_remapping": self._name_remapping, "qkv_merging": self._qkv_merging, "gated_mlp_merging": self._gated_mlp_merging, + "grouped_mlp_merging": self._grouped_mlp_merging, "unpack_name_remapping": self._unpack_name_remapping, "unpack_name_remapping_gpt_oss": self._unpack_name_remapping_gpt_oss, } @@ -157,7 +159,6 @@ def _name_remapping( prefix = prefix.replace("backbone", "mtp") else: prefix = prefix.replace("model", "mtp") - print(f"name_remapping: {prefix}, mapping: {mapping}") if dtype is None: dtype = self.dtype if isinstance(module, torch.Tensor): @@ -261,6 +262,37 @@ def _gated_mlp_merging( module.load_state_dict(state_dict) + def _grouped_mlp_merging( + self, + module, + prefix, + parallel_config: ParallelConfig | None = None, + is_mtp: bool = False, + init_expert_id: int = 0, + num_local_experts: int = 1, + ): + if is_mtp: + if "backbone" in prefix: + prefix = prefix.replace("backbone", "mtp") + else: + prefix = prefix.replace("model", "mtp") + + state_dict = module.state_dict() + weight = state_dict.get("weight", None) + print(f"mcore weight.shape: {weight.shape}") + weight_scale = state_dict.get("weight_quantizer._scale", None) + + all_experts = [] + for expert_id in range(init_expert_id, init_expert_id + num_local_experts): + tensor = self._get_safetensor(prefix.format(expert_id) + ".weight") + print(f"HF weight.shape: {tensor.shape}") + all_experts.append(tensor) + all_experts = torch.cat(all_experts, dim=0) + print(f"all_experts.shape: {all_experts.shape}") + state_dict["weight"] = all_experts + + module.load_state_dict(state_dict) + def _qkv_merging( self, module, @@ -302,8 +334,9 @@ def _qkv_merging( state_dict = {} - weight = module.state_dict().get("weight", None) - weight_scale = module.state_dict().get("weight_quantizer._scale", None) + module_state_dict = module.state_dict() + weight = module_state_dict.get("weight", None) + weight_scale = module_state_dict.get("weight_quantizer._scale", None) if weight is None: raise ValueError(f"{module!s} does not contain weight!") @@ -357,7 +390,7 @@ def _qkv_merging( state_dict["weight"] = tensor.reshape(-1, hidden_size) # Handle bias merging - bias = module.state_dict().get("bias", None) + bias = module_state_dict.get("bias", None) if bias is not None: q_bias = self._get_safetensor( prefix + q_proj_name + ".bias", parallel_config=parallel_config @@ -384,6 +417,11 @@ def _qkv_merging( state_dict["bias"] = bias_tensor.reshape(-1) + layer_norm_weight = module_state_dict.get("layer_norm_weight", None) + if layer_norm_weight is not None: + state_dict["layer_norm_weight"] = layer_norm_weight + state_dict["_extra_state"] = None # for TE modules require _extra_state key + module.load_state_dict(state_dict) def _unpack_name_remapping( @@ -495,47 +533,47 @@ def _import_mamba_layer(self, layer, layer_id, layer_pbar): self.rules["in_proj"](layer.mixer.in_proj, layer_id) self.rules["out_proj"](layer.mixer.out_proj, layer_id) - def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp=False): + def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = False): if not isinstance(layer.input_layernorm, IdentityOp): - self.rules["input_layernorm"](layer.input_layernorm, layer_id, is_mtp=is_mtp) + self.rules["input_layernorm"](layer.input_layernorm, layer_id) attention = layer.self_attention if not isinstance(attention, IdentityOp): if "MLASelfAttention" in str(type(attention)): if hasattr(attention, "linear_q_proj"): layer_pbar.set_description("Importing MLA (without q LoRA)") - self.rules["linear_q_proj"](attention.linear_q_proj, layer_id, is_mtp=is_mtp) + self.rules["linear_q_proj"](attention.linear_q_proj, layer_id) else: layer_pbar.set_description("Importing MLA (with q LoRA)") - self.rules["linear_q_down_proj"](attention.linear_q_down_proj, layer_id, is_mtp=is_mtp) - self.rules["linear_q_layernorm"](attention.q_layernorm, layer_id, is_mtp=is_mtp) - self.rules["linear_q_up_proj"](attention.linear_q_up_proj, layer_id, is_mtp=is_mtp) - self.rules["linear_kv_down_proj"](attention.linear_kv_down_proj, layer_id, is_mtp=is_mtp) - self.rules["linear_kv_layernorm"](attention.kv_layernorm, layer_id, is_mtp=is_mtp) - self.rules["linear_kv_up_proj"](attention.linear_kv_up_proj, layer_id, is_mtp=is_mtp) - self.rules["linear_proj"](attention.linear_proj, layer_id, is_mtp=is_mtp) + self.rules["linear_q_down_proj"](attention.linear_q_down_proj, layer_id) + self.rules["linear_q_layernorm"](attention.q_layernorm, layer_id) + self.rules["linear_q_up_proj"](attention.linear_q_up_proj, layer_id) + self.rules["linear_kv_down_proj"](attention.linear_kv_down_proj, layer_id) + self.rules["linear_kv_layernorm"](attention.kv_layernorm, layer_id) + self.rules["linear_kv_up_proj"](attention.linear_kv_up_proj, layer_id) + self.rules["linear_proj"](attention.linear_proj, layer_id) else: layer_pbar.set_description("Importing GQA/MHA") if attention.q_layernorm is not None and not isinstance( attention.q_layernorm, (IdentityOp, L2Norm) ): - self.rules["q_layernorm"](attention.q_layernorm, layer_id, is_mtp=is_mtp) - self.rules["k_layernorm"](attention.k_layernorm, layer_id, is_mtp=is_mtp) + self.rules["q_layernorm"](attention.q_layernorm, layer_id) + self.rules["k_layernorm"](attention.k_layernorm, layer_id) self.rules["linear_qkv"](attention.linear_qkv, layer_id, is_mtp=is_mtp) self.rules["linear_proj"](attention.linear_proj, layer_id, is_mtp=is_mtp) if getattr(attention.core_attention, "softmax_offset", None) is not None: self.rules["softmax_offset"]( - attention.core_attention.softmax_offset, layer_id, is_mtp=is_mtp + attention.core_attention.softmax_offset, layer_id ) if not isinstance(layer.pre_mlp_layernorm, IdentityOp): - self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id, is_mtp=is_mtp) + self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) if not isinstance(layer.mlp, IdentityOp): if "MoE" in str(type(layer.mlp)): layer_pbar.set_description("Importing MoE") self.rules["router"]( - layer.mlp.router, layer_id, dtype=self.moe_router_dtype, is_mtp=is_mtp + layer.mlp.router, layer_id, dtype=self.moe_router_dtype ) if ( hasattr(layer.mlp, "shared_experts") @@ -544,20 +582,41 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp=False): layer_pbar.set_description("Importing MoE shared experts") fc1 = layer.mlp.shared_experts.linear_fc1 fc2 = layer.mlp.shared_experts.linear_fc2 - self.rules["shared_experts.linear_fc1"](fc1, layer_id, is_mtp=is_mtp) - self.rules["shared_experts.linear_fc2"](fc2, layer_id, is_mtp=is_mtp) - if not self.rules.get("use_packed_local_experts", False): - for local_expert_id, expert in tqdm( - enumerate(layer.mlp.experts.local_experts), - desc="Importing MoE local experts", - leave=False, - disable=self.disable_tqdm, - ): - expert_id = layer.mlp.local_expert_indices[local_expert_id] - fc1 = expert.linear_fc1 - fc2 = expert.linear_fc2 - self.rules["local_experts.linear_fc1"](fc1, layer_id, expert_id, is_mtp=is_mtp) - self.rules["local_experts.linear_fc2"](fc2, layer_id, expert_id, is_mtp=is_mtp) + self.rules["shared_experts.linear_fc1"](fc1, layer_id) + self.rules["shared_experts.linear_fc2"](fc2, layer_id) + if not self.rules.get("use_packed_local_experts", False): # Import local experts + experts = layer.mlp.experts + if hasattr(experts, "local_experts"): + for local_expert_id, expert in tqdm( + enumerate(layer.mlp.experts.local_experts), + desc="Importing MoE local experts", + leave=False, + disable=self.disable_tqdm, + ): + expert_id = layer.mlp.local_expert_indices[local_expert_id] + fc1 = expert.linear_fc1 + fc2 = expert.linear_fc2 + self.rules["local_experts.linear_fc1"](fc1, layer_id, expert_id) + self.rules["local_experts.linear_fc2"](fc2, layer_id, expert_id) + else: # Slice TEGroupedMLP + layer_pbar.set_description("Importing MoE grouped local experts") + num_local_experts = experts.num_local_experts + num_global_experts = experts.config.num_moe_experts + print(f"num_local_experts: {num_local_experts}") + print(f"num_global_experts: {num_global_experts}") + + if parallel_config is not None: + etp_size = get_expert_tensor_parallel_world_size() + # etp_rank = get_expert_tensor_parallel_rank() # this gives group rank + etp_rank = dist.get_rank() + print(f"etp_size: {etp_size}") + print(f"etp_rank: {etp_rank}") + assert num_local_experts * etp_size == num_global_experts + init_index = etp_rank * num_local_experts + + self.rules["experts.linear_fc1"](experts.linear_fc1, layer_id, init_expert_id=init_index, num_local_experts=num_local_experts) + self.rules["experts.linear_fc2"](experts.linear_fc2, layer_id, init_expert_id=init_index, num_local_experts=num_local_experts ) + # We only support either EP or ETP for now elif get_expert_tensor_parallel_world_size() > 1: # ETP supports for packed MoE @@ -565,28 +624,28 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp=False): if self.arch in ["GptOssForCausalLM"]: raise ValueError("ETP is not supported for gpt-oss model") self.rules["local_experts.linear_fc1_etp"]( - layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + layer.mlp.experts.local_experts, layer_id ) self.rules["local_experts.linear_fc2_etp"]( - layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + layer.mlp.experts.local_experts, layer_id ) else: # EP supports for packed MoE self.rules["local_experts.linear_fc1_ep"]( - layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + layer.mlp.experts.local_experts, layer_id ) self.rules["local_experts.linear_fc2_ep"]( - layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + layer.mlp.experts.local_experts, layer_id ) else: layer_pbar.set_description("Importing MLP") - self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id, is_mtp=is_mtp) - self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id, is_mtp=is_mtp) + self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) + self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) def _import_state_dict(self): model = self.model - print(model, flush=True) + # print(model, flush=True) layer_pbar = tqdm(model.decoder.layers, disable=self.disable_tqdm) # Embedding @@ -607,7 +666,7 @@ def _import_state_dict(self): if self.verbose: print( "{:3}/{:3} completes importing layer {:3}.".format( - torch.distributed.get_rank(), torch.distributed.get_world_size(), layer_id + dist.get_rank(), dist.get_world_size(), layer_id ), flush=True, ) diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 44cd3e955..6af06cc8a 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -726,7 +726,7 @@ def _qkv_slicing( quantized_weight = to_quantized_weight( weight, scale, - qformat, + qformat, weight_scale_2, block_size, ) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 595034a8a..015a9a2de 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -187,7 +187,6 @@ def sync_quantizer_amax_across_tp( parallel_state=module.parallel_state, ) - # KV Cache Quantization if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): # We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache) @@ -280,6 +279,8 @@ def quant_func(x, amax, quantizer=module): # Step 4: Compute optimal amax and load it finish_stats_collection(model, method="mse") + # TODO: Sync amax across distributed processes + def enable_stats_collection(model: nn.Module): """Enable stats collection for all quantizers in the model.""" From d4bfbfe23f9afa7125ef9697d3791cbaf9b2e635 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Mon, 26 Jan 2026 13:17:54 -0800 Subject: [PATCH 09/11] fix grouped mlp import Signed-off-by: jenchen13 --- .../torch/export/plugins/mcore_nemotron.py | 4 ++-- .../torch/export/plugins/megatron_importer.py | 24 ++++++++----------- .../torch/export/unified_export_megatron.py | 6 ++--- 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 385418a37..12408480b 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -92,8 +92,8 @@ "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", {"is_mtp": True}), "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", {"is_mtp": True}), # Grouped local experts in MTP - "experts.linear_fc1": GroupedMLPMerging("mtp.layers.{}.experts.{{}}.up_proj", COL_ETP | {"is_mtp": True}), - "experts.linear_fc2": GroupedMLPMerging("mtp.layers.{}.experts.{{}}.down_proj", ROW_ETP | {"is_mtp": True}), + "experts.linear_fc1": GroupedMLPMerging("mtp.layers.{}.mixer.experts.{{}}.up_proj", COL_ETP | {"is_mtp": True}), + "experts.linear_fc2": GroupedMLPMerging("mtp.layers.{}.mixer.experts.{{}}.down_proj", ROW_ETP | {"is_mtp": True}), } diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index 924001db9..62e611305 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -277,19 +277,14 @@ def _grouped_mlp_merging( else: prefix = prefix.replace("model", "mtp") - state_dict = module.state_dict() - weight = state_dict.get("weight", None) - print(f"mcore weight.shape: {weight.shape}") - weight_scale = state_dict.get("weight_quantizer._scale", None) + state_dict = module.state_dict() + # TODO handle weight_scale + #weight_scale = state_dict.get("weight_quantizer._scale", None) - all_experts = [] + assert module.num_gemms == num_local_experts, "num_gemms must be equal to num_local_experts in TEGroupedMLP" for expert_id in range(init_expert_id, init_expert_id + num_local_experts): tensor = self._get_safetensor(prefix.format(expert_id) + ".weight") - print(f"HF weight.shape: {tensor.shape}") - all_experts.append(tensor) - all_experts = torch.cat(all_experts, dim=0) - print(f"all_experts.shape: {all_experts.shape}") - state_dict["weight"] = all_experts + state_dict[f"weight{expert_id}"] = tensor module.load_state_dict(state_dict) @@ -602,9 +597,9 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = layer_pbar.set_description("Importing MoE grouped local experts") num_local_experts = experts.num_local_experts num_global_experts = experts.config.num_moe_experts - print(f"num_local_experts: {num_local_experts}") - print(f"num_global_experts: {num_global_experts}") + assert num_local_experts == num_global_experts, "num_local_experts must be equal to num_global_experts during MoE import" + ''' if parallel_config is not None: etp_size = get_expert_tensor_parallel_world_size() # etp_rank = get_expert_tensor_parallel_rank() # this gives group rank @@ -613,9 +608,11 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = print(f"etp_rank: {etp_rank}") assert num_local_experts * etp_size == num_global_experts init_index = etp_rank * num_local_experts + ''' + init_index = 0 self.rules["experts.linear_fc1"](experts.linear_fc1, layer_id, init_expert_id=init_index, num_local_experts=num_local_experts) - self.rules["experts.linear_fc2"](experts.linear_fc2, layer_id, init_expert_id=init_index, num_local_experts=num_local_experts ) + self.rules["experts.linear_fc2"](experts.linear_fc2, layer_id, init_expert_id=init_index, num_local_experts=num_local_experts) # We only support either EP or ETP for now elif get_expert_tensor_parallel_world_size() > 1: @@ -670,7 +667,6 @@ def _import_state_dict(self): ), flush=True, ) - break # TODO: remove this # Final layernorm if hasattr(model.decoder, "final_layernorm") and model.decoder.final_layernorm: diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 6af06cc8a..dd39f6cdf 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -487,16 +487,16 @@ def _get_quantized_state( qformat: str = self._get_quantization_format(module) block_size = get_weight_block_size(module) - if hasattr(module, "weight") and module.weight is not None: + if hasattr(module, "weight") and module.weight is not None and module.weight.numel() > 0: weight = module.weight.to(dtype).cpu() name_to_value["weight"] = weight else: return name_to_value, qformat, block_size - if hasattr(module, "bias") and module.bias is not None: + if hasattr(module, "bias") and module.bias is not None and module.bias.numel() > 0: name_to_value["bias"] = module.bias.to(dtype).cpu() - if hasattr(module, "expert_bias") and module.expert_bias is not None: + if hasattr(module, "expert_bias") and module.expert_bias is not None and module.expert_bias.numel() > 0: name_to_value["expert_bias"] = module.expert_bias.to(dtype).cpu() if qformat == QUANTIZATION_NONE: From fd1b9f4d083086785b61ee3985cfda2e681ef3b9 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Mon, 26 Jan 2026 19:16:43 -0800 Subject: [PATCH 10/11] fix config.json Signed-off-by: jenchen13 --- .../torch/export/plugins/megatron_importer.py | 6 ++++ .../torch/export/unified_export_megatron.py | 33 ++++++++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index 62e611305..07e8dd533 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -567,9 +567,15 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = if not isinstance(layer.mlp, IdentityOp): if "MoE" in str(type(layer.mlp)): layer_pbar.set_description("Importing MoE") + print(f"moe_router_dtype: {self.moe_router_dtype}") self.rules["router"]( layer.mlp.router, layer_id, dtype=self.moe_router_dtype ) + if hasattr(layer.mlp, "fc1_latent_proj") and layer.mlp.fc1_latent_proj is not None: + self.rules["fc1_latent_proj"](layer.mlp.fc1_latent_proj, layer_id) + if hasattr(layer.mlp, "fc2_latent_proj") and layer.mlp.fc2_latent_proj is not None: + self.rules["fc2_latent_proj"](layer.mlp.fc2_latent_proj, layer_id) + if ( hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index dd39f6cdf..6ac5ae85a 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -47,7 +47,7 @@ QUANTIZATION_NVFP4, ) from .plugins.mcore_common import all_mcore_hf_export_mapping -from .plugins.mcore_custom import CustomModuleMapping, save_safetensors +from .plugins.mcore_custom import CustomModuleMapping, save_safetensors, get_safetensor from .plugins.megatron_importer import GPTModelImporter from .quant_utils import ( get_activation_scaling_factor, @@ -129,6 +129,7 @@ def __init__( self.moe_router_dtype = torch.float32 elif moe_router_dtype == "fp64": self.moe_router_dtype = torch.float64 + print(f"moe_router_dtype: {self.moe_router_dtype}") # If multimodal, extra the text_config self._hf_text_config = getattr(self._hf_config, "text_config", self._hf_config) @@ -309,6 +310,15 @@ def save_pretrained( with open(save_directory + "/hf_quant_config.json", "w") as f: json.dump(hf_quant_config, f, indent=4) + # Newer versions of VLLM expect config.json with hf_quant_config + config_file = save_directory + "/config.json" + if os.path.exists(config_file): + with open(config_file, "r") as f: + config = json.load(f) + config["quantization"] = hf_quant_config["quantization"] + with open(config_file, "w") as f: + json.dump(config, f, indent=4) + if ( is_first_stage_main_rank and self.is_multimodal @@ -1185,6 +1195,27 @@ def _get_state_dict(self): self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) else: raise ValueError("Only TransformerLayer or MambaLayer are supported.") + + # MTP module + # Hacky version for now: copy MTP weights from pretrained model + if os.path.isdir(self._hf_pretrained_model_name): + safetensors_index_file = Path(self._hf_pretrained_model_name) / "model.safetensors.index.json" + else: + safetensors_index_file = hf_hub_download( + repo_id=self._hf_pretrained_model_name, + filename="model.safetensors.index.json") + + print(f"safetensors_index_file: {safetensors_index_file}") + if safetensors_index_file and os.path.exists(safetensors_index_file): + with open(safetensors_index_file, "r") as f: + safetensors_index = json.load(f) + model_dir = Path(safetensors_index_file).parent + for key in safetensors_index["weight_map"]: + if "mtp" in key: + self._state_dict[key] = get_safetensor(model_dir, key) + + # TODO implement actual MTP export + def export_mcore_gpt_to_hf( From 348055dc4492a1130308437c601314759c45b605 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Tue, 27 Jan 2026 07:43:39 -0800 Subject: [PATCH 11/11] fix import router dtype bug Signed-off-by: jenchen13 --- modelopt/torch/export/plugins/megatron_importer.py | 5 ++--- modelopt/torch/export/unified_export_megatron.py | 3 +-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index 07e8dd533..d2a35bde0 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -191,7 +191,7 @@ def _name_remapping( tensor = expanded_tensor state_dict["weight"] = tensor.view(dtype=weight.dtype).to(device=weight.device) else: - state_dict["weight"] = tensor.to(dtype=self.dtype).to(device=weight.device) + state_dict["weight"] = tensor.to(dtype=dtype).to(device=weight.device) # Handle the rest of the state_dict. for key, val in module.state_dict().items(): @@ -566,8 +566,7 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = if not isinstance(layer.mlp, IdentityOp): if "MoE" in str(type(layer.mlp)): - layer_pbar.set_description("Importing MoE") - print(f"moe_router_dtype: {self.moe_router_dtype}") + layer_pbar.set_description(f"Importing MoE with moe_router_dtype: {self.moe_router_dtype}") self.rules["router"]( layer.mlp.router, layer_id, dtype=self.moe_router_dtype ) diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 6ac5ae85a..0c0cdb20e 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -129,7 +129,6 @@ def __init__( self.moe_router_dtype = torch.float32 elif moe_router_dtype == "fp64": self.moe_router_dtype = torch.float64 - print(f"moe_router_dtype: {self.moe_router_dtype}") # If multimodal, extra the text_config self._hf_text_config = getattr(self._hf_config, "text_config", self._hf_config) @@ -1224,7 +1223,7 @@ def export_mcore_gpt_to_hf( export_extra_modules: bool = False, dtype: torch.dtype = torch.bfloat16, export_dir: Path | str = tempfile.gettempdir(), - moe_router_dtype: torch.dtype | None = None, + moe_router_dtype: str | None = None, ): """Export Megatron Core GPTModel to unified checkpoint and save to export_dir.