From d54fe0a3a2bad85cbdcdec44ac7cfb68098a9b30 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 15 Apr 2025 16:52:51 -0400 Subject: [PATCH 1/4] Add partial wd --- main_training_mamba.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/main_training_mamba.py b/main_training_mamba.py index 3ff12a60..6cb1cd98 100644 --- a/main_training_mamba.py +++ b/main_training_mamba.py @@ -107,8 +107,35 @@ def main(**kwargs): model = torch.compile(model) # Optimizer + # optimizer = optim.AdamW( + # model.parameters(), lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1 + # ) + params_with_decay = [] + params_without_decay = [] + for name, param in model.named_parameters(): + print(f'{name=}') + if 'A_log' in name or 'D' in name or 'dt_bias' in name: + params_without_decay.append(param) + else: + params_with_decay.append(param) + + + print(f'{params_with_decay=}') + print(f'{params_without_decay=}') + optimizer = optim.AdamW( - model.parameters(), lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1 + [ + { + "params": params_with_decay, + "weight_decay": 0.1, + }, + { + "params": params_without_decay, + "weight_decay": 0., + }, + ], + betas = (0.9, 0.95), + lr = cfg.learning_rate, # cfg.learning_rate, ) # optionally load from checkpoint (when continue pretraining) From 79e1063b5982e1766fe2a29e33805c3ec979ba6a Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 15 Apr 2025 16:53:47 -0400 Subject: [PATCH 2/4] Add zl --- fms_fsdp/utils/train_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py index ef421f6f..593eb13a 100644 --- a/fms_fsdp/utils/train_utils.py +++ b/fms_fsdp/utils/train_utils.py @@ -89,6 +89,7 @@ def train( output = output.logits if hasattr(output, "logits") else output ce_loss = torch.nn.CrossEntropyLoss() loss = ce_loss(output.view(-1, output.size(-1)), label.view(-1).long()) + loss = loss + .0001 * torch.logsumexp(output, dim=-1).pow(2).mean() loss.backward() ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item() From f326e97e787e3c312dd6c9ba11f96cf6a0981aca Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 15 Apr 2025 17:06:05 -0400 Subject: [PATCH 3/4] Hardcode orig_params True --- main_training_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main_training_mamba.py b/main_training_mamba.py index 6cb1cd98..e45b37cb 100644 --- a/main_training_mamba.py +++ b/main_training_mamba.py @@ -86,7 +86,7 @@ def main(**kwargs): auto_wrap_policy=wrapping_policy, mixed_precision=mixed_precision_policy, sharding_strategy=sharding_strategy_policy, - use_orig_params=cfg.use_torch_compile, + use_orig_params=True, device_id=torch.cuda.current_device(), limit_all_gathers=True, param_init_fn=param_init_fn, From c728781ed36dd35576acd142be21b81fe2a0c689 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 15 Apr 2025 17:19:35 -0400 Subject: [PATCH 4/4] Turn off rope --- fms_fsdp/utils/config_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index f4e7628c..e18c4d8f 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -175,7 +175,7 @@ def get_model_config(model_variant): "num_heads_kv": 8, "out_proj_bias": False, "qkv_proj_bias": False, - "rotary_emb_dim": 64, + "rotary_emb_dim": 0, }, "rms_norm": True, "residual_in_fp32": True,