Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fms_fsdp/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def get_model_config(model_variant):
"num_heads_kv": 8,
"out_proj_bias": False,
"qkv_proj_bias": False,
"rotary_emb_dim": 64,
"rotary_emb_dim": 0,
},
"rms_norm": True,
"residual_in_fp32": True,
Expand Down
1 change: 1 addition & 0 deletions fms_fsdp/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def train(
output = output.logits if hasattr(output, "logits") else output
ce_loss = torch.nn.CrossEntropyLoss()
loss = ce_loss(output.view(-1, output.size(-1)), label.view(-1).long())
loss = loss + .0001 * torch.logsumexp(output, dim=-1).pow(2).mean()

loss.backward()
ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item()
Expand Down
31 changes: 29 additions & 2 deletions main_training_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def main(**kwargs):
auto_wrap_policy=wrapping_policy,
mixed_precision=mixed_precision_policy,
sharding_strategy=sharding_strategy_policy,
use_orig_params=cfg.use_torch_compile,
use_orig_params=True,
device_id=torch.cuda.current_device(),
limit_all_gathers=True,
param_init_fn=param_init_fn,
Expand All @@ -107,8 +107,35 @@ def main(**kwargs):
model = torch.compile(model)

# Optimizer
# optimizer = optim.AdamW(
# model.parameters(), lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1
# )
params_with_decay = []
params_without_decay = []
for name, param in model.named_parameters():
print(f'{name=}')
if 'A_log' in name or 'D' in name or 'dt_bias' in name:
params_without_decay.append(param)
else:
params_with_decay.append(param)


print(f'{params_with_decay=}')
print(f'{params_without_decay=}')

optimizer = optim.AdamW(
model.parameters(), lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1
[
{
"params": params_with_decay,
"weight_decay": 0.1,
},
{
"params": params_without_decay,
"weight_decay": 0.,
},
],
betas = (0.9, 0.95),
lr = cfg.learning_rate, # cfg.learning_rate,
)

# optionally load from checkpoint (when continue pretraining)
Expand Down
Loading