diff --git a/plugins/accelerated-moe/.pylintrc b/plugins/accelerated-moe/.pylintrc index 4ddccea1..66dfa353 100644 --- a/plugins/accelerated-moe/.pylintrc +++ b/plugins/accelerated-moe/.pylintrc @@ -281,7 +281,7 @@ ignored-parents= max-args=5 # Maximum number of attributes for a class (see R0902). -max-attributes=7 +max-attributes=8 # Maximum number of boolean expressions in an if statement (see R0916). max-bool-expr=5 diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py index 48351cfd..52040e0e 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py @@ -77,6 +77,7 @@ def augmentation( modifiable_args: Tuple[LoraConfig], ): rank, world_size = 0, 1 + (peft_config,) = modifiable_args if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() # we do not need to use the fallback as this is wrapped in an `is_initialized` block @@ -97,6 +98,7 @@ def augmentation( ep_degree=self._ep_degree, disable_distributed=self._disable_distributed, mixed_precision=False, # Currently this is hardcoded to OFF + lora_config=peft_config, ) return model, modifiable_args diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index e6fe1ba6..7bd4a390 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -53,6 +53,8 @@ KEY_MODEL = "model" KEY_OPTIMIZER = "optimizer" +ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors" + # Below are rewrite of HF FSDP model saving functions to be able to handle # that the parameters are now a mixture of regular and Dtensors. # - these functions are found in accelerate.utils.fsdp_utils.py @@ -110,16 +112,30 @@ def save_fsdp_optimizer( # get the state dicts for model and optimize (model_state_dict, optimizer_state_dict) = get_state_dict(model, optimizer) + # filter out lora state dict + lora_state_dict = { + k: v for k, v in model_state_dict.items() if "lora_A" in k or "lora_B" in k + } + # - save model - ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") - os.makedirs(ckpt_model, exist_ok=True) - logger.info(f"Saving model to {ckpt_model}") - dcp.save( - state_dict={KEY_MODEL: model_state_dict}, - storage_writer=dcp.FileSystemWriter(ckpt_model), - planner=DefaultSavePlanner(), - ) - logger.info(f"Model saved to {ckpt_model}") + if lora_state_dict: + ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") + os.makedirs(ckpt_model, exist_ok=True) + logger.info(f"Saving lora model to {ckpt_model}") + dcp.save( + state_dict={KEY_MODEL: lora_state_dict}, + storage_writer=dcp.FileSystemWriter(ckpt_model), + planner=DefaultSavePlanner(), + ) + else: + ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") + os.makedirs(ckpt_model, exist_ok=True) + logger.info(f"Saving ft model to {ckpt_model}") + dcp.save( + state_dict={KEY_MODEL: model_state_dict}, + storage_writer=dcp.FileSystemWriter(ckpt_model), + planner=DefaultSavePlanner(), + ) # - save optimizer ckpt_opt = os.path.join(output_dir, f"{OPTIMIZER_NAME}_{optimizer_index}") @@ -467,30 +483,54 @@ def save_sharded_safetensors( save_directory: str, metadata: Dict, max_shard_size: Union[int, str] = "5GB", + lora: bool = False, ): - filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace( - ".safetensors", "{suffix}.safetensors" - ) - state_dict_split = split_torch_state_dict_into_shards( - input_state_dict, - filename_pattern=filename_pattern, - max_shard_size=max_shard_size, - ) - index = { - "metadata": state_dict_split.metadata, - "weight_map": state_dict_split.tensor_to_filename, - } - # Save the index - with open( - os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8" - ) as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" - f.write(content) + if not lora: + filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) + state_dict_split = split_torch_state_dict_into_shards( + input_state_dict, + filename_pattern=filename_pattern, + max_shard_size=max_shard_size, + ) - filename_to_tensors = state_dict_split.filename_to_tensors.items() - for shard_file, tensors in filename_to_tensors: - shard = {tensor: input_state_dict[tensor].contiguous() for tensor in tensors} - save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata) + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + # Save the index + with open( + os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8" + ) as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + + filename_to_tensors = state_dict_split.filename_to_tensors.items() + for shard_file, tensors in filename_to_tensors: + shard = { + tensor: input_state_dict[tensor].contiguous() for tensor in tensors + } + save_file( + shard, os.path.join(save_directory, shard_file), metadata=metadata + ) + else: + filename_pattern = ADAPTER_SAFE_WEIGHTS_NAME.replace( + ".bin", "{suffix}.bin" + ).replace(".safetensors", "{suffix}.safetensors") + state_dict_split = split_torch_state_dict_into_shards( + input_state_dict, + filename_pattern=filename_pattern, + max_shard_size=max_shard_size, + ) + filename_to_tensors = state_dict_split.filename_to_tensors.items() + for shard_file, tensors in filename_to_tensors: + shard = { + tensor: input_state_dict[tensor].contiguous() for tensor in tensors + } + save_file( + shard, os.path.join(save_directory, shard_file), metadata=metadata + ) # --------------------------- SCRIPT ------------------------- @@ -540,14 +580,32 @@ def recover_safetensors_from_dcp( # get the state_dict state_dict = loader(checkpoint_dir) + # filter out additional names created by lora tuning + # create switch based on state dict for future use + new_state_dict = {} + lora = False + for name, param in state_dict.items(): + # if lora weight, set lora switch to true + if "lora_A" in name or "lora_B" in name: + lora = True + # if lora naming convention, convert to traditional + if "base_model.model." in name: + name = name.replace("base_model.model.", "", 1) + if "default." in name: + name = name.replace("default.", "", 1) + new_state_dict[name] = param + # recover the original state dict - state_dict = recover_original_state_dict_from_checkpoint(state_dict, _name_or_path) + state_dict = recover_original_state_dict_from_checkpoint( + new_state_dict, _name_or_path + ) # save it as a safetensors file save_sharded_safetensors( {k: v.contiguous() for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"}, + lora=lora, ) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index e16943b1..44125acd 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -17,7 +17,6 @@ # Third Party from peft import LoraConfig -from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND from torch.distributed._tensor import DTensor # pylint: disable=import-error @@ -237,10 +236,6 @@ def __init__( assert ( lora_config.bias == "none" ), "ScatterMoE currently unable to handle bias in the lora adapters" - assert ( - lora_config.target_modules == INCLUDE_LINEAR_LAYERS_SHORTHAND - or INCLUDE_LINEAR_LAYERS_SHORTHAND in lora_config.target_modules - ), "ScatterMoe currently only handles lora adapters on all linears." assert lora_config.init_lora_weights in { True, @@ -286,7 +281,6 @@ def __init__( grouped_out=True, dtype=dtype, device=device, - lora_config=lora_config, ) self.w2 = ScatteredExperts( in_features=self.intermediate_size, @@ -296,7 +290,6 @@ def __init__( grouped_in=True, dtype=dtype, device=device, - lora_config=lora_config, ) if mlp_arch == SCATTERMOE_SPEC_HAS_GATE: self.w3 = ScatteredExperts( @@ -307,7 +300,6 @@ def __init__( grouped_out=True, dtype=dtype, device=device, - lora_config=lora_config, ) # referenced from dolomite-engine diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py index 9adcc47b..c7162894 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py @@ -92,7 +92,8 @@ def _hook(grad): # install gradient scaling hook if KEY_SCATTERMOE_ROUTER not in weight_name: - param.register_hook(_hook) + if param.requires_grad: + param.register_hook(_hook) # register the sharded parameter onto the megablocks.dmoe mod.register_parameter(name, param)