diff --git a/plugins/accelerated-moe/.pylintrc b/plugins/accelerated-moe/.pylintrc index 4ddccea1..4141cba3 100644 --- a/plugins/accelerated-moe/.pylintrc +++ b/plugins/accelerated-moe/.pylintrc @@ -476,7 +476,7 @@ notes-rgx= [REFACTORING] # Maximum number of nested blocks for function / method body -max-nested-blocks=5 +max-nested-blocks=6 # Complete name of functions that never returns. When checking for # inconsistent-return-statements if a never returning function is called then diff --git a/plugins/accelerated-moe/README.md b/plugins/accelerated-moe/README.md index d05cdc82..634166ff 100644 --- a/plugins/accelerated-moe/README.md +++ b/plugins/accelerated-moe/README.md @@ -8,7 +8,7 @@ This library contains plugins to accelerate finetuning with the following optimi Plugin | Description | Depends | Loading | Augmentation | Callbacks --|--|--|--|--|-- -[scattermoe](./src/fms_acceleration_moe/framework_plugin_scattermoe.py) | MoE Expert Parallel with Triton Kernels from scattermoe (& megablocks) | ScatterMoE / extracted kernels from megablocks | ✅ | | ✅ +[scattermoe](./src/fms_acceleration_moe/framework_plugin_scattermoe.py) | MoE Expert Parallel with Triton Kernels from scattermoe (& megablocks) | ScatterMoE / extracted kernels from megablocks | | ✅ | ✅ ## Adding New Models @@ -33,6 +33,8 @@ python -m fms_acceleration_moe.utils.checkpoint_utils \ mistralai/Mixtral-8x7B-Instruct-v0.1 ``` +If running with fms-hf-tuning, this script runs automatically if the `fast_moe` parameter is set. + ## Code Extracted from Megablocks Notes on code extraction: @@ -81,9 +83,8 @@ Triton Kernels are copied into [scattermoe_utils](./src/fms_acceleration_moe/uti ### Known Issues These are currently some known issues not yet resolved: -- should eventually remove the dependency on an external `kernel-hyperdrive` repository. -- now support only loading *sharded* `safetensor` non-GGUF MoE checkpoints. This is a reasonable assumption since MoE checkpoints are typically above the size limit that prevents it being saved into a single checkpoint filed. - when used together with FSDP, the FSDP's `clip_grad_norm` will not properly compute for `ScatterMoE`, see [issue here](https://github.com/foundation-model-stack/fms-acceleration/issues/109). +- when used to lora train a model, if training experts on adapter model, the model will fail to run inference in vLLM/vanilla HF because of restrictions to parameter types. If running inference do not select `input_linear` and `output_linear` as target modules when lora training. diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py index 48351cfd..52040e0e 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py @@ -77,6 +77,7 @@ def augmentation( modifiable_args: Tuple[LoraConfig], ): rank, world_size = 0, 1 + (peft_config,) = modifiable_args if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() # we do not need to use the fallback as this is wrapped in an `is_initialized` block @@ -97,6 +98,7 @@ def augmentation( ep_degree=self._ep_degree, disable_distributed=self._disable_distributed, mixed_precision=False, # Currently this is hardcoded to OFF + lora_config=peft_config, ) return model, modifiable_args diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index e6fe1ba6..12a62ca4 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -53,6 +53,10 @@ KEY_MODEL = "model" KEY_OPTIMIZER = "optimizer" +ADAPTER_CONFIG_NAME = "adapter_config.json" +ADAPTER_WEIGHTS_NAME = "adapter_model.bin" +ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors" + # Below are rewrite of HF FSDP model saving functions to be able to handle # that the parameters are now a mixture of regular and Dtensors. # - these functions are found in accelerate.utils.fsdp_utils.py @@ -110,15 +114,30 @@ def save_fsdp_optimizer( # get the state dicts for model and optimize (model_state_dict, optimizer_state_dict) = get_state_dict(model, optimizer) - # - save model - ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") - os.makedirs(ckpt_model, exist_ok=True) - logger.info(f"Saving model to {ckpt_model}") - dcp.save( - state_dict={KEY_MODEL: model_state_dict}, - storage_writer=dcp.FileSystemWriter(ckpt_model), - planner=DefaultSavePlanner(), - ) + # filter out lora state dict + lora_state_dict = { + k: v for k, v in model_state_dict.items() if "lora_A" in k or "lora_B" in k + } + + # - save mode + if lora_state_dict: + ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") + os.makedirs(ckpt_model, exist_ok=True) + logger.info(f"Saving lora model to {ckpt_model}") + dcp.save( + state_dict={KEY_MODEL: lora_state_dict}, + storage_writer=dcp.FileSystemWriter(ckpt_model), + planner=DefaultSavePlanner(), + ) + else: + ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") + os.makedirs(ckpt_model, exist_ok=True) + logger.info(f"Saving ft model to {ckpt_model}") + dcp.save( + state_dict={KEY_MODEL: model_state_dict}, + storage_writer=dcp.FileSystemWriter(ckpt_model), + planner=DefaultSavePlanner(), + ) logger.info(f"Model saved to {ckpt_model}") # - save optimizer @@ -303,6 +322,7 @@ def get_state_dict_from_safe_checkpoint(safe_checkpoint_dir: str): # can restore the checkpoint to be loaded by the original architecture. def recover_original_state_dict_from_checkpoint( sd: Dict, + lora: bool, pretrained_model_name_or_path: str = None, ): """ @@ -323,6 +343,19 @@ def recover_original_state_dict_from_checkpoint( # config config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path) + # if lora, check for input/output layers + ip_op_layers = False + router_layer = False + if lora: + for name, _ in sd.items(): + if "w1" in name: + ip_op_layers = True + break + for name, _ in sd.items(): + if "router" in name: + router_layer = True + break + ( _, router_name, @@ -340,11 +373,13 @@ def recover_original_state_dict_from_checkpoint( def _infer_prefixes_and_module_names( sd_keys: List[str], - min_count: int = 3, + lora: bool, ): + min_count = 2 if lora else 3 + _name = "|".join([PARAM_NAME_ROUTER_SCATTERMOE, *PARAM_NAME_WEIGHT_SCATTERMOE]) # pylint: disable=anomalous-backslash-in-string - _reg = re.compile(f"(.*)\.({_name})\.weight") + _reg = re.compile(rf"(.*)\.({_name})\.(?:weight|lora_A|lora_B)") found = {} for k in sd_keys: @@ -364,7 +399,7 @@ def _infer_prefixes_and_module_names( return results - for prefix in _infer_prefixes_and_module_names(sd.keys()): + for prefix in _infer_prefixes_and_module_names(sd.keys(), lora): prefix = prefix.split(".") prefix, module_name = ".".join(prefix[:-1]), prefix[-1] @@ -390,6 +425,8 @@ def _infer_prefixes_and_module_names( module_name, router_name, expert_name, + ip_op_layers=ip_op_layers, + router_layer=router_layer, ) model2scatter = defaultdict(dict) @@ -398,6 +435,7 @@ def _infer_prefixes_and_module_names( # model param and they need to be cat for scatter_key, list_of_params in checkpoint_metadata.items(): scatter_key_fqdn = ".".join([prefix, module_name, scatter_key]) + scatter_param = sd[scatter_key_fqdn] # remove from state dict @@ -443,8 +481,36 @@ def _infer_prefixes_and_module_names( len(scatter_keys) > 0 ), f"Obtained zero scatter keys for model_key '{model_key}'" - if len(scatter_keys) == 1: + if lora: + for i, lora_key in enumerate(scatter_keys): + model_key_parts = model_key.split(".") + weight_index = model_key_parts.index("weight") + # Replace the "layer.weight" part with "layer.lora_A.weight" or + # "layer.lora_B.weight" + if "lora_A" in lora_key: + model_key_parts[weight_index] = "lora_A.weight" + elif "lora_B" in lora_key: + model_key_parts[weight_index] = "lora_B.weight" + # Rebuild the model_key and assign the corresponding scatter_param + new_model_key = ".".join(model_key_parts) + if len(scatter_keys) == 2: + sd[new_model_key] = scatter_params[lora_key] + else: + if "lora_A" in new_model_key: + filtered_keys = [k for k in scatter_keys if "lora_A" in k] + elif "lora_B" in new_model_key: + filtered_keys = [k for k in scatter_keys if "lora_B" in k] + else: + raise ValueError( + f"Unexpected LoRA key type in {new_model_key}" + ) + sd[new_model_key] = torch.cat( + [scatter_params[k] for k in filtered_keys], dim=1 + ) + + elif len(scatter_keys) == 1: sd[model_key] = scatter_params[scatter_keys[0]] + else: # unfortunately, there this is a in # scattermoe_state_dict._maybe_reshape_scattermoe_expert_weights @@ -466,31 +532,56 @@ def save_sharded_safetensors( input_state_dict: Dict, save_directory: str, metadata: Dict, + lora: bool, max_shard_size: Union[int, str] = "5GB", ): - filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace( - ".safetensors", "{suffix}.safetensors" - ) - state_dict_split = split_torch_state_dict_into_shards( - input_state_dict, - filename_pattern=filename_pattern, - max_shard_size=max_shard_size, - ) - index = { - "metadata": state_dict_split.metadata, - "weight_map": state_dict_split.tensor_to_filename, - } - # Save the index - with open( - os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8" - ) as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" - f.write(content) - filename_to_tensors = state_dict_split.filename_to_tensors.items() - for shard_file, tensors in filename_to_tensors: - shard = {tensor: input_state_dict[tensor].contiguous() for tensor in tensors} - save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata) + if not lora: + filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) + state_dict_split = split_torch_state_dict_into_shards( + input_state_dict, + filename_pattern=filename_pattern, + max_shard_size=max_shard_size, + ) + + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + # Save the index + with open( + os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8" + ) as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + + filename_to_tensors = state_dict_split.filename_to_tensors.items() + for shard_file, tensors in filename_to_tensors: + shard = { + tensor: input_state_dict[tensor].contiguous() for tensor in tensors + } + save_file( + shard, os.path.join(save_directory, shard_file), metadata=metadata + ) + else: + filename_pattern = ADAPTER_SAFE_WEIGHTS_NAME.replace( + ".bin", "{suffix}.bin" + ).replace(".safetensors", "{suffix}.safetensors") + state_dict_split = split_torch_state_dict_into_shards( + input_state_dict, + filename_pattern=filename_pattern, + max_shard_size=max_shard_size, + ) + filename_to_tensors = state_dict_split.filename_to_tensors.items() + for shard_file, tensors in filename_to_tensors: + shard = { + tensor: input_state_dict[tensor].contiguous() for tensor in tensors + } + save_file( + shard, os.path.join(save_directory, shard_file), metadata=metadata + ) # --------------------------- SCRIPT ------------------------- @@ -540,14 +631,28 @@ def recover_safetensors_from_dcp( # get the state_dict state_dict = loader(checkpoint_dir) + lora = False + new_state_dict = {} + for name, param in state_dict.items(): + if "lora_A" in name or "lora_B" in name: + lora = True + if "base_model.model." in name: + name = name.replace("base_model.model.", "", 1) + if "default." in name: + name = name.replace("default.", "", 1) + new_state_dict[name] = param + # recover the original state dict - state_dict = recover_original_state_dict_from_checkpoint(state_dict, _name_or_path) + state_dict = recover_original_state_dict_from_checkpoint( + new_state_dict, lora, _name_or_path + ) # save it as a safetensors file save_sharded_safetensors( {k: v.contiguous() for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"}, + lora=lora, ) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index e16943b1..2bc98e46 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -17,7 +17,6 @@ # Third Party from peft import LoraConfig -from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND from torch.distributed._tensor import DTensor # pylint: disable=import-error @@ -232,15 +231,24 @@ def __init__( not has_bias ), "ScatterMoE currently unable to handle bias in both gates and experts." + target_modules = None + if lora_config is not None: # since this is self implemented, we really only support basic lora funcs assert ( lora_config.bias == "none" ), "ScatterMoE currently unable to handle bias in the lora adapters" - assert ( - lora_config.target_modules == INCLUDE_LINEAR_LAYERS_SHORTHAND - or INCLUDE_LINEAR_LAYERS_SHORTHAND in lora_config.target_modules - ), "ScatterMoe currently only handles lora adapters on all linears." + + if lora_config and hasattr(lora_config, "target_modules"): + target_modules = lora_config.target_modules + + required_modules = ["router", "layer", "all-linear"] + if "input_linear" in target_modules or "output_linear" in target_modules: + # Assert that the target modules also include at least one from required_modules + assert any( + module in target_modules for module in required_modules + ), "If 'input_linear' or 'output_linear' is included as a target module,\ + 'router' must also be included" assert lora_config.init_lora_weights in { True, @@ -278,28 +286,8 @@ def __init__( # - w1: the up_projection. # - w2: the down_projection. # - w3 (optional): the gate projection. - self.w1 = ScatteredExperts( - in_features=self.hidden_size, - out_features=self.intermediate_size, - num_experts=self.num_experts, - fan_out=self.top_k if not self.all_to_all else 1, - grouped_out=True, - dtype=dtype, - device=device, - lora_config=lora_config, - ) - self.w2 = ScatteredExperts( - in_features=self.intermediate_size, - out_features=self.hidden_size, - num_experts=self.num_experts, - fan_out=1, - grouped_in=True, - dtype=dtype, - device=device, - lora_config=lora_config, - ) - if mlp_arch == SCATTERMOE_SPEC_HAS_GATE: - self.w3 = ScatteredExperts( + if not lora_config or ("input_linear" in target_modules): + self.w1 = ScatteredExperts( in_features=self.hidden_size, out_features=self.intermediate_size, num_experts=self.num_experts, @@ -309,6 +297,29 @@ def __init__( device=device, lora_config=lora_config, ) + if not lora_config or ("output_linear" in target_modules): + self.w2 = ScatteredExperts( + in_features=self.intermediate_size, + out_features=self.hidden_size, + num_experts=self.num_experts, + fan_out=1, + grouped_in=True, + dtype=dtype, + device=device, + lora_config=lora_config, + ) + if not lora_config or ("input_linear" in target_modules): + if mlp_arch == SCATTERMOE_SPEC_HAS_GATE: + self.w3 = ScatteredExperts( + in_features=self.hidden_size, + out_features=self.intermediate_size, + num_experts=self.num_experts, + fan_out=self.top_k if not self.all_to_all else 1, + grouped_out=True, + dtype=dtype, + device=device, + lora_config=lora_config, + ) # referenced from dolomite-engine def _compute_routing_weights(self, hidden_states: torch.Tensor): @@ -457,36 +468,39 @@ def forward(self, hidden_states: torch.Tensor): ) # compute the up projection - out = self.w1( - hidden_states, - sorted_expert_idxs, - sorted_scattered_idxs, - padded_block_idxs, - expert_offsets, - ) - out = self.activation(out) - - # - if the arch has a seperate gate projection - if self.w3: - out *= self.w3( + if hasattr(self, "w1"): + out = self.w1( hidden_states, sorted_expert_idxs, sorted_scattered_idxs, padded_block_idxs, expert_offsets, ) + out = self.activation(out) + + # - if the arch has a seperate gate projection + if hasattr(self, "w3"): + if self.w3: + out *= self.w3( + hidden_states, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + ) # compute the down projection # - if no all-to-all processing, then depend on # scattermoe kernel to perform the final scattering - hidden_states = self.w2( - out, - sorted_expert_idxs, - sorted_scattered_idxs, - padded_block_idxs, - expert_offsets, - gates=(None if self.all_to_all else routing_weights), - ) + if hasattr(self, "w2"): + hidden_states = self.w2( + out, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates=(None if self.all_to_all else routing_weights), + ) # maybe scatter hidden_states = self._maybe_scatter( diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_constants.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_constants.py index 2a6847be..a6b95541 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_constants.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_constants.py @@ -24,7 +24,9 @@ KEY_EXPERT_PARALLEL = "expert_parallel" DIM_EXPERT = 0 -KEY_SCATTERMOE_ROUTER = PARAM_NAME_ROUTER_SCATTERMOE + ".weight" +KEY_SCATTERMOE_ROUTER = "router.weight" +KEY_SCATTERMOE_LORA_A_ROUTER = "router.lora_A.weight" +KEY_SCATTERMOE_LORA_B_ROUTER = "router.lora_B.weight" # Currently out ScatterMoE drop supports an up/down proj, and # and optional gate_proj. diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py index 9adcc47b..eaf11aff 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py @@ -33,6 +33,8 @@ FILE_SAFETENSOR_INDEX, KEY_EXPERT_PARALLEL, KEY_REPLICATE, + KEY_SCATTERMOE_LORA_A_ROUTER, + KEY_SCATTERMOE_LORA_B_ROUTER, KEY_SCATTERMOE_ROUTER, get_scattermoe_conv_spec_from_archs, ) @@ -66,7 +68,11 @@ def _hook(grad): for weight_name, param in state_dict.items(): - if KEY_SCATTERMOE_ROUTER in weight_name: + if ( + KEY_SCATTERMOE_ROUTER in weight_name + or KEY_SCATTERMOE_LORA_A_ROUTER in weight_name + or KEY_SCATTERMOE_LORA_B_ROUTER in weight_name + ): # if its the router, replicate param = distribute_tensor(param, device_mesh, reps + [Replicate()]) elif param.shape[0] > num_experts_per_device: @@ -91,8 +97,13 @@ def _hook(grad): ) # install gradient scaling hook - if KEY_SCATTERMOE_ROUTER not in weight_name: - param.register_hook(_hook) + if ( + KEY_SCATTERMOE_ROUTER not in weight_name + and KEY_SCATTERMOE_LORA_A_ROUTER not in weight_name + and KEY_SCATTERMOE_LORA_B_ROUTER not in weight_name + ): + if param.requires_grad: + param.register_hook(_hook) # register the sharded parameter onto the megablocks.dmoe mod.register_parameter(name, param) @@ -229,6 +240,10 @@ def prepare_scattermoe( weight_map = load_weight_map(loc, "model.safetensors", FILE_SAFETENSOR_INDEX) + target_modules = None + if lora_config and hasattr(lora_config, "target_modules"): + target_modules = lora_config.target_modules + # e.g., prefix: 'model.layers.0', # module_name: 'block_sparse_moe' for prefix, (module_name, _, has_bias) in tqdm( @@ -240,6 +255,7 @@ def prepare_scattermoe( module_name, router_name, "|".join(expert_name), + target_modules=target_modules, ) # the parent module @@ -329,21 +345,37 @@ def prepare_scattermoe( elif "lora_B" in name: torch.nn.init.normal_(sd[name]) - if device_mesh is None: - # - if not on meta, just load the state dict - # - and then put on the device - if not is_fsdp_enabled() or is_local_dist_rank_0(): - moe.load_state_dict(sd) - moe = moe.to(device) - else: - # - otherwise, we need to distribtue and will - # replace the parameters - load_experts_onto_device(moe, sd, device_mesh, num_experts_per_device) - # module swap - setattr(parent, module_name, moe) - - # - keep track of the name for returning - moe_module_names.add(module_name) + possible_target_modules = [ + "all-linear", + "router", + "layer", + "input_linear", + "output_linear", + ] + if ( + any( + module in (target_modules or []) + for module in possible_target_modules + ) + or not lora_config + ): + if device_mesh is None: + # - if not on meta, just load the state dict + # - and then put on the device + if not is_fsdp_enabled() or is_local_dist_rank_0(): + moe.load_state_dict(sd) + moe = moe.to(device) + else: + # - otherwise, we need to distribtue and will + # replace the parameters + load_experts_onto_device( + moe, sd, device_mesh, num_experts_per_device + ) + # module swap + setattr(parent, module_name, moe) + + # - keep track of the name for returning + moe_module_names.add(module_name) except ValueError as e: raise ValueError( diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index e13f6ba5..a436b2a6 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -26,6 +26,8 @@ # Local from .scattermoe_constants import ( DIM_EXPERT, + KEY_SCATTERMOE_LORA_A_ROUTER, + KEY_SCATTERMOE_LORA_B_ROUTER, KEY_SCATTERMOE_ROUTER, PARAM_NAME_WEIGHT_SCATTERMOE, ) @@ -86,6 +88,9 @@ def get_checkpoint_meta_from_sharded_safetensor( router_name: str = "gate", # e.g., named "gate" within block_sparse_moe expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe expert_map: Dict = None, # map -> [w1,w2,w3] + ip_op_layers: bool = False, # if input/output layers are detected in utils + router_layer: bool = False, # if router layer is detected in utils + target_modules: Dict = None, # target modules from prepare_scattermoe.py ) -> Dict[str, List[Tuple]]: """ utilty function to infer the mapping of ScatterMoe parameters @@ -106,6 +111,14 @@ def get_checkpoint_meta_from_sharded_safetensor( e.g., input_linear|output_linear|input_linear expert_map (dict): This is used with pattern ii) described above in expert_name. If not specified, will be the identity map, e.g., w1 -> w1 + ip_op_layers (bool): Boolean to determine if input/output layers are detected + in checkpoint utils. + router_layer (bool): Boolean to determine if router layer is detected + in checkpoint utils. + target_modules (dict): Target modules from scattermoe_prepare.py lora_config. + Used to check for input and output layers while preparing scattermoe. + + Returns: Map of used ScatterMoE weights to their files """ # insert in order @@ -147,6 +160,9 @@ def _insert(L: List, i: int, v): # `w1.weight`: [...] _map = defaultdict(list) prefix = f"{prefix}.{instance_name}." + + target_modules = target_modules or {} + for k, stfile in weight_map.items(): if not k.startswith(prefix): continue @@ -163,15 +179,33 @@ def _insert(L: List, i: int, v): f"'{router_name}' or expert_name '{expert_name}'" ) if m.group(1) == router_name: - _map[KEY_SCATTERMOE_ROUTER].append((k, stfile)) + if router_layer: + _map[KEY_SCATTERMOE_LORA_A_ROUTER].append((k, stfile)) + _map[KEY_SCATTERMOE_LORA_B_ROUTER].append((k, stfile)) + else: + _map[KEY_SCATTERMOE_ROUTER].append((k, stfile)) elif m.group(1) in expert_name: - index = m.group(2) - index = 0 if index is None else int(index) - mod = None - for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): - _insert(_map[f"{mod}.weight"], index, (k, stfile)) - - assert mod is not None, f"cannot map '{rel_k}'" + if ( + "input_linear" in target_modules and "output_linear" in target_modules + ) or ip_op_layers: + index = m.group(2) + index = 0 if index is None else int(index) + mod = None + if not ip_op_layers: + for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): + _insert(_map[f"{mod}.weight"], index, (k, stfile)) + else: + for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): + _insert(_map[f"{mod}.lora_A"], index, (k, stfile)) + _insert(_map[f"{mod}.lora_B"], index, (k, stfile)) + assert mod is not None, f"cannot map '{rel_k}'" + elif not ip_op_layers and not router_layer and not target_modules: + index = m.group(2) + index = 0 if index is None else int(index) + mod = None + for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): + _insert(_map[f"{mod}.weight"], index, (k, stfile)) + assert mod is not None, f"cannot map '{rel_k}'" if len(_map) == 0: raise ValueError( @@ -295,7 +329,11 @@ def get_state_dict_from_checkpoint_metadata( # go by one weight at a time. for scatter_key, vs in checkpoint_metadata.items(): - if KEY_SCATTERMOE_ROUTER in scatter_key: + if ( + KEY_SCATTERMOE_ROUTER in scatter_key + or KEY_SCATTERMOE_LORA_A_ROUTER in scatter_key + or KEY_SCATTERMOE_LORA_B_ROUTER in scatter_key + ): k, fi = vs[0] # only one item param = files[fi].get_tensor(k)