From cd9c74b79041902c6fc5947aaea6df9992d06cf1 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Mon, 24 Mar 2025 14:57:05 -0400 Subject: [PATCH 01/44] add peft configs to fast moe augmentation Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/framework_plugin_scattermoe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py index 1bb33871..88ef83e5 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py @@ -59,6 +59,7 @@ def augmentation( modifiable_args: Tuple[LoraConfig], ): rank, world_size = 0, 1 + (peft_config,) = modifiable_args if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() # we do not need to use the fallback as this is wrapped in an `is_initialized` block @@ -78,6 +79,7 @@ def augmentation( world_size=world_size, ep_degree=self._ep_degree, mixed_precision=False, # Currently this is hardcoded to OFF + lora_config=peft_config ) return model, modifiable_args From f99ae710821b28869c1f7d4acdb8f0c045e8d17e Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Tue, 25 Mar 2025 11:17:41 -0400 Subject: [PATCH 02/44] fmt Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/framework_plugin_scattermoe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py index 88ef83e5..b2acc96f 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py @@ -79,7 +79,7 @@ def augmentation( world_size=world_size, ep_degree=self._ep_degree, mixed_precision=False, # Currently this is hardcoded to OFF - lora_config=peft_config + lora_config=peft_config, ) return model, modifiable_args From 7b453cde43501065ae4f6eebcabbe426015a7cc2 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 26 Mar 2025 15:34:35 -0400 Subject: [PATCH 03/44] fix: lora constants Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/checkpoint_utils.py | 2 +- .../src/fms_acceleration_moe/utils/scattermoe_constants.py | 4 +++- .../src/fms_acceleration_moe/utils/scattermoe_prepare.py | 4 +++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index e6fe1ba6..73d00975 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -344,7 +344,7 @@ def _infer_prefixes_and_module_names( ): _name = "|".join([PARAM_NAME_ROUTER_SCATTERMOE, *PARAM_NAME_WEIGHT_SCATTERMOE]) # pylint: disable=anomalous-backslash-in-string - _reg = re.compile(f"(.*)\.({_name})\.weight") + _reg = re.compile(rf"(.*)\.({_name})\.(?:weight|lora_A\.weight|lora_B\.weight)") found = {} for k in sd_keys: diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_constants.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_constants.py index 2a6847be..a6b95541 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_constants.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_constants.py @@ -24,7 +24,9 @@ KEY_EXPERT_PARALLEL = "expert_parallel" DIM_EXPERT = 0 -KEY_SCATTERMOE_ROUTER = PARAM_NAME_ROUTER_SCATTERMOE + ".weight" +KEY_SCATTERMOE_ROUTER = "router.weight" +KEY_SCATTERMOE_LORA_A_ROUTER = "router.lora_A.weight" +KEY_SCATTERMOE_LORA_B_ROUTER = "router.lora_B.weight" # Currently out ScatterMoE drop supports an up/down proj, and # and optional gate_proj. diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py index 348c7b77..888f3487 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py @@ -34,6 +34,8 @@ KEY_EXPERT_PARALLEL, KEY_REPLICATE, KEY_SCATTERMOE_ROUTER, + KEY_SCATTERMOE_LORA_A_ROUTER, + KEY_SCATTERMOE_LORA_B_ROUTER, get_scattermoe_conv_spec_from_archs, ) from .scattermoe_state_dict import ( @@ -66,7 +68,7 @@ def _hook(grad): for weight_name, param in state_dict.items(): - if KEY_SCATTERMOE_ROUTER in weight_name: + if KEY_SCATTERMOE_ROUTER in weight_name or KEY_SCATTERMOE_LORA_A_ROUTER in weight_name or KEY_SCATTERMOE_LORA_B_ROUTER in weight_name: # if its the router, replicate param = distribute_tensor(param, device_mesh, reps + [Replicate()]) elif param.shape[0] > num_experts_per_device: From c3e6a48060a6fffd0adced60dd994d7c12b3340c Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 27 Mar 2025 16:09:21 -0400 Subject: [PATCH 04/44] fix: check Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/scattermoe_state_dict.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index e13f6ba5..118795a1 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -27,6 +27,8 @@ from .scattermoe_constants import ( DIM_EXPERT, KEY_SCATTERMOE_ROUTER, + KEY_SCATTERMOE_LORA_A_ROUTER, + KEY_SCATTERMOE_LORA_B_ROUTER, PARAM_NAME_WEIGHT_SCATTERMOE, ) @@ -295,7 +297,7 @@ def get_state_dict_from_checkpoint_metadata( # go by one weight at a time. for scatter_key, vs in checkpoint_metadata.items(): - if KEY_SCATTERMOE_ROUTER in scatter_key: + if KEY_SCATTERMOE_ROUTER in scatter_key or KEY_SCATTERMOE_LORA_A_ROUTER in scatter_key or KEY_SCATTERMOE_LORA_B_ROUTER in scatter_key: k, fi = vs[0] # only one item param = files[fi].get_tensor(k) From 57f2a37838ffa11330ab8cd0730ae8c10d232da0 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Fri, 28 Mar 2025 12:57:18 -0400 Subject: [PATCH 05/44] feat: lora case (draft) Signed-off-by: Will Johnson --- .../utils/checkpoint_utils.py | 3 +- .../utils/scattermoe_prepare.py | 4 +-- .../utils/scattermoe_state_dict.py | 31 ++++++++++++++----- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 73d00975..cafad381 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -340,7 +340,7 @@ def recover_original_state_dict_from_checkpoint( def _infer_prefixes_and_module_names( sd_keys: List[str], - min_count: int = 3, + min_count: int = 1, ): _name = "|".join([PARAM_NAME_ROUTER_SCATTERMOE, *PARAM_NAME_WEIGHT_SCATTERMOE]) # pylint: disable=anomalous-backslash-in-string @@ -398,6 +398,7 @@ def _infer_prefixes_and_module_names( # model param and they need to be cat for scatter_key, list_of_params in checkpoint_metadata.items(): scatter_key_fqdn = ".".join([prefix, module_name, scatter_key]) + scatter_param = sd[scatter_key_fqdn] # remove from state dict diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py index 888f3487..b935d4ff 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py @@ -93,8 +93,8 @@ def _hook(grad): ) # install gradient scaling hook - if KEY_SCATTERMOE_ROUTER not in weight_name: - param.register_hook(_hook) + if KEY_SCATTERMOE_ROUTER not in weight_name: # does this need to look for lora a and b as well? + param.register_hook(_hook) # register the sharded parameter onto the megablocks.dmoe mod.register_parameter(name, param) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index 118795a1..3eab484a 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -123,6 +123,8 @@ def _insert(L: List, i: int, v): n -= 1 L[i] = v + lora = False + # if expert_name = input_linear|output_linear|input_linear # - in this case will map # - input_linear: [w1, w3], output_linear: {w2} @@ -149,6 +151,12 @@ def _insert(L: List, i: int, v): # `w1.weight`: [...] _map = defaultdict(list) prefix = f"{prefix}.{instance_name}." + # Lora case where it prefix looks like base_model.model.model... + # instead of model... + if not prefix.startswith("model."): + prefix=prefix.replace("base_model.model.", "", 1) + lora=True + for k, stfile in weight_map.items(): if not k.startswith(prefix): continue @@ -165,15 +173,22 @@ def _insert(L: List, i: int, v): f"'{router_name}' or expert_name '{expert_name}'" ) if m.group(1) == router_name: - _map[KEY_SCATTERMOE_ROUTER].append((k, stfile)) + if lora: + k_lora_a = k.replace(".layer.", ".lora_A.") + _map[KEY_SCATTERMOE_LORA_A_ROUTER].append((k_lora_a, stfile)) + k_lora_b = k.replace(".layer.", ".lora_B.") + _map[KEY_SCATTERMOE_LORA_B_ROUTER].append((k_lora_b, stfile)) + else: + _map[KEY_SCATTERMOE_ROUTER].append((k, stfile)) elif m.group(1) in expert_name: - index = m.group(2) - index = 0 if index is None else int(index) - mod = None - for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): - _insert(_map[f"{mod}.weight"], index, (k, stfile)) - - assert mod is not None, f"cannot map '{rel_k}'" + if not lora: + index = m.group(2) + index = 0 if index is None else int(index) + mod = None + for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): + _insert(_map[f"{mod}.weight"], index, (k, stfile)) + + assert mod is not None, f"cannot map '{rel_k}'" if len(_map) == 0: raise ValueError( From ca863d12f24f499956e52fa344ea169c70ec5da0 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 2 Apr 2025 14:32:00 -0400 Subject: [PATCH 06/44] fix: revert min count Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/checkpoint_utils.py | 2 +- .../src/fms_acceleration_moe/utils/scattermoe_prepare.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index cafad381..4fadffa8 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -340,7 +340,7 @@ def recover_original_state_dict_from_checkpoint( def _infer_prefixes_and_module_names( sd_keys: List[str], - min_count: int = 1, + min_count: int = 3, ): _name = "|".join([PARAM_NAME_ROUTER_SCATTERMOE, *PARAM_NAME_WEIGHT_SCATTERMOE]) # pylint: disable=anomalous-backslash-in-string diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py index b935d4ff..20e69b1d 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py @@ -93,8 +93,9 @@ def _hook(grad): ) # install gradient scaling hook - if KEY_SCATTERMOE_ROUTER not in weight_name: # does this need to look for lora a and b as well? - param.register_hook(_hook) + if KEY_SCATTERMOE_ROUTER not in weight_name and KEY_SCATTERMOE_LORA_A_ROUTER not in weight_name and KEY_SCATTERMOE_LORA_B_ROUTER not in weight_name: + if param.requires_grad: + param.register_hook(_hook) # register the sharded parameter onto the megablocks.dmoe mod.register_parameter(name, param) From a26db682d6d1149e8bde77e5f02978e7a32b99ec Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 2 Apr 2025 14:53:28 -0400 Subject: [PATCH 07/44] fix: regex Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/checkpoint_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 4fadffa8..bd26440a 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -344,7 +344,7 @@ def _infer_prefixes_and_module_names( ): _name = "|".join([PARAM_NAME_ROUTER_SCATTERMOE, *PARAM_NAME_WEIGHT_SCATTERMOE]) # pylint: disable=anomalous-backslash-in-string - _reg = re.compile(rf"(.*)\.({_name})\.(?:weight|lora_A\.weight|lora_B\.weight)") + _reg = re.compile(rf"(.*)\.({_name})\.(?:weight|lora_A|lora_B)") found = {} for k in sd_keys: From da4ccf7e81c42d24b186e08218a110a50b2c55b5 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 2 Apr 2025 15:21:03 -0400 Subject: [PATCH 08/44] feat: lora in fsdp utils save Signed-off-by: Will Johnson --- .../utils/checkpoint_utils.py | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index bd26440a..3df14bf6 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -110,15 +110,30 @@ def save_fsdp_optimizer( # get the state dicts for model and optimize (model_state_dict, optimizer_state_dict) = get_state_dict(model, optimizer) - # - save model - ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") - os.makedirs(ckpt_model, exist_ok=True) - logger.info(f"Saving model to {ckpt_model}") - dcp.save( - state_dict={KEY_MODEL: model_state_dict}, - storage_writer=dcp.FileSystemWriter(ckpt_model), - planner=DefaultSavePlanner(), - ) + # filter out lora state dict + lora_state_dict = { + k: v for k, v in model_state_dict.items() if "lora_A" in k or "lora_B" in k + } + + # - save mode + if lora_state_dict: + ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") + os.makedirs(ckpt_model, exist_ok=True) + logger.info(f"Saving lora model to {ckpt_model}") + dcp.save( + state_dict={KEY_MODEL: lora_state_dict}, + storage_writer=dcp.FileSystemWriter(ckpt_model), + planner=DefaultSavePlanner(), + ) + else: + ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") + os.makedirs(ckpt_model, exist_ok=True) + logger.info(f"Saving ft model to {ckpt_model}") + dcp.save( + state_dict={KEY_MODEL: model_state_dict}, + storage_writer=dcp.FileSystemWriter(ckpt_model), + planner=DefaultSavePlanner(), + ) logger.info(f"Model saved to {ckpt_model}") # - save optimizer From 25e91553f184190bf3663e8757ade6600c44731d Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 2 Apr 2025 16:25:48 -0400 Subject: [PATCH 09/44] fix: lora keys to map to original dict Signed-off-by: Will Johnson --- .../utils/scattermoe_state_dict.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index 3eab484a..0893bae8 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -174,21 +174,23 @@ def _insert(L: List, i: int, v): ) if m.group(1) == router_name: if lora: - k_lora_a = k.replace(".layer.", ".lora_A.") - _map[KEY_SCATTERMOE_LORA_A_ROUTER].append((k_lora_a, stfile)) - k_lora_b = k.replace(".layer.", ".lora_B.") - _map[KEY_SCATTERMOE_LORA_B_ROUTER].append((k_lora_b, stfile)) + _map["router.lora_A.default.weight"].append((k, stfile)) + _map["router.lora_B.default.weight"].append((k, stfile)) else: _map[KEY_SCATTERMOE_ROUTER].append((k, stfile)) elif m.group(1) in expert_name: + index = m.group(2) + index = 0 if index is None else int(index) + mod = None if not lora: - index = m.group(2) - index = 0 if index is None else int(index) - mod = None for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): _insert(_map[f"{mod}.weight"], index, (k, stfile)) + else: + for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): + _insert(_map[f"{mod}.lora_A"], index, (k, stfile)) + _insert(_map[f"{mod}.lora_B"], index, (k, stfile)) - assert mod is not None, f"cannot map '{rel_k}'" + assert mod is not None, f"cannot map '{rel_k}'" if len(_map) == 0: raise ValueError( From 6ae1e1cea4c16fea0a550234b8f0f8ee6c025737 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 3 Apr 2025 11:18:44 -0400 Subject: [PATCH 10/44] feat: handle lora A and B for converting checkpoint Signed-off-by: Will Johnson --- .../fms_acceleration_moe/utils/checkpoint_utils.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 3df14bf6..6702da57 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -461,6 +461,20 @@ def _infer_prefixes_and_module_names( if len(scatter_keys) == 1: sd[model_key] = scatter_params[scatter_keys[0]] + + elif any("lora_A" in k for k in scatter_keys) and any("lora_B" in k for k in scatter_keys): + lora_A_key = next((k for k in scatter_keys if "lora_A" in k), None) + lora_B_key = next((k for k in scatter_keys if "lora_B" in k), None) + + if lora_A_key and lora_B_key: + lora_A = scatter_params[lora_A_key] + lora_B = scatter_params[lora_B_key] + + # Multiply matrices + lora_weight = torch.matmul(lora_B, lora_A) + + sd[model_key] = lora_weight + else: # unfortunately, there this is a in # scattermoe_state_dict._maybe_reshape_scattermoe_expert_weights From 866c2e0c86521c0e5e148c18dda57ed55ea7aa2c Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Fri, 4 Apr 2025 11:40:41 -0400 Subject: [PATCH 11/44] fix: scatter keys fqdn -> scatter keys Signed-off-by: Will Johnson --- .../utils/checkpoint_utils.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 6702da57..4ec0c94c 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -455,6 +455,8 @@ def _infer_prefixes_and_module_names( # it will go by order of scatter keys scatter_keys = sorted(scatter_params.keys()) + scatter_keys_fqdn = [".".join([prefix, module_name, scatter_key]) for scatter_key in scatter_keys] + assert ( len(scatter_keys) > 0 ), f"Obtained zero scatter keys for model_key '{model_key}'" @@ -463,17 +465,10 @@ def _infer_prefixes_and_module_names( sd[model_key] = scatter_params[scatter_keys[0]] elif any("lora_A" in k for k in scatter_keys) and any("lora_B" in k for k in scatter_keys): - lora_A_key = next((k for k in scatter_keys if "lora_A" in k), None) - lora_B_key = next((k for k in scatter_keys if "lora_B" in k), None) - - if lora_A_key and lora_B_key: - lora_A = scatter_params[lora_A_key] - lora_B = scatter_params[lora_B_key] - - # Multiply matrices - lora_weight = torch.matmul(lora_B, lora_A) - - sd[model_key] = lora_weight + # If lora, do not associate to model keys but keep scatter keys + for i, lora_key in enumerate(scatter_keys): + lora = scatter_params[lora_key] + sd[scatter_keys_fqdn[i]] = lora else: # unfortunately, there this is a in @@ -506,6 +501,9 @@ def save_sharded_safetensors( filename_pattern=filename_pattern, max_shard_size=max_shard_size, ) + + # If input state dict includes lora_A and lora_B params, need to save as a lora adapter_model.safetensors file + index = { "metadata": state_dict_split.metadata, "weight_map": state_dict_split.tensor_to_filename, From cbb222ba04ead906897314a019605a6ab78b48de Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Fri, 4 Apr 2025 15:25:14 -0400 Subject: [PATCH 12/44] fix: save for adapter model (draft) Signed-off-by: Will Johnson --- .../utils/checkpoint_utils.py | 69 ++++++++++++------- 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 4ec0c94c..a75b2b37 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -33,6 +33,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType from transformers import PretrainedConfig from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME +from transformers.peft_utils import ADAPTER_CONFIG_NAME, ADAPTER_SAFE_WEIGHTS_NAME, ADAPTER_WEIGHTS_NAME import torch import torch.distributed.checkpoint as dcp @@ -493,32 +494,50 @@ def save_sharded_safetensors( metadata: Dict, max_shard_size: Union[int, str] = "5GB", ): - filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace( - ".safetensors", "{suffix}.safetensors" - ) - state_dict_split = split_torch_state_dict_into_shards( - input_state_dict, - filename_pattern=filename_pattern, - max_shard_size=max_shard_size, - ) - - # If input state dict includes lora_A and lora_B params, need to save as a lora adapter_model.safetensors file + lora = False + for name, _ in input_state_dict.items(): + if "lora_A" or "lora_B" in name: + lora = True + break + + if not lora: + filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) + state_dict_split = split_torch_state_dict_into_shards( + input_state_dict, + filename_pattern=filename_pattern, + max_shard_size=max_shard_size, + ) - index = { - "metadata": state_dict_split.metadata, - "weight_map": state_dict_split.tensor_to_filename, - } - # Save the index - with open( - os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8" - ) as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" - f.write(content) - - filename_to_tensors = state_dict_split.filename_to_tensors.items() - for shard_file, tensors in filename_to_tensors: - shard = {tensor: input_state_dict[tensor].contiguous() for tensor in tensors} - save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata) + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + # Save the index + with open( + os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8" + ) as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + + filename_to_tensors = state_dict_split.filename_to_tensors.items() + for shard_file, tensors in filename_to_tensors: + shard = {tensor: input_state_dict[tensor].contiguous() for tensor in tensors} + save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata) + else: + filename_pattern = ADAPTER_SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) + state_dict_split = split_torch_state_dict_into_shards( + input_state_dict, + filename_pattern=filename_pattern, + max_shard_size=max_shard_size, + ) + filename_to_tensors = state_dict_split.filename_to_tensors.items() + for shard_file, tensors in filename_to_tensors: + shard = {tensor: input_state_dict[tensor].contiguous() for tensor in tensors} + save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata) # --------------------------- SCRIPT ------------------------- From 86f9d8bc793dd7a3d2ea34ee327d0c362e965f92 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Tue, 8 Apr 2025 11:48:52 -0400 Subject: [PATCH 13/44] fix: associate w1, w2, w3 lora keys to input output linear lora layers Signed-off-by: Will Johnson --- .../utils/checkpoint_utils.py | 49 ++++++++++++++++--- 1 file changed, 42 insertions(+), 7 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index a75b2b37..dce51230 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -33,7 +33,6 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType from transformers import PretrainedConfig from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME -from transformers.peft_utils import ADAPTER_CONFIG_NAME, ADAPTER_SAFE_WEIGHTS_NAME, ADAPTER_WEIGHTS_NAME import torch import torch.distributed.checkpoint as dcp @@ -54,6 +53,10 @@ KEY_MODEL = "model" KEY_OPTIMIZER = "optimizer" +ADAPTER_CONFIG_NAME = "adapter_config.json" +ADAPTER_WEIGHTS_NAME = "adapter_model.bin" +ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors" + # Below are rewrite of HF FSDP model saving functions to be able to handle # that the parameters are now a mixture of regular and Dtensors. # - these functions are found in accelerate.utils.fsdp_utils.py @@ -462,14 +465,46 @@ def _infer_prefixes_and_module_names( len(scatter_keys) > 0 ), f"Obtained zero scatter keys for model_key '{model_key}'" - if len(scatter_keys) == 1: - sd[model_key] = scatter_params[scatter_keys[0]] - - elif any("lora_A" in k for k in scatter_keys) and any("lora_B" in k for k in scatter_keys): + if any("lora_A" in k for k in scatter_keys) and any("lora_B" in k for k in scatter_keys): # If lora, do not associate to model keys but keep scatter keys + # TODO: Actually these need to be associated with an + # input-linear and output-linear layer much like the FT case + # so, do that. Seperate the cases out. Use torch.cat if len + # scatter keys is greater than 2 + def transform_model_key(model_key, lora_key): + lora_parts = lora_key.split(".") + model_parts = model_key.split(".") + + try: + lora_index = lora_parts.index("block_sparse_moe") + model_index = model_parts.index("block_sparse_moe") + except ValueError: + raise ValueError("Both keys must contain 'block_sparse_moe'") + + # Replace the component after 'block_sparse_moe' in lora_key + updated_lora_parts = lora_parts[:] + updated_lora_parts[lora_index + 1] = model_parts[model_index + 1] + + # Return the updated lora parts as the model key + return ".".join(updated_lora_parts) for i, lora_key in enumerate(scatter_keys): - lora = scatter_params[lora_key] - sd[scatter_keys_fqdn[i]] = lora + new_model_key = transform_model_key(model_key, scatter_keys_fqdn[i]) + if len(scatter_keys) == 2: + sd[new_model_key] = scatter_params[lora_key] + else: + if "lora_A" in new_model_key: + filtered_keys = [k for k in scatter_keys if "lora_A" in k] + elif "lora_B" in new_model_key: + filtered_keys = [k for k in scatter_keys if "lora_B" in k] + else: + raise ValueError(f"Unexpected LoRA key type in {new_model_key}") + + sd[new_model_key] = torch.cat( + [scatter_params[k] for k in filtered_keys], dim=1 + ) + + elif len(scatter_keys) == 1: + sd[model_key] = scatter_params[scatter_keys[0]] else: # unfortunately, there this is a in From a5bdea2848e269f80e24e200f96e86a1e32caf42 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Tue, 8 Apr 2025 14:48:56 -0400 Subject: [PATCH 14/44] fix: comment Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/checkpoint_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index dce51230..690a838b 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -466,11 +466,7 @@ def _infer_prefixes_and_module_names( ), f"Obtained zero scatter keys for model_key '{model_key}'" if any("lora_A" in k for k in scatter_keys) and any("lora_B" in k for k in scatter_keys): - # If lora, do not associate to model keys but keep scatter keys - # TODO: Actually these need to be associated with an - # input-linear and output-linear layer much like the FT case - # so, do that. Seperate the cases out. Use torch.cat if len - # scatter keys is greater than 2 + # If lora, split input linear and output linear into lora layers def transform_model_key(model_key, lora_key): lora_parts = lora_key.split(".") model_parts = model_key.split(".") From 6b030aaa86ab083203fa1a27baaa553ddb3dcccc Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 9 Apr 2025 10:04:07 -0400 Subject: [PATCH 15/44] fix: block off lora on w1, w2, w3 Signed-off-by: Will Johnson --- .../fms_acceleration_moe/utils/scattermoe.py | 47 ++++++++++--------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index e16943b1..37020d26 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -278,28 +278,10 @@ def __init__( # - w1: the up_projection. # - w2: the down_projection. # - w3 (optional): the gate projection. - self.w1 = ScatteredExperts( - in_features=self.hidden_size, - out_features=self.intermediate_size, - num_experts=self.num_experts, - fan_out=self.top_k if not self.all_to_all else 1, - grouped_out=True, - dtype=dtype, - device=device, - lora_config=lora_config, - ) - self.w2 = ScatteredExperts( - in_features=self.intermediate_size, - out_features=self.hidden_size, - num_experts=self.num_experts, - fan_out=1, - grouped_in=True, - dtype=dtype, - device=device, - lora_config=lora_config, - ) - if mlp_arch == SCATTERMOE_SPEC_HAS_GATE: - self.w3 = ScatteredExperts( + # TODO: Custom non-linear layers not supported in vLLM, + # must be investigated further before enabling + if lora_config is not None: + self.w1 = ScatteredExperts( in_features=self.hidden_size, out_features=self.intermediate_size, num_experts=self.num_experts, @@ -309,6 +291,27 @@ def __init__( device=device, lora_config=lora_config, ) + self.w2 = ScatteredExperts( + in_features=self.intermediate_size, + out_features=self.hidden_size, + num_experts=self.num_experts, + fan_out=1, + grouped_in=True, + dtype=dtype, + device=device, + lora_config=lora_config, + ) + if mlp_arch == SCATTERMOE_SPEC_HAS_GATE: + self.w3 = ScatteredExperts( + in_features=self.hidden_size, + out_features=self.intermediate_size, + num_experts=self.num_experts, + fan_out=self.top_k if not self.all_to_all else 1, + grouped_out=True, + dtype=dtype, + device=device, + lora_config=lora_config, + ) # referenced from dolomite-engine def _compute_routing_weights(self, hidden_states: torch.Tensor): From 587f2a4b10918e0ff4ef17df5764fcc284935050 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 9 Apr 2025 10:17:40 -0400 Subject: [PATCH 16/44] fix: if condition flip Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/scattermoe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index 37020d26..60be5484 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -280,7 +280,7 @@ def __init__( # - w3 (optional): the gate projection. # TODO: Custom non-linear layers not supported in vLLM, # must be investigated further before enabling - if lora_config is not None: + if lora_config is None: self.w1 = ScatteredExperts( in_features=self.hidden_size, out_features=self.intermediate_size, From f247d26f81bd037d4681f5319d4179979e8d58aa Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 9 Apr 2025 10:29:32 -0400 Subject: [PATCH 17/44] fix: ignore weights for lora Signed-off-by: Will Johnson --- .../utils/scattermoe_prepare.py | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py index 800859d3..4252bcb9 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py @@ -53,6 +53,7 @@ def load_experts_onto_device( state_dict: OrderedDict, device_mesh: DeviceMesh, num_experts_per_device: int, + lora: bool, ): # hook for scaling the gradient @@ -81,21 +82,22 @@ def _hook(grad): param, device_mesh=device_mesh, placements=reps + [Shard(0)] ) - # get the module we want to shard - name = weight_name.split(".") - path, name = ".".join(name[:-1]), name[-1] - mod = module.get_submodule(path) - requires_grad = getattr(mod, name).requires_grad + if not lora: + # get the module we want to shard + name = weight_name.split(".") + path, name = ".".join(name[:-1]), name[-1] + mod = module.get_submodule(path) + requires_grad = getattr(mod, name).requires_grad - param = torch.nn.Parameter( - param, - requires_grad=requires_grad, - ) + param = torch.nn.Parameter( + param, + requires_grad=requires_grad, + ) - # install gradient scaling hook - if KEY_SCATTERMOE_ROUTER not in weight_name and KEY_SCATTERMOE_LORA_A_ROUTER not in weight_name and KEY_SCATTERMOE_LORA_B_ROUTER not in weight_name: - if param.requires_grad: - param.register_hook(_hook) + # install gradient scaling hook + if KEY_SCATTERMOE_ROUTER not in weight_name: + if param.requires_grad: + param.register_hook(_hook) # register the sharded parameter onto the megablocks.dmoe mod.register_parameter(name, param) @@ -120,6 +122,10 @@ def prepare_scattermoe( # pylint: disable=import-outside-toplevel from .scattermoe import ScatterMoE + lora = False + if lora_config is not None: + lora = True + if disable_distributed and ep_degree > 1: raise ValueError( "expert sharding can not be deferred to top level sharding" @@ -341,7 +347,7 @@ def prepare_scattermoe( else: # - otherwise, we need to distribtue and will # replace the parameters - load_experts_onto_device(moe, sd, device_mesh, num_experts_per_device) + load_experts_onto_device(moe, sd, device_mesh, num_experts_per_device, lora) # module swap setattr(parent, module_name, moe) From f2bb29f1171e2f92252bdd200e63f54eb6c5a67e Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 9 Apr 2025 10:49:07 -0400 Subject: [PATCH 18/44] fix: if lora in scattermoe prepare, don't put weights in map Signed-off-by: Will Johnson --- .../utils/scattermoe_prepare.py | 1 + .../utils/scattermoe_state_dict.py | 31 ++++++++++--------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py index 4252bcb9..a705bf04 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py @@ -247,6 +247,7 @@ def prepare_scattermoe( weight_map, prefix, module_name, + lora, router_name, "|".join(expert_name), ) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index 0893bae8..912c3b71 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -85,6 +85,7 @@ def get_checkpoint_meta_from_sharded_safetensor( weight_map: Dict, prefix: str, # e.g., 'model.layers.0, instance_name: str, # e.g., block_sparse_moe + lora_start: bool, # if lora is detected in prepare_scattermoe.py router_name: str = "gate", # e.g., named "gate" within block_sparse_moe expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe expert_map: Dict = None, # map -> [w1,w2,w3] @@ -123,8 +124,6 @@ def _insert(L: List, i: int, v): n -= 1 L[i] = v - lora = False - # if expert_name = input_linear|output_linear|input_linear # - in this case will map # - input_linear: [w1, w3], output_linear: {w2} @@ -151,8 +150,9 @@ def _insert(L: List, i: int, v): # `w1.weight`: [...] _map = defaultdict(list) prefix = f"{prefix}.{instance_name}." - # Lora case where it prefix looks like base_model.model.model... - # instead of model... + # Lora case in checkpoint_utils where it prefix looks like + # `base_model.model.model...` instead of `model...` + lora = False if not prefix.startswith("model."): prefix=prefix.replace("base_model.model.", "", 1) lora=True @@ -179,18 +179,19 @@ def _insert(L: List, i: int, v): else: _map[KEY_SCATTERMOE_ROUTER].append((k, stfile)) elif m.group(1) in expert_name: - index = m.group(2) - index = 0 if index is None else int(index) - mod = None - if not lora: - for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): - _insert(_map[f"{mod}.weight"], index, (k, stfile)) - else: - for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): - _insert(_map[f"{mod}.lora_A"], index, (k, stfile)) - _insert(_map[f"{mod}.lora_B"], index, (k, stfile)) + if not lora_start: + index = m.group(2) + index = 0 if index is None else int(index) + mod = None + if not lora: + for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): + _insert(_map[f"{mod}.weight"], index, (k, stfile)) + else: + for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): + _insert(_map[f"{mod}.lora_A"], index, (k, stfile)) + _insert(_map[f"{mod}.lora_B"], index, (k, stfile)) - assert mod is not None, f"cannot map '{rel_k}'" + assert mod is not None, f"cannot map '{rel_k}'" if len(_map) == 0: raise ValueError( From ebbcd570e2d11edf9b703eb019beb6cdabb3b1e2 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 9 Apr 2025 10:53:52 -0400 Subject: [PATCH 19/44] fix: modules Signed-off-by: Will Johnson --- .../utils/scattermoe_prepare.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py index a705bf04..f16e7e63 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py @@ -82,22 +82,21 @@ def _hook(grad): param, device_mesh=device_mesh, placements=reps + [Shard(0)] ) - if not lora: - # get the module we want to shard - name = weight_name.split(".") - path, name = ".".join(name[:-1]), name[-1] - mod = module.get_submodule(path) - requires_grad = getattr(mod, name).requires_grad - - param = torch.nn.Parameter( - param, - requires_grad=requires_grad, - ) + # get the module we want to shard + name = weight_name.split(".") + path, name = ".".join(name[:-1]), name[-1] + mod = module.get_submodule(path) + requires_grad = getattr(mod, name).requires_grad + + param = torch.nn.Parameter( + param, + requires_grad=requires_grad, + ) - # install gradient scaling hook - if KEY_SCATTERMOE_ROUTER not in weight_name: - if param.requires_grad: - param.register_hook(_hook) + # install gradient scaling hook + if KEY_SCATTERMOE_ROUTER not in weight_name: + if param.requires_grad: + param.register_hook(_hook) # register the sharded parameter onto the megablocks.dmoe mod.register_parameter(name, param) From 4a838ca45f90ed9b3f6446bd0b4449c1e3998deb Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 9 Apr 2025 10:57:29 -0400 Subject: [PATCH 20/44] fix: with lora self w1 and w2 are not gauranteed to exist Signed-off-by: Will Johnson --- .../fms_acceleration_moe/utils/scattermoe.py | 46 ++++++++++--------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index 60be5484..e8847e9f 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -460,14 +460,15 @@ def forward(self, hidden_states: torch.Tensor): ) # compute the up projection - out = self.w1( - hidden_states, - sorted_expert_idxs, - sorted_scattered_idxs, - padded_block_idxs, - expert_offsets, - ) - out = self.activation(out) + if self.w1: + out = self.w1( + hidden_states, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + ) + out = self.activation(out) # - if the arch has a seperate gate projection if self.w3: @@ -482,21 +483,22 @@ def forward(self, hidden_states: torch.Tensor): # compute the down projection # - if no all-to-all processing, then depend on # scattermoe kernel to perform the final scattering - hidden_states = self.w2( - out, - sorted_expert_idxs, - sorted_scattered_idxs, - padded_block_idxs, - expert_offsets, - gates=(None if self.all_to_all else routing_weights), - ) + if self.w2: + hidden_states = self.w2( + out, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates=(None if self.all_to_all else routing_weights), + ) - # maybe scatter - hidden_states = self._maybe_scatter( - hidden_states, - routing_weights, - _gather_products, - ) + # maybe scatter + hidden_states = self._maybe_scatter( + hidden_states, + routing_weights, + _gather_products, + ) # return hidden states and router logits return (hidden_states.view(original_shape), router_logits) From 582379b0281a6486ece14b3eda9bf5d62f8e64ed Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 9 Apr 2025 11:02:16 -0400 Subject: [PATCH 21/44] fix: if w1, w2, w3 exist Signed-off-by: Will Johnson --- .../fms_acceleration_moe/utils/scattermoe.py | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index e8847e9f..6fe1c5fc 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -460,7 +460,7 @@ def forward(self, hidden_states: torch.Tensor): ) # compute the up projection - if self.w1: + if hasattr(self, "w1"): out = self.w1( hidden_states, sorted_expert_idxs, @@ -471,19 +471,20 @@ def forward(self, hidden_states: torch.Tensor): out = self.activation(out) # - if the arch has a seperate gate projection - if self.w3: - out *= self.w3( - hidden_states, - sorted_expert_idxs, - sorted_scattered_idxs, - padded_block_idxs, - expert_offsets, - ) + if hasattr(self, "w3"): + if self.w3: + out *= self.w3( + hidden_states, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + ) # compute the down projection # - if no all-to-all processing, then depend on # scattermoe kernel to perform the final scattering - if self.w2: + if hasattr(self, "w2"): hidden_states = self.w2( out, sorted_expert_idxs, @@ -493,12 +494,12 @@ def forward(self, hidden_states: torch.Tensor): gates=(None if self.all_to_all else routing_weights), ) - # maybe scatter - hidden_states = self._maybe_scatter( - hidden_states, - routing_weights, - _gather_products, - ) + # maybe scatter + hidden_states = self._maybe_scatter( + hidden_states, + routing_weights, + _gather_products, + ) # return hidden states and router logits return (hidden_states.view(original_shape), router_logits) From eb605371c21d36aca617ea63426599c96897afff Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 9 Apr 2025 11:04:03 -0400 Subject: [PATCH 22/44] fix: mapping to be router.layer Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/scattermoe_state_dict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index 912c3b71..27cbf43d 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -174,8 +174,8 @@ def _insert(L: List, i: int, v): ) if m.group(1) == router_name: if lora: - _map["router.lora_A.default.weight"].append((k, stfile)) - _map["router.lora_B.default.weight"].append((k, stfile)) + _map["router.layer.lora_A.default.weight"].append((k, stfile)) + _map["router.layer.lora_B.default.weight"].append((k, stfile)) else: _map[KEY_SCATTERMOE_ROUTER].append((k, stfile)) elif m.group(1) in expert_name: From 5e1bb5248c5effe6771f6df16833ec6c531c9235 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 9 Apr 2025 11:34:34 -0400 Subject: [PATCH 23/44] fix: lora condition Signed-off-by: Will Johnson --- .../utils/checkpoint_utils.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 690a838b..89edc1b1 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -322,6 +322,7 @@ def get_state_dict_from_safe_checkpoint(safe_checkpoint_dir: str): # can restore the checkpoint to be loaded by the original architecture. def recover_original_state_dict_from_checkpoint( sd: Dict, + lora: bool, pretrained_model_name_or_path: str = None, ): """ @@ -359,8 +360,10 @@ def recover_original_state_dict_from_checkpoint( def _infer_prefixes_and_module_names( sd_keys: List[str], - min_count: int = 3, + lora: bool, ): + min_count = 2 if lora else 3 + _name = "|".join([PARAM_NAME_ROUTER_SCATTERMOE, *PARAM_NAME_WEIGHT_SCATTERMOE]) # pylint: disable=anomalous-backslash-in-string _reg = re.compile(rf"(.*)\.({_name})\.(?:weight|lora_A|lora_B)") @@ -523,13 +526,9 @@ def save_sharded_safetensors( input_state_dict: Dict, save_directory: str, metadata: Dict, + lora: bool, max_shard_size: Union[int, str] = "5GB", ): - lora = False - for name, _ in input_state_dict.items(): - if "lora_A" or "lora_B" in name: - lora = True - break if not lora: filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace( @@ -618,14 +617,21 @@ def recover_safetensors_from_dcp( # get the state_dict state_dict = loader(checkpoint_dir) + lora = False + for name, _ in state_dict.items(): + if "lora_A" or "lora_B" in name: + lora = True + break + # recover the original state dict - state_dict = recover_original_state_dict_from_checkpoint(state_dict, _name_or_path) + state_dict = recover_original_state_dict_from_checkpoint(state_dict, lora, _name_or_path) # save it as a safetensors file save_sharded_safetensors( {k: v.contiguous() for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"}, + lora=lora, ) From 8b0144a08fed1057a842d5d889d778b12ae5a215 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 9 Apr 2025 11:38:12 -0400 Subject: [PATCH 24/44] fix: pass lora into infer Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/checkpoint_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 89edc1b1..362acf37 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -386,7 +386,7 @@ def _infer_prefixes_and_module_names( return results - for prefix in _infer_prefixes_and_module_names(sd.keys()): + for prefix in _infer_prefixes_and_module_names(sd.keys(), lora): prefix = prefix.split(".") prefix, module_name = ".".join(prefix[:-1]), prefix[-1] From f8e83a23f46629e9d2f5d25f77d36fed7d9c7fec Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 9 Apr 2025 11:56:31 -0400 Subject: [PATCH 25/44] lora utils Signed-off-by: Will Johnson --- .../utils/checkpoint_utils.py | 1 + .../utils/scattermoe_prepare.py | 2 +- .../utils/scattermoe_state_dict.py | 19 +++++++------------ 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 362acf37..136e0960 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -412,6 +412,7 @@ def _infer_prefixes_and_module_names( module_name, router_name, expert_name, + lora_utils=lora, ) model2scatter = defaultdict(dict) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py index f16e7e63..1ac747fd 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py @@ -246,9 +246,9 @@ def prepare_scattermoe( weight_map, prefix, module_name, - lora, router_name, "|".join(expert_name), + lora_start=lora ) # the parent module diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index 27cbf43d..ae1620eb 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -85,10 +85,11 @@ def get_checkpoint_meta_from_sharded_safetensor( weight_map: Dict, prefix: str, # e.g., 'model.layers.0, instance_name: str, # e.g., block_sparse_moe - lora_start: bool, # if lora is detected in prepare_scattermoe.py router_name: str = "gate", # e.g., named "gate" within block_sparse_moe expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe expert_map: Dict = None, # map -> [w1,w2,w3] + lora_start: bool = False, # if lora is detected in prepare_scattermoe.py + lora_utils: bool = False, # if lora is detected in checkpoint_utils.py ) -> Dict[str, List[Tuple]]: """ utilty function to infer the mapping of ScatterMoe parameters @@ -152,10 +153,8 @@ def _insert(L: List, i: int, v): prefix = f"{prefix}.{instance_name}." # Lora case in checkpoint_utils where it prefix looks like # `base_model.model.model...` instead of `model...` - lora = False if not prefix.startswith("model."): prefix=prefix.replace("base_model.model.", "", 1) - lora=True for k, stfile in weight_map.items(): if not k.startswith(prefix): @@ -173,23 +172,19 @@ def _insert(L: List, i: int, v): f"'{router_name}' or expert_name '{expert_name}'" ) if m.group(1) == router_name: - if lora: + if lora_utils: _map["router.layer.lora_A.default.weight"].append((k, stfile)) _map["router.layer.lora_B.default.weight"].append((k, stfile)) else: _map[KEY_SCATTERMOE_ROUTER].append((k, stfile)) elif m.group(1) in expert_name: - if not lora_start: + # Custom w1, w2, w3 are not supported for lora + if not lora_start and not lora_utils: index = m.group(2) index = 0 if index is None else int(index) mod = None - if not lora: - for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): - _insert(_map[f"{mod}.weight"], index, (k, stfile)) - else: - for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): - _insert(_map[f"{mod}.lora_A"], index, (k, stfile)) - _insert(_map[f"{mod}.lora_B"], index, (k, stfile)) + for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): + _insert(_map[f"{mod}.weight"], index, (k, stfile)) assert mod is not None, f"cannot map '{rel_k}'" From 227c8b9bb82086b9f806660656094ec6f7f8d794 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 9 Apr 2025 12:21:08 -0400 Subject: [PATCH 26/44] fix: .layer Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/scattermoe_state_dict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index ae1620eb..4fcf1007 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -173,8 +173,8 @@ def _insert(L: List, i: int, v): ) if m.group(1) == router_name: if lora_utils: - _map["router.layer.lora_A.default.weight"].append((k, stfile)) - _map["router.layer.lora_B.default.weight"].append((k, stfile)) + _map["router.lora_A.default.weight"].append((k, stfile)) + _map["router.lora_B.default.weight"].append((k, stfile)) else: _map[KEY_SCATTERMOE_ROUTER].append((k, stfile)) elif m.group(1) in expert_name: From 57a68189a90b2313efac9fa765acfbb43257919c Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 9 Apr 2025 12:28:02 -0400 Subject: [PATCH 27/44] add .layer Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/checkpoint_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 136e0960..46a27df9 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -483,7 +483,7 @@ def transform_model_key(model_key, lora_key): # Replace the component after 'block_sparse_moe' in lora_key updated_lora_parts = lora_parts[:] - updated_lora_parts[lora_index + 1] = model_parts[model_index + 1] + updated_lora_parts[lora_index + 1] = model_parts[model_index + 1] + ".layer" # Return the updated lora parts as the model key return ".".join(updated_lora_parts) From 809a917fa96bcfa64eee8907497e47ca38ffe230 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 9 Apr 2025 13:46:07 -0400 Subject: [PATCH 28/44] fix: update state dict when loaded with lora before operations Signed-off-by: Will Johnson --- .../utils/checkpoint_utils.py | 53 +++++++------------ .../utils/scattermoe_state_dict.py | 4 +- 2 files changed, 21 insertions(+), 36 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 46a27df9..f4157fa8 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -463,45 +463,25 @@ def _infer_prefixes_and_module_names( # it will go by order of scatter keys scatter_keys = sorted(scatter_params.keys()) - scatter_keys_fqdn = [".".join([prefix, module_name, scatter_key]) for scatter_key in scatter_keys] - assert ( len(scatter_keys) > 0 ), f"Obtained zero scatter keys for model_key '{model_key}'" - if any("lora_A" in k for k in scatter_keys) and any("lora_B" in k for k in scatter_keys): - # If lora, split input linear and output linear into lora layers - def transform_model_key(model_key, lora_key): - lora_parts = lora_key.split(".") - model_parts = model_key.split(".") - - try: - lora_index = lora_parts.index("block_sparse_moe") - model_index = model_parts.index("block_sparse_moe") - except ValueError: - raise ValueError("Both keys must contain 'block_sparse_moe'") - - # Replace the component after 'block_sparse_moe' in lora_key - updated_lora_parts = lora_parts[:] - updated_lora_parts[lora_index + 1] = model_parts[model_index + 1] + ".layer" - - # Return the updated lora parts as the model key - return ".".join(updated_lora_parts) + if lora: for i, lora_key in enumerate(scatter_keys): - new_model_key = transform_model_key(model_key, scatter_keys_fqdn[i]) if len(scatter_keys) == 2: + model_key_parts = model_key.split(".") + layer_index = model_key_parts.index("layer") + + # Replace the "layer.weight" part with "layer.lora_A.weight" or "layer.lora_B.weight" + if "lora_A" in lora_key: + model_key_parts[layer_index + 1] = "lora_A.weight" + elif "lora_B" in lora_key: + model_key_parts[layer_index + 1] = "lora_B.weight" + + # Rebuild the model_key and assign the corresponding scatter_param + new_model_key = ".".join(model_key_parts) sd[new_model_key] = scatter_params[lora_key] - else: - if "lora_A" in new_model_key: - filtered_keys = [k for k in scatter_keys if "lora_A" in k] - elif "lora_B" in new_model_key: - filtered_keys = [k for k in scatter_keys if "lora_B" in k] - else: - raise ValueError(f"Unexpected LoRA key type in {new_model_key}") - - sd[new_model_key] = torch.cat( - [scatter_params[k] for k in filtered_keys], dim=1 - ) elif len(scatter_keys) == 1: sd[model_key] = scatter_params[scatter_keys[0]] @@ -619,10 +599,15 @@ def recover_safetensors_from_dcp( state_dict = loader(checkpoint_dir) lora = False - for name, _ in state_dict.items(): + new_state_dict = {} # To store the modified state_dict + for name, param in state_dict.items(): if "lora_A" or "lora_B" in name: lora = True - break + if "base_model.model." in name: + name = name.replace("base_model.model.", "", 1) + if "default." in name: + name = name.replace("default.", "", 1) + new_state_dict[name] = param # recover the original state dict state_dict = recover_original_state_dict_from_checkpoint(state_dict, lora, _name_or_path) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index 4fcf1007..2328c3ec 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -173,8 +173,8 @@ def _insert(L: List, i: int, v): ) if m.group(1) == router_name: if lora_utils: - _map["router.lora_A.default.weight"].append((k, stfile)) - _map["router.lora_B.default.weight"].append((k, stfile)) + _map[KEY_SCATTERMOE_LORA_A_ROUTER].append((k, stfile)) + _map[KEY_SCATTERMOE_LORA_B_ROUTER].append((k, stfile)) else: _map[KEY_SCATTERMOE_ROUTER].append((k, stfile)) elif m.group(1) in expert_name: From a0887cd862788ac1122623dac13921eb846b536d Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 9 Apr 2025 13:47:52 -0400 Subject: [PATCH 29/44] fix: remove duplicative code Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/scattermoe_state_dict.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index 2328c3ec..271050d1 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -151,10 +151,6 @@ def _insert(L: List, i: int, v): # `w1.weight`: [...] _map = defaultdict(list) prefix = f"{prefix}.{instance_name}." - # Lora case in checkpoint_utils where it prefix looks like - # `base_model.model.model...` instead of `model...` - if not prefix.startswith("model."): - prefix=prefix.replace("base_model.model.", "", 1) for k, stfile in weight_map.items(): if not k.startswith(prefix): From b754a927a35bae28a390197866948c720e30d591 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 9 Apr 2025 14:07:23 -0400 Subject: [PATCH 30/44] fix: use new state dict Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/checkpoint_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index f4157fa8..0e418a39 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -610,7 +610,7 @@ def recover_safetensors_from_dcp( new_state_dict[name] = param # recover the original state dict - state_dict = recover_original_state_dict_from_checkpoint(state_dict, lora, _name_or_path) + state_dict = recover_original_state_dict_from_checkpoint(new_state_dict, lora, _name_or_path) # save it as a safetensors file save_sharded_safetensors( From 1ef1ebb7a829d614670fd41e635a2ac546fb4875 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 9 Apr 2025 15:22:13 -0400 Subject: [PATCH 31/44] lint + fmt Signed-off-by: Will Johnson --- plugins/accelerated-moe/.pylintrc | 2 +- .../utils/checkpoint_utils.py | 39 ++++++++++++------- .../utils/scattermoe_prepare.py | 16 +++++--- .../utils/scattermoe_state_dict.py | 12 ++++-- 4 files changed, 45 insertions(+), 24 deletions(-) diff --git a/plugins/accelerated-moe/.pylintrc b/plugins/accelerated-moe/.pylintrc index 4ddccea1..4141cba3 100644 --- a/plugins/accelerated-moe/.pylintrc +++ b/plugins/accelerated-moe/.pylintrc @@ -476,7 +476,7 @@ notes-rgx= [REFACTORING] # Maximum number of nested blocks for function / method body -max-nested-blocks=5 +max-nested-blocks=6 # Complete name of functions that never returns. When checking for # inconsistent-return-statements if a never returning function is called then diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 0e418a39..25c42acd 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -118,7 +118,7 @@ def save_fsdp_optimizer( lora_state_dict = { k: v for k, v in model_state_dict.items() if "lora_A" in k or "lora_B" in k } - + # - save mode if lora_state_dict: ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") @@ -129,7 +129,7 @@ def save_fsdp_optimizer( storage_writer=dcp.FileSystemWriter(ckpt_model), planner=DefaultSavePlanner(), ) - else: + else: ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") os.makedirs(ckpt_model, exist_ok=True) logger.info(f"Saving ft model to {ckpt_model}") @@ -421,7 +421,7 @@ def _infer_prefixes_and_module_names( # model param and they need to be cat for scatter_key, list_of_params in checkpoint_metadata.items(): scatter_key_fqdn = ".".join([prefix, module_name, scatter_key]) - + scatter_param = sd[scatter_key_fqdn] # remove from state dict @@ -472,8 +472,9 @@ def _infer_prefixes_and_module_names( if len(scatter_keys) == 2: model_key_parts = model_key.split(".") layer_index = model_key_parts.index("layer") - - # Replace the "layer.weight" part with "layer.lora_A.weight" or "layer.lora_B.weight" + + # Replace the "layer.weight" part with "layer.lora_A.weight" or + # "layer.lora_B.weight" if "lora_A" in lora_key: model_key_parts[layer_index + 1] = "lora_A.weight" elif "lora_B" in lora_key: @@ -534,12 +535,16 @@ def save_sharded_safetensors( filename_to_tensors = state_dict_split.filename_to_tensors.items() for shard_file, tensors in filename_to_tensors: - shard = {tensor: input_state_dict[tensor].contiguous() for tensor in tensors} - save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata) + shard = { + tensor: input_state_dict[tensor].contiguous() for tensor in tensors + } + save_file( + shard, os.path.join(save_directory, shard_file), metadata=metadata + ) else: - filename_pattern = ADAPTER_SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace( - ".safetensors", "{suffix}.safetensors" - ) + filename_pattern = ADAPTER_SAFE_WEIGHTS_NAME.replace( + ".bin", "{suffix}.bin" + ).replace(".safetensors", "{suffix}.safetensors") state_dict_split = split_torch_state_dict_into_shards( input_state_dict, filename_pattern=filename_pattern, @@ -547,8 +552,12 @@ def save_sharded_safetensors( ) filename_to_tensors = state_dict_split.filename_to_tensors.items() for shard_file, tensors in filename_to_tensors: - shard = {tensor: input_state_dict[tensor].contiguous() for tensor in tensors} - save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata) + shard = { + tensor: input_state_dict[tensor].contiguous() for tensor in tensors + } + save_file( + shard, os.path.join(save_directory, shard_file), metadata=metadata + ) # --------------------------- SCRIPT ------------------------- @@ -601,7 +610,7 @@ def recover_safetensors_from_dcp( lora = False new_state_dict = {} # To store the modified state_dict for name, param in state_dict.items(): - if "lora_A" or "lora_B" in name: + if "lora_A" in name or "lora_B" in name: lora = True if "base_model.model." in name: name = name.replace("base_model.model.", "", 1) @@ -610,7 +619,9 @@ def recover_safetensors_from_dcp( new_state_dict[name] = param # recover the original state dict - state_dict = recover_original_state_dict_from_checkpoint(new_state_dict, lora, _name_or_path) + state_dict = recover_original_state_dict_from_checkpoint( + new_state_dict, lora, _name_or_path + ) # save it as a safetensors file save_sharded_safetensors( diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py index 1ac747fd..d0f0d445 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py @@ -33,9 +33,9 @@ FILE_SAFETENSOR_INDEX, KEY_EXPERT_PARALLEL, KEY_REPLICATE, - KEY_SCATTERMOE_ROUTER, KEY_SCATTERMOE_LORA_A_ROUTER, KEY_SCATTERMOE_LORA_B_ROUTER, + KEY_SCATTERMOE_ROUTER, get_scattermoe_conv_spec_from_archs, ) from .scattermoe_state_dict import ( @@ -69,7 +69,11 @@ def _hook(grad): for weight_name, param in state_dict.items(): - if KEY_SCATTERMOE_ROUTER in weight_name or KEY_SCATTERMOE_LORA_A_ROUTER in weight_name or KEY_SCATTERMOE_LORA_B_ROUTER in weight_name: + if ( + KEY_SCATTERMOE_ROUTER in weight_name + or KEY_SCATTERMOE_LORA_A_ROUTER in weight_name + or KEY_SCATTERMOE_LORA_B_ROUTER in weight_name + ): # if its the router, replicate param = distribute_tensor(param, device_mesh, reps + [Replicate()]) elif param.shape[0] > num_experts_per_device: @@ -96,7 +100,7 @@ def _hook(grad): # install gradient scaling hook if KEY_SCATTERMOE_ROUTER not in weight_name: if param.requires_grad: - param.register_hook(_hook) + param.register_hook(_hook) # register the sharded parameter onto the megablocks.dmoe mod.register_parameter(name, param) @@ -248,7 +252,7 @@ def prepare_scattermoe( module_name, router_name, "|".join(expert_name), - lora_start=lora + lora_start=lora, ) # the parent module @@ -347,7 +351,9 @@ def prepare_scattermoe( else: # - otherwise, we need to distribtue and will # replace the parameters - load_experts_onto_device(moe, sd, device_mesh, num_experts_per_device, lora) + load_experts_onto_device( + moe, sd, device_mesh, num_experts_per_device, lora + ) # module swap setattr(parent, module_name, moe) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index 271050d1..844bc0bb 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -26,9 +26,9 @@ # Local from .scattermoe_constants import ( DIM_EXPERT, - KEY_SCATTERMOE_ROUTER, KEY_SCATTERMOE_LORA_A_ROUTER, KEY_SCATTERMOE_LORA_B_ROUTER, + KEY_SCATTERMOE_ROUTER, PARAM_NAME_WEIGHT_SCATTERMOE, ) @@ -88,8 +88,8 @@ def get_checkpoint_meta_from_sharded_safetensor( router_name: str = "gate", # e.g., named "gate" within block_sparse_moe expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe expert_map: Dict = None, # map -> [w1,w2,w3] - lora_start: bool = False, # if lora is detected in prepare_scattermoe.py - lora_utils: bool = False, # if lora is detected in checkpoint_utils.py + lora_start: bool = False, # if lora is detected in prepare_scattermoe.py + lora_utils: bool = False, # if lora is detected in checkpoint_utils.py ) -> Dict[str, List[Tuple]]: """ utilty function to infer the mapping of ScatterMoe parameters @@ -306,7 +306,11 @@ def get_state_dict_from_checkpoint_metadata( # go by one weight at a time. for scatter_key, vs in checkpoint_metadata.items(): - if KEY_SCATTERMOE_ROUTER in scatter_key or KEY_SCATTERMOE_LORA_A_ROUTER in scatter_key or KEY_SCATTERMOE_LORA_B_ROUTER in scatter_key: + if ( + KEY_SCATTERMOE_ROUTER in scatter_key + or KEY_SCATTERMOE_LORA_A_ROUTER in scatter_key + or KEY_SCATTERMOE_LORA_B_ROUTER in scatter_key + ): k, fi = vs[0] # only one item param = files[fi].get_tensor(k) From b185ebf8d8e683c2a355f873e5746b1c73802567 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 9 Apr 2025 15:25:16 -0400 Subject: [PATCH 32/44] fix: trailing whitespacE Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/checkpoint_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 25c42acd..ca26719e 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -473,7 +473,7 @@ def _infer_prefixes_and_module_names( model_key_parts = model_key.split(".") layer_index = model_key_parts.index("layer") - # Replace the "layer.weight" part with "layer.lora_A.weight" or + # Replace the "layer.weight" part with "layer.lora_A.weight" or # "layer.lora_B.weight" if "lora_A" in lora_key: model_key_parts[layer_index + 1] = "lora_A.weight" From 6f3852fcbe2134a8c0661cdaf0f335f147535fb8 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 10 Apr 2025 15:48:36 -0400 Subject: [PATCH 33/44] fix: target modules dictate which scatterMoE layers are trained Signed-off-by: Will Johnson --- .../utils/checkpoint_utils.py | 32 +++++++++----- .../fms_acceleration_moe/utils/scattermoe.py | 19 ++++---- .../utils/scattermoe_prepare.py | 44 +++++++++++-------- .../utils/scattermoe_state_dict.py | 13 ++++-- 4 files changed, 65 insertions(+), 43 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index ca26719e..720fe2e2 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -469,20 +469,28 @@ def _infer_prefixes_and_module_names( if lora: for i, lora_key in enumerate(scatter_keys): + model_key_parts = model_key.split(".") + weight_index = model_key_parts.index("weight") + # Replace the "layer.weight" part with "layer.lora_A.weight" or + # "layer.lora_B.weight" + if "lora_A" in lora_key: + model_key_parts[weight_index] = "lora_A.weight" + elif "lora_B" in lora_key: + model_key_parts[weight_index] = "lora_B.weight" + # Rebuild the model_key and assign the corresponding scatter_param + new_model_key = ".".join(model_key_parts) if len(scatter_keys) == 2: - model_key_parts = model_key.split(".") - layer_index = model_key_parts.index("layer") - - # Replace the "layer.weight" part with "layer.lora_A.weight" or - # "layer.lora_B.weight" - if "lora_A" in lora_key: - model_key_parts[layer_index + 1] = "lora_A.weight" - elif "lora_B" in lora_key: - model_key_parts[layer_index + 1] = "lora_B.weight" - - # Rebuild the model_key and assign the corresponding scatter_param - new_model_key = ".".join(model_key_parts) sd[new_model_key] = scatter_params[lora_key] + else: + if "lora_A" in new_model_key: + filtered_keys = [k for k in scatter_keys if "lora_A" in k] + elif "lora_B" in new_model_key: + filtered_keys = [k for k in scatter_keys if "lora_B" in k] + else: + raise ValueError(f"Unexpected LoRA key type in {new_model_key}") + sd[new_model_key] = torch.cat( + [scatter_params[k] for k in filtered_keys], dim=1 + ) elif len(scatter_keys) == 1: sd[model_key] = scatter_params[scatter_keys[0]] diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index 6fe1c5fc..bd20f8d5 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -237,10 +237,13 @@ def __init__( assert ( lora_config.bias == "none" ), "ScatterMoE currently unable to handle bias in the lora adapters" - assert ( - lora_config.target_modules == INCLUDE_LINEAR_LAYERS_SHORTHAND - or INCLUDE_LINEAR_LAYERS_SHORTHAND in lora_config.target_modules - ), "ScatterMoe currently only handles lora adapters on all linears." + + required_modules = ["router", "layer", "all-linear"] + if "input_linear" in lora_config.target_modules or "output_linear" in lora_config.target_modules: + # Assert that the target modules also include at least one from required_modules + assert ( + any(module in lora_config.target_modules for module in required_modules) + ), f"If 'input_linear' or 'output_linear' is included as a target module, 'router' must also be included" assert lora_config.init_lora_weights in { True, @@ -278,9 +281,7 @@ def __init__( # - w1: the up_projection. # - w2: the down_projection. # - w3 (optional): the gate projection. - # TODO: Custom non-linear layers not supported in vLLM, - # must be investigated further before enabling - if lora_config is None: + if "input_linear" in lora_config.target_modules: self.w1 = ScatteredExperts( in_features=self.hidden_size, out_features=self.intermediate_size, @@ -291,6 +292,7 @@ def __init__( device=device, lora_config=lora_config, ) + if "output_linear" in lora_config.target_modules: self.w2 = ScatteredExperts( in_features=self.intermediate_size, out_features=self.hidden_size, @@ -301,6 +303,7 @@ def __init__( device=device, lora_config=lora_config, ) + if "input_linear" in lora_config.target_modules: if mlp_arch == SCATTERMOE_SPEC_HAS_GATE: self.w3 = ScatteredExperts( in_features=self.hidden_size, @@ -502,4 +505,4 @@ def forward(self, hidden_states: torch.Tensor): ) # return hidden states and router logits - return (hidden_states.view(original_shape), router_logits) + return (hidden_states.view(original_shape), router_logits) \ No newline at end of file diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py index d0f0d445..8299f917 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py @@ -53,7 +53,6 @@ def load_experts_onto_device( state_dict: OrderedDict, device_mesh: DeviceMesh, num_experts_per_device: int, - lora: bool, ): # hook for scaling the gradient @@ -98,7 +97,11 @@ def _hook(grad): ) # install gradient scaling hook - if KEY_SCATTERMOE_ROUTER not in weight_name: + if ( + KEY_SCATTERMOE_ROUTER not in weight_name + and KEY_SCATTERMOE_LORA_A_ROUTER not in weight_name + and KEY_SCATTERMOE_LORA_B_ROUTER not in weight_name + ): if param.requires_grad: param.register_hook(_hook) @@ -253,6 +256,7 @@ def prepare_scattermoe( router_name, "|".join(expert_name), lora_start=lora, + target_modules=lora_config.target_modules, ) # the parent module @@ -342,23 +346,25 @@ def prepare_scattermoe( elif "lora_B" in name: torch.nn.init.normal_(sd[name]) - if device_mesh is None: - # - if not on meta, just load the state dict - # - and then put on the device - if not is_fsdp_enabled() or is_local_dist_rank_0(): - moe.load_state_dict(sd) - moe = moe.to(device) - else: - # - otherwise, we need to distribtue and will - # replace the parameters - load_experts_onto_device( - moe, sd, device_mesh, num_experts_per_device, lora - ) - # module swap - setattr(parent, module_name, moe) - - # - keep track of the name for returning - moe_module_names.add(module_name) + possible_target_modules = ["all_linear", "router", "layer", "input_linear", "output_linear"] + if any(module in lora_config.target_modules for module in possible_target_modules): + if device_mesh is None: + # - if not on meta, just load the state dict + # - and then put on the device + if not is_fsdp_enabled() or is_local_dist_rank_0(): + moe.load_state_dict(sd) + moe = moe.to(device) + else: + # - otherwise, we need to distribtue and will + # replace the parameters + load_experts_onto_device( + moe, sd, device_mesh, num_experts_per_device + ) + # module swap + setattr(parent, module_name, moe) + + # - keep track of the name for returning + moe_module_names.add(module_name) except ValueError as e: raise ValueError( diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index 844bc0bb..4d50b7e5 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -90,6 +90,7 @@ def get_checkpoint_meta_from_sharded_safetensor( expert_map: Dict = None, # map -> [w1,w2,w3] lora_start: bool = False, # if lora is detected in prepare_scattermoe.py lora_utils: bool = False, # if lora is detected in checkpoint_utils.py + target_modules: dict = {}, ) -> Dict[str, List[Tuple]]: """ utilty function to infer the mapping of ScatterMoe parameters @@ -174,13 +175,17 @@ def _insert(L: List, i: int, v): else: _map[KEY_SCATTERMOE_ROUTER].append((k, stfile)) elif m.group(1) in expert_name: - # Custom w1, w2, w3 are not supported for lora - if not lora_start and not lora_utils: + if ("input_linear" in target_modules and "output_linear" in target_modules) or lora_utils: index = m.group(2) index = 0 if index is None else int(index) mod = None - for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): - _insert(_map[f"{mod}.weight"], index, (k, stfile)) + if not lora_utils: + for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): + _insert(_map[f"{mod}.weight"], index, (k, stfile)) + else: + for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): + _insert(_map[f"{mod}.lora_A"], index, (k, stfile)) + _insert(_map[f"{mod}.lora_B"], index, (k, stfile)) assert mod is not None, f"cannot map '{rel_k}'" From 8162d99aff0db57375ba23da3fdbb8d63ed79d54 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 10 Apr 2025 15:50:51 -0400 Subject: [PATCH 34/44] lint + fmt Signed-off-by: Will Johnson --- .../utils/checkpoint_utils.py | 4 +++- .../src/fms_acceleration_moe/utils/scattermoe.py | 15 +++++++++------ .../utils/scattermoe_prepare.py | 15 ++++++++++++--- .../utils/scattermoe_state_dict.py | 6 ++++-- 4 files changed, 28 insertions(+), 12 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 720fe2e2..f4887b6d 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -487,7 +487,9 @@ def _infer_prefixes_and_module_names( elif "lora_B" in new_model_key: filtered_keys = [k for k in scatter_keys if "lora_B" in k] else: - raise ValueError(f"Unexpected LoRA key type in {new_model_key}") + raise ValueError( + f"Unexpected LoRA key type in {new_model_key}" + ) sd[new_model_key] = torch.cat( [scatter_params[k] for k in filtered_keys], dim=1 ) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index bd20f8d5..e11e5147 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -17,7 +17,6 @@ # Third Party from peft import LoraConfig -from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND from torch.distributed._tensor import DTensor # pylint: disable=import-error @@ -239,11 +238,15 @@ def __init__( ), "ScatterMoE currently unable to handle bias in the lora adapters" required_modules = ["router", "layer", "all-linear"] - if "input_linear" in lora_config.target_modules or "output_linear" in lora_config.target_modules: + if ( + "input_linear" in lora_config.target_modules + or "output_linear" in lora_config.target_modules + ): # Assert that the target modules also include at least one from required_modules - assert ( - any(module in lora_config.target_modules for module in required_modules) - ), f"If 'input_linear' or 'output_linear' is included as a target module, 'router' must also be included" + assert any( + module in lora_config.target_modules for module in required_modules + ), "If 'input_linear' or 'output_linear' is included as a target module,\ + 'router' must also be included" assert lora_config.init_lora_weights in { True, @@ -505,4 +508,4 @@ def forward(self, hidden_states: torch.Tensor): ) # return hidden states and router logits - return (hidden_states.view(original_shape), router_logits) \ No newline at end of file + return (hidden_states.view(original_shape), router_logits) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py index 8299f917..750672c6 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py @@ -99,7 +99,7 @@ def _hook(grad): # install gradient scaling hook if ( KEY_SCATTERMOE_ROUTER not in weight_name - and KEY_SCATTERMOE_LORA_A_ROUTER not in weight_name + and KEY_SCATTERMOE_LORA_A_ROUTER not in weight_name and KEY_SCATTERMOE_LORA_B_ROUTER not in weight_name ): if param.requires_grad: @@ -346,8 +346,17 @@ def prepare_scattermoe( elif "lora_B" in name: torch.nn.init.normal_(sd[name]) - possible_target_modules = ["all_linear", "router", "layer", "input_linear", "output_linear"] - if any(module in lora_config.target_modules for module in possible_target_modules): + possible_target_modules = [ + "all_linear", + "router", + "layer", + "input_linear", + "output_linear", + ] + if any( + module in lora_config.target_modules + for module in possible_target_modules + ): if device_mesh is None: # - if not on meta, just load the state dict # - and then put on the device diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index 4d50b7e5..0ca08f7c 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -90,7 +90,7 @@ def get_checkpoint_meta_from_sharded_safetensor( expert_map: Dict = None, # map -> [w1,w2,w3] lora_start: bool = False, # if lora is detected in prepare_scattermoe.py lora_utils: bool = False, # if lora is detected in checkpoint_utils.py - target_modules: dict = {}, + target_modules: dict = None, ) -> Dict[str, List[Tuple]]: """ utilty function to infer the mapping of ScatterMoe parameters @@ -175,7 +175,9 @@ def _insert(L: List, i: int, v): else: _map[KEY_SCATTERMOE_ROUTER].append((k, stfile)) elif m.group(1) in expert_name: - if ("input_linear" in target_modules and "output_linear" in target_modules) or lora_utils: + if ( + "input_linear" in target_modules and "output_linear" in target_modules + ) or lora_utils: index = m.group(2) index = 0 if index is None else int(index) mod = None From 8ce09d8d6655e0c1d887dddb3cb5c6af06a5ee08 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 10 Apr 2025 15:55:47 -0400 Subject: [PATCH 35/44] fix: cleanup Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/checkpoint_utils.py | 2 +- .../src/fms_acceleration_moe/utils/scattermoe_prepare.py | 1 - .../src/fms_acceleration_moe/utils/scattermoe_state_dict.py | 3 +-- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index f4887b6d..88a3cb48 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -618,7 +618,7 @@ def recover_safetensors_from_dcp( state_dict = loader(checkpoint_dir) lora = False - new_state_dict = {} # To store the modified state_dict + new_state_dict = {} for name, param in state_dict.items(): if "lora_A" in name or "lora_B" in name: lora = True diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py index 750672c6..e2eef859 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py @@ -255,7 +255,6 @@ def prepare_scattermoe( module_name, router_name, "|".join(expert_name), - lora_start=lora, target_modules=lora_config.target_modules, ) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index 0ca08f7c..4c56f2e0 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -88,9 +88,8 @@ def get_checkpoint_meta_from_sharded_safetensor( router_name: str = "gate", # e.g., named "gate" within block_sparse_moe expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe expert_map: Dict = None, # map -> [w1,w2,w3] - lora_start: bool = False, # if lora is detected in prepare_scattermoe.py lora_utils: bool = False, # if lora is detected in checkpoint_utils.py - target_modules: dict = None, + target_modules: dict = None, # target modules from prepare_scattermoe.py ) -> Dict[str, List[Tuple]]: """ utilty function to infer the mapping of ScatterMoe parameters From 4d051355b8e5b948a6f068ecaf0af2e728628c20 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 10 Apr 2025 19:51:42 -0400 Subject: [PATCH 36/44] lint Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/scattermoe_prepare.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py index e2eef859..9791b320 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py @@ -128,10 +128,6 @@ def prepare_scattermoe( # pylint: disable=import-outside-toplevel from .scattermoe import ScatterMoE - lora = False - if lora_config is not None: - lora = True - if disable_distributed and ep_degree > 1: raise ValueError( "expert sharding can not be deferred to top level sharding" From 2d959fc448335f5fc7d7ce75e73563439501357f Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Fri, 11 Apr 2025 09:15:19 -0400 Subject: [PATCH 37/44] fmt Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/scattermoe_state_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index 4c56f2e0..27335fb3 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -89,7 +89,7 @@ def get_checkpoint_meta_from_sharded_safetensor( expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe expert_map: Dict = None, # map -> [w1,w2,w3] lora_utils: bool = False, # if lora is detected in checkpoint_utils.py - target_modules: dict = None, # target modules from prepare_scattermoe.py + target_modules: dict = None, # target modules from prepare_scattermoe.py ) -> Dict[str, List[Tuple]]: """ utilty function to infer the mapping of ScatterMoe parameters From 1eea1652b3ae2fabaf1bedd9d889bdef0a9f96b3 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Fri, 11 Apr 2025 10:45:17 -0400 Subject: [PATCH 38/44] fix: type Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/scattermoe_state_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index 27335fb3..37ed0c43 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -89,7 +89,7 @@ def get_checkpoint_meta_from_sharded_safetensor( expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe expert_map: Dict = None, # map -> [w1,w2,w3] lora_utils: bool = False, # if lora is detected in checkpoint_utils.py - target_modules: dict = None, # target modules from prepare_scattermoe.py + target_modules: Dict = None, # target modules from prepare_scattermoe.py ) -> Dict[str, List[Tuple]]: """ utilty function to infer the mapping of ScatterMoe parameters From bff9b0466fcf6d71af92ff1ce6418e9ff64504ea Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Fri, 11 Apr 2025 10:54:28 -0400 Subject: [PATCH 39/44] fix: default target modules Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/scattermoe_state_dict.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index 37ed0c43..cc833cea 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -152,6 +152,8 @@ def _insert(L: List, i: int, v): _map = defaultdict(list) prefix = f"{prefix}.{instance_name}." + target_modules = target_modules or {} + for k, stfile in weight_map.items(): if not k.startswith(prefix): continue From 07d38b0e3285f30581d06e7ddb12bc1b3f06be60 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Fri, 11 Apr 2025 11:10:20 -0400 Subject: [PATCH 40/44] fix: ft logic Signed-off-by: Will Johnson --- .../utils/scattermoe_prepare.py | 5 +++++ .../utils/scattermoe_state_dict.py | 18 ++++++++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py index 9791b320..afc8cc76 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py @@ -128,6 +128,10 @@ def prepare_scattermoe( # pylint: disable=import-outside-toplevel from .scattermoe import ScatterMoE + lora = False + if lora_config: + lora = True + if disable_distributed and ep_degree > 1: raise ValueError( "expert sharding can not be deferred to top level sharding" @@ -251,6 +255,7 @@ def prepare_scattermoe( module_name, router_name, "|".join(expert_name), + lora_start=lora target_modules=lora_config.target_modules, ) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index cc833cea..11748012 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -88,6 +88,7 @@ def get_checkpoint_meta_from_sharded_safetensor( router_name: str = "gate", # e.g., named "gate" within block_sparse_moe expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe expert_map: Dict = None, # map -> [w1,w2,w3] + lora_start: bool = False, # if lora is detected in prepare_scattermoe.py lora_utils: bool = False, # if lora is detected in checkpoint_utils.py target_modules: Dict = None, # target modules from prepare_scattermoe.py ) -> Dict[str, List[Tuple]]: @@ -176,12 +177,14 @@ def _insert(L: List, i: int, v): else: _map[KEY_SCATTERMOE_ROUTER].append((k, stfile)) elif m.group(1) in expert_name: + index = m.group(2) + index = 0 if index is None else int(index) + mod = None + + # LoRA case if ( "input_linear" in target_modules and "output_linear" in target_modules ) or lora_utils: - index = m.group(2) - index = 0 if index is None else int(index) - mod = None if not lora_utils: for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): _insert(_map[f"{mod}.weight"], index, (k, stfile)) @@ -190,7 +193,14 @@ def _insert(L: List, i: int, v): _insert(_map[f"{mod}.lora_A"], index, (k, stfile)) _insert(_map[f"{mod}.lora_B"], index, (k, stfile)) - assert mod is not None, f"cannot map '{rel_k}'" + # Fine-tuning case + elif not lora_utils and not lora_start: + for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): + _insert(_map[f"{mod}.weight"], index, (k, stfile)) + + assert mod is not None, f"cannot map '{rel_k}'" + + if len(_map) == 0: raise ValueError( From f9176c52fb1bb37a057f60a024c2450d80b94da3 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Fri, 11 Apr 2025 11:17:39 -0400 Subject: [PATCH 41/44] fix: fmt + lint Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/scattermoe_prepare.py | 2 +- .../src/fms_acceleration_moe/utils/scattermoe_state_dict.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py index afc8cc76..7712803f 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py @@ -255,7 +255,7 @@ def prepare_scattermoe( module_name, router_name, "|".join(expert_name), - lora_start=lora + lora_start=lora, target_modules=lora_config.target_modules, ) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index 11748012..5d2d9cca 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -88,7 +88,7 @@ def get_checkpoint_meta_from_sharded_safetensor( router_name: str = "gate", # e.g., named "gate" within block_sparse_moe expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe expert_map: Dict = None, # map -> [w1,w2,w3] - lora_start: bool = False, # if lora is detected in prepare_scattermoe.py + lora_start: bool = False, # if lora is detected in prepare_scattermoe.py lora_utils: bool = False, # if lora is detected in checkpoint_utils.py target_modules: Dict = None, # target modules from prepare_scattermoe.py ) -> Dict[str, List[Tuple]]: @@ -197,10 +197,8 @@ def _insert(L: List, i: int, v): elif not lora_utils and not lora_start: for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): _insert(_map[f"{mod}.weight"], index, (k, stfile)) - - assert mod is not None, f"cannot map '{rel_k}'" - + assert mod is not None, f"cannot map '{rel_k}'" if len(_map) == 0: raise ValueError( From d98b2c9e824dc9a5f8af235e216f62884e700903 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Fri, 11 Apr 2025 14:30:24 -0400 Subject: [PATCH 42/44] fix: logic for lora Signed-off-by: Will Johnson --- .../utils/checkpoint_utils.py | 16 ++++++++++- .../utils/scattermoe_prepare.py | 5 ---- .../utils/scattermoe_state_dict.py | 28 ++++++++----------- 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 88a3cb48..12a62ca4 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -343,6 +343,19 @@ def recover_original_state_dict_from_checkpoint( # config config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path) + # if lora, check for input/output layers + ip_op_layers = False + router_layer = False + if lora: + for name, _ in sd.items(): + if "w1" in name: + ip_op_layers = True + break + for name, _ in sd.items(): + if "router" in name: + router_layer = True + break + ( _, router_name, @@ -412,7 +425,8 @@ def _infer_prefixes_and_module_names( module_name, router_name, expert_name, - lora_utils=lora, + ip_op_layers=ip_op_layers, + router_layer=router_layer, ) model2scatter = defaultdict(dict) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py index 7712803f..9791b320 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py @@ -128,10 +128,6 @@ def prepare_scattermoe( # pylint: disable=import-outside-toplevel from .scattermoe import ScatterMoE - lora = False - if lora_config: - lora = True - if disable_distributed and ep_degree > 1: raise ValueError( "expert sharding can not be deferred to top level sharding" @@ -255,7 +251,6 @@ def prepare_scattermoe( module_name, router_name, "|".join(expert_name), - lora_start=lora, target_modules=lora_config.target_modules, ) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index 5d2d9cca..a02236c3 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -88,8 +88,8 @@ def get_checkpoint_meta_from_sharded_safetensor( router_name: str = "gate", # e.g., named "gate" within block_sparse_moe expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe expert_map: Dict = None, # map -> [w1,w2,w3] - lora_start: bool = False, # if lora is detected in prepare_scattermoe.py - lora_utils: bool = False, # if lora is detected in checkpoint_utils.py + ip_op_layers: bool = False, # if input/output layers are detected in utils + router_layer: bool = False, # if router layer is detected in utils target_modules: Dict = None, # target modules from prepare_scattermoe.py ) -> Dict[str, List[Tuple]]: """ @@ -111,6 +111,8 @@ def get_checkpoint_meta_from_sharded_safetensor( e.g., input_linear|output_linear|input_linear expert_map (dict): This is used with pattern ii) described above in expert_name. If not specified, will be the identity map, e.g., w1 -> w1 + lora_start (bool): Boolean to determine if lora is detected in scattermoe_prepare.py + lora_utils (bool): """ # insert in order @@ -171,34 +173,26 @@ def _insert(L: List, i: int, v): f"'{router_name}' or expert_name '{expert_name}'" ) if m.group(1) == router_name: - if lora_utils: + if router_layer: _map[KEY_SCATTERMOE_LORA_A_ROUTER].append((k, stfile)) _map[KEY_SCATTERMOE_LORA_B_ROUTER].append((k, stfile)) else: _map[KEY_SCATTERMOE_ROUTER].append((k, stfile)) elif m.group(1) in expert_name: - index = m.group(2) - index = 0 if index is None else int(index) - mod = None - - # LoRA case if ( "input_linear" in target_modules and "output_linear" in target_modules - ) or lora_utils: - if not lora_utils: + ) or ip_op_layers: + index = m.group(2) + index = 0 if index is None else int(index) + mod = None + if not ip_op_layers: for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): _insert(_map[f"{mod}.weight"], index, (k, stfile)) else: for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): _insert(_map[f"{mod}.lora_A"], index, (k, stfile)) _insert(_map[f"{mod}.lora_B"], index, (k, stfile)) - - # Fine-tuning case - elif not lora_utils and not lora_start: - for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): - _insert(_map[f"{mod}.weight"], index, (k, stfile)) - - assert mod is not None, f"cannot map '{rel_k}'" + assert mod is not None, f"cannot map '{rel_k}'" if len(_map) == 0: raise ValueError( From 570bf347dcd9924d39bc9794352bf18f95f527a9 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Fri, 11 Apr 2025 15:59:14 -0400 Subject: [PATCH 43/44] fix: logic + docs Signed-off-by: Will Johnson --- plugins/accelerated-moe/README.md | 7 ++++--- .../fms_acceleration_moe/utils/scattermoe.py | 20 ++++++++++--------- .../utils/scattermoe_prepare.py | 17 +++++++++++----- .../utils/scattermoe_state_dict.py | 17 ++++++++++++++-- 4 files changed, 42 insertions(+), 19 deletions(-) diff --git a/plugins/accelerated-moe/README.md b/plugins/accelerated-moe/README.md index d05cdc82..634166ff 100644 --- a/plugins/accelerated-moe/README.md +++ b/plugins/accelerated-moe/README.md @@ -8,7 +8,7 @@ This library contains plugins to accelerate finetuning with the following optimi Plugin | Description | Depends | Loading | Augmentation | Callbacks --|--|--|--|--|-- -[scattermoe](./src/fms_acceleration_moe/framework_plugin_scattermoe.py) | MoE Expert Parallel with Triton Kernels from scattermoe (& megablocks) | ScatterMoE / extracted kernels from megablocks | ✅ | | ✅ +[scattermoe](./src/fms_acceleration_moe/framework_plugin_scattermoe.py) | MoE Expert Parallel with Triton Kernels from scattermoe (& megablocks) | ScatterMoE / extracted kernels from megablocks | | ✅ | ✅ ## Adding New Models @@ -33,6 +33,8 @@ python -m fms_acceleration_moe.utils.checkpoint_utils \ mistralai/Mixtral-8x7B-Instruct-v0.1 ``` +If running with fms-hf-tuning, this script runs automatically if the `fast_moe` parameter is set. + ## Code Extracted from Megablocks Notes on code extraction: @@ -81,9 +83,8 @@ Triton Kernels are copied into [scattermoe_utils](./src/fms_acceleration_moe/uti ### Known Issues These are currently some known issues not yet resolved: -- should eventually remove the dependency on an external `kernel-hyperdrive` repository. -- now support only loading *sharded* `safetensor` non-GGUF MoE checkpoints. This is a reasonable assumption since MoE checkpoints are typically above the size limit that prevents it being saved into a single checkpoint filed. - when used together with FSDP, the FSDP's `clip_grad_norm` will not properly compute for `ScatterMoE`, see [issue here](https://github.com/foundation-model-stack/fms-acceleration/issues/109). +- when used to lora train a model, if training experts on adapter model, the model will fail to run inference in vLLM/vanilla HF because of restrictions to parameter types. If running inference do not select `input_linear` and `output_linear` as target modules when lora training. diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index e11e5147..89186388 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -231,22 +231,24 @@ def __init__( not has_bias ), "ScatterMoE currently unable to handle bias in both gates and experts." + target_modules = None + if lora_config is not None: # since this is self implemented, we really only support basic lora funcs assert ( lora_config.bias == "none" ), "ScatterMoE currently unable to handle bias in the lora adapters" + if lora_config and hasattr(lora_config, "target_modules"): + target_modules = lora_config.target_modules + required_modules = ["router", "layer", "all-linear"] - if ( - "input_linear" in lora_config.target_modules - or "output_linear" in lora_config.target_modules - ): + if "input_linear" in target_modules or "output_linear" in target_modules: # Assert that the target modules also include at least one from required_modules assert any( - module in lora_config.target_modules for module in required_modules + module in target_modules for module in required_modules ), "If 'input_linear' or 'output_linear' is included as a target module,\ - 'router' must also be included" + 'router' must also be included" assert lora_config.init_lora_weights in { True, @@ -284,7 +286,7 @@ def __init__( # - w1: the up_projection. # - w2: the down_projection. # - w3 (optional): the gate projection. - if "input_linear" in lora_config.target_modules: + if not lora_config or ("input_linear" in target_modules): self.w1 = ScatteredExperts( in_features=self.hidden_size, out_features=self.intermediate_size, @@ -295,7 +297,7 @@ def __init__( device=device, lora_config=lora_config, ) - if "output_linear" in lora_config.target_modules: + if not lora_config or ("input_linear" in target_modules): self.w2 = ScatteredExperts( in_features=self.intermediate_size, out_features=self.hidden_size, @@ -306,7 +308,7 @@ def __init__( device=device, lora_config=lora_config, ) - if "input_linear" in lora_config.target_modules: + if not lora_config or ("input_linear" in target_modules): if mlp_arch == SCATTERMOE_SPEC_HAS_GATE: self.w3 = ScatteredExperts( in_features=self.hidden_size, diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py index 9791b320..eaf11aff 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py @@ -240,6 +240,10 @@ def prepare_scattermoe( weight_map = load_weight_map(loc, "model.safetensors", FILE_SAFETENSOR_INDEX) + target_modules = None + if lora_config and hasattr(lora_config, "target_modules"): + target_modules = lora_config.target_modules + # e.g., prefix: 'model.layers.0', # module_name: 'block_sparse_moe' for prefix, (module_name, _, has_bias) in tqdm( @@ -251,7 +255,7 @@ def prepare_scattermoe( module_name, router_name, "|".join(expert_name), - target_modules=lora_config.target_modules, + target_modules=target_modules, ) # the parent module @@ -342,15 +346,18 @@ def prepare_scattermoe( torch.nn.init.normal_(sd[name]) possible_target_modules = [ - "all_linear", + "all-linear", "router", "layer", "input_linear", "output_linear", ] - if any( - module in lora_config.target_modules - for module in possible_target_modules + if ( + any( + module in (target_modules or []) + for module in possible_target_modules + ) + or not lora_config ): if device_mesh is None: # - if not on meta, just load the state dict diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index a02236c3..a436b2a6 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -111,8 +111,14 @@ def get_checkpoint_meta_from_sharded_safetensor( e.g., input_linear|output_linear|input_linear expert_map (dict): This is used with pattern ii) described above in expert_name. If not specified, will be the identity map, e.g., w1 -> w1 - lora_start (bool): Boolean to determine if lora is detected in scattermoe_prepare.py - lora_utils (bool): + ip_op_layers (bool): Boolean to determine if input/output layers are detected + in checkpoint utils. + router_layer (bool): Boolean to determine if router layer is detected + in checkpoint utils. + target_modules (dict): Target modules from scattermoe_prepare.py lora_config. + Used to check for input and output layers while preparing scattermoe. + + Returns: Map of used ScatterMoE weights to their files """ # insert in order @@ -193,6 +199,13 @@ def _insert(L: List, i: int, v): _insert(_map[f"{mod}.lora_A"], index, (k, stfile)) _insert(_map[f"{mod}.lora_B"], index, (k, stfile)) assert mod is not None, f"cannot map '{rel_k}'" + elif not ip_op_layers and not router_layer and not target_modules: + index = m.group(2) + index = 0 if index is None else int(index) + mod = None + for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): + _insert(_map[f"{mod}.weight"], index, (k, stfile)) + assert mod is not None, f"cannot map '{rel_k}'" if len(_map) == 0: raise ValueError( From 4cd4288a091ef6759835a22a55db810fcfb8ffdf Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Fri, 11 Apr 2025 16:01:17 -0400 Subject: [PATCH 44/44] fix: mistype Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/scattermoe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index 89186388..2bc98e46 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -297,7 +297,7 @@ def __init__( device=device, lora_config=lora_config, ) - if not lora_config or ("input_linear" in target_modules): + if not lora_config or ("output_linear" in target_modules): self.w2 = ScatteredExperts( in_features=self.intermediate_size, out_features=self.hidden_size,