From 423e4bfa76a081c66a6374de68e9b8ffbe0d4547 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 17 Apr 2025 07:07:01 -0400 Subject: [PATCH 01/15] limit changes Signed-off-by: Will Johnson --- .../utils/checkpoint_utils.py | 13 +++ .../fms_acceleration_moe/utils/scattermoe.py | 107 ++++++++++-------- 2 files changed, 74 insertions(+), 46 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index e6fe1ba6..134eca99 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -540,6 +540,19 @@ def recover_safetensors_from_dcp( # get the state_dict state_dict = loader(checkpoint_dir) + new_state_dict = {} + for name, param in state_dict.items(): + if "base_model.model." in name: + name = name.replace("base_model.model.", "", 1) + if "default." in name: + name = name.replace("default.", "", 1) + new_state_dict[name] = param + + # recover the original state dict + state_dict = recover_original_state_dict_from_checkpoint( + new_state_dict, _name_or_path + ) + # recover the original state dict state_dict = recover_original_state_dict_from_checkpoint(state_dict, _name_or_path) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index e16943b1..7558ab52 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -232,15 +232,24 @@ def __init__( not has_bias ), "ScatterMoE currently unable to handle bias in both gates and experts." + target_modules = None + if lora_config is not None: # since this is self implemented, we really only support basic lora funcs assert ( lora_config.bias == "none" ), "ScatterMoE currently unable to handle bias in the lora adapters" - assert ( - lora_config.target_modules == INCLUDE_LINEAR_LAYERS_SHORTHAND - or INCLUDE_LINEAR_LAYERS_SHORTHAND in lora_config.target_modules - ), "ScatterMoe currently only handles lora adapters on all linears." + + if lora_config and hasattr(lora_config, "target_modules"): + target_modules = lora_config.target_modules + + required_modules = ["router", "layer", "all-linear"] + if "input_linear" in target_modules or "output_linear" in target_modules: + # Assert that the target modules also include at least one from required_modules + assert any( + module in target_modules for module in required_modules + ), "If 'input_linear' or 'output_linear' is included as a target module,\ + 'router' must also be included" assert lora_config.init_lora_weights in { True, @@ -278,28 +287,8 @@ def __init__( # - w1: the up_projection. # - w2: the down_projection. # - w3 (optional): the gate projection. - self.w1 = ScatteredExperts( - in_features=self.hidden_size, - out_features=self.intermediate_size, - num_experts=self.num_experts, - fan_out=self.top_k if not self.all_to_all else 1, - grouped_out=True, - dtype=dtype, - device=device, - lora_config=lora_config, - ) - self.w2 = ScatteredExperts( - in_features=self.intermediate_size, - out_features=self.hidden_size, - num_experts=self.num_experts, - fan_out=1, - grouped_in=True, - dtype=dtype, - device=device, - lora_config=lora_config, - ) - if mlp_arch == SCATTERMOE_SPEC_HAS_GATE: - self.w3 = ScatteredExperts( + if not lora_config or ("input_linear" in target_modules): + self.w1 = ScatteredExperts( in_features=self.hidden_size, out_features=self.intermediate_size, num_experts=self.num_experts, @@ -309,6 +298,29 @@ def __init__( device=device, lora_config=lora_config, ) + if not lora_config or ("output_linear" in target_modules): + self.w2 = ScatteredExperts( + in_features=self.intermediate_size, + out_features=self.hidden_size, + num_experts=self.num_experts, + fan_out=1, + grouped_in=True, + dtype=dtype, + device=device, + lora_config=lora_config, + ) + if not lora_config or ("input_linear" in target_modules): + if mlp_arch == SCATTERMOE_SPEC_HAS_GATE: + self.w3 = ScatteredExperts( + in_features=self.hidden_size, + out_features=self.intermediate_size, + num_experts=self.num_experts, + fan_out=self.top_k if not self.all_to_all else 1, + grouped_out=True, + dtype=dtype, + device=device, + lora_config=lora_config, + ) # referenced from dolomite-engine def _compute_routing_weights(self, hidden_states: torch.Tensor): @@ -457,36 +469,39 @@ def forward(self, hidden_states: torch.Tensor): ) # compute the up projection - out = self.w1( - hidden_states, - sorted_expert_idxs, - sorted_scattered_idxs, - padded_block_idxs, - expert_offsets, - ) - out = self.activation(out) - - # - if the arch has a seperate gate projection - if self.w3: - out *= self.w3( + if hasattr(self, "w1"): + out = self.w1( hidden_states, sorted_expert_idxs, sorted_scattered_idxs, padded_block_idxs, expert_offsets, ) + out = self.activation(out) + + # - if the arch has a seperate gate projection + if hasattr(self, "w3"): + if self.w3: + out *= self.w3( + hidden_states, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + ) # compute the down projection # - if no all-to-all processing, then depend on # scattermoe kernel to perform the final scattering - hidden_states = self.w2( - out, - sorted_expert_idxs, - sorted_scattered_idxs, - padded_block_idxs, - expert_offsets, - gates=(None if self.all_to_all else routing_weights), - ) + if hasattr(self, "w2"): + hidden_states = self.w2( + out, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates=(None if self.all_to_all else routing_weights), + ) # maybe scatter hidden_states = self._maybe_scatter( From eebe34087755e60697387772a0727d7433177ecb Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 17 Apr 2025 07:13:14 -0400 Subject: [PATCH 02/15] lora filtering Signed-off-by: Will Johnson --- .../utils/checkpoint_utils.py | 32 +++++++++++++------ 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 134eca99..cfc351a6 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -110,16 +110,30 @@ def save_fsdp_optimizer( # get the state dicts for model and optimize (model_state_dict, optimizer_state_dict) = get_state_dict(model, optimizer) + # filter out lora state dict + lora_state_dict = { + k: v for k, v in model_state_dict.items() if "lora_A" in k or "lora_B" in k + } + # - save model - ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") - os.makedirs(ckpt_model, exist_ok=True) - logger.info(f"Saving model to {ckpt_model}") - dcp.save( - state_dict={KEY_MODEL: model_state_dict}, - storage_writer=dcp.FileSystemWriter(ckpt_model), - planner=DefaultSavePlanner(), - ) - logger.info(f"Model saved to {ckpt_model}") + if lora_state_dict: + ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") + os.makedirs(ckpt_model, exist_ok=True) + logger.info(f"Saving lora model to {ckpt_model}") + dcp.save( + state_dict={KEY_MODEL: lora_state_dict}, + storage_writer=dcp.FileSystemWriter(ckpt_model), + planner=DefaultSavePlanner(), + ) + else: + ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") + os.makedirs(ckpt_model, exist_ok=True) + logger.info(f"Saving ft model to {ckpt_model}") + dcp.save( + state_dict={KEY_MODEL: model_state_dict}, + storage_writer=dcp.FileSystemWriter(ckpt_model), + planner=DefaultSavePlanner(), + ) # - save optimizer ckpt_opt = os.path.join(output_dir, f"{OPTIMIZER_NAME}_{optimizer_index}") From 30ecfa83d22b009473396f4c70b471677147b42e Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 17 Apr 2025 07:20:19 -0400 Subject: [PATCH 03/15] naming Signed-off-by: Will Johnson --- .../utils/checkpoint_utils.py | 79 ++++++++++++------- 1 file changed, 52 insertions(+), 27 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index cfc351a6..b2ca205c 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -53,6 +53,8 @@ KEY_MODEL = "model" KEY_OPTIMIZER = "optimizer" +ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors" + # Below are rewrite of HF FSDP model saving functions to be able to handle # that the parameters are now a mixture of regular and Dtensors. # - these functions are found in accelerate.utils.fsdp_utils.py @@ -481,30 +483,54 @@ def save_sharded_safetensors( save_directory: str, metadata: Dict, max_shard_size: Union[int, str] = "5GB", + lora: bool = False, ): - filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace( - ".safetensors", "{suffix}.safetensors" - ) - state_dict_split = split_torch_state_dict_into_shards( - input_state_dict, - filename_pattern=filename_pattern, - max_shard_size=max_shard_size, - ) - index = { - "metadata": state_dict_split.metadata, - "weight_map": state_dict_split.tensor_to_filename, - } - # Save the index - with open( - os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8" - ) as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" - f.write(content) + if not lora: + filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) + state_dict_split = split_torch_state_dict_into_shards( + input_state_dict, + filename_pattern=filename_pattern, + max_shard_size=max_shard_size, + ) - filename_to_tensors = state_dict_split.filename_to_tensors.items() - for shard_file, tensors in filename_to_tensors: - shard = {tensor: input_state_dict[tensor].contiguous() for tensor in tensors} - save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata) + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + # Save the index + with open( + os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8" + ) as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + + filename_to_tensors = state_dict_split.filename_to_tensors.items() + for shard_file, tensors in filename_to_tensors: + shard = { + tensor: input_state_dict[tensor].contiguous() for tensor in tensors + } + save_file( + shard, os.path.join(save_directory, shard_file), metadata=metadata + ) + else: + filename_pattern = ADAPTER_SAFE_WEIGHTS_NAME.replace( + ".bin", "{suffix}.bin" + ).replace(".safetensors", "{suffix}.safetensors") + state_dict_split = split_torch_state_dict_into_shards( + input_state_dict, + filename_pattern=filename_pattern, + max_shard_size=max_shard_size, + ) + filename_to_tensors = state_dict_split.filename_to_tensors.items() + for shard_file, tensors in filename_to_tensors: + shard = { + tensor: input_state_dict[tensor].contiguous() for tensor in tensors + } + save_file( + shard, os.path.join(save_directory, shard_file), metadata=metadata + ) # --------------------------- SCRIPT ------------------------- @@ -555,18 +581,16 @@ def recover_safetensors_from_dcp( state_dict = loader(checkpoint_dir) new_state_dict = {} + lora = False for name, param in state_dict.items(): + if "lora_A" in name or "lora_B" in name: + lora = True if "base_model.model." in name: name = name.replace("base_model.model.", "", 1) if "default." in name: name = name.replace("default.", "", 1) new_state_dict[name] = param - # recover the original state dict - state_dict = recover_original_state_dict_from_checkpoint( - new_state_dict, _name_or_path - ) - # recover the original state dict state_dict = recover_original_state_dict_from_checkpoint(state_dict, _name_or_path) @@ -575,6 +599,7 @@ def recover_safetensors_from_dcp( {k: v.contiguous() for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"}, + lora=lora, ) From 2b1d26b9c593f5009d10ee664e55236e045e8d17 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 17 Apr 2025 07:26:34 -0400 Subject: [PATCH 04/15] fmt + lint Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/scattermoe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index 7558ab52..2bc98e46 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -17,7 +17,6 @@ # Third Party from peft import LoraConfig -from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND from torch.distributed._tensor import DTensor # pylint: disable=import-error @@ -239,7 +238,7 @@ def __init__( assert ( lora_config.bias == "none" ), "ScatterMoE currently unable to handle bias in the lora adapters" - + if lora_config and hasattr(lora_config, "target_modules"): target_modules = lora_config.target_modules From 540af3e85ada5ef6462aab62e361a8fe0210bad7 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 17 Apr 2025 07:37:37 -0400 Subject: [PATCH 05/15] fix: hardcodes Signed-off-by: Will Johnson --- .../fms_acceleration_moe/utils/scattermoe.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index 2bc98e46..5b3d542d 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -239,17 +239,6 @@ def __init__( lora_config.bias == "none" ), "ScatterMoE currently unable to handle bias in the lora adapters" - if lora_config and hasattr(lora_config, "target_modules"): - target_modules = lora_config.target_modules - - required_modules = ["router", "layer", "all-linear"] - if "input_linear" in target_modules or "output_linear" in target_modules: - # Assert that the target modules also include at least one from required_modules - assert any( - module in target_modules for module in required_modules - ), "If 'input_linear' or 'output_linear' is included as a target module,\ - 'router' must also be included" - assert lora_config.init_lora_weights in { True, "gaussian", @@ -286,7 +275,7 @@ def __init__( # - w1: the up_projection. # - w2: the down_projection. # - w3 (optional): the gate projection. - if not lora_config or ("input_linear" in target_modules): + if not lora_config: self.w1 = ScatteredExperts( in_features=self.hidden_size, out_features=self.intermediate_size, @@ -297,7 +286,7 @@ def __init__( device=device, lora_config=lora_config, ) - if not lora_config or ("output_linear" in target_modules): + if not lora_config self.w2 = ScatteredExperts( in_features=self.intermediate_size, out_features=self.hidden_size, @@ -308,7 +297,7 @@ def __init__( device=device, lora_config=lora_config, ) - if not lora_config or ("input_linear" in target_modules): + if not lora_config: if mlp_arch == SCATTERMOE_SPEC_HAS_GATE: self.w3 = ScatteredExperts( in_features=self.hidden_size, From 85e62dfa646d14e86e79ee439b0bb32cf18b45b0 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 17 Apr 2025 07:38:13 -0400 Subject: [PATCH 06/15] fix: target modules Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/scattermoe.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index 5b3d542d..fd46ab68 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -231,8 +231,6 @@ def __init__( not has_bias ), "ScatterMoE currently unable to handle bias in both gates and experts." - target_modules = None - if lora_config is not None: # since this is self implemented, we really only support basic lora funcs assert ( From 67cdfb80f9c728fb6b10b5463515a357162d5a0f Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 17 Apr 2025 07:41:42 -0400 Subject: [PATCH 07/15] fix Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/checkpoint_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index b2ca205c..09b6c1c1 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -592,7 +592,7 @@ def recover_safetensors_from_dcp( new_state_dict[name] = param # recover the original state dict - state_dict = recover_original_state_dict_from_checkpoint(state_dict, _name_or_path) + state_dict = recover_original_state_dict_from_checkpoint(new_state_dict, _name_or_path) # save it as a safetensors file save_sharded_safetensors( From b60eb66cbe99d9e297a37d460b7480efb3a36a46 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 17 Apr 2025 07:43:19 -0400 Subject: [PATCH 08/15] fmt + lint Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/checkpoint_utils.py | 4 +++- .../src/fms_acceleration_moe/utils/scattermoe.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 09b6c1c1..e0d2870a 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -592,7 +592,9 @@ def recover_safetensors_from_dcp( new_state_dict[name] = param # recover the original state dict - state_dict = recover_original_state_dict_from_checkpoint(new_state_dict, _name_or_path) + state_dict = recover_original_state_dict_from_checkpoint( + new_state_dict, _name_or_path + ) # save it as a safetensors file save_sharded_safetensors( diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index fd46ab68..543f5012 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -284,7 +284,7 @@ def __init__( device=device, lora_config=lora_config, ) - if not lora_config + if not lora_config: self.w2 = ScatteredExperts( in_features=self.intermediate_size, out_features=self.hidden_size, From a061780ab54856cff3e4cad87a9abc45afb795bd Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 17 Apr 2025 20:03:42 -0400 Subject: [PATCH 09/15] remove lora config from expert weights Signed-off-by: Will Johnson --- .../fms_acceleration_moe/utils/scattermoe.py | 93 +++++++++---------- 1 file changed, 44 insertions(+), 49 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index 543f5012..71de7601 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -273,8 +273,26 @@ def __init__( # - w1: the up_projection. # - w2: the down_projection. # - w3 (optional): the gate projection. - if not lora_config: - self.w1 = ScatteredExperts( + self.w1 = ScatteredExperts( + in_features=self.hidden_size, + out_features=self.intermediate_size, + num_experts=self.num_experts, + fan_out=self.top_k if not self.all_to_all else 1, + grouped_out=True, + dtype=dtype, + device=device, + ) + self.w2 = ScatteredExperts( + in_features=self.intermediate_size, + out_features=self.hidden_size, + num_experts=self.num_experts, + fan_out=1, + grouped_in=True, + dtype=dtype, + device=device, + ) + if mlp_arch == SCATTERMOE_SPEC_HAS_GATE: + self.w3 = ScatteredExperts( in_features=self.hidden_size, out_features=self.intermediate_size, num_experts=self.num_experts, @@ -282,31 +300,11 @@ def __init__( grouped_out=True, dtype=dtype, device=device, - lora_config=lora_config, - ) - if not lora_config: - self.w2 = ScatteredExperts( - in_features=self.intermediate_size, - out_features=self.hidden_size, - num_experts=self.num_experts, - fan_out=1, - grouped_in=True, - dtype=dtype, - device=device, - lora_config=lora_config, ) - if not lora_config: - if mlp_arch == SCATTERMOE_SPEC_HAS_GATE: - self.w3 = ScatteredExperts( - in_features=self.hidden_size, - out_features=self.intermediate_size, - num_experts=self.num_experts, - fan_out=self.top_k if not self.all_to_all else 1, - grouped_out=True, - dtype=dtype, - device=device, - lora_config=lora_config, - ) + if lora_config: + self.w1.requires_grad = False + self.w2.requires_grad = False + self.w3.requires_grad = False # referenced from dolomite-engine def _compute_routing_weights(self, hidden_states: torch.Tensor): @@ -455,39 +453,36 @@ def forward(self, hidden_states: torch.Tensor): ) # compute the up projection - if hasattr(self, "w1"): - out = self.w1( + out = self.w1( + hidden_states, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + ) + out = self.activation(out) + + # - if the arch has a seperate gate projection + if self.w3: + out *= self.w3( hidden_states, sorted_expert_idxs, sorted_scattered_idxs, padded_block_idxs, expert_offsets, ) - out = self.activation(out) - - # - if the arch has a seperate gate projection - if hasattr(self, "w3"): - if self.w3: - out *= self.w3( - hidden_states, - sorted_expert_idxs, - sorted_scattered_idxs, - padded_block_idxs, - expert_offsets, - ) # compute the down projection # - if no all-to-all processing, then depend on # scattermoe kernel to perform the final scattering - if hasattr(self, "w2"): - hidden_states = self.w2( - out, - sorted_expert_idxs, - sorted_scattered_idxs, - padded_block_idxs, - expert_offsets, - gates=(None if self.all_to_all else routing_weights), - ) + hidden_states = self.w2( + out, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates=(None if self.all_to_all else routing_weights), + ) # maybe scatter hidden_states = self._maybe_scatter( From a9176a6897bd879f90e40b48533bd02bd2594eba Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 17 Apr 2025 20:23:38 -0400 Subject: [PATCH 10/15] pylint Signed-off-by: Will Johnson --- plugins/accelerated-moe/.pylintrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/.pylintrc b/plugins/accelerated-moe/.pylintrc index 4ddccea1..66dfa353 100644 --- a/plugins/accelerated-moe/.pylintrc +++ b/plugins/accelerated-moe/.pylintrc @@ -281,7 +281,7 @@ ignored-parents= max-args=5 # Maximum number of attributes for a class (see R0902). -max-attributes=7 +max-attributes=8 # Maximum number of boolean expressions in an if statement (see R0916). max-bool-expr=5 From 0d4a32d19f2e7aa5ea23c936720ea745662d02fc Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 17 Apr 2025 20:37:34 -0400 Subject: [PATCH 11/15] fix: pass in lora config Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/framework_plugin_scattermoe.py | 2 ++ .../src/fms_acceleration_moe/utils/scattermoe.py | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py index 48351cfd..52040e0e 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py @@ -77,6 +77,7 @@ def augmentation( modifiable_args: Tuple[LoraConfig], ): rank, world_size = 0, 1 + (peft_config,) = modifiable_args if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() # we do not need to use the fallback as this is wrapped in an `is_initialized` block @@ -97,6 +98,7 @@ def augmentation( ep_degree=self._ep_degree, disable_distributed=self._disable_distributed, mixed_precision=False, # Currently this is hardcoded to OFF + lora_config=peft_config, ) return model, modifiable_args diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index 71de7601..160c2db4 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -302,9 +302,9 @@ def __init__( device=device, ) if lora_config: - self.w1.requires_grad = False - self.w2.requires_grad = False - self.w3.requires_grad = False + self.w1.weight.requires_grad = False + self.w2.weight.requires_grad = False + self.w3.weight.requires_grad = False # referenced from dolomite-engine def _compute_routing_weights(self, hidden_states: torch.Tensor): From 7729f7e661032e23bb940f09a28c9daf2b063449 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 17 Apr 2025 20:39:59 -0400 Subject: [PATCH 12/15] fix: requires grad Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/scattermoe_prepare.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py index 9adcc47b..c7162894 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py @@ -92,7 +92,8 @@ def _hook(grad): # install gradient scaling hook if KEY_SCATTERMOE_ROUTER not in weight_name: - param.register_hook(_hook) + if param.requires_grad: + param.register_hook(_hook) # register the sharded parameter onto the megablocks.dmoe mod.register_parameter(name, param) From 855a9f07886b9625fd5ff254a037c87cb6926b6d Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 17 Apr 2025 20:51:11 -0400 Subject: [PATCH 13/15] comments Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/checkpoint_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index e0d2870a..7bd4a390 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -580,11 +580,15 @@ def recover_safetensors_from_dcp( # get the state_dict state_dict = loader(checkpoint_dir) + # filter out additional names created by lora tuning + # create switch based on state dict for future use new_state_dict = {} lora = False for name, param in state_dict.items(): + # if lora weight, set lora switch to true if "lora_A" in name or "lora_B" in name: lora = True + # if lora naming convention, convert to traditional if "base_model.model." in name: name = name.replace("base_model.model.", "", 1) if "default." in name: From 1a0b04944a5591fd90606b6d8de02467a9ef9735 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Fri, 18 Apr 2025 10:54:24 -0400 Subject: [PATCH 14/15] note temp fix Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/scattermoe.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index 160c2db4..52d04ba8 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -301,6 +301,9 @@ def __init__( dtype=dtype, device=device, ) + + # Temporary fix: in future will want to require grad when lora tuning + # on self.wx.weight.lora_A/lora_B if lora_config: self.w1.weight.requires_grad = False self.w2.weight.requires_grad = False From a4c4f2281b7f234503658e0a451177f85ed84d12 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Fri, 18 Apr 2025 11:44:13 -0400 Subject: [PATCH 15/15] fix: don't turn off requires grad Signed-off-by: Will Johnson --- .../src/fms_acceleration_moe/utils/scattermoe.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index 52d04ba8..44125acd 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -302,13 +302,6 @@ def __init__( device=device, ) - # Temporary fix: in future will want to require grad when lora tuning - # on self.wx.weight.lora_A/lora_B - if lora_config: - self.w1.weight.requires_grad = False - self.w2.weight.requires_grad = False - self.w3.weight.requires_grad = False - # referenced from dolomite-engine def _compute_routing_weights(self, hidden_states: torch.Tensor):