diff --git a/modelopt/torch/export/plugins/mcore_custom.py b/modelopt/torch/export/plugins/mcore_custom.py index 23804b322..77a3208bb 100644 --- a/modelopt/torch/export/plugins/mcore_custom.py +++ b/modelopt/torch/export/plugins/mcore_custom.py @@ -102,7 +102,16 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] func_kwargs=func_kwargs, ) +class GroupedMLPMerging(CustomModuleMapping): + """A custom module mapping that merges up_proj and down_proj for Grouped MLP.""" + def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}): + """Create a custom module mapping that merges up_proj and down_proj for Grouped MLP.""" + super().__init__( + func_name="grouped_mlp_merging", + target_name_or_prefix=target_name_or_prefix, + func_kwargs=func_kwargs, + ) class GatedMLPMerging(CustomModuleMapping): """A custom module mapping that merges gate_proj and up_proj.""" @@ -126,6 +135,15 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] func_kwargs=func_kwargs, ) +class SelfAttentionScaling(CustomModuleMapping): + """A custom module mapping that scales self attention.""" + def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}): + """Create a custom module mapping that scales self attention.""" + super().__init__( + func_name="self_attention_scaling", + target_name_or_prefix=target_name_or_prefix, + func_kwargs=func_kwargs, + ) class GatedMLPSlicing(CustomModuleMapping): """A custom module mapping that slices gate_proj and up_proj.""" diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 5fdb8ba1b..12408480b 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -26,6 +26,8 @@ NameRemapping, QKVMerging, QKVSlicing, + GroupedMLPMerging, + SelfAttentionScaling, ) # Example on adding a new CausalLM. @@ -81,8 +83,21 @@ "shared_experts.linear_fc2": NameRemapping( "backbone.layers.{}.mixer.shared_experts.down_proj.", ROW_TP ), + # Latent MoE + "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj.", REPLICATE), + "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj.", REPLICATE), + # Repeated MTP module + "mtp.enorm": NameRemapping("mtp.layers.{}.enorm.", {"is_mtp": True}), + "mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm.", {"is_mtp": True}), + "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", {"is_mtp": True}), + "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", {"is_mtp": True}), + # Grouped local experts in MTP + "experts.linear_fc1": GroupedMLPMerging("mtp.layers.{}.mixer.experts.{{}}.up_proj", COL_ETP | {"is_mtp": True}), + "experts.linear_fc2": GroupedMLPMerging("mtp.layers.{}.mixer.experts.{{}}.down_proj", ROW_ETP | {"is_mtp": True}), + } +# TODO ADD MTP export nemotron_h_causal_lm_export: dict[str, CustomModuleMapping] = { "word_embeddings": NameRemapping("backbone.embeddings."), @@ -101,6 +116,7 @@ "input_layernorm": NameRemapping("backbone.layers.{}.norm."), "linear_qkv": QKVSlicing("backbone.layers.{}.mixer."), "linear_proj": NameRemapping("backbone.layers.{}.mixer.o_proj."), + "core_attention": SelfAttentionScaling("backbone.layers.{}.mixer."), # MLP "pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm."), "linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj."), @@ -115,4 +131,13 @@ "shared_experts.linear_fc2": NameRemapping( "backbone.layers.{}.mixer.shared_experts.down_proj." ), + # Latent MoE + "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj."), + "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj."), + # MTP + "mtp.enorm": NameRemapping("mtp.layers.{}.enorm."), + "mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm."), + "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj."), + "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm."), + } diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index 0af79eb36..d2a35bde0 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -19,7 +19,7 @@ from pathlib import Path import torch -import torch.distributed +import torch.distributed as dist from huggingface_hub import snapshot_download from tqdm import tqdm @@ -40,6 +40,7 @@ with import_plugin("megatron"): from megatron.core.parallel_state import ( get_expert_tensor_parallel_world_size, + get_expert_tensor_parallel_rank, get_tensor_model_parallel_world_size, ) from megatron.core.ssm.mamba_layer import MambaLayer @@ -94,12 +95,12 @@ def __init__( if workspace_dir is None: workspace_dir = tempfile.gettempdir() pretrained_model_path = workspace_dir + "/" + pretrained_model_name_or_path - if torch.distributed.get_rank() == 0: + if dist.get_rank() == 0: snapshot_download( repo_id=pretrained_model_name_or_path, local_dir=pretrained_model_path, ) - torch.distributed.barrier() + dist.barrier() self.arch = self._hf_config.architectures[0] self.all_rules = self._populate_rule_book() self.rules = self.all_rules[self.arch] @@ -108,7 +109,7 @@ def __init__( self.dtype = dtype self.dequantize = dequantize self.verbose = verbose - self.disable_tqdm = torch.distributed.get_rank() > 0 or verbose + self.disable_tqdm = dist.get_rank() > 0 or verbose def _populate_rule_book(self): """The rule book maps each state_dict key to a Callable.""" @@ -119,6 +120,7 @@ def _custom_mapping_to_lambda(mapping): "name_remapping": self._name_remapping, "qkv_merging": self._qkv_merging, "gated_mlp_merging": self._gated_mlp_merging, + "grouped_mlp_merging": self._grouped_mlp_merging, "unpack_name_remapping": self._unpack_name_remapping, "unpack_name_remapping_gpt_oss": self._unpack_name_remapping_gpt_oss, } @@ -150,7 +152,13 @@ def _name_remapping( mapping={}, parallel_config: ParallelConfig | None = None, dtype: torch.dtype | None = None, + is_mtp: bool = False, ): + if is_mtp: + if "backbone" in prefix: + prefix = prefix.replace("backbone", "mtp") + else: + prefix = prefix.replace("model", "mtp") if dtype is None: dtype = self.dtype if isinstance(module, torch.Tensor): @@ -183,7 +191,7 @@ def _name_remapping( tensor = expanded_tensor state_dict["weight"] = tensor.view(dtype=weight.dtype).to(device=weight.device) else: - state_dict["weight"] = tensor.to(dtype=self.dtype).to(device=weight.device) + state_dict["weight"] = tensor.to(dtype=dtype).to(device=weight.device) # Handle the rest of the state_dict. for key, val in module.state_dict().items(): @@ -254,6 +262,32 @@ def _gated_mlp_merging( module.load_state_dict(state_dict) + def _grouped_mlp_merging( + self, + module, + prefix, + parallel_config: ParallelConfig | None = None, + is_mtp: bool = False, + init_expert_id: int = 0, + num_local_experts: int = 1, + ): + if is_mtp: + if "backbone" in prefix: + prefix = prefix.replace("backbone", "mtp") + else: + prefix = prefix.replace("model", "mtp") + + state_dict = module.state_dict() + # TODO handle weight_scale + #weight_scale = state_dict.get("weight_quantizer._scale", None) + + assert module.num_gemms == num_local_experts, "num_gemms must be equal to num_local_experts in TEGroupedMLP" + for expert_id in range(init_expert_id, init_expert_id + num_local_experts): + tensor = self._get_safetensor(prefix.format(expert_id) + ".weight") + state_dict[f"weight{expert_id}"] = tensor + + module.load_state_dict(state_dict) + def _qkv_merging( self, module, @@ -262,7 +296,13 @@ def _qkv_merging( k_proj_name="k_proj", v_proj_name="v_proj", parallel_config: ParallelConfig | None = None, + is_mtp: bool = False, ): + if is_mtp: + if "backbone" in prefix: + prefix = prefix.replace("backbone", "mtp") + else: + prefix = prefix.replace("model", "mtp") config = module.config hidden_size = config.hidden_size num_query_groups = config.num_query_groups @@ -289,8 +329,9 @@ def _qkv_merging( state_dict = {} - weight = module.state_dict().get("weight", None) - weight_scale = module.state_dict().get("weight_quantizer._scale", None) + module_state_dict = module.state_dict() + weight = module_state_dict.get("weight", None) + weight_scale = module_state_dict.get("weight_quantizer._scale", None) if weight is None: raise ValueError(f"{module!s} does not contain weight!") @@ -344,7 +385,7 @@ def _qkv_merging( state_dict["weight"] = tensor.reshape(-1, hidden_size) # Handle bias merging - bias = module.state_dict().get("bias", None) + bias = module_state_dict.get("bias", None) if bias is not None: q_bias = self._get_safetensor( prefix + q_proj_name + ".bias", parallel_config=parallel_config @@ -371,6 +412,11 @@ def _qkv_merging( state_dict["bias"] = bias_tensor.reshape(-1) + layer_norm_weight = module_state_dict.get("layer_norm_weight", None) + if layer_norm_weight is not None: + state_dict["layer_norm_weight"] = layer_norm_weight + state_dict["_extra_state"] = None # for TE modules require _extra_state key + module.load_state_dict(state_dict) def _unpack_name_remapping( @@ -469,9 +515,139 @@ def _unpack_name_remapping_gpt_oss( linear_module.load_state_dict(state_dict) + def _import_mamba_layer(self, layer, layer_id, layer_pbar): + layer_pbar.set_description("Importing Mamba layer") + if not isinstance(layer.norm, IdentityOp): + self.rules["norm"](layer.norm, layer_id) + + self.rules["mixer_norm"](layer.mixer.norm, layer_id) + self.rules["A_log"](layer.mixer.A_log, layer_id) + self.rules["D"](layer.mixer.D, layer_id) + self.rules["dt_bias"](layer.mixer.dt_bias, layer_id) + self.rules["conv1d"](layer.mixer.conv1d, layer_id) + self.rules["in_proj"](layer.mixer.in_proj, layer_id) + self.rules["out_proj"](layer.mixer.out_proj, layer_id) + + def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = False): + if not isinstance(layer.input_layernorm, IdentityOp): + self.rules["input_layernorm"](layer.input_layernorm, layer_id) + + attention = layer.self_attention + if not isinstance(attention, IdentityOp): + if "MLASelfAttention" in str(type(attention)): + if hasattr(attention, "linear_q_proj"): + layer_pbar.set_description("Importing MLA (without q LoRA)") + self.rules["linear_q_proj"](attention.linear_q_proj, layer_id) + else: + layer_pbar.set_description("Importing MLA (with q LoRA)") + self.rules["linear_q_down_proj"](attention.linear_q_down_proj, layer_id) + self.rules["linear_q_layernorm"](attention.q_layernorm, layer_id) + self.rules["linear_q_up_proj"](attention.linear_q_up_proj, layer_id) + self.rules["linear_kv_down_proj"](attention.linear_kv_down_proj, layer_id) + self.rules["linear_kv_layernorm"](attention.kv_layernorm, layer_id) + self.rules["linear_kv_up_proj"](attention.linear_kv_up_proj, layer_id) + self.rules["linear_proj"](attention.linear_proj, layer_id) + else: + layer_pbar.set_description("Importing GQA/MHA") + if attention.q_layernorm is not None and not isinstance( + attention.q_layernorm, (IdentityOp, L2Norm) + ): + self.rules["q_layernorm"](attention.q_layernorm, layer_id) + self.rules["k_layernorm"](attention.k_layernorm, layer_id) + self.rules["linear_qkv"](attention.linear_qkv, layer_id, is_mtp=is_mtp) + self.rules["linear_proj"](attention.linear_proj, layer_id, is_mtp=is_mtp) + if getattr(attention.core_attention, "softmax_offset", None) is not None: + self.rules["softmax_offset"]( + attention.core_attention.softmax_offset, layer_id + ) + + if not isinstance(layer.pre_mlp_layernorm, IdentityOp): + self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) + + if not isinstance(layer.mlp, IdentityOp): + if "MoE" in str(type(layer.mlp)): + layer_pbar.set_description(f"Importing MoE with moe_router_dtype: {self.moe_router_dtype}") + self.rules["router"]( + layer.mlp.router, layer_id, dtype=self.moe_router_dtype + ) + if hasattr(layer.mlp, "fc1_latent_proj") and layer.mlp.fc1_latent_proj is not None: + self.rules["fc1_latent_proj"](layer.mlp.fc1_latent_proj, layer_id) + if hasattr(layer.mlp, "fc2_latent_proj") and layer.mlp.fc2_latent_proj is not None: + self.rules["fc2_latent_proj"](layer.mlp.fc2_latent_proj, layer_id) + + if ( + hasattr(layer.mlp, "shared_experts") + and layer.mlp.shared_experts is not None + ): + layer_pbar.set_description("Importing MoE shared experts") + fc1 = layer.mlp.shared_experts.linear_fc1 + fc2 = layer.mlp.shared_experts.linear_fc2 + self.rules["shared_experts.linear_fc1"](fc1, layer_id) + self.rules["shared_experts.linear_fc2"](fc2, layer_id) + if not self.rules.get("use_packed_local_experts", False): # Import local experts + experts = layer.mlp.experts + if hasattr(experts, "local_experts"): + for local_expert_id, expert in tqdm( + enumerate(layer.mlp.experts.local_experts), + desc="Importing MoE local experts", + leave=False, + disable=self.disable_tqdm, + ): + expert_id = layer.mlp.local_expert_indices[local_expert_id] + fc1 = expert.linear_fc1 + fc2 = expert.linear_fc2 + self.rules["local_experts.linear_fc1"](fc1, layer_id, expert_id) + self.rules["local_experts.linear_fc2"](fc2, layer_id, expert_id) + else: # Slice TEGroupedMLP + layer_pbar.set_description("Importing MoE grouped local experts") + num_local_experts = experts.num_local_experts + num_global_experts = experts.config.num_moe_experts + assert num_local_experts == num_global_experts, "num_local_experts must be equal to num_global_experts during MoE import" + + ''' + if parallel_config is not None: + etp_size = get_expert_tensor_parallel_world_size() + # etp_rank = get_expert_tensor_parallel_rank() # this gives group rank + etp_rank = dist.get_rank() + print(f"etp_size: {etp_size}") + print(f"etp_rank: {etp_rank}") + assert num_local_experts * etp_size == num_global_experts + init_index = etp_rank * num_local_experts + ''' + init_index = 0 + + self.rules["experts.linear_fc1"](experts.linear_fc1, layer_id, init_expert_id=init_index, num_local_experts=num_local_experts) + self.rules["experts.linear_fc2"](experts.linear_fc2, layer_id, init_expert_id=init_index, num_local_experts=num_local_experts) + + # We only support either EP or ETP for now + elif get_expert_tensor_parallel_world_size() > 1: + # ETP supports for packed MoE + # ETP is not supported for gpt-oss model + if self.arch in ["GptOssForCausalLM"]: + raise ValueError("ETP is not supported for gpt-oss model") + self.rules["local_experts.linear_fc1_etp"]( + layer.mlp.experts.local_experts, layer_id + ) + self.rules["local_experts.linear_fc2_etp"]( + layer.mlp.experts.local_experts, layer_id + ) + else: + # EP supports for packed MoE + self.rules["local_experts.linear_fc1_ep"]( + layer.mlp.experts.local_experts, layer_id + ) + self.rules["local_experts.linear_fc2_ep"]( + layer.mlp.experts.local_experts, layer_id + ) + else: + layer_pbar.set_description("Importing MLP") + self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) + self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) + + def _import_state_dict(self): model = self.model - + # print(model, flush=True) layer_pbar = tqdm(model.decoder.layers, disable=self.disable_tqdm) # Embedding @@ -481,113 +657,18 @@ def _import_state_dict(self): # Decoder layers for layer in layer_pbar: + print(f"Importing layer {layer.layer_number}", flush=True) layer_id = layer.layer_number - 1 if isinstance(layer, MambaLayer): - if not isinstance(layer.norm, IdentityOp): - self.rules["norm"](layer.norm, layer_id) - - self.rules["mixer_norm"](layer.mixer.norm, layer_id) - self.rules["A_log"](layer.mixer.A_log, layer_id) - self.rules["D"](layer.mixer.D, layer_id) - self.rules["dt_bias"](layer.mixer.dt_bias, layer_id) - - self.rules["conv1d"](layer.mixer.conv1d, layer_id) - self.rules["in_proj"](layer.mixer.in_proj, layer_id) - self.rules["out_proj"](layer.mixer.out_proj, layer_id) - + self._import_mamba_layer(layer, layer_id, layer_pbar) elif isinstance(layer, TransformerLayer): - if not isinstance(layer.input_layernorm, IdentityOp): - self.rules["input_layernorm"](layer.input_layernorm, layer_id) - - attention = layer.self_attention - if not isinstance(attention, IdentityOp): - if "MLASelfAttention" in str(type(attention)): - if hasattr(attention, "linear_q_proj"): - layer_pbar.set_description("Importing MLA (without q LoRA)") - self.rules["linear_q_proj"](attention.linear_q_proj, layer_id) - else: - layer_pbar.set_description("Importing MLA (with q LoRA)") - self.rules["linear_q_down_proj"](attention.linear_q_down_proj, layer_id) - self.rules["linear_q_layernorm"](attention.q_layernorm, layer_id) - self.rules["linear_q_up_proj"](attention.linear_q_up_proj, layer_id) - self.rules["linear_kv_down_proj"](attention.linear_kv_down_proj, layer_id) - self.rules["linear_kv_layernorm"](attention.kv_layernorm, layer_id) - self.rules["linear_kv_up_proj"](attention.linear_kv_up_proj, layer_id) - self.rules["linear_proj"](attention.linear_proj, layer_id) - else: - layer_pbar.set_description("Importing GQA/MHA") - if attention.q_layernorm is not None and not isinstance( - attention.q_layernorm, (IdentityOp, L2Norm) - ): - self.rules["q_layernorm"](attention.q_layernorm, layer_id) - self.rules["k_layernorm"](attention.k_layernorm, layer_id) - self.rules["linear_qkv"](attention.linear_qkv, layer_id) - self.rules["linear_proj"](attention.linear_proj, layer_id) - if getattr(attention.core_attention, "softmax_offset", None) is not None: - self.rules["softmax_offset"]( - attention.core_attention.softmax_offset, layer_id - ) - - if not isinstance(layer.pre_mlp_layernorm, IdentityOp): - self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) - - if not isinstance(layer.mlp, IdentityOp): - if "MoE" in str(type(layer.mlp)): - layer_pbar.set_description("Importing MoE") - self.rules["router"]( - layer.mlp.router, layer_id, dtype=self.moe_router_dtype - ) - if ( - hasattr(layer.mlp, "shared_experts") - and layer.mlp.shared_experts is not None - ): - layer_pbar.set_description("Importing MoE shared experts") - fc1 = layer.mlp.shared_experts.linear_fc1 - fc2 = layer.mlp.shared_experts.linear_fc2 - self.rules["shared_experts.linear_fc1"](fc1, layer_id) - self.rules["shared_experts.linear_fc2"](fc2, layer_id) - if not self.rules.get("use_packed_local_experts", False): - for local_expert_id, expert in tqdm( - enumerate(layer.mlp.experts.local_experts), - desc="Importing MoE local experts", - leave=False, - disable=self.disable_tqdm, - ): - expert_id = layer.mlp.local_expert_indices[local_expert_id] - fc1 = expert.linear_fc1 - fc2 = expert.linear_fc2 - self.rules["local_experts.linear_fc1"](fc1, layer_id, expert_id) - self.rules["local_experts.linear_fc2"](fc2, layer_id, expert_id) - # We only support either EP or ETP for now - elif get_expert_tensor_parallel_world_size() > 1: - # ETP supports for packed MoE - # ETP is not supported for gpt-oss model - if self.arch in ["GptOssForCausalLM"]: - raise ValueError("ETP is not supported for gpt-oss model") - self.rules["local_experts.linear_fc1_etp"]( - layer.mlp.experts.local_experts, layer_id - ) - self.rules["local_experts.linear_fc2_etp"]( - layer.mlp.experts.local_experts, layer_id - ) - else: - # EP supports for packed MoE - self.rules["local_experts.linear_fc1_ep"]( - layer.mlp.experts.local_experts, layer_id - ) - self.rules["local_experts.linear_fc2_ep"]( - layer.mlp.experts.local_experts, layer_id - ) - else: - layer_pbar.set_description("Importing MLP") - self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) - self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) + self._import_transformer_layer(layer, layer_id, layer_pbar) if self.verbose: print( "{:3}/{:3} completes importing layer {:3}.".format( - torch.distributed.get_rank(), torch.distributed.get_world_size(), layer_id + dist.get_rank(), dist.get_world_size(), layer_id ), flush=True, ) @@ -595,67 +676,87 @@ def _import_state_dict(self): # Final layernorm if hasattr(model.decoder, "final_layernorm") and model.decoder.final_layernorm: self.rules["final_layernorm"](model.decoder.final_layernorm) - if hasattr(model.decoder, "final_norm") and model.decoder.final_norm: self.rules["final_norm"](model.decoder.final_norm) # Output layer if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights: self.rules["output_layer"](model.output_layer) + # MTP if hasattr(model, "mtp"): + print("Importing MTP", flush=True) # MTP is the last layer in DeepSeek V3/R1 - layer_id += 1 - for mtp in model.mtp: - self.rules["mtp.fc"](mtp.fc, layer_id) + if len(model.mtp.layers) == 1: # Repeated MTP + layer_id = 0 # reset layer_id for repeated MTP + mtp = model.mtp.layers[0] + + self.rules["mtp.eh_proj"](mtp.eh_proj, layer_id) self.rules["mtp.enorm"](mtp.enorm, layer_id) self.rules["mtp.hnorm"](mtp.hnorm, layer_id) - self.rules["mtp.input_layernorm"](mtp.decoder.layers[0].input_layernorm, layer_id) - if hasattr(mtp.decoder.layers[0].self_attention, "linear_q_proj"): - self.rules["mtp.linear_q_proj"]( - mtp.decoder.layers[0].self_attention.linear_q_proj, layer_id + + mtp_model_layers = mtp.mtp_model_layer.layers + for mtp_model_layer in mtp_model_layers: + if isinstance(mtp_model_layer, MambaLayer): + self._import_mamba_layer(mtp_model_layer, layer_id, layer_pbar) + elif isinstance(mtp_model_layer, TransformerLayer): + self._import_transformer_layer(mtp_model_layer, layer_id, layer_pbar, is_mtp=True) + else: + raise ValueError(f"Unsupported layer type during MTP import: {type(mtp_model_layer)}") + + layer_id += 1 + else: # non-repeated MTP + + for mtp in model.mtp.layers: + self.rules["mtp.eh_proj"](mtp.eh_proj, layer_id) + self.rules["mtp.enorm"](mtp.enorm, layer_id) + self.rules["mtp.hnorm"](mtp.hnorm, layer_id) + self.rules["mtp.input_layernorm"](mtp.decoder.layers[0].input_layernorm, layer_id) + if hasattr(mtp.decoder.layers[0].self_attention, "linear_q_proj"): + self.rules["mtp.linear_q_proj"]( + mtp.decoder.layers[0].self_attention.linear_q_proj, layer_id + ) + else: + self.rules["mtp.linear_q_down_proj"]( + mtp.decoder.layers[0].self_attention.linear_q_down_proj, layer_id + ) + self.rules["mtp.linear_q_layernorm"]( + mtp.decoder.layers[0].self_attention.q_layernorm, layer_id + ) + self.rules["mtp.linear_q_up_proj"]( + mtp.decoder.layers[0].self_attention.linear_q_up_proj, layer_id + ) + self.rules["mtp.linear_kv_down_proj"]( + mtp.decoder.layers[0].self_attention.linear_kv_down_proj, layer_id ) - else: - self.rules["mtp.linear_q_down_proj"]( - mtp.decoder.layers[0].self_attention.linear_q_down_proj, layer_id + self.rules["mtp.linear_kv_layernorm"]( + mtp.decoder.layers[0].self_attention.kv_layernorm, layer_id ) - self.rules["mtp.linear_q_layernorm"]( - mtp.decoder.layers[0].self_attention.q_layernorm, layer_id + self.rules["mtp.linear_kv_up_proj"]( + mtp.decoder.layers[0].self_attention.linear_kv_up_proj, layer_id ) - self.rules["mtp.linear_q_up_proj"]( - mtp.decoder.layers[0].self_attention.linear_q_up_proj, layer_id + self.rules["mtp.linear_proj"]( + mtp.decoder.layers[0].self_attention.linear_proj, layer_id ) - self.rules["mtp.linear_kv_down_proj"]( - mtp.decoder.layers[0].self_attention.linear_kv_down_proj, layer_id - ) - self.rules["mtp.linear_kv_layernorm"]( - mtp.decoder.layers[0].self_attention.kv_layernorm, layer_id - ) - self.rules["mtp.linear_kv_up_proj"]( - mtp.decoder.layers[0].self_attention.linear_kv_up_proj, layer_id - ) - self.rules["mtp.linear_proj"]( - mtp.decoder.layers[0].self_attention.linear_proj, layer_id - ) - self.rules["mtp.pre_mlp_layernorm"]( - mtp.decoder.layers[0].pre_mlp_layernorm, layer_id - ) - self.rules["mtp.router"](mtp.decoder.layers[0].mlp.router, layer_id) - self.rules["mtp.shared_experts.linear_fc1"]( - mtp.decoder.layers[0].mlp.shared_experts.linear_fc1, layer_id - ) - self.rules["mtp.shared_experts.linear_fc2"]( - mtp.decoder.layers[0].mlp.shared_experts.linear_fc2, layer_id - ) - for expert_id, expert in tqdm( - enumerate(mtp.decoder.layers[0].mlp.experts.local_experts), - desc="Importing MoE local experts", - leave=False, - disable=self.disable_tqdm, - ): - self.rules["mtp.local_experts.linear_fc1"]( - expert.linear_fc1, layer_id, expert_id + self.rules["mtp.pre_mlp_layernorm"]( + mtp.decoder.layers[0].pre_mlp_layernorm, layer_id + ) + self.rules["mtp.router"](mtp.decoder.layers[0].mlp.router, layer_id) + self.rules["mtp.shared_experts.linear_fc1"]( + mtp.decoder.layers[0].mlp.shared_experts.linear_fc1, layer_id ) - self.rules["mtp.local_experts.linear_fc2"]( - expert.linear_fc2, layer_id, expert_id + self.rules["mtp.shared_experts.linear_fc2"]( + mtp.decoder.layers[0].mlp.shared_experts.linear_fc2, layer_id ) + for expert_id, expert in tqdm( + enumerate(mtp.decoder.layers[0].mlp.experts.local_experts), + desc="Importing MoE local experts", + leave=False, + disable=self.disable_tqdm, + ): + self.rules["mtp.local_experts.linear_fc1"]( + expert.linear_fc1, layer_id, expert_id + ) + self.rules["mtp.local_experts.linear_fc2"]( + expert.linear_fc2, layer_id, expert_id + ) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index eee13dc51..8f43f052c 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -332,7 +332,7 @@ def get_prequant_scaling_factor(module: nn.Module) -> torch.Tensor: if prequant_scaling_factor is not None: assert torch.all(prequant_scaling_factor > 0), ( f"prequant scaling factor {prequant_scaling_factor} not positive." - ) + ) return prequant_scaling_factor @@ -344,32 +344,22 @@ def get_kv_cache_bias(kv_module: nn.Module) -> list[torch.Tensor]: kv_bias.append(getattr(quantizer_module, "_bias_value", None)) return kv_bias - -def get_kv_cache_scaling_factor(kv_module: nn.Module) -> list[torch.Tensor]: - """Returns the kv_cache scaling factor if output quantizer is set. Else returns None by default.""" - if not hasattr(kv_module, "k_bmm_quantizer") or not hasattr(kv_module, "v_bmm_quantizer"): +def get_kv_cache_scaling_factor(self_attention_module: nn.Module) -> torch.Tensor: + """ + Returns the k and v BMM scaling factors if BMM quantizers are set in the self attention module. + Else returns None by default. + """ + if not hasattr(self_attention_module, "k_bmm_quantizer") or not hasattr(self_attention_module, "v_bmm_quantizer"): return [None, None] scaling_factors = [ - get_scaling_factor(getattr(kv_module, quantizer)) + get_scaling_factor(getattr(self_attention_module, quantizer)) for quantizer in ("k_bmm_quantizer", "v_bmm_quantizer") ] - - # For FP8, we recommend default kv cache scaling factor to be 1. - if get_kv_cache_dtype(kv_module) == KV_CACHE_FP8: - for i, factor in enumerate(scaling_factors): - if factor.item() > 0.5: - warn( - f"Warning: Large KV activation detected: {factor.item()}, " - "Quantized KV cache may lead to higher accuracy drop." - ) - scaling_factors[i] = torch.max( - factor, torch.tensor([1.0], dtype=torch.float, device=factor.device) - ) - return scaling_factors + def get_kv_cache_dtype(modules: list[nn.Module] | nn.Module) -> str | None: """Returns the kv_cache dtype. @@ -397,6 +387,22 @@ def get_kv_cache_dtype(modules: list[nn.Module] | nn.Module) -> str | None: num_bits_list.append(quantizer_attr.num_bits) is_affine &= hasattr(quantizer_attr, "_bias_value") + return _compute_kv_cache_dtype(num_bits_list) + +def _compute_kv_cache_dtype(num_bits_list: list[int]) -> str | None: + """Returns the kv_cache dtype. + + If num_bits of output_quantizer is (4, 3) then returns FP8; if it is 8, returns int8, + otherwise returns None. + + Args: + modules: The module or list of modules to inspect. + + Returns: + The kv_cache dtype. + """ + is_affine = True + if (4, 3) in num_bits_list: return KV_CACHE_FP8 elif 8 in num_bits_list: @@ -920,18 +926,6 @@ def postprocess_state_dict( value = value.float() / maxbound - # Warn if scale exceeds threshold - if quantization == KV_CACHE_FP8 and value.item() > 0.5: - logger.warning( - "Large KV activations detected. Quantized KV cache may lead to higher accuracy drop. " - "Setting KV cache scaling factor to at least 1." - ) - - # Ensure scale is at least 1 for KV_CACHE_FP8 - # We export real value for KV_CACHE_NVFP4 - if quantization == KV_CACHE_FP8: - value.clamp_(min=1.0) - post_state_dict[prefix + new_suffix] = value break diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index f1bd67327..0c0cdb20e 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -47,11 +47,13 @@ QUANTIZATION_NVFP4, ) from .plugins.mcore_common import all_mcore_hf_export_mapping -from .plugins.mcore_custom import CustomModuleMapping, save_safetensors +from .plugins.mcore_custom import CustomModuleMapping, save_safetensors, get_safetensor from .plugins.megatron_importer import GPTModelImporter from .quant_utils import ( get_activation_scaling_factor, + get_kv_cache_scaling_factor, get_kv_cache_dtype, + get_quant_config, get_quantization_format, get_scaling_factor, get_weight_block_size, @@ -86,33 +88,6 @@ ] -# This path uses output_quantizer for KV cache quantization. -# The function below is the old version of get_kv_cache_scaling_factor which is now refactored to handle bmm_quantizer. -def get_kv_cache_scaling_factor(kv_module: nn.Module) -> torch.Tensor: - """Returns the kv_cache scaling factor if output quantizer is set. Else returns None by default.""" - scaling_factor = ( - get_scaling_factor(kv_module.output_quantizer) - if hasattr(kv_module, "output_quantizer") - else None - ) - - if not scaling_factor: - return None - - # For FP8, we recommend default kv cache scaling factor to be 1. - if get_kv_cache_dtype(kv_module) == KV_CACHE_FP8: - if scaling_factor.item() > 0.5: - warn( - f"!!!!Large KV activations detected: {scaling_factor.item()}, " - "Quantized KV cache may lead to higher accuracy drop.\n!!!!" - ) - scaling_factor = torch.max( - scaling_factor, - torch.tensor([1.0], dtype=torch.float, device=scaling_factor.device), - ) - return scaling_factor - - class GPTModelExporter: """Megatron Core GPTModel Exporter. @@ -281,11 +256,6 @@ def save_pretrained( elif quantization_format == QUANTIZATION_NVFP4: quantization = "NVFP4" - kv_cache_quantization = None - kv_cache_dtype = get_kv_cache_dtype(self.model) - if kv_cache_dtype in (KV_CACHE_FP8, KV_CACHE_NVFP4): - # FP8 KV Cache is supported in VLLM; NVFP4 supported in TRTLLM - kv_cache_quantization = kv_cache_dtype # We use the last PP rank and the 1st EP rank to write the config because # medusa_heads and eagle_module only exist in the last stage. if is_last_stage_main_rank: @@ -320,6 +290,8 @@ def save_pretrained( pass if is_last_stage_main_rank and quantization is not None: + # TODO refactor to use mte.quant_utils.get_quant_config + # except layer names are different in MCore and HF hf_quant_config = { "producer": { "name": "modelopt", @@ -327,13 +299,25 @@ def save_pretrained( }, "quantization": { "quant_algo": quantization, - "kv_cache_quant_algo": kv_cache_quantization, - "exclude_modules": ["lm_head"], + "exclude_modules": ["lm_head"], # TODO update this dynamically }, } + if quantization == "NVFP4": # update block size + hf_quant_config["quantization"]["group_size"] = 16 + if hasattr(self, "kv_cache_dtype"): + hf_quant_config["quantization"]["kv_cache_quant_algo"] = self.kv_cache_dtype with open(save_directory + "/hf_quant_config.json", "w") as f: json.dump(hf_quant_config, f, indent=4) + # Newer versions of VLLM expect config.json with hf_quant_config + config_file = save_directory + "/config.json" + if os.path.exists(config_file): + with open(config_file, "r") as f: + config = json.load(f) + config["quantization"] = hf_quant_config["quantization"] + with open(config_file, "w") as f: + json.dump(config, f, indent=4) + if ( is_first_stage_main_rank and self.is_multimodal @@ -473,6 +457,7 @@ def _custom_mapping_to_lambda(mapping): method_map = { "name_remapping": self._name_remapping, "qkv_slicing": self._qkv_slicing, + "self_attention_scaling": self._self_attention_scaling, "gated_mlp_slicing": self._gated_mlp_slicing, "pack_name_remapping": self._pack_name_remapping, "pack_name_remapping_gpt_oss": self._pack_name_remapping_gpt_oss, @@ -511,16 +496,16 @@ def _get_quantized_state( qformat: str = self._get_quantization_format(module) block_size = get_weight_block_size(module) - if hasattr(module, "weight") and module.weight is not None: + if hasattr(module, "weight") and module.weight is not None and module.weight.numel() > 0: weight = module.weight.to(dtype).cpu() name_to_value["weight"] = weight else: return name_to_value, qformat, block_size - if hasattr(module, "bias") and module.bias is not None: + if hasattr(module, "bias") and module.bias is not None and module.bias.numel() > 0: name_to_value["bias"] = module.bias.to(dtype).cpu() - if hasattr(module, "expert_bias") and module.expert_bias is not None: + if hasattr(module, "expert_bias") and module.expert_bias is not None and module.expert_bias.numel() > 0: name_to_value["expert_bias"] = module.expert_bias.to(dtype).cpu() if qformat == QUANTIZATION_NONE: @@ -541,12 +526,8 @@ def _get_quantized_state( # TODO (chenhany): support AWQ with pre_quant_scale if hasattr(module.input_quantizer, "_pre_quant_scale"): raise ValueError("Detect pre_quant_scale! SmoothQuant/AWQ are not yet supported!") - - if hasattr(module, "output_quantizer"): - output_scale = get_kv_cache_scaling_factor(module) - if output_scale is not None: - name_to_value["output_scale"] = output_scale - + + return name_to_value, qformat, block_size def _get_quantization_format(self, module: torch.nn.Module): @@ -674,9 +655,7 @@ def _qkv_slicing( q_proj_name="q_proj", k_proj_name="k_proj", v_proj_name="v_proj", - k_scale_name="k_scale", - v_scale_name="v_scale", - ): + ): name_to_value, qformat, block_size = self._get_quantized_state(module, self.dtype) q_proj_prefix = prefix + q_proj_name + "." @@ -774,10 +753,7 @@ def _qkv_slicing( q_proj_key = q_proj_prefix + key k_proj_key = k_proj_prefix + key v_proj_key = v_proj_prefix + key - if key == "output_scale": - self._state_dict[prefix + k_scale_name] = val.detach().clone() - self._state_dict[prefix + v_scale_name] = val.detach().clone() - elif key == "bias": + if key == "bias": # Slice bias similar to weight bias = val.detach().clone() bias = bias.reshape([qkv_total_dim, head_size]) @@ -790,6 +766,21 @@ def _qkv_slicing( self._state_dict[k_proj_key] = val.detach().clone() self._state_dict[v_proj_key] = val.detach().clone() + def _self_attention_scaling(self, module, prefix, k_scale_name="k_scale", v_scale_name="v_scale"): + """KV cache scaling for CoreAttention module.""" + k_scale_key = prefix + k_scale_name + v_scale_key = prefix + v_scale_name + if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): + kv_scales = get_kv_cache_scaling_factor(module) + if all(s is not None for s in kv_scales): + self._state_dict[k_scale_key] = kv_scales[0] + self._state_dict[v_scale_key] = kv_scales[1] + + kv_cache_dtype = get_kv_cache_dtype(module) + if kv_cache_dtype in (KV_CACHE_FP8, KV_CACHE_NVFP4): + # FP8 KV Cache is supported in VLLM; NVFP4 supported in TRTLLM + self.kv_cache_dtype = kv_cache_dtype + def _pack_name_remapping(self, module, prefix, layer_type=None): """Pack name remapping into one tensor.""" weight_list = [] @@ -1149,6 +1140,8 @@ def _get_state_dict(self): self.rules["q_layernorm"](layer.self_attention.q_layernorm, layer_id) self.rules["k_layernorm"](layer.self_attention.k_layernorm, layer_id) self.rules["linear_qkv"](layer.self_attention.linear_qkv, layer_id) + if hasattr(layer.self_attention, "core_attention"): + self.rules["core_attention"](layer.self_attention.core_attention, layer_id) self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) if ( getattr(layer.self_attention.core_attention, "softmax_offset", None) @@ -1166,6 +1159,10 @@ def _get_state_dict(self): self.rules["router"]( layer.mlp.router, layer_id, dtype=self.moe_router_dtype ) + if hasattr(layer.mlp, "fc1_latent_proj") and layer.mlp.fc1_latent_proj is not None: + self.rules["fc1_latent_proj"](layer.mlp.fc1_latent_proj, layer_id) + if hasattr(layer.mlp, "fc2_latent_proj") and layer.mlp.fc2_latent_proj is not None: + self.rules["fc2_latent_proj"](layer.mlp.fc2_latent_proj, layer_id) if ( hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None @@ -1197,6 +1194,27 @@ def _get_state_dict(self): self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) else: raise ValueError("Only TransformerLayer or MambaLayer are supported.") + + # MTP module + # Hacky version for now: copy MTP weights from pretrained model + if os.path.isdir(self._hf_pretrained_model_name): + safetensors_index_file = Path(self._hf_pretrained_model_name) / "model.safetensors.index.json" + else: + safetensors_index_file = hf_hub_download( + repo_id=self._hf_pretrained_model_name, + filename="model.safetensors.index.json") + + print(f"safetensors_index_file: {safetensors_index_file}") + if safetensors_index_file and os.path.exists(safetensors_index_file): + with open(safetensors_index_file, "r") as f: + safetensors_index = json.load(f) + model_dir = Path(safetensors_index_file).parent + for key in safetensors_index["weight_map"]: + if "mtp" in key: + self._state_dict[key] = get_safetensor(model_dir, key) + + # TODO implement actual MTP export + def export_mcore_gpt_to_hf( @@ -1205,7 +1223,7 @@ def export_mcore_gpt_to_hf( export_extra_modules: bool = False, dtype: torch.dtype = torch.bfloat16, export_dir: Path | str = tempfile.gettempdir(), - moe_router_dtype: torch.dtype | None = None, + moe_router_dtype: str | None = None, ): """Export Megatron Core GPTModel to unified checkpoint and save to export_dir. diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 3184f2a78..015a9a2de 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -81,6 +81,11 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis forward_loop(model) finish_stats_collection(model) + # Step 1: Sync amax across local experts in a SequentialMLP + for name, module in model.named_modules(): + if hasattr(module, "sync_moe_local_experts_amax"): + module.sync_moe_local_experts_amax() + if not distributed_sync: return @@ -95,13 +100,14 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) # TODO: create sync_bias_across_distributed_group - # Step 1:Sync amax across data parallelism + + # Step 2:Sync amax across data parallelism for name, module in model.named_modules(): if isinstance(module, QuantModule): for child in module.children(): if isinstance(child, (TensorQuantizer, SequentialQuantizer)): sync_quantizer_amax_across_dp_ep(child, module.parallel_state) - # TP sync: + # Step 3: TP sync # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same # ColumnParallel: X @ [A_1, A_2] (weights split along Cout) @@ -156,7 +162,6 @@ def sync_quantizer_amax_across_tp( axes_for_sync=[None, -1], parallel_state=module.parallel_state, ) - sync_quantizer_amax_across_tp( module.weight_quantizer, name, @@ -182,10 +187,6 @@ def sync_quantizer_amax_across_tp( parallel_state=module.parallel_state, ) - # MOE Quantization - if hasattr(module, "sync_moe_local_experts_amax"): - module.sync_moe_local_experts_amax() - # KV Cache Quantization if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): # We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 803c9747f..521b73ea1 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -23,6 +23,7 @@ import megatron.core.parallel_state as mcore_parallel import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp +import megatron.core.transformer.transformer_layer as megatron_transformer_layer import megatron.core.transformer.moe.experts as megatron_moe import megatron.core.transformer.moe.moe_layer as megatron_moe_layer import torch @@ -40,6 +41,7 @@ register_modelopt_extra_state_callbacks, ) from modelopt.torch.utils.distributed import ParallelState +import torch.distributed as dist from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear @@ -48,6 +50,7 @@ try: from megatron.core.extensions.transformer_engine import ( + TELinear, TEColumnParallelGroupedLinear, TEColumnParallelLinear, TEDotProductAttention, @@ -581,7 +584,6 @@ def sync_moe_local_experts_amax(self): This function is called to synchronize the amax values across local experts s.t. all localexperts will share the same amax. """ - torch.distributed.barrier() # Collect amax from all local experts amax_dict = {} for expert in self.local_experts: @@ -595,6 +597,8 @@ def sync_moe_local_experts_amax(self): else torch.maximum(stored_amax, amax_tensor) ) + + # Apply synchronized amax values back to all local experts for expert in self.local_experts: for name, module in expert.named_modules(): @@ -619,6 +623,10 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): if HAS_TE: + @QuantModuleRegistry.register({TELinear: "te_mcore_Linear"}) + class _QuantTEMCoreLinear(_QuantTELinear): + pass + @QuantModuleRegistry.register({TERowParallelLinear: "te_mcore_RowParallelLinear"}) class _QuantTEMCoreRowParallelLinear(_QuantTELinear, _MegatronRowParallelLinear): pass @@ -756,6 +764,22 @@ def forward(self, hidden_states): if any(getattr(m, "_if_calib", False) for m in self.experts.modules()): original_top_k = self.router.topk self.router.topk = self.router.num_experts - super().forward(hidden_states) + output = super().forward(hidden_states) self.router.topk = original_top_k return super().forward(hidden_states) + +# TODO double check if MOE forward will be implemented in MoELayer or TransformerLayer +# We do not need both layers to be patched + +@QuantModuleRegistry.register({megatron_transformer_layer.TransformerLayer: "megatron_transformer_layer_TransformerLayer"}) +class _QuantTransformerLayer(QuantModule): + def _setup(self): + pass + + def _forward_mlp_moe_preprocess(self, hidden_states): + if any(getattr(m, "_if_calib", False) for m in self.mlp.experts.modules()): + original_top_k = self.mlp.router.topk + self.mlp.router.topk = self.mlp.router.num_experts + output = super()._forward_mlp_moe_preprocess(hidden_states) + self.mlp.router.topk = original_top_k + return super()._forward_mlp_moe_preprocess(hidden_states) \ No newline at end of file