Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions modelopt/torch/export/plugins/mcore_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,16 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any]
func_kwargs=func_kwargs,
)

class GroupedMLPMerging(CustomModuleMapping):
"""A custom module mapping that merges up_proj and down_proj for Grouped MLP."""

def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}):
"""Create a custom module mapping that merges up_proj and down_proj for Grouped MLP."""
super().__init__(
func_name="grouped_mlp_merging",
target_name_or_prefix=target_name_or_prefix,
func_kwargs=func_kwargs,
)
class GatedMLPMerging(CustomModuleMapping):
"""A custom module mapping that merges gate_proj and up_proj."""

Expand All @@ -126,6 +135,15 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any]
func_kwargs=func_kwargs,
)

class SelfAttentionScaling(CustomModuleMapping):
"""A custom module mapping that scales self attention."""
def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}):
"""Create a custom module mapping that scales self attention."""
super().__init__(
func_name="self_attention_scaling",
target_name_or_prefix=target_name_or_prefix,
func_kwargs=func_kwargs,
)

class GatedMLPSlicing(CustomModuleMapping):
"""A custom module mapping that slices gate_proj and up_proj."""
Expand Down
25 changes: 25 additions & 0 deletions modelopt/torch/export/plugins/mcore_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
NameRemapping,
QKVMerging,
QKVSlicing,
GroupedMLPMerging,
SelfAttentionScaling,
)

# Example on adding a new CausalLM.
Expand Down Expand Up @@ -81,8 +83,21 @@
"shared_experts.linear_fc2": NameRemapping(
"backbone.layers.{}.mixer.shared_experts.down_proj.", ROW_TP
),
# Latent MoE
"fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj.", REPLICATE),
"fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj.", REPLICATE),
# Repeated MTP module
"mtp.enorm": NameRemapping("mtp.layers.{}.enorm.", {"is_mtp": True}),
"mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm.", {"is_mtp": True}),
"mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", {"is_mtp": True}),
"mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", {"is_mtp": True}),
# Grouped local experts in MTP
"experts.linear_fc1": GroupedMLPMerging("mtp.layers.{}.mixer.experts.{{}}.up_proj", COL_ETP | {"is_mtp": True}),
"experts.linear_fc2": GroupedMLPMerging("mtp.layers.{}.mixer.experts.{{}}.down_proj", ROW_ETP | {"is_mtp": True}),

}

# TODO ADD MTP export

nemotron_h_causal_lm_export: dict[str, CustomModuleMapping] = {
"word_embeddings": NameRemapping("backbone.embeddings."),
Expand All @@ -101,6 +116,7 @@
"input_layernorm": NameRemapping("backbone.layers.{}.norm."),
"linear_qkv": QKVSlicing("backbone.layers.{}.mixer."),
"linear_proj": NameRemapping("backbone.layers.{}.mixer.o_proj."),
"core_attention": SelfAttentionScaling("backbone.layers.{}.mixer."),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doublecheck that this is only needed for export

# MLP
"pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm."),
"linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj."),
Expand All @@ -115,4 +131,13 @@
"shared_experts.linear_fc2": NameRemapping(
"backbone.layers.{}.mixer.shared_experts.down_proj."
),
# Latent MoE
"fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj."),
"fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj."),
# MTP
"mtp.enorm": NameRemapping("mtp.layers.{}.enorm."),
"mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm."),
"mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj."),
"mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm."),

}
Loading