From 0a43eddfecf422e7a619313ff9d92a99780d59f8 Mon Sep 17 00:00:00 2001 From: Linsong Chu Date: Thu, 23 Jan 2025 14:01:17 -0500 Subject: [PATCH 01/10] enable mamba moe --- fms_fsdp/utils/config_utils.py | 35 ++++++++++++++++++++++++++++++++++ fms_fsdp/utils/train_utils.py | 21 ++++++++++++++++++-- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index f4e7628c..cc1f2be6 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -183,6 +183,41 @@ def get_model_config(model_variant): "pad_vocab_size_multiple": 16, "tie_embeddings": False, } + elif model_variant == "mamba_9.8b_moe": + model_config = { + "d_model": 4096, + "d_intermediate": 448, + "n_layer": 32, + "vocab_size": 128256, + "ssm_cfg": { + "layer": "Mamba2" + }, + "attn_layer_idx": [ + 9, + 18, + 27 + ], + "attn_cfg": { + "causal": True, + "d_conv": 0, + "head_dim": 128, + "num_heads": 32, + "num_heads_kv": 8, + "out_proj_bias": False, + "qkv_proj_bias": False, + "rotary_emb_dim": 64 + }, + "mlp_cfg": { + "n_expert": 32, + "load_balancing_loss": True, + "top_k": 8 + }, + "rms_norm": True, + "residual_in_fp32": True, + "fused_add_norm": True, + "pad_vocab_size_multiple": 16, + "tie_embeddings": False + } else: raise ValueError(f"model variant {model_variant} not supported.") diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py index ef421f6f..362ab447 100644 --- a/fms_fsdp/utils/train_utils.py +++ b/fms_fsdp/utils/train_utils.py @@ -2,6 +2,7 @@ from dataclasses import asdict from functools import partial +from transformers.models.granitemoe.modeling_granitemoe import load_balancing_loss_func try: import packaging.version @@ -86,9 +87,23 @@ def train( optimizer.zero_grad() output = model(input) - output = output.logits if hasattr(output, "logits") else output + logits = output.logits if hasattr(output, "logits") else output ce_loss = torch.nn.CrossEntropyLoss() - loss = ce_loss(output.view(-1, output.size(-1)), label.view(-1).long()) + loss = ce_loss(logits.view(-1, logits.size(-1)), label.view(-1).long()) + if "moe" in cfg.model_variant: + aux_outputs = output.aux_outputs + if aux_outputs is not None: + top_k = model.config.mlp_cfg.get('top_k',2) + aux_loss = load_balancing_loss_func( + aux_outputs, + num_experts=model.config.mlp_cfg['n_expert'], + top_k=top_k, + ) + _, distribution = aux_outputs[0].topk( + top_k + ).indices.unique(return_counts=True) + distribution = distribution.detach().cpu().tolist() + loss += 0.2 * aux_loss loss.backward() ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item() @@ -145,6 +160,8 @@ def train( "overall token per day:", int(new_tokens_seen / elapsed_time * 3600 * 24), ) + if "moe" in cfg.model_variant: + print("distribution", distribution) if cfg.tracker: vals_to_track = { "learning rate": current_lr, From 184539282611c30d1762aea44dfa834e6d27a3bb Mon Sep 17 00:00:00 2001 From: Linsong Chu Date: Thu, 23 Jan 2025 16:20:25 -0500 Subject: [PATCH 02/10] add 30b mamba moe and 120b mamba moe --- fms_fsdp/utils/config_utils.py | 62 ++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index cc1f2be6..bdafc2a5 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -218,6 +218,68 @@ def get_model_config(model_variant): "pad_vocab_size_multiple": 16, "tie_embeddings": False } + elif model_variant == "mamba_30b_moe": + model_config = { + "d_model": 6144, + "d_intermediate": 336, + "n_layer": 48, + "vocab_size": 128256, + "ssm_cfg": { + "layer": "Mamba2" + }, + "attn_layer_idx": [9, 18, 27, 36, 45], + "attn_cfg": { + "causal": True, + "d_conv": 0, + "head_dim": 128, + "num_heads": 48, + "num_heads_kv": 8, + "out_proj_bias": False, + "qkv_proj_bias": False, + "rotary_emb_dim": 64 + }, + "mlp_cfg": { + "n_expert": 64, + "load_balancing_loss": True, + "top_k": 8 + }, + "rms_norm": True, + "residual_in_fp32": True, + "fused_add_norm": True, + "pad_vocab_size_multiple": 16, + "tie_embeddings": False + } + elif model_variant == "mamba_120b_moe": + model_config = { + "d_model": 8192, + "d_intermediate": 112, + "n_layer": 108, + "vocab_size": 128256, + "ssm_cfg": { + "layer": "Mamba2" + }, + "attn_layer_idx": [9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99], + "attn_cfg": { + "causal": True, + "d_conv": 0, + "head_dim": 128, + "num_heads": 64, + "num_heads_kv": 8, + "out_proj_bias": False, + "qkv_proj_bias": False, + "rotary_emb_dim": 64 + }, + "mlp_cfg": { + "n_expert": 256, + "load_balancing_loss": True, + "top_k": 16 + }, + "rms_norm": True, + "residual_in_fp32": True, + "fused_add_norm": True, + "pad_vocab_size_multiple": 16, + "tie_embeddings": False + } else: raise ValueError(f"model variant {model_variant} not supported.") From c6cca86c93680c7fb2533ace5a581ce3f0d04f52 Mon Sep 17 00:00:00 2001 From: Linsong Chu Date: Thu, 23 Jan 2025 16:22:07 -0500 Subject: [PATCH 03/10] lint --- fms_fsdp/utils/config_utils.py | 48 +++++++++------------------------- fms_fsdp/utils/train_utils.py | 11 ++++---- 2 files changed, 19 insertions(+), 40 deletions(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index bdafc2a5..40b3b4b9 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -189,14 +189,8 @@ def get_model_config(model_variant): "d_intermediate": 448, "n_layer": 32, "vocab_size": 128256, - "ssm_cfg": { - "layer": "Mamba2" - }, - "attn_layer_idx": [ - 9, - 18, - 27 - ], + "ssm_cfg": {"layer": "Mamba2"}, + "attn_layer_idx": [9, 18, 27], "attn_cfg": { "causal": True, "d_conv": 0, @@ -205,18 +199,14 @@ def get_model_config(model_variant): "num_heads_kv": 8, "out_proj_bias": False, "qkv_proj_bias": False, - "rotary_emb_dim": 64 - }, - "mlp_cfg": { - "n_expert": 32, - "load_balancing_loss": True, - "top_k": 8 + "rotary_emb_dim": 64, }, + "mlp_cfg": {"n_expert": 32, "load_balancing_loss": True, "top_k": 8}, "rms_norm": True, "residual_in_fp32": True, "fused_add_norm": True, "pad_vocab_size_multiple": 16, - "tie_embeddings": False + "tie_embeddings": False, } elif model_variant == "mamba_30b_moe": model_config = { @@ -224,9 +214,7 @@ def get_model_config(model_variant): "d_intermediate": 336, "n_layer": 48, "vocab_size": 128256, - "ssm_cfg": { - "layer": "Mamba2" - }, + "ssm_cfg": {"layer": "Mamba2"}, "attn_layer_idx": [9, 18, 27, 36, 45], "attn_cfg": { "causal": True, @@ -236,18 +224,14 @@ def get_model_config(model_variant): "num_heads_kv": 8, "out_proj_bias": False, "qkv_proj_bias": False, - "rotary_emb_dim": 64 - }, - "mlp_cfg": { - "n_expert": 64, - "load_balancing_loss": True, - "top_k": 8 + "rotary_emb_dim": 64, }, + "mlp_cfg": {"n_expert": 64, "load_balancing_loss": True, "top_k": 8}, "rms_norm": True, "residual_in_fp32": True, "fused_add_norm": True, "pad_vocab_size_multiple": 16, - "tie_embeddings": False + "tie_embeddings": False, } elif model_variant == "mamba_120b_moe": model_config = { @@ -255,9 +239,7 @@ def get_model_config(model_variant): "d_intermediate": 112, "n_layer": 108, "vocab_size": 128256, - "ssm_cfg": { - "layer": "Mamba2" - }, + "ssm_cfg": {"layer": "Mamba2"}, "attn_layer_idx": [9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99], "attn_cfg": { "causal": True, @@ -267,18 +249,14 @@ def get_model_config(model_variant): "num_heads_kv": 8, "out_proj_bias": False, "qkv_proj_bias": False, - "rotary_emb_dim": 64 - }, - "mlp_cfg": { - "n_expert": 256, - "load_balancing_loss": True, - "top_k": 16 + "rotary_emb_dim": 64, }, + "mlp_cfg": {"n_expert": 256, "load_balancing_loss": True, "top_k": 16}, "rms_norm": True, "residual_in_fp32": True, "fused_add_norm": True, "pad_vocab_size_multiple": 16, - "tie_embeddings": False + "tie_embeddings": False, } else: raise ValueError(f"model variant {model_variant} not supported.") diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py index 362ab447..4260f51c 100644 --- a/fms_fsdp/utils/train_utils.py +++ b/fms_fsdp/utils/train_utils.py @@ -4,6 +4,7 @@ from transformers.models.granitemoe.modeling_granitemoe import load_balancing_loss_func + try: import packaging.version except ImportError: @@ -93,15 +94,15 @@ def train( if "moe" in cfg.model_variant: aux_outputs = output.aux_outputs if aux_outputs is not None: - top_k = model.config.mlp_cfg.get('top_k',2) + top_k = model.config.mlp_cfg.get("top_k", 2) aux_loss = load_balancing_loss_func( aux_outputs, - num_experts=model.config.mlp_cfg['n_expert'], + num_experts=model.config.mlp_cfg["n_expert"], top_k=top_k, ) - _, distribution = aux_outputs[0].topk( - top_k - ).indices.unique(return_counts=True) + _, distribution = ( + aux_outputs[0].topk(top_k).indices.unique(return_counts=True) + ) distribution = distribution.detach().cpu().tolist() loss += 0.2 * aux_loss From 7d26b453714c095f6f3f6dd1aefb9c8bdda1a184 Mon Sep 17 00:00:00 2001 From: Linsong Chu Date: Thu, 23 Jan 2025 19:00:31 -0500 Subject: [PATCH 04/10] remove distribution report --- fms_fsdp/utils/train_utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py index 4260f51c..6923dadf 100644 --- a/fms_fsdp/utils/train_utils.py +++ b/fms_fsdp/utils/train_utils.py @@ -100,10 +100,6 @@ def train( num_experts=model.config.mlp_cfg["n_expert"], top_k=top_k, ) - _, distribution = ( - aux_outputs[0].topk(top_k).indices.unique(return_counts=True) - ) - distribution = distribution.detach().cpu().tolist() loss += 0.2 * aux_loss loss.backward() @@ -161,8 +157,6 @@ def train( "overall token per day:", int(new_tokens_seen / elapsed_time * 3600 * 24), ) - if "moe" in cfg.model_variant: - print("distribution", distribution) if cfg.tracker: vals_to_track = { "learning rate": current_lr, From 0bbfabe0f54f3d0cb8ba2b2e6a0c3b8ca1ea9e1f Mon Sep 17 00:00:00 2001 From: Linsong Chu Date: Thu, 23 Jan 2025 21:11:25 -0500 Subject: [PATCH 05/10] add both losses to report --- fms_fsdp/utils/train_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py index 6923dadf..4c1ff5d6 100644 --- a/fms_fsdp/utils/train_utils.py +++ b/fms_fsdp/utils/train_utils.py @@ -75,7 +75,7 @@ def train( run["hparams"] = asdict(cfg) model.train() - ddp_stats = torch.zeros(3).to(local_rank) + ddp_stats = torch.zeros(4).to(local_rank) start = time.time() loop_start = time.time() @@ -91,6 +91,7 @@ def train( logits = output.logits if hasattr(output, "logits") else output ce_loss = torch.nn.CrossEntropyLoss() loss = ce_loss(logits.view(-1, logits.size(-1)), label.view(-1).long()) + ddp_stats[3] += loss.item() if "moe" in cfg.model_variant: aux_outputs = output.aux_outputs if aux_outputs is not None: @@ -117,6 +118,7 @@ def train( dist.all_reduce(ddp_stats, op=dist.ReduceOp.SUM) train_loss = ddp_stats[0] / ddp_stats[2] g_norm = ddp_stats[1] / ddp_stats[2] + original_loss = ddp_stats[3] / ddp_stats[2] elapsed_time = time.time() - loop_start world_size = int(os.environ["WORLD_SIZE"]) new_tokens_seen = ( @@ -125,6 +127,7 @@ def train( if rank == 0: total_tokens_seen = tokens_seen + new_tokens_seen current_loss = train_loss.item() + current_original_loss = original_loss.item() current_lr = scheduler.get_last_lr()[0] current_gnorm = g_norm.item() current_step_time = (time.time() - start) / cfg.report_interval @@ -144,6 +147,7 @@ def train( print("step:", batch_idx) print("loss:", current_loss) + print("original loss:", current_original_loss) print("LR:", current_lr) print("tokens seen:", total_tokens_seen) print("gradient norm:", current_gnorm) @@ -161,6 +165,7 @@ def train( vals_to_track = { "learning rate": current_lr, "loss": current_loss, + "original_loss": current_original_loss, "gradient norm": current_gnorm, "token seen": total_tokens_seen, "current throughput (token per gpu per sec)": current_throughput, From de631af875c4fd4f078f54f2808caf30be5960c1 Mon Sep 17 00:00:00 2001 From: Linsong Chu Date: Fri, 24 Jan 2025 11:44:02 -0500 Subject: [PATCH 06/10] update 30b and 120b config for mamba moe --- fms_fsdp/utils/config_utils.py | 43 +++++++--------------------------- 1 file changed, 9 insertions(+), 34 deletions(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 40b3b4b9..c96ef25d 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -183,10 +183,10 @@ def get_model_config(model_variant): "pad_vocab_size_multiple": 16, "tie_embeddings": False, } - elif model_variant == "mamba_9.8b_moe": + elif model_variant == "mamba_30b_moe": model_config = { - "d_model": 4096, - "d_intermediate": 448, + "d_model": 3072, + "d_intermediate": 1344, "n_layer": 32, "vocab_size": 128256, "ssm_cfg": {"layer": "Mamba2"}, @@ -195,32 +195,7 @@ def get_model_config(model_variant): "causal": True, "d_conv": 0, "head_dim": 128, - "num_heads": 32, - "num_heads_kv": 8, - "out_proj_bias": False, - "qkv_proj_bias": False, - "rotary_emb_dim": 64, - }, - "mlp_cfg": {"n_expert": 32, "load_balancing_loss": True, "top_k": 8}, - "rms_norm": True, - "residual_in_fp32": True, - "fused_add_norm": True, - "pad_vocab_size_multiple": 16, - "tie_embeddings": False, - } - elif model_variant == "mamba_30b_moe": - model_config = { - "d_model": 6144, - "d_intermediate": 336, - "n_layer": 48, - "vocab_size": 128256, - "ssm_cfg": {"layer": "Mamba2"}, - "attn_layer_idx": [9, 18, 27, 36, 45], - "attn_cfg": { - "causal": True, - "d_conv": 0, - "head_dim": 128, - "num_heads": 48, + "num_heads": 24, "num_heads_kv": 8, "out_proj_bias": False, "qkv_proj_bias": False, @@ -235,17 +210,17 @@ def get_model_config(model_variant): } elif model_variant == "mamba_120b_moe": model_config = { - "d_model": 8192, - "d_intermediate": 112, - "n_layer": 108, + "d_model": 4096, + "d_intermediate": 896, + "n_layer": 40, "vocab_size": 128256, "ssm_cfg": {"layer": "Mamba2"}, - "attn_layer_idx": [9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99], + "attn_layer_idx": [9, 18, 27, 36], "attn_cfg": { "causal": True, "d_conv": 0, "head_dim": 128, - "num_heads": 64, + "num_heads": 32, "num_heads_kv": 8, "out_proj_bias": False, "qkv_proj_bias": False, From c5271d28929abd0061d985312e0701ea67b3da96 Mon Sep 17 00:00:00 2001 From: Linsong Chu Date: Tue, 28 Jan 2025 14:30:36 -0500 Subject: [PATCH 07/10] add 236b config --- fms_fsdp/utils/config_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index c96ef25d..19cb541b 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -233,6 +233,8 @@ def get_model_config(model_variant): "pad_vocab_size_multiple": 16, "tie_embeddings": False, } + elif model_variant == "mamba_236b_moe": + model_config = {'d_model': 5120, 'd_intermediate': 1536, 'n_layer': 60, 'vocab_size': 128256, 'ssm_cfg': {'layer': 'Mamba2'}, 'attn_layer_idx': [9, 18, 27, 36, 45, 54], 'attn_cfg': {'causal': True, 'd_conv': 0, 'head_dim': 128, 'num_heads': 40, 'num_heads_kv': 8, 'out_proj_bias': False, 'qkv_proj_bias': False, 'rotary_emb_dim': 64}, "mlp_cfg": {"n_expert": 160, "load_balancing_loss": True, "top_k": 8},'rms_norm': True, 'residual_in_fp32': True, 'fused_add_norm': True, 'pad_vocab_size_multiple': 16, 'tie_embeddings': False} else: raise ValueError(f"model variant {model_variant} not supported.") From 0ed114e5e3c7cfcd736ca09ddf1e580997cfaadc Mon Sep 17 00:00:00 2001 From: Linsong Chu Date: Tue, 28 Jan 2025 14:31:07 -0500 Subject: [PATCH 08/10] lint --- fms_fsdp/utils/config_utils.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 19cb541b..5e372825 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -234,7 +234,30 @@ def get_model_config(model_variant): "tie_embeddings": False, } elif model_variant == "mamba_236b_moe": - model_config = {'d_model': 5120, 'd_intermediate': 1536, 'n_layer': 60, 'vocab_size': 128256, 'ssm_cfg': {'layer': 'Mamba2'}, 'attn_layer_idx': [9, 18, 27, 36, 45, 54], 'attn_cfg': {'causal': True, 'd_conv': 0, 'head_dim': 128, 'num_heads': 40, 'num_heads_kv': 8, 'out_proj_bias': False, 'qkv_proj_bias': False, 'rotary_emb_dim': 64}, "mlp_cfg": {"n_expert": 160, "load_balancing_loss": True, "top_k": 8},'rms_norm': True, 'residual_in_fp32': True, 'fused_add_norm': True, 'pad_vocab_size_multiple': 16, 'tie_embeddings': False} + model_config = { + "d_model": 5120, + "d_intermediate": 1536, + "n_layer": 60, + "vocab_size": 128256, + "ssm_cfg": {"layer": "Mamba2"}, + "attn_layer_idx": [9, 18, 27, 36, 45, 54], + "attn_cfg": { + "causal": True, + "d_conv": 0, + "head_dim": 128, + "num_heads": 40, + "num_heads_kv": 8, + "out_proj_bias": False, + "qkv_proj_bias": False, + "rotary_emb_dim": 64, + }, + "mlp_cfg": {"n_expert": 160, "load_balancing_loss": True, "top_k": 8}, + "rms_norm": True, + "residual_in_fp32": True, + "fused_add_norm": True, + "pad_vocab_size_multiple": 16, + "tie_embeddings": False, + } else: raise ValueError(f"model variant {model_variant} not supported.") From 7a29b60abca2fddf915469e2e7da16e596361991 Mon Sep 17 00:00:00 2001 From: Linsong Chu Date: Thu, 30 Jan 2025 20:49:27 -0500 Subject: [PATCH 09/10] make meta device init --- main_training_mamba.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/main_training_mamba.py b/main_training_mamba.py index 3619ea25..7b66536a 100644 --- a/main_training_mamba.py +++ b/main_training_mamba.py @@ -64,7 +64,8 @@ def main(**kwargs): # get model config_data = get_model_config(cfg.model_variant) mamba_config = MambaConfig(**config_data) - model = MambaLMHeadModel(mamba_config) + with torch.device("meta"): + model = MambaLMHeadModel(mamba_config) if rank == 0: total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) From eae33c37965ada428967f9ff06087cd12b2a1040 Mon Sep 17 00:00:00 2001 From: Linsong Chu Date: Thu, 30 Jan 2025 21:08:01 -0500 Subject: [PATCH 10/10] make meta device init --- main_training_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main_training_mamba.py b/main_training_mamba.py index 7b66536a..87b76ff0 100644 --- a/main_training_mamba.py +++ b/main_training_mamba.py @@ -90,7 +90,7 @@ def main(**kwargs): use_orig_params=cfg.use_torch_compile, device_id=torch.cuda.current_device(), limit_all_gathers=True, - param_init_fn=param_init_fn, + param_init_fn=lambda x: x.to_empty(device=torch.cuda.current_device(), recurse=False), ) # fsdp activation checkpointing