Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
cd9c74b
add peft configs to fast moe augmentation
willmj Mar 24, 2025
f99ae71
fmt
willmj Mar 25, 2025
7b453cd
fix: lora constants
willmj Mar 26, 2025
c3e6a48
fix: check
willmj Mar 27, 2025
57f2a37
feat: lora case (draft)
willmj Mar 28, 2025
ca863d1
fix: revert min count
willmj Apr 2, 2025
a26db68
fix: regex
willmj Apr 2, 2025
da4ccf7
feat: lora in fsdp utils save
willmj Apr 2, 2025
25e9155
fix: lora keys to map to original dict
willmj Apr 2, 2025
6ae1e1c
feat: handle lora A and B for converting checkpoint
willmj Apr 3, 2025
866c2e0
fix: scatter keys fqdn -> scatter keys
willmj Apr 4, 2025
cbb222b
fix: save for adapter model (draft)
willmj Apr 4, 2025
86f9d8b
fix: associate w1, w2, w3 lora keys to input output linear lora layers
willmj Apr 8, 2025
a5bdea2
fix: comment
willmj Apr 8, 2025
5870dae
Merge branch 'main' into lora-fast-moe-v1
willmj Apr 9, 2025
6b030aa
fix: block off lora on w1, w2, w3
willmj Apr 9, 2025
587f2a4
fix: if condition flip
willmj Apr 9, 2025
f247d26
fix: ignore weights for lora
willmj Apr 9, 2025
f2bb29f
fix: if lora in scattermoe prepare, don't put weights in map
willmj Apr 9, 2025
ebbcd57
fix: modules
willmj Apr 9, 2025
4a838ca
fix: with lora self w1 and w2 are not gauranteed to exist
willmj Apr 9, 2025
582379b
fix: if w1, w2, w3 exist
willmj Apr 9, 2025
eb60537
fix: mapping to be router.layer
willmj Apr 9, 2025
5e1bb52
fix: lora condition
willmj Apr 9, 2025
8b0144a
fix: pass lora into infer
willmj Apr 9, 2025
f8e83a2
lora utils
willmj Apr 9, 2025
227c8b9
fix: .layer
willmj Apr 9, 2025
57a6818
add .layer
willmj Apr 9, 2025
809a917
fix: update state dict when loaded with lora before operations
willmj Apr 9, 2025
a0887cd
fix: remove duplicative code
willmj Apr 9, 2025
b754a92
fix: use new state dict
willmj Apr 9, 2025
1ef1ebb
lint + fmt
willmj Apr 9, 2025
b185ebf
fix: trailing whitespacE
willmj Apr 9, 2025
6f3852f
fix: target modules dictate which scatterMoE layers are trained
willmj Apr 10, 2025
8162d99
lint + fmt
willmj Apr 10, 2025
8ce09d8
fix: cleanup
willmj Apr 10, 2025
4d05135
lint
willmj Apr 10, 2025
2d959fc
fmt
willmj Apr 11, 2025
1eea165
fix: type
willmj Apr 11, 2025
bff9b04
fix: default target modules
willmj Apr 11, 2025
07d38b0
fix: ft logic
willmj Apr 11, 2025
f9176c5
fix: fmt + lint
willmj Apr 11, 2025
d98b2c9
fix: logic for lora
willmj Apr 11, 2025
570bf34
fix: logic + docs
willmj Apr 11, 2025
4cd4288
fix: mistype
willmj Apr 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugins/accelerated-moe/.pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ notes-rgx=
[REFACTORING]

# Maximum number of nested blocks for function / method body
max-nested-blocks=5
max-nested-blocks=6

# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
Expand Down
7 changes: 4 additions & 3 deletions plugins/accelerated-moe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ This library contains plugins to accelerate finetuning with the following optimi

Plugin | Description | Depends | Loading | Augmentation | Callbacks
--|--|--|--|--|--
[scattermoe](./src/fms_acceleration_moe/framework_plugin_scattermoe.py) | MoE Expert Parallel with Triton Kernels from scattermoe (& megablocks) | ScatterMoE / extracted kernels from megablocks | ✅ | | ✅
[scattermoe](./src/fms_acceleration_moe/framework_plugin_scattermoe.py) | MoE Expert Parallel with Triton Kernels from scattermoe (& megablocks) | ScatterMoE / extracted kernels from megablocks | | ✅ | ✅


