diff --git a/docs/training.md b/docs/training.md index 40f3c90b9..03b0f3998 100644 --- a/docs/training.md +++ b/docs/training.md @@ -254,4 +254,201 @@ datasets: remove_columns: all fn_kwargs: conversation_column_name: "messages" -``` \ No newline at end of file +``` + +## Long Context Training + +Long context training for instance to train on 128k sequence length can be performed using context parallel. + +### Model Architectures Supported + +1. Hybrid attention dense models. e.g. granite-4.0-h-1b +1. Hybrid attention moe models. e.g. ibm-granite/granite-4.0-h-small +1. SDPA attention dense models e.g. granite-4.0-1b +1. SDPA attention moe models e.g. ibm-research/moe-7b-1b-active-shared-experts, mixtral etc + +### Parallelisms Supported with Context Parallel + +1. Context Parallel + FSDP sharding +1. Context Parallel + FSDP sharding + Expert Parallel +1. Context Parallel + FSDP sharding + DP +1. Context Parallel + FSDP sharding + DP + Expert Parallel + +### Usage + +#### Enabling Context Parallel + +FSDPv2 is compulsory to use context parallel. FSDPv2 can be activated using the following accelerate config + +``` +compute_environment: LOCAL_MACHINE +distributed_type: FSDP +fsdp_config: + fsdp_version: "2" # turn on v2 of FSDP + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_backward_prefetch_policy: BACKWARD_PRE + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_sharding_strategy: FULL_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_cpu_ram_efficient_loading: true + fsdp_sync_module_states: true + fsdp_use_orig_params: true +``` + +Then, context parallel can be activated using the below accelerate config + +``` +compute_environment: LOCAL_MACHINE +distributed_type: FSDP +fsdp_config: + fsdp_version: "2" # turn on v2 of FSDP + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_backward_prefetch_policy: BACKWARD_PRE + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_sharding_strategy: FULL_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_cpu_ram_efficient_loading: true + fsdp_sync_module_states: true + fsdp_use_orig_params: true +use_parallelism_config: "true" # required to turn on parallelism feature +parallelism_config_cp_size: 2 # context parallel degree +machine_rank: 0 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +``` + +When using any model with mamba attention, its required to set the flag `--mcp` with context parallel degree. Further, for hybrid models that use combination of mamba and SDPA attention should use both `--mcp` and `parallelism_config_cp_size` options both having the same cp degree value. + +#### Enabling Context Parallel with Data Parallel + +Context parallel can be combined with data parallel using the `parallelism_config_dp_shard_size` parameter. + +``` +compute_environment: LOCAL_MACHINE +distributed_type: FSDP +fsdp_config: + fsdp_version: "2" # turn on v2 of FSDP + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_backward_prefetch_policy: BACKWARD_PRE + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_sharding_strategy: FULL_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_cpu_ram_efficient_loading: true + fsdp_sync_module_states: true + fsdp_use_orig_params: true +use_parallelism_config: "true" # required to turn on parallelism feature +parallelism_config_cp_size: 2 # context parallel degree +parallelism_config_dp_shard_size: 8 # data parallel degree +machine_rank: 0 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +``` + +To be noted that, context parallel degree multiplied by data parallel degree should be equal to the total number of GPUs being used. + +#### Enabling Mixed Precision + +Mixed precision has to be provided using `fsdp_mixed_precision_policy` parameter only. Do not use direct flags like `--bf16` or `mixed_precision` accelerate config parameter. + +``` +compute_environment: LOCAL_MACHINE +distributed_type: FSDP +fsdp_config: + fsdp_version: "2" # turn on v2 of FSDP + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_backward_prefetch_policy: BACKWARD_PRE + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_sharding_strategy: FULL_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_cpu_ram_efficient_loading: true + fsdp_sync_module_states: true + fsdp_use_orig_params: true + fsdp_mixed_precision_policy: "bf16" # mixed precision policy +use_parallelism_config: "true" # required to turn on parallelism feature +parallelism_config_cp_size: 2 # context parallel degree +parallelism_config_dp_shard_size: 8 # data parallel degree +machine_rank: 0 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +``` + +#### Gradient Checkpointing + +Optimal way to enable gradient checkpointing is using the accelerate config parameter `fsdp_activation_checkpointing` as shown below: + +``` +compute_environment: LOCAL_MACHINE +distributed_type: FSDP +fsdp_config: + fsdp_version: "2" # turn on v2 of FSDP + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_backward_prefetch_policy: BACKWARD_PRE + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_sharding_strategy: FULL_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_cpu_ram_efficient_loading: true + fsdp_sync_module_states: true + fsdp_use_orig_params: true + fsdp_mixed_precision_policy: "bf16" # mixed precision policy + fsdp_activation_checkpointing: true +use_parallelism_config: "true" # required to turn on parallelism feature +parallelism_config_cp_size: 2 # context parallel degree +parallelism_config_dp_shard_size: 8 # data parallel degree +machine_rank: 0 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +``` + +#### Enabling Context Parallel with Data Parallel and Expert Parallel + +For MoE models, expert parallel with MoE kernels can be enabled using the `--fast_moe` flag along with context and data parallelisms. The expert parallel degree is agnostic of context parallel degree. Therefore it can be used like described [here](./tuning-techniques.md#fms-acceleration). + +### Recommendations + +1. Keeping context parallelism within a node is usually optimal unless there is need for extremely long sequences like 256k. Given that, its optimal to choose the right cp degree in the multiple of 2 starting from 2 and upto 8. +2. Data parallel degree multiplied by context parallel degree should be equal to total number of GPUs being used. +3. Context parallel degree determinies number of chunks sequence has to be divided and distributed across GPUs, therefore it has to be choosen as minimium as needed to accommodate a sequence length. + +Further, below ablations can be used as reference configurations. + +#### Ablations + +##### Parity Experiments + +| model | experiment setting | loss | tps per gpu | +| -------- | -------- | ------- | ------- | +| ibm-granite/granite-4.0-h-tiny | cp8-ebs4-s8192-gas1 | 0.8059140625 | 973.6 | +| ibm-granite/granite-4.0-h-tiny | cp8-ebs4-s8192-gas1-ep8 | 0.80224609375 | 2367.6 | +| ibm-granite/granite-4.0-h-tiny | cp8-ebs4-s8192-gas2 | 0.8059765625 | NA | +| ibm-granite/granite-4.0-h-tiny | cp4-dp2-ebs4-s8192-gas1 | 0.802953125 | 953.4 | +| ibm-granite/granite-4.0-h-tiny | cp1-dp4-ep4-ebs4-s8192-gas1 | 0.7967056884765625 | 2576 | + +##### Long Context (sequence length is 131072 (128k)) + +| model | experiment setting | tps per gpu | GPU memory util ratio | +| -------- | -------- | ------- | ------- | +| ibm-granite/granite-4.0-h-tiny | cp8-ebs1-s131072-gas1-ep8 | 1462.8 | 0.5140136719 | +| ibm-granite/granite-4.0-h-small | cp8-ebs1-s131072-gas1-ep8 | 682.7 | 0.9887207031 | + +### Known Limitations + +1. load balancing is removed given limited support on mamba cp implementation. This could lead to potential throughput drops for trainings using causal mask. +2. Padding free and flash attention are not supported. diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 598152317..38c996158 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -412,6 +412,7 @@ def test_parse_arguments(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_copy) assert str(model_args.torch_dtype) == "torch.bfloat16" assert data_args.dataset_text_field == "output" @@ -438,6 +439,7 @@ def test_parse_arguments_defaults(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_defaults) assert str(model_args.torch_dtype) == "torch.bfloat16" assert model_args.use_flash_attn is False @@ -461,6 +463,7 @@ def test_parse_arguments_peft_method(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_pt) assert isinstance(tune_config, peft_config.PromptTuningConfig) @@ -480,6 +483,7 @@ def test_parse_arguments_peft_method(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_lora) assert isinstance(tune_config, peft_config.LoraConfig) assert not tune_config.target_modules diff --git a/tuning/config/acceleration_configs/__init__.py b/tuning/config/acceleration_configs/__init__.py index 3adf2171c..14d18f173 100644 --- a/tuning/config/acceleration_configs/__init__.py +++ b/tuning/config/acceleration_configs/__init__.py @@ -18,5 +18,6 @@ from .callbacks import get_additional_accel_framework_callbacks from .fast_moe import FastMoeConfig from .fused_ops_and_kernels import FusedOpsAndKernelsConfig +from .mcp import MCP, MCPConfig from .odm import ODM, ODMConfig from .quantized_lora_config import QuantizedLoraConfig diff --git a/tuning/config/acceleration_configs/acceleration_framework_config.py b/tuning/config/acceleration_configs/acceleration_framework_config.py index 09309cbfa..482570bf9 100644 --- a/tuning/config/acceleration_configs/acceleration_framework_config.py +++ b/tuning/config/acceleration_configs/acceleration_framework_config.py @@ -24,6 +24,7 @@ from .attention_and_distributed_packing import MultiPack, PaddingFree from .fast_moe import FastMoe from .fused_ops_and_kernels import FastKernelsConfig, FusedLoraConfig +from .mcp import MCP from .odm import ODM from .quantized_lora_config import AutoGPTQLoraConfig, BNBQLoraConfig from tuning.utils.import_utils import is_fms_accelerate_available @@ -133,6 +134,17 @@ class AccelerationFrameworkConfig: ), ] = None + mcp: Annotated[ + MCP, + ConfigAnnotation( + path="training.mamba", + key="cp", + standalone=True, + experimental=True, + required_packages=["mcp"], + ), + ] = None + multipack: Annotated[ MultiPack, ConfigAnnotation( diff --git a/tuning/config/acceleration_configs/mcp.py b/tuning/config/acceleration_configs/mcp.py new file mode 100644 index 000000000..bcdc182a7 --- /dev/null +++ b/tuning/config/acceleration_configs/mcp.py @@ -0,0 +1,37 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from dataclasses import dataclass + +# Local +from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass + + +@parsable_dataclass +@dataclass +class MCP: + degree: int = None + mamba_impl: str = None + mamba_recompute: bool = None + + +@dataclass +class MCPConfig: + + mcp: MCP = None + + def __post_init__(self): + # ensure nested dataclasses initialized + ensure_nested_dataclasses_initialized(self) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index dc6a74174..6beb90791 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -47,6 +47,7 @@ AttentionAndDistributedPackingConfig, FastMoeConfig, FusedOpsAndKernelsConfig, + MCPConfig, ODMConfig, QuantizedLoraConfig, get_additional_accel_framework_callbacks, @@ -87,6 +88,7 @@ def train( AttentionAndDistributedPackingConfig ] = None, fast_moe_config: Optional[FastMoeConfig] = None, + mcp_config: Optional[MCPConfig] = None, additional_data_handlers: Optional[Dict[str, DataHandler]] = None, ) -> tuple[SFTTrainer, dict]: """Call the SFTTrainer @@ -198,6 +200,8 @@ def train( ) if fast_moe_config is not None and fast_moe_config.fast_moe is None: fast_moe_config = None + if mcp_config is not None and mcp_config.mcp is None: + mcp_config = None if fast_moe_config is not None: # If LoRA with ScatterMoE detected, raise warning accepted_layers = ["all-linear"] @@ -261,6 +265,7 @@ def train( quantized_lora_config, fusedops_kernels_config, odm_config, + mcp_config, ).get_framework() # option to set multimodal var here @@ -594,6 +599,7 @@ def get_parser(): FusedOpsAndKernelsConfig, AttentionAndDistributedPackingConfig, FastMoeConfig, + MCPConfig, TrackerConfigs, ) ) @@ -675,6 +681,7 @@ def parse_arguments(parser, json_config=None): fusedops_kernels_config, attention_and_distributed_packing_config, fast_moe_config, + mcp_config, tracker_configs, ) = parser.parse_dict(json_config, allow_extra_keys=True) peft_method = json_config.get("peft_method") @@ -694,6 +701,7 @@ def parse_arguments(parser, json_config=None): fusedops_kernels_config, attention_and_distributed_packing_config, fast_moe_config, + mcp_config, tracker_configs, additional, _, @@ -730,6 +738,7 @@ def parse_arguments(parser, json_config=None): fusedops_kernels_config, attention_and_distributed_packing_config, fast_moe_config, + mcp_config, tracker_configs, exp_metadata, ) @@ -752,6 +761,7 @@ def main(): fusedops_kernels_config, attention_and_distributed_packing_config, fast_moe_config, + mcp_config, tracker_configs, exp_metadata, ) = parse_arguments(parser, job_config) @@ -773,6 +783,7 @@ def main(): "AADP (fms-acceleration) Config": attention_and_distributed_packing_config, "Fused Ops Kernels Config": fusedops_kernels_config, "Fast MoE Config": fast_moe_config, + "MCP Config": mcp_config, "Tracker Config": tracker_configs, "Extra Metadata": exp_metadata, "Trainer Controller Config": trainer_controller_args, @@ -816,6 +827,7 @@ def main(): quantized_lora_config=quantized_lora_config, fusedops_kernels_config=fusedops_kernels_config, attention_and_distributed_packing_config=attention_and_distributed_packing_config, + mcp_config=mcp_config, fast_moe_config=fast_moe_config, ) except (MemoryError, OutOfMemoryError) as e: