Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
da29217
Set llama2-1.4b to gqa
daviswer May 22, 2024
41ae740
Add singlefile ckp saving/conversion
daviswer May 28, 2024
5171b5d
Turn off GQA on 1.4B
daviswer May 31, 2024
abd5b19
GQA on, add for 7b
daviswer Jun 5, 2024
8c31a0c
Merge branch 'foundation-model-stack:main' into main
daviswer Jun 6, 2024
0ac0a5f
Add llama3 tele cfg
daviswer Jun 10, 2024
8caeaa2
Add missing paren
daviswer Jun 10, 2024
941e98f
Back to gqa4 for llama3
daviswer Jun 13, 2024
44edc0d
Nonstrict ckpt load
daviswer Jun 19, 2024
a48a055
If singlefile load, don't append "checkpoints" folder
daviswer Jun 19, 2024
9031328
Add reset stepcount field
daviswer Jun 21, 2024
0e3430a
Add reset stepcount support
daviswer Jun 21, 2024
45d7e41
Override optimizer LR values with desired
daviswer Jun 21, 2024
9cb0329
gqa16
daviswer Jun 24, 2024
756c3ee
GOTHERE
daviswer Jun 24, 2024
fd28fb7
No gothere
daviswer Jun 24, 2024
ffded35
Nonstrict fsdp load
daviswer Jun 25, 2024
1050d1d
Nonstrict fsdp load pt2
daviswer Jun 25, 2024
166c01d
Stop nonstrict fsdp load
daviswer Jun 25, 2024
fee4c48
Separate gqa4 and 16 cfgs
daviswer Jul 1, 2024
6f3fd09
Fix indent
daviswer Jul 1, 2024
f5a707e
Add mini llama cfg
daviswer Jul 16, 2024
57e3ffd
mini llama3 vsize
daviswer Jul 16, 2024
b2e6ae0
Add muP fields, auto-update model cfg
daviswer Jul 17, 2024
4a02c82
Add mup scaling to fsdp init params
daviswer Jul 18, 2024
22c54a6
Only set mup cfg if >0
daviswer Jul 18, 2024
af52614
1d init mup
daviswer Jul 19, 2024
57ed6f9
Attempt mup lrs
daviswer Jul 19, 2024
372e1d2
cleanup, typofix
daviswer Jul 19, 2024
c0d1d1f
diag print
daviswer Jul 19, 2024
2017a98
Non double list comp
daviswer Jul 19, 2024
9a77a2b
diag print
daviswer Jul 19, 2024
6c01a0b
Stop named params
daviswer Jul 19, 2024
101652b
List sum
daviswer Jul 19, 2024
49341e1
diag print
daviswer Jul 19, 2024
58c1662
diag print
daviswer Jul 19, 2024
a14f57e
diag print
daviswer Jul 19, 2024
5c8d8c4
diag print
daviswer Jul 19, 2024
d0e4888
diag print
daviswer Jul 19, 2024
e9701a1
Iterate over submodules explicitly
daviswer Jul 19, 2024
0c46c3a
linear submods only
daviswer Jul 19, 2024
58ce680
diag print
daviswer Jul 19, 2024
39c5832
diag print
daviswer Jul 19, 2024
476dca5
Use orig params
daviswer Jul 19, 2024
a11abf7
Remove default lr arg
daviswer Jul 19, 2024
f2c5590
Enlist param groups
daviswer Jul 19, 2024
63a834a
divide by mup scales
daviswer Jul 19, 2024
5887896
Remove tele configs
daviswer Jul 22, 2024
4dd3998
Don't change Llama2 small configs
daviswer Jul 22, 2024
1491706
linting
daviswer Jul 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions fms_fsdp/config/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class train_config:
learning_rate: float = 3e-4
grad_clip_thresh: float = 1.0
seed: int = 2023
reset_stepcount: bool = False

# profiling
use_profiler: bool = False
Expand All @@ -51,3 +52,13 @@ class train_config:

# compile
use_torch_compile: bool = True

# muP scale params
mup_emb_scale: float = 0
mup_head_scale: float = 0
mup_ffn_init: float = 0
mup_attn_init: float = 0
mup_attn_temp: float = 0
mup_0d_lr: float = 0
mup_1d_lr: float = 0
mup_2d_lr: float = 0
25 changes: 16 additions & 9 deletions fms_fsdp/policies/param_init.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
import torch
from fms.modules.attention import MultiHeadAttention
from fms.modules.attention import QKV, MultiHeadAttention
from fms.modules.embedding import WordEmbedding
from fms.modules.feedforward import GatedLinearUnit
from fms.modules.layernorm import LayerNormParameterized


# for details, read https://github.com/foundation-model-stack/fms-fsdp/issues/64
def param_init_function(module):
if (
isinstance(module, MultiHeadAttention)
or isinstance(module, WordEmbedding)
or isinstance(module, GatedLinearUnit)
or isinstance(module, LayerNormParameterized)
):
def param_init_function(module, cfg):
scales = {
MultiHeadAttention: cfg.mup_attn_init,
QKV: cfg.mup_attn_init,
GatedLinearUnit: cfg.mup_ffn_init,
WordEmbedding: (cfg.mup_1d_init, cfg.mup_emb_scale, cfg.mup_head_scale),
LayerNormParameterized: 1,
}
scale_keys = list(scales.keys())
scale_vals = list(scales.values())
type_id = [isinstance(module, x) for x in scale_keys]
is_resettable = sum(type_id)
if is_resettable:
module_type_id = type_id.index(True)
module.to_empty(device=torch.cuda.current_device())
with torch.no_grad():
module.reset_parameters()
module.reset_parameters(scale=scale_vals[module_type_id])
15 changes: 15 additions & 0 deletions fms_fsdp/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,22 @@ def get_model_config(model_variant):
hidden_grow_factor=3.5,
max_expected_seq_len=4096,
)
elif model_variant == "llama3_194m_4k":
llama_config = LLaMAConfig(
src_vocab_size=128256,
emb_dim=1024,
nheads=8,
nlayers=10,
max_expected_seq_len=4096,
)
else:
raise ValueError(f"model variant {model_variant} not supported.")

return llama_config


def set_mup_from_cfg(job_cfg, model_cfg):
fields = {k: v for k, v in vars(job_cfg).items() if "mup" in k and v > 0}
for f in fields:
setattr(model_cfg, f, fields[f])
return model_cfg
5 changes: 3 additions & 2 deletions fms_fsdp/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def train(

start = time.time()
loop_start = time.time()
train_loss = -1
for batch_idx, (input, label) in enumerate(train_loader, start=start_step + 1):
if batch_idx > cfg.num_steps:
break
Expand Down Expand Up @@ -186,7 +187,7 @@ def setup_environ_flags():
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)


def get_policies(cfg, rank, block):
def get_policies(cfg, rank, block, model_cfg):
"""Get policies for mixed precision, wrapping, sharding, ac and param init function."""

# mixed precision
Expand Down Expand Up @@ -230,7 +231,7 @@ def get_policies(cfg, rank, block):

# param init function
if cfg.low_cpu_fsdp:
param_init_fn = param_init_function
param_init_fn = partial(param_init_function, cfg=model_cfg)
else:
param_init_fn = None

Expand Down
80 changes: 66 additions & 14 deletions main_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,24 @@

import fire
import torch
import torch.nn as nn
import torch.optim as optim
from fms.models.llama import LLaMA, LLaMABlock
from fms.modules.attention import MultiHeadAttention
from fms.modules.embedding import WordEmbedding
from fms.modules.feedforward import GatedLinearUnit
from fms.modules.layernorm import LayerNormParameterized
from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim.lr_scheduler import LambdaLR

from fms_fsdp import config
from fms_fsdp.utils.checkpointing_utils import Checkpointer
from fms_fsdp.utils.config_utils import get_model_config, update_config
from fms_fsdp.utils.config_utils import (
get_model_config,
set_mup_from_cfg,
update_config,
)
from fms_fsdp.utils.dataloader_utils import get_data_loader, get_dummy_loader
from fms_fsdp.utils.train_utils import (
get_policies,
Expand Down Expand Up @@ -45,18 +54,9 @@ def main(**kwargs):
torch.cuda.empty_cache()
setup_environ_flags()

# get policy
block = LLaMABlock
(
mixed_precision_policy,
wrapping_policy,
sharding_strategy_policy,
apply_selective_ac,
param_init_fn,
) = get_policies(cfg, rank, block)

# get fms model
llama_config = get_model_config(cfg.model_variant)
llama_config = set_mup_from_cfg(cfg, llama_config)
if cfg.low_cpu_fsdp:
with torch.device("meta"):
model = LLaMA(llama_config)
Expand All @@ -78,13 +78,23 @@ def main(**kwargs):
if rank == 0:
print("Datasets constructed!")

# get policy
block = LLaMABlock
(
mixed_precision_policy,
wrapping_policy,
sharding_strategy_policy,
apply_selective_ac,
param_init_fn,
) = get_policies(cfg, rank, block, llama_config)

# FSDP
model = FSDP(
model,
auto_wrap_policy=wrapping_policy,
mixed_precision=mixed_precision_policy,
sharding_strategy=sharding_strategy_policy,
use_orig_params=cfg.use_torch_compile,
use_orig_params=True,
device_id=torch.cuda.current_device(),
limit_all_gathers=True,
param_init_fn=param_init_fn,
Expand All @@ -110,8 +120,40 @@ def main(**kwargs):
model = torch.compile(model)

# Optimizer
params_0d = [p for name, p in model.named_parameters() if "bias" in name] + [
m.weight for m in model.modules() if isinstance(m, LayerNormParameterized)
]
params_1d = []
params_2d = []
for m in model.modules():
if isinstance(m, WordEmbedding):
params_1d.append(m.emb.weight)
if m.abs_pos:
params_1d.append(m.pos_emb.weight)
if m.reversible and not m.tie_weights:
params_1d.append(m.head.weight)
elif isinstance(m, MultiHeadAttention):
params_2d += [
m.dense.weight,
] + [m_.weight for m_ in m.in_proj.modules() if isinstance(m_, nn.Linear)]
elif isinstance(m, GatedLinearUnit):
params_2d += [m.wg1_fused.weight, m.w2.weight]
optimizer = optim.AdamW(
model.parameters(), lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1
[
{"params": params_0d, "lr": cfg.learning_rate * llama_config.mup_0d_lr},
{
"params": params_1d,
"lr": cfg.learning_rate
* llama_config.mup_1d_lr
/ llama_config.emb_dim**0.5,
},
{
"params": params_2d,
"lr": cfg.learning_rate * llama_config.mup_2d_lr / llama_config.emb_dim,
},
],
betas=(0.9, 0.95),
weight_decay=0.1,
)

# optionally load from checkpoint (when continue pretraining)
Expand All @@ -122,8 +164,16 @@ def main(**kwargs):
model,
optimizer,
None,
path=os.path.join(cfg.ckpt_load_path, "checkpoints/"),
path=os.path.join(cfg.ckpt_load_path, "checkpoints/")
if not os.path.isfile(cfg.ckpt_load_path)
else cfg.ckpt_load_path,
strict=False,
)
if cfg.reset_stepcount:
start_step = 0
# Override loaded optim hyperparams with the current values
for g in optimizer.param_groups:
g["initial_lr"] = cfg.learning_rate

# LR schedule
warmup_interval = min(2000, cfg.num_steps // 20)
Expand Down Expand Up @@ -156,6 +206,8 @@ def main(**kwargs):
tokens_seen,
)

checkpointer.save_single_file(cfg.num_steps, model)

dist.barrier()
dist.destroy_process_group()

Expand Down