diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index e8b4df06..86df2fce 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -36,6 +36,7 @@ class train_config: learning_rate: float = 3e-4 grad_clip_thresh: float = 1.0 seed: int = 2023 + reset_stepcount: bool = False # profiling use_profiler: bool = False @@ -51,3 +52,13 @@ class train_config: # compile use_torch_compile: bool = True + + # muP scale params + mup_emb_scale: float = 0 + mup_head_scale: float = 0 + mup_ffn_init: float = 0 + mup_attn_init: float = 0 + mup_attn_temp: float = 0 + mup_0d_lr: float = 0 + mup_1d_lr: float = 0 + mup_2d_lr: float = 0 diff --git a/fms_fsdp/policies/param_init.py b/fms_fsdp/policies/param_init.py index 49655b26..81504539 100644 --- a/fms_fsdp/policies/param_init.py +++ b/fms_fsdp/policies/param_init.py @@ -1,18 +1,25 @@ import torch -from fms.modules.attention import MultiHeadAttention +from fms.modules.attention import QKV, MultiHeadAttention from fms.modules.embedding import WordEmbedding from fms.modules.feedforward import GatedLinearUnit from fms.modules.layernorm import LayerNormParameterized # for details, read https://github.com/foundation-model-stack/fms-fsdp/issues/64 -def param_init_function(module): - if ( - isinstance(module, MultiHeadAttention) - or isinstance(module, WordEmbedding) - or isinstance(module, GatedLinearUnit) - or isinstance(module, LayerNormParameterized) - ): +def param_init_function(module, cfg): + scales = { + MultiHeadAttention: cfg.mup_attn_init, + QKV: cfg.mup_attn_init, + GatedLinearUnit: cfg.mup_ffn_init, + WordEmbedding: (cfg.mup_1d_init, cfg.mup_emb_scale, cfg.mup_head_scale), + LayerNormParameterized: 1, + } + scale_keys = list(scales.keys()) + scale_vals = list(scales.values()) + type_id = [isinstance(module, x) for x in scale_keys] + is_resettable = sum(type_id) + if is_resettable: + module_type_id = type_id.index(True) module.to_empty(device=torch.cuda.current_device()) with torch.no_grad(): - module.reset_parameters() + module.reset_parameters(scale=scale_vals[module_type_id]) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 9d3f0386..23e300fa 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -115,7 +115,22 @@ def get_model_config(model_variant): hidden_grow_factor=3.5, max_expected_seq_len=4096, ) + elif model_variant == "llama3_194m_4k": + llama_config = LLaMAConfig( + src_vocab_size=128256, + emb_dim=1024, + nheads=8, + nlayers=10, + max_expected_seq_len=4096, + ) else: raise ValueError(f"model variant {model_variant} not supported.") return llama_config + + +def set_mup_from_cfg(job_cfg, model_cfg): + fields = {k: v for k, v in vars(job_cfg).items() if "mup" in k and v > 0} + for f in fields: + setattr(model_cfg, f, fields[f]) + return model_cfg diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py index 9b507e6b..80db6362 100644 --- a/fms_fsdp/utils/train_utils.py +++ b/fms_fsdp/utils/train_utils.py @@ -77,6 +77,7 @@ def train( start = time.time() loop_start = time.time() + train_loss = -1 for batch_idx, (input, label) in enumerate(train_loader, start=start_step + 1): if batch_idx > cfg.num_steps: break @@ -186,7 +187,7 @@ def setup_environ_flags(): os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1) -def get_policies(cfg, rank, block): +def get_policies(cfg, rank, block, model_cfg): """Get policies for mixed precision, wrapping, sharding, ac and param init function.""" # mixed precision @@ -230,7 +231,7 @@ def get_policies(cfg, rank, block): # param init function if cfg.low_cpu_fsdp: - param_init_fn = param_init_function + param_init_fn = partial(param_init_function, cfg=model_cfg) else: param_init_fn = None diff --git a/main_training.py b/main_training.py index 9c4c5e44..6ddf28a6 100644 --- a/main_training.py +++ b/main_training.py @@ -3,15 +3,24 @@ import fire import torch +import torch.nn as nn import torch.optim as optim from fms.models.llama import LLaMA, LLaMABlock +from fms.modules.attention import MultiHeadAttention +from fms.modules.embedding import WordEmbedding +from fms.modules.feedforward import GatedLinearUnit +from fms.modules.layernorm import LayerNormParameterized from torch import distributed as dist from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.optim.lr_scheduler import LambdaLR from fms_fsdp import config from fms_fsdp.utils.checkpointing_utils import Checkpointer -from fms_fsdp.utils.config_utils import get_model_config, update_config +from fms_fsdp.utils.config_utils import ( + get_model_config, + set_mup_from_cfg, + update_config, +) from fms_fsdp.utils.dataloader_utils import get_data_loader, get_dummy_loader from fms_fsdp.utils.train_utils import ( get_policies, @@ -45,18 +54,9 @@ def main(**kwargs): torch.cuda.empty_cache() setup_environ_flags() - # get policy - block = LLaMABlock - ( - mixed_precision_policy, - wrapping_policy, - sharding_strategy_policy, - apply_selective_ac, - param_init_fn, - ) = get_policies(cfg, rank, block) - # get fms model llama_config = get_model_config(cfg.model_variant) + llama_config = set_mup_from_cfg(cfg, llama_config) if cfg.low_cpu_fsdp: with torch.device("meta"): model = LLaMA(llama_config) @@ -78,13 +78,23 @@ def main(**kwargs): if rank == 0: print("Datasets constructed!") + # get policy + block = LLaMABlock + ( + mixed_precision_policy, + wrapping_policy, + sharding_strategy_policy, + apply_selective_ac, + param_init_fn, + ) = get_policies(cfg, rank, block, llama_config) + # FSDP model = FSDP( model, auto_wrap_policy=wrapping_policy, mixed_precision=mixed_precision_policy, sharding_strategy=sharding_strategy_policy, - use_orig_params=cfg.use_torch_compile, + use_orig_params=True, device_id=torch.cuda.current_device(), limit_all_gathers=True, param_init_fn=param_init_fn, @@ -110,8 +120,40 @@ def main(**kwargs): model = torch.compile(model) # Optimizer + params_0d = [p for name, p in model.named_parameters() if "bias" in name] + [ + m.weight for m in model.modules() if isinstance(m, LayerNormParameterized) + ] + params_1d = [] + params_2d = [] + for m in model.modules(): + if isinstance(m, WordEmbedding): + params_1d.append(m.emb.weight) + if m.abs_pos: + params_1d.append(m.pos_emb.weight) + if m.reversible and not m.tie_weights: + params_1d.append(m.head.weight) + elif isinstance(m, MultiHeadAttention): + params_2d += [ + m.dense.weight, + ] + [m_.weight for m_ in m.in_proj.modules() if isinstance(m_, nn.Linear)] + elif isinstance(m, GatedLinearUnit): + params_2d += [m.wg1_fused.weight, m.w2.weight] optimizer = optim.AdamW( - model.parameters(), lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1 + [ + {"params": params_0d, "lr": cfg.learning_rate * llama_config.mup_0d_lr}, + { + "params": params_1d, + "lr": cfg.learning_rate + * llama_config.mup_1d_lr + / llama_config.emb_dim**0.5, + }, + { + "params": params_2d, + "lr": cfg.learning_rate * llama_config.mup_2d_lr / llama_config.emb_dim, + }, + ], + betas=(0.9, 0.95), + weight_decay=0.1, ) # optionally load from checkpoint (when continue pretraining) @@ -122,8 +164,16 @@ def main(**kwargs): model, optimizer, None, - path=os.path.join(cfg.ckpt_load_path, "checkpoints/"), + path=os.path.join(cfg.ckpt_load_path, "checkpoints/") + if not os.path.isfile(cfg.ckpt_load_path) + else cfg.ckpt_load_path, + strict=False, ) + if cfg.reset_stepcount: + start_step = 0 + # Override loaded optim hyperparams with the current values + for g in optimizer.param_groups: + g["initial_lr"] = cfg.learning_rate # LR schedule warmup_interval = min(2000, cfg.num_steps // 20) @@ -156,6 +206,8 @@ def main(**kwargs): tokens_seen, ) + checkpointer.save_single_file(cfg.num_steps, model) + dist.barrier() dist.destroy_process_group()