From da29217d88efc7c504383a159fb21baeece8bd0e Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 22 May 2024 15:06:25 -0400 Subject: [PATCH 01/49] Set llama2-1.4b to gqa --- fms_fsdp/utils/config_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 9d3f0386..13384031 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -54,6 +54,8 @@ def get_model_config(model_variant): emb_dim=2048, nheads=16, nlayers=24, + hidden_grow_factor=3, + kvheads=4, ) elif model_variant == "llama3_8b": llama_config = LLaMAConfig( From 41ae740d00838701afe4936f0b17501cd1767ebb Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 28 May 2024 14:58:55 -0400 Subject: [PATCH 02/49] Add singlefile ckp saving/conversion --- fms_fsdp/utils/train_utils.py | 1 + main_training.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py index 9b507e6b..1a75bd44 100644 --- a/fms_fsdp/utils/train_utils.py +++ b/fms_fsdp/utils/train_utils.py @@ -77,6 +77,7 @@ def train( start = time.time() loop_start = time.time() + train_loss = -1 for batch_idx, (input, label) in enumerate(train_loader, start=start_step + 1): if batch_idx > cfg.num_steps: break diff --git a/main_training.py b/main_training.py index 9c4c5e44..21a24429 100644 --- a/main_training.py +++ b/main_training.py @@ -156,6 +156,8 @@ def main(**kwargs): tokens_seen, ) + checkpointer.save_single_file(cfg.num_steps, model) + dist.barrier() dist.destroy_process_group() From 5171b5d427acdc4723b61e32fde8c50c24f946f6 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 31 May 2024 18:39:08 -0400 Subject: [PATCH 03/49] Turn off GQA on 1.4B --- fms_fsdp/utils/config_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 13384031..3e046f63 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -54,8 +54,8 @@ def get_model_config(model_variant): emb_dim=2048, nheads=16, nlayers=24, - hidden_grow_factor=3, - kvheads=4, + # hidden_grow_factor=3, + # kvheads=4, ) elif model_variant == "llama3_8b": llama_config = LLaMAConfig( From abd5b197cda7ec69dc32b6f9c8fb8f251247860f Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 5 Jun 2024 14:34:40 -0400 Subject: [PATCH 04/49] GQA on, add for 7b --- fms_fsdp/utils/config_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 3e046f63..b63f9b10 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -48,14 +48,17 @@ def get_model_config(model_variant): hidden_grow_factor=13824 / 5120, ) elif model_variant == "llama2_7b": - llama_config = LLaMAConfig() + llama_config = LLaMAConfig( + hidden_grow_factor=3, + kvheads=8, + ) elif model_variant == "llama2_1.4b": llama_config = LLaMAConfig( emb_dim=2048, nheads=16, nlayers=24, - # hidden_grow_factor=3, - # kvheads=4, + hidden_grow_factor=3, + kvheads=4, ) elif model_variant == "llama3_8b": llama_config = LLaMAConfig( From 0ac0a5f00ef6b06e1648b7cf738a813dff93b27e Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 10 Jun 2024 19:15:47 -0400 Subject: [PATCH 05/49] Add llama3 tele cfg --- fms_fsdp/utils/config_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index b63f9b10..3546cd5b 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -100,6 +100,15 @@ def get_model_config(model_variant): hidden_grow_factor=3.5, max_expected_seq_len=4096, ) + elif model_variant == "llama3_1.8b_tele": + llama_config = LLaMAConfig( + src_vocab_size=128256, + emb_dim=2048, + nheads=32, + kvheads=2, + nlayers=24, + hidden_grow_factor=3.75, + max_expected_seq_len=4096, elif model_variant == "llama3_70b": llama_config = LLaMAConfig( src_vocab_size=128256, From 8caeaa24bda8bf7e92727bd6386e1b842983689f Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 10 Jun 2024 19:37:57 -0400 Subject: [PATCH 06/49] Add missing paren --- fms_fsdp/utils/config_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 3546cd5b..fdd31780 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -109,6 +109,7 @@ def get_model_config(model_variant): nlayers=24, hidden_grow_factor=3.75, max_expected_seq_len=4096, + ) elif model_variant == "llama3_70b": llama_config = LLaMAConfig( src_vocab_size=128256, From 941e98fe5c57f8380485977b574fed6dd73391e7 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 13 Jun 2024 11:21:20 -0400 Subject: [PATCH 07/49] Back to gqa4 for llama3 --- fms_fsdp/utils/config_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index fdd31780..5e0ae2e7 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -104,10 +104,10 @@ def get_model_config(model_variant): llama_config = LLaMAConfig( src_vocab_size=128256, emb_dim=2048, - nheads=32, - kvheads=2, + nheads=16, + kvheads=4, nlayers=24, - hidden_grow_factor=3.75, + hidden_grow_factor=3.5, max_expected_seq_len=4096, ) elif model_variant == "llama3_70b": From 44edc0d57feb4520c5b5b2d20a1583b3041030fd Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 18 Jun 2024 22:36:34 -0400 Subject: [PATCH 08/49] Nonstrict ckpt load --- main_training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/main_training.py b/main_training.py index 21a24429..4eabf152 100644 --- a/main_training.py +++ b/main_training.py @@ -123,6 +123,7 @@ def main(**kwargs): optimizer, None, path=os.path.join(cfg.ckpt_load_path, "checkpoints/"), + strict=False, ) # LR schedule From a48a055288d6e6ec74356048553ce65ab66bb9a5 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 18 Jun 2024 22:56:52 -0400 Subject: [PATCH 09/49] If singlefile load, don't append "checkpoints" folder --- main_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main_training.py b/main_training.py index 4eabf152..afae1b49 100644 --- a/main_training.py +++ b/main_training.py @@ -122,7 +122,7 @@ def main(**kwargs): model, optimizer, None, - path=os.path.join(cfg.ckpt_load_path, "checkpoints/"), + path=os.path.join(cfg.ckpt_load_path, "checkpoints/") if not os.path.isfile(cfg.ckpt_load_path) else cfg.ckpt_load_path, strict=False, ) From 9031328050b05d3d1f3a12ed1c9b637c8c8a43b5 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 21 Jun 2024 12:47:48 -0400 Subject: [PATCH 10/49] Add reset stepcount field --- fms_fsdp/config/training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index e8b4df06..6b1f3888 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -36,6 +36,7 @@ class train_config: learning_rate: float = 3e-4 grad_clip_thresh: float = 1.0 seed: int = 2023 + reset_stepcount: bool = False # profiling use_profiler: bool = False From 0e3430a451200b116aa69093a03e3a7dd252de6c Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 21 Jun 2024 12:48:48 -0400 Subject: [PATCH 11/49] Add reset stepcount support --- main_training.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/main_training.py b/main_training.py index afae1b49..cd858156 100644 --- a/main_training.py +++ b/main_training.py @@ -125,6 +125,8 @@ def main(**kwargs): path=os.path.join(cfg.ckpt_load_path, "checkpoints/") if not os.path.isfile(cfg.ckpt_load_path) else cfg.ckpt_load_path, strict=False, ) + if cfg.reset_stepcount: + start_step = 0 # LR schedule warmup_interval = min(2000, cfg.num_steps // 20) From 45d7e414089233e4f605c3b6cbf24a54efa80309 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 21 Jun 2024 15:07:07 -0400 Subject: [PATCH 12/49] Override optimizer LR values with desired --- main_training.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/main_training.py b/main_training.py index cd858156..fc3625ee 100644 --- a/main_training.py +++ b/main_training.py @@ -127,6 +127,9 @@ def main(**kwargs): ) if cfg.reset_stepcount: start_step = 0 + # Override loaded optim hyperparams with the current values + for g in optimizer.param_groups: + g["initial_lr"] = cfg.learning_rate # LR schedule warmup_interval = min(2000, cfg.num_steps // 20) From 9cb032986cfa3bb641aa1726159e651be503f46c Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 24 Jun 2024 11:41:57 -0400 Subject: [PATCH 13/49] gqa16 --- fms_fsdp/utils/config_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 5e0ae2e7..fdd31780 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -104,10 +104,10 @@ def get_model_config(model_variant): llama_config = LLaMAConfig( src_vocab_size=128256, emb_dim=2048, - nheads=16, - kvheads=4, + nheads=32, + kvheads=2, nlayers=24, - hidden_grow_factor=3.5, + hidden_grow_factor=3.75, max_expected_seq_len=4096, ) elif model_variant == "llama3_70b": From 756c3eea36696b642d1d3b2ecb003f47e9e1d49d Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 24 Jun 2024 11:48:01 -0400 Subject: [PATCH 14/49] GOTHERE --- fms_fsdp/utils/train_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py index 1a75bd44..a5fc9b40 100644 --- a/fms_fsdp/utils/train_utils.py +++ b/fms_fsdp/utils/train_utils.py @@ -89,6 +89,7 @@ def train( output = output.logits if hasattr(output, "logits") else output ce_loss = torch.nn.CrossEntropyLoss() loss = ce_loss(output.view(-1, output.size(-1)), label.view(-1).long()) + print("GOTHERE") loss.backward() ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item() From fd28fb72a1d77e8f8d535e91e5a356839f680ba8 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 24 Jun 2024 11:50:48 -0400 Subject: [PATCH 15/49] No gothere --- fms_fsdp/utils/train_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py index a5fc9b40..1a75bd44 100644 --- a/fms_fsdp/utils/train_utils.py +++ b/fms_fsdp/utils/train_utils.py @@ -89,7 +89,6 @@ def train( output = output.logits if hasattr(output, "logits") else output ce_loss = torch.nn.CrossEntropyLoss() loss = ce_loss(output.view(-1, output.size(-1)), label.view(-1).long()) - print("GOTHERE") loss.backward() ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item() From ffded3584e8111012292d2d5c78427f2fb0aa4be Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 25 Jun 2024 13:20:15 -0400 Subject: [PATCH 16/49] Nonstrict fsdp load --- fms_fsdp/utils/checkpointing_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index 41dd8e2d..d299baed 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -208,6 +208,7 @@ def load( state_dict=model_ckp, storage_reader=FileSystemReader(load_path), planner=DefaultLoadPlanner(), + strict=strict, ) model.load_state_dict(model_ckp["model_state"]) model.to(self.local_rank) From 1050d1dc68809cbee1d5ef950a4b233adc576d1c Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 25 Jun 2024 13:23:23 -0400 Subject: [PATCH 17/49] Nonstrict fsdp load pt2 --- fms_fsdp/utils/checkpointing_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index d299baed..b27b75a4 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -208,9 +208,8 @@ def load( state_dict=model_ckp, storage_reader=FileSystemReader(load_path), planner=DefaultLoadPlanner(), - strict=strict, ) - model.load_state_dict(model_ckp["model_state"]) + model.load_state_dict(model_ckp["model_state"], strict=strict) model.to(self.local_rank) self.report(model_load_time=time.time() - model_load_time) step = 0 From 166c01de04dca59315d019326f054c51102bede1 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 25 Jun 2024 13:26:34 -0400 Subject: [PATCH 18/49] Stop nonstrict fsdp load --- fms_fsdp/utils/checkpointing_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index b27b75a4..41dd8e2d 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -209,7 +209,7 @@ def load( storage_reader=FileSystemReader(load_path), planner=DefaultLoadPlanner(), ) - model.load_state_dict(model_ckp["model_state"], strict=strict) + model.load_state_dict(model_ckp["model_state"]) model.to(self.local_rank) self.report(model_load_time=time.time() - model_load_time) step = 0 From fee4c4878805a22ccb186921c41322b3d98a5603 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 1 Jul 2024 13:45:39 -0400 Subject: [PATCH 19/49] Separate gqa4 and 16 cfgs --- fms_fsdp/utils/config_utils.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index fdd31780..a96fb66b 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -100,7 +100,7 @@ def get_model_config(model_variant): hidden_grow_factor=3.5, max_expected_seq_len=4096, ) - elif model_variant == "llama3_1.8b_tele": + elif model_variant == "llama3_1.8b_tele16": llama_config = LLaMAConfig( src_vocab_size=128256, emb_dim=2048, @@ -110,6 +110,16 @@ def get_model_config(model_variant): hidden_grow_factor=3.75, max_expected_seq_len=4096, ) + elif model_variant == "llama3_1.8b_tele4": + llama_config = LLaMAConfig( + src_vocab_size=128256, + emb_dim=2048, + nheads=16, + kvheads=4, + nlayers=24, + hidden_grow_factor=3.5, + max_expected_seq_len=4096, + ) elif model_variant == "llama3_70b": llama_config = LLaMAConfig( src_vocab_size=128256, From 6f3fd09a4661ad8b2e702dec4002d46f139c0a65 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 1 Jul 2024 13:48:20 -0400 Subject: [PATCH 20/49] Fix indent --- fms_fsdp/utils/config_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index a96fb66b..e4270892 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -110,7 +110,7 @@ def get_model_config(model_variant): hidden_grow_factor=3.75, max_expected_seq_len=4096, ) - elif model_variant == "llama3_1.8b_tele4": + elif model_variant == "llama3_1.8b_tele4": llama_config = LLaMAConfig( src_vocab_size=128256, emb_dim=2048, From f5a707e92fe12a1d8378d76fb7a137b52344674d Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 16 Jul 2024 11:48:21 -0400 Subject: [PATCH 21/49] Add mini llama cfg --- fms_fsdp/utils/config_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index e4270892..8f47b765 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -140,6 +140,13 @@ def get_model_config(model_variant): hidden_grow_factor=3.5, max_expected_seq_len=4096, ) + elif model_variant == "llama3_194m_4k": + llama_config = LLaMAConfig( + emb_dim=1024, + nheads=8, + nlayers=10, + max_expected_seq_len=4096, + ) else: raise ValueError(f"model variant {model_variant} not supported.") From 57e3ffd598c1eb1ff6e8e7f8dc52f5d9f9a26c02 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 16 Jul 2024 11:56:03 -0400 Subject: [PATCH 22/49] mini llama3 vsize --- fms_fsdp/utils/config_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 8f47b765..3eb7104f 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -142,6 +142,7 @@ def get_model_config(model_variant): ) elif model_variant == "llama3_194m_4k": llama_config = LLaMAConfig( + src_vocab_size=128256, emb_dim=1024, nheads=8, nlayers=10, From b2e6ae04863397a28914a3706c21e52170cae222 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 17 Jul 2024 18:09:53 -0400 Subject: [PATCH 23/49] Add muP fields, auto-update model cfg --- fms_fsdp/config/training.py | 10 ++++++++++ fms_fsdp/utils/config_utils.py | 7 +++++++ main_training.py | 3 ++- 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index 6b1f3888..86df2fce 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -52,3 +52,13 @@ class train_config: # compile use_torch_compile: bool = True + + # muP scale params + mup_emb_scale: float = 0 + mup_head_scale: float = 0 + mup_ffn_init: float = 0 + mup_attn_init: float = 0 + mup_attn_temp: float = 0 + mup_0d_lr: float = 0 + mup_1d_lr: float = 0 + mup_2d_lr: float = 0 diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 3eb7104f..b74de69a 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -152,3 +152,10 @@ def get_model_config(model_variant): raise ValueError(f"model variant {model_variant} not supported.") return llama_config + + +def set_mup_from_cfg(job_cfg, model_cfg): + fields = {k:v for k,v in vars(job_cfg).items() if "mup" in k and v >= 0} + for f in fields: + setattr(model_cfg, f, fields[f]) + return model_cfg \ No newline at end of file diff --git a/main_training.py b/main_training.py index fc3625ee..4f5a6864 100644 --- a/main_training.py +++ b/main_training.py @@ -11,7 +11,7 @@ from fms_fsdp import config from fms_fsdp.utils.checkpointing_utils import Checkpointer -from fms_fsdp.utils.config_utils import get_model_config, update_config +from fms_fsdp.utils.config_utils import get_model_config, set_mup_from_cfg, update_config from fms_fsdp.utils.dataloader_utils import get_data_loader, get_dummy_loader from fms_fsdp.utils.train_utils import ( get_policies, @@ -57,6 +57,7 @@ def main(**kwargs): # get fms model llama_config = get_model_config(cfg.model_variant) + llama_config = set_mup_from_cfg(cfg, llama_config) if cfg.low_cpu_fsdp: with torch.device("meta"): model = LLaMA(llama_config) From 4a02c822ca1ec423f6e876832c241a7f537aa34d Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 18 Jul 2024 13:14:58 -0400 Subject: [PATCH 24/49] Add mup scaling to fsdp init params --- fms_fsdp/policies/param_init.py | 25 ++++++++++++++++--------- fms_fsdp/utils/train_utils.py | 4 ++-- main_training.py | 20 ++++++++++---------- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/fms_fsdp/policies/param_init.py b/fms_fsdp/policies/param_init.py index 49655b26..86306730 100644 --- a/fms_fsdp/policies/param_init.py +++ b/fms_fsdp/policies/param_init.py @@ -1,18 +1,25 @@ import torch -from fms.modules.attention import MultiHeadAttention +from fms.modules.attention import MultiHeadAttention, QKV from fms.modules.embedding import WordEmbedding from fms.modules.feedforward import GatedLinearUnit from fms.modules.layernorm import LayerNormParameterized # for details, read https://github.com/foundation-model-stack/fms-fsdp/issues/64 -def param_init_function(module): - if ( - isinstance(module, MultiHeadAttention) - or isinstance(module, WordEmbedding) - or isinstance(module, GatedLinearUnit) - or isinstance(module, LayerNormParameterized) - ): +def param_init_function(module, cfg): + scales = { + MultiHeadAttention: cfg.mup_attn_init, + QKV: cfg.mup_attn_init, + GatedLinearUnit: cfg.mup_ffn_init, + WordEmbedding: 1, + LayerNormParameterized: 1, + } + scale_keys = list(scales.keys()) + scale_vals = list(scales.values()) + type_id = [isinstance(module, x) for x in scale_keys] + is_resettable = sum(type_id) + if is_resettable: + module_type_id = type_id.index(True) module.to_empty(device=torch.cuda.current_device()) with torch.no_grad(): - module.reset_parameters() + module.reset_parameters(scale=scale_vals[module_type_id]) diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py index 1a75bd44..80db6362 100644 --- a/fms_fsdp/utils/train_utils.py +++ b/fms_fsdp/utils/train_utils.py @@ -187,7 +187,7 @@ def setup_environ_flags(): os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1) -def get_policies(cfg, rank, block): +def get_policies(cfg, rank, block, model_cfg): """Get policies for mixed precision, wrapping, sharding, ac and param init function.""" # mixed precision @@ -231,7 +231,7 @@ def get_policies(cfg, rank, block): # param init function if cfg.low_cpu_fsdp: - param_init_fn = param_init_function + param_init_fn = partial(param_init_function, cfg=model_cfg) else: param_init_fn = None diff --git a/main_training.py b/main_training.py index 4f5a6864..b16d244c 100644 --- a/main_training.py +++ b/main_training.py @@ -45,16 +45,6 @@ def main(**kwargs): torch.cuda.empty_cache() setup_environ_flags() - # get policy - block = LLaMABlock - ( - mixed_precision_policy, - wrapping_policy, - sharding_strategy_policy, - apply_selective_ac, - param_init_fn, - ) = get_policies(cfg, rank, block) - # get fms model llama_config = get_model_config(cfg.model_variant) llama_config = set_mup_from_cfg(cfg, llama_config) @@ -79,6 +69,16 @@ def main(**kwargs): if rank == 0: print("Datasets constructed!") + # get policy + block = LLaMABlock + ( + mixed_precision_policy, + wrapping_policy, + sharding_strategy_policy, + apply_selective_ac, + param_init_fn, + ) = get_policies(cfg, rank, block, llama_config) + # FSDP model = FSDP( model, From 22c54a61b09c52533ed2e997f8aaefc0c078afe0 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 18 Jul 2024 13:24:21 -0400 Subject: [PATCH 25/49] Only set mup cfg if >0 --- fms_fsdp/utils/config_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index b74de69a..ba8c1f27 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -155,7 +155,7 @@ def get_model_config(model_variant): def set_mup_from_cfg(job_cfg, model_cfg): - fields = {k:v for k,v in vars(job_cfg).items() if "mup" in k and v >= 0} + fields = {k:v for k,v in vars(job_cfg).items() if "mup" in k and v > 0} for f in fields: setattr(model_cfg, f, fields[f]) return model_cfg \ No newline at end of file From af5261453c60fd14b8b2e6c335113e2afa404a68 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 15:26:26 -0400 Subject: [PATCH 26/49] 1d init mup --- fms_fsdp/policies/param_init.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fms_fsdp/policies/param_init.py b/fms_fsdp/policies/param_init.py index 86306730..b82fd3c2 100644 --- a/fms_fsdp/policies/param_init.py +++ b/fms_fsdp/policies/param_init.py @@ -11,7 +11,7 @@ def param_init_function(module, cfg): MultiHeadAttention: cfg.mup_attn_init, QKV: cfg.mup_attn_init, GatedLinearUnit: cfg.mup_ffn_init, - WordEmbedding: 1, + WordEmbedding: (cfg.mup_1d_init, cfg.mup_emb_scale, cfg.mup_head_scale), LayerNormParameterized: 1, } scale_keys = list(scales.keys()) From 57ed6f9585a69f833cd3a87b4a24ea3b2e7312fa Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 15:45:42 -0400 Subject: [PATCH 27/49] Attempt mup lrs --- main_training.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/main_training.py b/main_training.py index b16d244c..9c89aa8e 100644 --- a/main_training.py +++ b/main_training.py @@ -5,6 +5,8 @@ import torch import torch.optim as optim from fms.models.llama import LLaMA, LLaMABlock +from fms.modules.layernorm import LayerNormParameterized +from fms.modules.embedding import WordEmbedding from torch import distributed as dist from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.optim.lr_scheduler import LambdaLR @@ -111,8 +113,20 @@ def main(**kwargs): model = torch.compile(model) # Optimizer + params_0d = ( + [p for name,p in model.named_parameters() if "bias" in name] + + [m.weight for m in model.modules if isinstance(m, LayerNormParameterized)] + ) + params_1d = [p for m in model.modules() for name,p in m.named_parameters() if isinstance(m, WordEmbedding) and "bias" not in name] + params_2d = [p for m in model.modules() for name,p in m.named_parameters() if (isinstance(m, MultiHeadAttention) or isinstance(m, GatedLinearUnit)) and "bias" not in name] + params_all = set(sum([params_0d, params_1d, params_2d])) + for p in model.parameters(): + assert p in params_all, p.shape optimizer = optim.AdamW( - model.parameters(), lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1 + {"params": params_0d, "lr": cfg.learning_rate * llama_config.mup_0d_lr}, + {"params": params_1d, "lr": cfg.learning_rate * llama_config.mup_1d_lr}, + {"params": params_2d, "lr": cfg.learning_rate * llama_config.mup_2d_lr}, + lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1 ) # optionally load from checkpoint (when continue pretraining) From 372e1d200448513a8998eca9e7815da701621537 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 15:53:44 -0400 Subject: [PATCH 28/49] cleanup, typofix --- fms_fsdp/policies/param_init.py | 2 +- fms_fsdp/utils/config_utils.py | 4 ++-- main_training.py | 40 ++++++++++++++++++++++++--------- 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/fms_fsdp/policies/param_init.py b/fms_fsdp/policies/param_init.py index b82fd3c2..81504539 100644 --- a/fms_fsdp/policies/param_init.py +++ b/fms_fsdp/policies/param_init.py @@ -1,5 +1,5 @@ import torch -from fms.modules.attention import MultiHeadAttention, QKV +from fms.modules.attention import QKV, MultiHeadAttention from fms.modules.embedding import WordEmbedding from fms.modules.feedforward import GatedLinearUnit from fms.modules.layernorm import LayerNormParameterized diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index ba8c1f27..4dd45ff3 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -155,7 +155,7 @@ def get_model_config(model_variant): def set_mup_from_cfg(job_cfg, model_cfg): - fields = {k:v for k,v in vars(job_cfg).items() if "mup" in k and v > 0} + fields = {k: v for k, v in vars(job_cfg).items() if "mup" in k and v > 0} for f in fields: setattr(model_cfg, f, fields[f]) - return model_cfg \ No newline at end of file + return model_cfg diff --git a/main_training.py b/main_training.py index 9c89aa8e..443dcbff 100644 --- a/main_training.py +++ b/main_training.py @@ -5,15 +5,21 @@ import torch import torch.optim as optim from fms.models.llama import LLaMA, LLaMABlock -from fms.modules.layernorm import LayerNormParameterized +from fms.modules.attention import MultiHeadAttention from fms.modules.embedding import WordEmbedding +from fms.modules.feedforward import GatedLinearUnit +from fms.modules.layernorm import LayerNormParameterized from torch import distributed as dist from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.optim.lr_scheduler import LambdaLR from fms_fsdp import config from fms_fsdp.utils.checkpointing_utils import Checkpointer -from fms_fsdp.utils.config_utils import get_model_config, set_mup_from_cfg, update_config +from fms_fsdp.utils.config_utils import ( + get_model_config, + set_mup_from_cfg, + update_config, +) from fms_fsdp.utils.dataloader_utils import get_data_loader, get_dummy_loader from fms_fsdp.utils.train_utils import ( get_policies, @@ -113,12 +119,22 @@ def main(**kwargs): model = torch.compile(model) # Optimizer - params_0d = ( - [p for name,p in model.named_parameters() if "bias" in name] + - [m.weight for m in model.modules if isinstance(m, LayerNormParameterized)] - ) - params_1d = [p for m in model.modules() for name,p in m.named_parameters() if isinstance(m, WordEmbedding) and "bias" not in name] - params_2d = [p for m in model.modules() for name,p in m.named_parameters() if (isinstance(m, MultiHeadAttention) or isinstance(m, GatedLinearUnit)) and "bias" not in name] + params_0d = [p for name, p in model.named_parameters() if "bias" in name] + [ + m.weight for m in model.modules() if isinstance(m, LayerNormParameterized) + ] + params_1d = [ + p + for m in model.modules() + for name, p in m.named_parameters() + if isinstance(m, WordEmbedding) and "bias" not in name + ] + params_2d = [ + p + for m in model.modules() + for name, p in m.named_parameters() + if (isinstance(m, MultiHeadAttention) or isinstance(m, GatedLinearUnit)) + and "bias" not in name + ] params_all = set(sum([params_0d, params_1d, params_2d])) for p in model.parameters(): assert p in params_all, p.shape @@ -126,7 +142,9 @@ def main(**kwargs): {"params": params_0d, "lr": cfg.learning_rate * llama_config.mup_0d_lr}, {"params": params_1d, "lr": cfg.learning_rate * llama_config.mup_1d_lr}, {"params": params_2d, "lr": cfg.learning_rate * llama_config.mup_2d_lr}, - lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1 + lr=cfg.learning_rate, + betas=(0.9, 0.95), + weight_decay=0.1, ) # optionally load from checkpoint (when continue pretraining) @@ -137,7 +155,9 @@ def main(**kwargs): model, optimizer, None, - path=os.path.join(cfg.ckpt_load_path, "checkpoints/") if not os.path.isfile(cfg.ckpt_load_path) else cfg.ckpt_load_path, + path=os.path.join(cfg.ckpt_load_path, "checkpoints/") + if not os.path.isfile(cfg.ckpt_load_path) + else cfg.ckpt_load_path, strict=False, ) if cfg.reset_stepcount: From c0d1d1ffae69ec3f3531a13d0f5c58615c82d9ea Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 16:00:23 -0400 Subject: [PATCH 29/49] diag print --- main_training.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/main_training.py b/main_training.py index 443dcbff..912857f8 100644 --- a/main_training.py +++ b/main_training.py @@ -135,6 +135,9 @@ def main(**kwargs): if (isinstance(m, MultiHeadAttention) or isinstance(m, GatedLinearUnit)) and "bias" not in name ] + print("0d", type(params_0d), len(params_0d)) + print("1d", type(params_1d), len(params_1d)) + print("2d", type(params_2d), len(params_2d)) params_all = set(sum([params_0d, params_1d, params_2d])) for p in model.parameters(): assert p in params_all, p.shape From 2017a9857cadfd79a741341157b490e1a6f35bf1 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 16:08:51 -0400 Subject: [PATCH 30/49] Non double list comp --- main_training.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/main_training.py b/main_training.py index 912857f8..0b6906c6 100644 --- a/main_training.py +++ b/main_training.py @@ -122,19 +122,13 @@ def main(**kwargs): params_0d = [p for name, p in model.named_parameters() if "bias" in name] + [ m.weight for m in model.modules() if isinstance(m, LayerNormParameterized) ] - params_1d = [ - p - for m in model.modules() - for name, p in m.named_parameters() - if isinstance(m, WordEmbedding) and "bias" not in name - ] - params_2d = [ - p - for m in model.modules() - for name, p in m.named_parameters() - if (isinstance(m, MultiHeadAttention) or isinstance(m, GatedLinearUnit)) - and "bias" not in name - ] + params_1d = [] + params_2d = [] + for m in model.modules(): + if isinstance(m, WordEmbedding): + params_1d += [p for name, p in m.named_parameters() if "bias" not in name] + elif isinstance(m, MultiHeadAttention) or isinstance(m, GatedLinearUnit): + params_2d += [p for name, p in m.named_parameters() if "bias" not in name] print("0d", type(params_0d), len(params_0d)) print("1d", type(params_1d), len(params_1d)) print("2d", type(params_2d), len(params_2d)) From 9a77a2b3d96b96a3598379d25f4a4f0d4356ce1d Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 16:12:35 -0400 Subject: [PATCH 31/49] diag print --- main_training.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/main_training.py b/main_training.py index 0b6906c6..0451cca0 100644 --- a/main_training.py +++ b/main_training.py @@ -126,8 +126,10 @@ def main(**kwargs): params_2d = [] for m in model.modules(): if isinstance(m, WordEmbedding): + print("GOTHERE: 1D", [name for name,_ in m.named_parameters()]) params_1d += [p for name, p in m.named_parameters() if "bias" not in name] elif isinstance(m, MultiHeadAttention) or isinstance(m, GatedLinearUnit): + print("GOTHERE: 2D", [name for name,_ in m.named_parameters()]) params_2d += [p for name, p in m.named_parameters() if "bias" not in name] print("0d", type(params_0d), len(params_0d)) print("1d", type(params_1d), len(params_1d)) From 6c01a0b68c572d94650fefc34d20c614dcb9806e Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 16:17:40 -0400 Subject: [PATCH 32/49] Stop named params --- main_training.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/main_training.py b/main_training.py index 0451cca0..42df9ce7 100644 --- a/main_training.py +++ b/main_training.py @@ -126,11 +126,15 @@ def main(**kwargs): params_2d = [] for m in model.modules(): if isinstance(m, WordEmbedding): - print("GOTHERE: 1D", [name for name,_ in m.named_parameters()]) - params_1d += [p for name, p in m.named_parameters() if "bias" not in name] - elif isinstance(m, MultiHeadAttention) or isinstance(m, GatedLinearUnit): - print("GOTHERE: 2D", [name for name,_ in m.named_parameters()]) - params_2d += [p for name, p in m.named_parameters() if "bias" not in name] + params_1d.append(m.emb.weight) + if m.abs_pos: + params_1d.append(m.pos_emb.weight) + if m.reversible and not m.tie_weights: + params_1d.append(m.head.weight) + elif isinstance(m, MultiHeadAttention): + params_2d += [m.dense.weight,] + list(m.in_proj.parameters()) + elif isinstance(m, GatedLinearUnit): + params_2d += [m.wg1_fused.weight, m.w2.weight] print("0d", type(params_0d), len(params_0d)) print("1d", type(params_1d), len(params_1d)) print("2d", type(params_2d), len(params_2d)) From 101652bbd5fe6edb659e293f42a9882dcdc5c8b9 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 16:20:15 -0400 Subject: [PATCH 33/49] List sum --- main_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main_training.py b/main_training.py index 42df9ce7..4bb8bc05 100644 --- a/main_training.py +++ b/main_training.py @@ -138,7 +138,7 @@ def main(**kwargs): print("0d", type(params_0d), len(params_0d)) print("1d", type(params_1d), len(params_1d)) print("2d", type(params_2d), len(params_2d)) - params_all = set(sum([params_0d, params_1d, params_2d])) + params_all = set(sum([params_0d, params_1d, params_2d], [])) for p in model.parameters(): assert p in params_all, p.shape optimizer = optim.AdamW( From 49341e1b4479df1e945924befd23105349d0c7f3 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 16:25:33 -0400 Subject: [PATCH 34/49] diag print --- main_training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/main_training.py b/main_training.py index 4bb8bc05..97f0bf2a 100644 --- a/main_training.py +++ b/main_training.py @@ -132,6 +132,7 @@ def main(**kwargs): if m.reversible and not m.tie_weights: params_1d.append(m.head.weight) elif isinstance(m, MultiHeadAttention): + print(len(list(m.in_proj.parameters()))) params_2d += [m.dense.weight,] + list(m.in_proj.parameters()) elif isinstance(m, GatedLinearUnit): params_2d += [m.wg1_fused.weight, m.w2.weight] From 58c1662d3d9dc594ee7b364aac6bc85ac5d42c9d Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 16:29:23 -0400 Subject: [PATCH 35/49] diag print --- main_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main_training.py b/main_training.py index 97f0bf2a..43bf4d39 100644 --- a/main_training.py +++ b/main_training.py @@ -132,7 +132,7 @@ def main(**kwargs): if m.reversible and not m.tie_weights: params_1d.append(m.head.weight) elif isinstance(m, MultiHeadAttention): - print(len(list(m.in_proj.parameters()))) + print(len(list(m.in_proj))) params_2d += [m.dense.weight,] + list(m.in_proj.parameters()) elif isinstance(m, GatedLinearUnit): params_2d += [m.wg1_fused.weight, m.w2.weight] From a14f57ee3a0af888555dbdf1a3df1f1499fe2708 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 16:29:38 -0400 Subject: [PATCH 36/49] diag print --- main_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main_training.py b/main_training.py index 43bf4d39..0d92ce35 100644 --- a/main_training.py +++ b/main_training.py @@ -132,7 +132,7 @@ def main(**kwargs): if m.reversible and not m.tie_weights: params_1d.append(m.head.weight) elif isinstance(m, MultiHeadAttention): - print(len(list(m.in_proj))) + print(list(m.in_proj)) params_2d += [m.dense.weight,] + list(m.in_proj.parameters()) elif isinstance(m, GatedLinearUnit): params_2d += [m.wg1_fused.weight, m.w2.weight] From 5c8d8c4bfd3512a079845c6b606acbc8c8ebc217 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 16:32:18 -0400 Subject: [PATCH 37/49] diag print --- main_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main_training.py b/main_training.py index 0d92ce35..22d14a4e 100644 --- a/main_training.py +++ b/main_training.py @@ -132,7 +132,7 @@ def main(**kwargs): if m.reversible and not m.tie_weights: params_1d.append(m.head.weight) elif isinstance(m, MultiHeadAttention): - print(list(m.in_proj)) + print(list(m.in_proj.parameters())) params_2d += [m.dense.weight,] + list(m.in_proj.parameters()) elif isinstance(m, GatedLinearUnit): params_2d += [m.wg1_fused.weight, m.w2.weight] From d0e4888aac7f36d05ab41e9ff9d4d2ea3f5b3868 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 16:36:16 -0400 Subject: [PATCH 38/49] diag print --- main_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main_training.py b/main_training.py index 22d14a4e..8287d43a 100644 --- a/main_training.py +++ b/main_training.py @@ -132,7 +132,7 @@ def main(**kwargs): if m.reversible and not m.tie_weights: params_1d.append(m.head.weight) elif isinstance(m, MultiHeadAttention): - print(list(m.in_proj.parameters())) + print(m.in_proj) params_2d += [m.dense.weight,] + list(m.in_proj.parameters()) elif isinstance(m, GatedLinearUnit): params_2d += [m.wg1_fused.weight, m.w2.weight] From e9701a196e1ea7d77aeebf695718ae2eb264dbd1 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 16:41:16 -0400 Subject: [PATCH 39/49] Iterate over submodules explicitly --- main_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main_training.py b/main_training.py index 8287d43a..97f0cf9c 100644 --- a/main_training.py +++ b/main_training.py @@ -133,7 +133,7 @@ def main(**kwargs): params_1d.append(m.head.weight) elif isinstance(m, MultiHeadAttention): print(m.in_proj) - params_2d += [m.dense.weight,] + list(m.in_proj.parameters()) + params_2d += [m.dense.weight,] + [m_.weight for m_ in m.in_proj.modules()] elif isinstance(m, GatedLinearUnit): params_2d += [m.wg1_fused.weight, m.w2.weight] print("0d", type(params_0d), len(params_0d)) From 0c46c3ac40cd54a3c069a585a4ab50276fc5359c Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 16:46:34 -0400 Subject: [PATCH 40/49] linear submods only --- main_training.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/main_training.py b/main_training.py index 97f0cf9c..6900a7fe 100644 --- a/main_training.py +++ b/main_training.py @@ -3,6 +3,7 @@ import fire import torch +import torch.nn as nn import torch.optim as optim from fms.models.llama import LLaMA, LLaMABlock from fms.modules.attention import MultiHeadAttention @@ -133,7 +134,7 @@ def main(**kwargs): params_1d.append(m.head.weight) elif isinstance(m, MultiHeadAttention): print(m.in_proj) - params_2d += [m.dense.weight,] + [m_.weight for m_ in m.in_proj.modules()] + params_2d += [m.dense.weight,] + [m_.weight for m_ in m.in_proj.modules() if isinstance(m_, nn.Linear)] elif isinstance(m, GatedLinearUnit): params_2d += [m.wg1_fused.weight, m.w2.weight] print("0d", type(params_0d), len(params_0d)) From 58ce6807682807d7df50d7bc7d1380476012af9c Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 16:50:41 -0400 Subject: [PATCH 41/49] diag print --- main_training.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/main_training.py b/main_training.py index 6900a7fe..a175a81e 100644 --- a/main_training.py +++ b/main_training.py @@ -133,14 +133,12 @@ def main(**kwargs): if m.reversible and not m.tie_weights: params_1d.append(m.head.weight) elif isinstance(m, MultiHeadAttention): - print(m.in_proj) params_2d += [m.dense.weight,] + [m_.weight for m_ in m.in_proj.modules() if isinstance(m_, nn.Linear)] elif isinstance(m, GatedLinearUnit): params_2d += [m.wg1_fused.weight, m.w2.weight] - print("0d", type(params_0d), len(params_0d)) - print("1d", type(params_1d), len(params_1d)) - print("2d", type(params_2d), len(params_2d)) params_all = set(sum([params_0d, params_1d, params_2d], [])) + print("Mup:", sum([p.numel() for p in params_all])) + print("Base:", sum([p.numel() for p in model.parameters()])) for p in model.parameters(): assert p in params_all, p.shape optimizer = optim.AdamW( From 39c5832b5f38f284d85d753b851cf8132c11ac68 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 16:55:04 -0400 Subject: [PATCH 42/49] diag print --- main_training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/main_training.py b/main_training.py index a175a81e..5b36e3c8 100644 --- a/main_training.py +++ b/main_training.py @@ -125,6 +125,7 @@ def main(**kwargs): ] params_1d = [] params_2d = [] + print(model) for m in model.modules(): if isinstance(m, WordEmbedding): params_1d.append(m.emb.weight) From 476dca59fc52bf8a048166e82f8318e3a2fa7481 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 17:03:30 -0400 Subject: [PATCH 43/49] Use orig params --- main_training.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/main_training.py b/main_training.py index 5b36e3c8..3121f97d 100644 --- a/main_training.py +++ b/main_training.py @@ -94,7 +94,7 @@ def main(**kwargs): auto_wrap_policy=wrapping_policy, mixed_precision=mixed_precision_policy, sharding_strategy=sharding_strategy_policy, - use_orig_params=cfg.use_torch_compile, + use_orig_params=True, device_id=torch.cuda.current_device(), limit_all_gathers=True, param_init_fn=param_init_fn, @@ -125,7 +125,6 @@ def main(**kwargs): ] params_1d = [] params_2d = [] - print(model) for m in model.modules(): if isinstance(m, WordEmbedding): params_1d.append(m.emb.weight) From a11abf73d35f9919f5dfdb396de71f2d77445e37 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 17:08:04 -0400 Subject: [PATCH 44/49] Remove default lr arg --- main_training.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/main_training.py b/main_training.py index 3121f97d..7b128960 100644 --- a/main_training.py +++ b/main_training.py @@ -136,16 +136,10 @@ def main(**kwargs): params_2d += [m.dense.weight,] + [m_.weight for m_ in m.in_proj.modules() if isinstance(m_, nn.Linear)] elif isinstance(m, GatedLinearUnit): params_2d += [m.wg1_fused.weight, m.w2.weight] - params_all = set(sum([params_0d, params_1d, params_2d], [])) - print("Mup:", sum([p.numel() for p in params_all])) - print("Base:", sum([p.numel() for p in model.parameters()])) - for p in model.parameters(): - assert p in params_all, p.shape optimizer = optim.AdamW( {"params": params_0d, "lr": cfg.learning_rate * llama_config.mup_0d_lr}, {"params": params_1d, "lr": cfg.learning_rate * llama_config.mup_1d_lr}, {"params": params_2d, "lr": cfg.learning_rate * llama_config.mup_2d_lr}, - lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1, ) From f2c5590c1c7e6322265eb61cdca6be5ab2321b2e Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 17:11:09 -0400 Subject: [PATCH 45/49] Enlist param groups --- main_training.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/main_training.py b/main_training.py index 7b128960..6c13ee3b 100644 --- a/main_training.py +++ b/main_training.py @@ -137,9 +137,11 @@ def main(**kwargs): elif isinstance(m, GatedLinearUnit): params_2d += [m.wg1_fused.weight, m.w2.weight] optimizer = optim.AdamW( - {"params": params_0d, "lr": cfg.learning_rate * llama_config.mup_0d_lr}, - {"params": params_1d, "lr": cfg.learning_rate * llama_config.mup_1d_lr}, - {"params": params_2d, "lr": cfg.learning_rate * llama_config.mup_2d_lr}, + [ + {"params": params_0d, "lr": cfg.learning_rate * llama_config.mup_0d_lr}, + {"params": params_1d, "lr": cfg.learning_rate * llama_config.mup_1d_lr}, + {"params": params_2d, "lr": cfg.learning_rate * llama_config.mup_2d_lr}, + ], betas=(0.9, 0.95), weight_decay=0.1, ) From 63a834a38b25fd7f922abc55ad75095eebaa6a6e Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 19 Jul 2024 17:14:40 -0400 Subject: [PATCH 46/49] divide by mup scales --- main_training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/main_training.py b/main_training.py index 6c13ee3b..02fa2919 100644 --- a/main_training.py +++ b/main_training.py @@ -139,8 +139,8 @@ def main(**kwargs): optimizer = optim.AdamW( [ {"params": params_0d, "lr": cfg.learning_rate * llama_config.mup_0d_lr}, - {"params": params_1d, "lr": cfg.learning_rate * llama_config.mup_1d_lr}, - {"params": params_2d, "lr": cfg.learning_rate * llama_config.mup_2d_lr}, + {"params": params_1d, "lr": cfg.learning_rate * llama_config.mup_1d_lr / llama_config.emb_dim**.5}, + {"params": params_2d, "lr": cfg.learning_rate * llama_config.mup_2d_lr / llama_config.emb_dim}, ], betas=(0.9, 0.95), weight_decay=0.1, From 5887896cf18c0a878e9f514d4b7f4703adba9c05 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 22 Jul 2024 10:48:34 -0400 Subject: [PATCH 47/49] Remove tele configs --- fms_fsdp/utils/config_utils.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 4dd45ff3..9ac84608 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -100,26 +100,6 @@ def get_model_config(model_variant): hidden_grow_factor=3.5, max_expected_seq_len=4096, ) - elif model_variant == "llama3_1.8b_tele16": - llama_config = LLaMAConfig( - src_vocab_size=128256, - emb_dim=2048, - nheads=32, - kvheads=2, - nlayers=24, - hidden_grow_factor=3.75, - max_expected_seq_len=4096, - ) - elif model_variant == "llama3_1.8b_tele4": - llama_config = LLaMAConfig( - src_vocab_size=128256, - emb_dim=2048, - nheads=16, - kvheads=4, - nlayers=24, - hidden_grow_factor=3.5, - max_expected_seq_len=4096, - ) elif model_variant == "llama3_70b": llama_config = LLaMAConfig( src_vocab_size=128256, From 4dd399820ab01bc013a1d36f0a30eaac18b3856a Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 22 Jul 2024 10:50:01 -0400 Subject: [PATCH 48/49] Don't change Llama2 small configs --- fms_fsdp/utils/config_utils.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 9ac84608..23e300fa 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -48,17 +48,12 @@ def get_model_config(model_variant): hidden_grow_factor=13824 / 5120, ) elif model_variant == "llama2_7b": - llama_config = LLaMAConfig( - hidden_grow_factor=3, - kvheads=8, - ) + llama_config = LLaMAConfig() elif model_variant == "llama2_1.4b": llama_config = LLaMAConfig( emb_dim=2048, nheads=16, nlayers=24, - hidden_grow_factor=3, - kvheads=4, ) elif model_variant == "llama3_8b": llama_config = LLaMAConfig( From 1491706c7a47cf16e53f7d6c3864810d8f9b65e7 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 22 Jul 2024 10:51:31 -0400 Subject: [PATCH 49/49] linting --- main_training.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/main_training.py b/main_training.py index 02fa2919..6ddf28a6 100644 --- a/main_training.py +++ b/main_training.py @@ -133,14 +133,24 @@ def main(**kwargs): if m.reversible and not m.tie_weights: params_1d.append(m.head.weight) elif isinstance(m, MultiHeadAttention): - params_2d += [m.dense.weight,] + [m_.weight for m_ in m.in_proj.modules() if isinstance(m_, nn.Linear)] + params_2d += [ + m.dense.weight, + ] + [m_.weight for m_ in m.in_proj.modules() if isinstance(m_, nn.Linear)] elif isinstance(m, GatedLinearUnit): params_2d += [m.wg1_fused.weight, m.w2.weight] optimizer = optim.AdamW( [ {"params": params_0d, "lr": cfg.learning_rate * llama_config.mup_0d_lr}, - {"params": params_1d, "lr": cfg.learning_rate * llama_config.mup_1d_lr / llama_config.emb_dim**.5}, - {"params": params_2d, "lr": cfg.learning_rate * llama_config.mup_2d_lr / llama_config.emb_dim}, + { + "params": params_1d, + "lr": cfg.learning_rate + * llama_config.mup_1d_lr + / llama_config.emb_dim**0.5, + }, + { + "params": params_2d, + "lr": cfg.learning_rate * llama_config.mup_2d_lr / llama_config.emb_dim, + }, ], betas=(0.9, 0.95), weight_decay=0.1,