## Adding New Models
Expand All @@ -33,6 +33,8 @@ python -m fms_acceleration_moe.utils.checkpoint_utils \
mistralai/Mixtral-8x7B-Instruct-v0.1
```

If running with fms-hf-tuning, this script runs automatically if the `fast_moe` parameter is set.

## Code Extracted from Megablocks

Notes on code extraction:
Expand Down Expand Up @@ -81,9 +83,8 @@ Triton Kernels are copied into [scattermoe_utils](./src/fms_acceleration_moe/uti
### Known Issues

These are currently some known issues not yet resolved:
- should eventually remove the dependency on an external `kernel-hyperdrive` repository.
- now support only loading *sharded* `safetensor` non-GGUF MoE checkpoints. This is a reasonable assumption since MoE checkpoints are typically above the size limit that prevents it being saved into a single checkpoint filed.
- when used together with FSDP, the FSDP's `clip_grad_norm` will not properly compute for `ScatterMoE`, see [issue here](https://github.com/foundation-model-stack/fms-acceleration/issues/109).
- when used to lora train a model, if training experts on adapter model, the model will fail to run inference in vLLM/vanilla HF because of restrictions to parameter types. If running inference do not select `input_linear` and `output_linear` as target modules when lora training.



Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def augmentation(
modifiable_args: Tuple[LoraConfig],
):
rank, world_size = 0, 1
(peft_config,) = modifiable_args
if torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
# we do not need to use the fallback as this is wrapped in an `is_initialized` block
Expand All @@ -97,6 +98,7 @@ def augmentation(
ep_degree=self._ep_degree,
disable_distributed=self._disable_distributed,
mixed_precision=False, # Currently this is hardcoded to OFF
lora_config=peft_config,
)
return model, modifiable_args

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@
KEY_MODEL = "model"
KEY_OPTIMIZER = "optimizer"

ADAPTER_CONFIG_NAME = "adapter_config.json"
ADAPTER_WEIGHTS_NAME = "adapter_model.bin"
ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"

# Below are rewrite of HF FSDP model saving functions to be able to handle
# that the parameters are now a mixture of regular and Dtensors.
# - these functions are found in accelerate.utils.fsdp_utils.py
Expand Down Expand Up @@ -110,15 +114,30 @@ def save_fsdp_optimizer(
# get the state dicts for model and optimize
(model_state_dict, optimizer_state_dict) = get_state_dict(model, optimizer)

# - save model
ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
os.makedirs(ckpt_model, exist_ok=True)
logger.info(f"Saving model to {ckpt_model}")
dcp.save(
state_dict={KEY_MODEL: model_state_dict},
storage_writer=dcp.FileSystemWriter(ckpt_model),
planner=DefaultSavePlanner(),
)
# filter out lora state dict
lora_state_dict = {
k: v for k, v in model_state_dict.items() if "lora_A" in k or "lora_B" in k
}

# - save mode
if lora_state_dict:
ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
os.makedirs(ckpt_model, exist_ok=True)
logger.info(f"Saving lora model to {ckpt_model}")
dcp.save(
state_dict={KEY_MODEL: lora_state_dict},
storage_writer=dcp.FileSystemWriter(ckpt_model),
planner=DefaultSavePlanner(),
)
else:
ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
os.makedirs(ckpt_model, exist_ok=True)
logger.info(f"Saving ft model to {ckpt_model}")
dcp.save(
state_dict={KEY_MODEL: model_state_dict},
storage_writer=dcp.FileSystemWriter(ckpt_model),
planner=DefaultSavePlanner(),
)
logger.info(f"Model saved to {ckpt_model}")

# - save optimizer
Expand Down Expand Up @@ -303,6 +322,7 @@ def get_state_dict_from_safe_checkpoint(safe_checkpoint_dir: str):
# can restore the checkpoint to be loaded by the original architecture.
def recover_original_state_dict_from_checkpoint(
sd: Dict,
lora: bool,
pretrained_model_name_or_path: str = None,
):
"""
Expand All @@ -323,6 +343,19 @@ def recover_original_state_dict_from_checkpoint(
# config
config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path)

# if lora, check for input/output layers
ip_op_layers = False
router_layer = False
if lora:
for name, _ in sd.items():
if "w1" in name:
ip_op_layers = True
break
for name, _ in sd.items():
if "router" in name:
router_layer = True
break

(
_,
router_name,
Expand All @@ -340,11 +373,13 @@ def recover_original_state_dict_from_checkpoint(

def _infer_prefixes_and_module_names(
sd_keys: List[str],
min_count: int = 3,
lora: bool,
):
min_count = 2 if lora else 3

_name = "|".join([PARAM_NAME_ROUTER_SCATTERMOE, *PARAM_NAME_WEIGHT_SCATTERMOE])
# pylint: disable=anomalous-backslash-in-string
_reg = re.compile(f"(.*)\.({_name})\.weight")
_reg = re.compile(rf"(.*)\.({_name})\.(?:weight|lora_A|lora_B)")
found = {}

for k in sd_keys:
Expand All @@ -364,7 +399,7 @@ def _infer_prefixes_and_module_names(

return results

for prefix in _infer_prefixes_and_module_names(sd.keys()):
for prefix in _infer_prefixes_and_module_names(sd.keys(), lora):
prefix = prefix.split(".")
prefix, module_name = ".".join(prefix[:-1]), prefix[-1]

Expand All @@ -390,6 +425,8 @@ def _infer_prefixes_and_module_names(
module_name,
router_name,
expert_name,
ip_op_layers=ip_op_layers,
router_layer=router_layer,
)

model2scatter = defaultdict(dict)
Expand All @@ -398,6 +435,7 @@ def _infer_prefixes_and_module_names(
# model param and they need to be cat
for scatter_key, list_of_params in checkpoint_metadata.items():
scatter_key_fqdn = ".".join([prefix, module_name, scatter_key])

scatter_param = sd[scatter_key_fqdn]

# remove from state dict
Expand Down Expand Up @@ -443,8 +481,36 @@ def _infer_prefixes_and_module_names(
len(scatter_keys) > 0
), f"Obtained zero scatter keys for model_key '{model_key}'"

if len(scatter_keys) == 1:
if lora:
for i, lora_key in enumerate(scatter_keys):
model_key_parts = model_key.split(".")
weight_index = model_key_parts.index("weight")
# Replace the "layer.weight" part with "layer.lora_A.weight" or
# "layer.lora_B.weight"
if "lora_A" in lora_key:
model_key_parts[weight_index] = "lora_A.weight"
elif "lora_B" in lora_key:
model_key_parts[weight_index] = "lora_B.weight"
# Rebuild the model_key and assign the corresponding scatter_param
new_model_key = ".".join(model_key_parts)
if len(scatter_keys) == 2:
sd[new_model_key] = scatter_params[lora_key]
else:
if "lora_A" in new_model_key:
filtered_keys = [k for k in scatter_keys if "lora_A" in k]
elif "lora_B" in new_model_key:
filtered_keys = [k for k in scatter_keys if "lora_B" in k]
else:
raise ValueError(
f"Unexpected LoRA key type in {new_model_key}"
)
sd[new_model_key] = torch.cat(
[scatter_params[k] for k in filtered_keys], dim=1
)

elif len(scatter_keys) == 1:
sd[model_key] = scatter_params[scatter_keys[0]]

else:
# unfortunately, there this is a in
# scattermoe_state_dict._maybe_reshape_scattermoe_expert_weights
Expand All @@ -466,31 +532,56 @@ def save_sharded_safetensors(
input_state_dict: Dict,
save_directory: str,
metadata: Dict,
lora: bool,
max_shard_size: Union[int, str] = "5GB",
):
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
input_state_dict,
filename_pattern=filename_pattern,
max_shard_size=max_shard_size,
)
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
# Save the index
with open(
os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8"
) as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)

filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {tensor: input_state_dict[tensor].contiguous() for tensor in tensors}
save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
if not lora:
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
input_state_dict,
filename_pattern=filename_pattern,
max_shard_size=max_shard_size,
)

index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
# Save the index
with open(
os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8"
) as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)

filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {
tensor: input_state_dict[tensor].contiguous() for tensor in tensors
}
save_file(
shard, os.path.join(save_directory, shard_file), metadata=metadata
)
else:
filename_pattern = ADAPTER_SAFE_WEIGHTS_NAME.replace(
".bin", "{suffix}.bin"
).replace(".safetensors", "{suffix}.safetensors")
state_dict_split = split_torch_state_dict_into_shards(
input_state_dict,
filename_pattern=filename_pattern,
max_shard_size=max_shard_size,
)
filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {
tensor: input_state_dict[tensor].contiguous() for tensor in tensors
}
save_file(
shard, os.path.join(save_directory, shard_file), metadata=metadata
)


# --------------------------- SCRIPT -------------------------
Expand Down Expand Up @@ -540,14 +631,28 @@ def recover_safetensors_from_dcp(
# get the state_dict
state_dict = loader(checkpoint_dir)

lora = False
new_state_dict = {}
for name, param in state_dict.items():
if "lora_A" in name or "lora_B" in name:
lora = True
if "base_model.model." in name:
name = name.replace("base_model.model.", "", 1)
if "default." in name:
name = name.replace("default.", "", 1)
new_state_dict[name] = param

# recover the original state dict
state_dict = recover_original_state_dict_from_checkpoint(state_dict, _name_or_path)
state_dict = recover_original_state_dict_from_checkpoint(
new_state_dict, lora, _name_or_path
)

# save it as a safetensors file
save_sharded_safetensors(
{k: v.contiguous() for k, v in state_dict.items()},
output_dir,
metadata={"format": "pt"},
lora=lora,
)


Expand Down
Loading