Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions fms_fsdp/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,81 @@ def get_model_config(model_variant):
"pad_vocab_size_multiple": 16,
"tie_embeddings": False,
}
elif model_variant == "mamba_30b_moe":
model_config = {
"d_model": 3072,
"d_intermediate": 1344,
"n_layer": 32,
"vocab_size": 128256,
"ssm_cfg": {"layer": "Mamba2"},
"attn_layer_idx": [9, 18, 27],
"attn_cfg": {
"causal": True,
"d_conv": 0,
"head_dim": 128,
"num_heads": 24,
"num_heads_kv": 8,
"out_proj_bias": False,
"qkv_proj_bias": False,
"rotary_emb_dim": 64,
},
"mlp_cfg": {"n_expert": 64, "load_balancing_loss": True, "top_k": 8},
"rms_norm": True,
"residual_in_fp32": True,
"fused_add_norm": True,
"pad_vocab_size_multiple": 16,
"tie_embeddings": False,
}
elif model_variant == "mamba_120b_moe":
model_config = {
"d_model": 4096,
"d_intermediate": 896,
"n_layer": 40,
"vocab_size": 128256,
"ssm_cfg": {"layer": "Mamba2"},
"attn_layer_idx": [9, 18, 27, 36],
"attn_cfg": {
"causal": True,
"d_conv": 0,
"head_dim": 128,
"num_heads": 32,
"num_heads_kv": 8,
"out_proj_bias": False,
"qkv_proj_bias": False,
"rotary_emb_dim": 64,
},
"mlp_cfg": {"n_expert": 256, "load_balancing_loss": True, "top_k": 16},
"rms_norm": True,
"residual_in_fp32": True,
"fused_add_norm": True,
"pad_vocab_size_multiple": 16,
"tie_embeddings": False,
}
elif model_variant == "mamba_236b_moe":
model_config = {
"d_model": 5120,
"d_intermediate": 1536,
"n_layer": 60,
"vocab_size": 128256,
"ssm_cfg": {"layer": "Mamba2"},
"attn_layer_idx": [9, 18, 27, 36, 45, 54],
"attn_cfg": {
"causal": True,
"d_conv": 0,
"head_dim": 128,
"num_heads": 40,
"num_heads_kv": 8,
"out_proj_bias": False,
"qkv_proj_bias": False,
"rotary_emb_dim": 64,
},
"mlp_cfg": {"n_expert": 160, "load_balancing_loss": True, "top_k": 8},
"rms_norm": True,
"residual_in_fp32": True,
"fused_add_norm": True,
"pad_vocab_size_multiple": 16,
"tie_embeddings": False,
}
else:
raise ValueError(f"model variant {model_variant} not supported.")

Expand Down
23 changes: 20 additions & 3 deletions fms_fsdp/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from dataclasses import asdict
from functools import partial

from transformers.models.granitemoe.modeling_granitemoe import load_balancing_loss_func

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should copy in this code later on



try:
import packaging.version
Expand Down Expand Up @@ -73,7 +75,7 @@ def train(
run["hparams"] = asdict(cfg)

model.train()
ddp_stats = torch.zeros(3).to(local_rank)
ddp_stats = torch.zeros(4).to(local_rank)

start = time.time()
loop_start = time.time()
Expand All @@ -86,9 +88,20 @@ def train(

optimizer.zero_grad()
output = model(input)
output = output.logits if hasattr(output, "logits") else output
logits = output.logits if hasattr(output, "logits") else output
ce_loss = torch.nn.CrossEntropyLoss()
loss = ce_loss(output.view(-1, output.size(-1)), label.view(-1).long())
loss = ce_loss(logits.view(-1, logits.size(-1)), label.view(-1).long())
ddp_stats[3] += loss.item()
if "moe" in cfg.model_variant:
aux_outputs = output.aux_outputs
if aux_outputs is not None:
top_k = model.config.mlp_cfg.get("top_k", 2)
aux_loss = load_balancing_loss_func(
aux_outputs,
num_experts=model.config.mlp_cfg["n_expert"],
top_k=top_k,
)
loss += 0.2 * aux_loss

loss.backward()
ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item()
Expand All @@ -105,6 +118,7 @@ def train(
dist.all_reduce(ddp_stats, op=dist.ReduceOp.SUM)
train_loss = ddp_stats[0] / ddp_stats[2]
g_norm = ddp_stats[1] / ddp_stats[2]
original_loss = ddp_stats[3] / ddp_stats[2]
elapsed_time = time.time() - loop_start
world_size = int(os.environ["WORLD_SIZE"])
new_tokens_seen = (
Expand All @@ -113,6 +127,7 @@ def train(
if rank == 0:
total_tokens_seen = tokens_seen + new_tokens_seen
current_loss = train_loss.item()
current_original_loss = original_loss.item()
current_lr = scheduler.get_last_lr()[0]
current_gnorm = g_norm.item()
current_step_time = (time.time() - start) / cfg.report_interval
Expand All @@ -132,6 +147,7 @@ def train(

print("step:", batch_idx)
print("loss:", current_loss)
print("original loss:", current_original_loss)
print("LR:", current_lr)
print("tokens seen:", total_tokens_seen)
print("gradient norm:", current_gnorm)
Expand All @@ -149,6 +165,7 @@ def train(
vals_to_track = {
"learning rate": current_lr,
"loss": current_loss,
"original_loss": current_original_loss,
"gradient norm": current_gnorm,
"token seen": total_tokens_seen,
"current throughput (token per gpu per sec)": current_throughput,
Expand Down
5 changes: 3 additions & 2 deletions main_training_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def main(**kwargs):
# get model
config_data = get_model_config(cfg.model_variant)
mamba_config = MambaConfig(**config_data)
model = MambaLMHeadModel(mamba_config)
with torch.device("meta"):
model = MambaLMHeadModel(mamba_config)

if rank == 0:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
Expand All @@ -89,7 +90,7 @@ def main(**kwargs):
use_orig_params=cfg.use_torch_compile,
device_id=torch.cuda.current_device(),
limit_all_gathers=True,
param_init_fn=param_init_fn,
param_init_fn=lambda x: x.to_empty(device=torch.cuda.current_device(), recurse=False),
)

# fsdp activation checkpointing
Expand Down
Loading