From ef4936c774da7055183488278303876969d3164b Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Thu, 13 Nov 2025 23:43:38 +0530 Subject: [PATCH 1/7] feat: add support for mamba cp Signed-off-by: Mehant Kammakomati --- .../config/acceleration_configs/__init__.py | 1 + .../acceleration_framework_config.py | 12 ++++++ tuning/config/acceleration_configs/mcp.py | 39 +++++++++++++++++++ tuning/sft_trainer.py | 12 ++++++ 4 files changed, 64 insertions(+) create mode 100644 tuning/config/acceleration_configs/mcp.py diff --git a/tuning/config/acceleration_configs/__init__.py b/tuning/config/acceleration_configs/__init__.py index 3adf2171c..14d18f173 100644 --- a/tuning/config/acceleration_configs/__init__.py +++ b/tuning/config/acceleration_configs/__init__.py @@ -18,5 +18,6 @@ from .callbacks import get_additional_accel_framework_callbacks from .fast_moe import FastMoeConfig from .fused_ops_and_kernels import FusedOpsAndKernelsConfig +from .mcp import MCP, MCPConfig from .odm import ODM, ODMConfig from .quantized_lora_config import QuantizedLoraConfig diff --git a/tuning/config/acceleration_configs/acceleration_framework_config.py b/tuning/config/acceleration_configs/acceleration_framework_config.py index 09309cbfa..482570bf9 100644 --- a/tuning/config/acceleration_configs/acceleration_framework_config.py +++ b/tuning/config/acceleration_configs/acceleration_framework_config.py @@ -24,6 +24,7 @@ from .attention_and_distributed_packing import MultiPack, PaddingFree from .fast_moe import FastMoe from .fused_ops_and_kernels import FastKernelsConfig, FusedLoraConfig +from .mcp import MCP from .odm import ODM from .quantized_lora_config import AutoGPTQLoraConfig, BNBQLoraConfig from tuning.utils.import_utils import is_fms_accelerate_available @@ -133,6 +134,17 @@ class AccelerationFrameworkConfig: ), ] = None + mcp: Annotated[ + MCP, + ConfigAnnotation( + path="training.mamba", + key="cp", + standalone=True, + experimental=True, + required_packages=["mcp"], + ), + ] = None + multipack: Annotated[ MultiPack, ConfigAnnotation( diff --git a/tuning/config/acceleration_configs/mcp.py b/tuning/config/acceleration_configs/mcp.py new file mode 100644 index 000000000..558a72809 --- /dev/null +++ b/tuning/config/acceleration_configs/mcp.py @@ -0,0 +1,39 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from dataclasses import dataclass +from typing import Union + +# Local +from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass + + +@parsable_dataclass +@dataclass +class MCP: + degree: int = None + mamba_impl: str = None + attn_impl: str = None + mamba_recompute: bool = None + + +@dataclass +class MCPConfig: + + cp: MCP = None + + def __post_init__(self): + # ensure nested dataclasses initialized + ensure_nested_dataclasses_initialized(self) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index dc6a74174..776a08541 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -47,6 +47,7 @@ AttentionAndDistributedPackingConfig, FastMoeConfig, FusedOpsAndKernelsConfig, + MCPConfig, ODMConfig, QuantizedLoraConfig, get_additional_accel_framework_callbacks, @@ -87,6 +88,7 @@ def train( AttentionAndDistributedPackingConfig ] = None, fast_moe_config: Optional[FastMoeConfig] = None, + mcp_config: Optional[MCPConfig] = None, additional_data_handlers: Optional[Dict[str, DataHandler]] = None, ) -> tuple[SFTTrainer, dict]: """Call the SFTTrainer @@ -198,6 +200,8 @@ def train( ) if fast_moe_config is not None and fast_moe_config.fast_moe is None: fast_moe_config = None + if mcp_config is not None and mcp_config.cp is None: + mcp_config = None if fast_moe_config is not None: # If LoRA with ScatterMoE detected, raise warning accepted_layers = ["all-linear"] @@ -261,6 +265,7 @@ def train( quantized_lora_config, fusedops_kernels_config, odm_config, + mcp_config, ).get_framework() # option to set multimodal var here @@ -594,6 +599,7 @@ def get_parser(): FusedOpsAndKernelsConfig, AttentionAndDistributedPackingConfig, FastMoeConfig, + MCPConfig, TrackerConfigs, ) ) @@ -675,6 +681,7 @@ def parse_arguments(parser, json_config=None): fusedops_kernels_config, attention_and_distributed_packing_config, fast_moe_config, + mcp_config, tracker_configs, ) = parser.parse_dict(json_config, allow_extra_keys=True) peft_method = json_config.get("peft_method") @@ -694,6 +701,7 @@ def parse_arguments(parser, json_config=None): fusedops_kernels_config, attention_and_distributed_packing_config, fast_moe_config, + mcp_config, tracker_configs, additional, _, @@ -730,6 +738,7 @@ def parse_arguments(parser, json_config=None): fusedops_kernels_config, attention_and_distributed_packing_config, fast_moe_config, + mcp_config, tracker_configs, exp_metadata, ) @@ -752,6 +761,7 @@ def main(): fusedops_kernels_config, attention_and_distributed_packing_config, fast_moe_config, + mcp_config, tracker_configs, exp_metadata, ) = parse_arguments(parser, job_config) @@ -773,6 +783,7 @@ def main(): "AADP (fms-acceleration) Config": attention_and_distributed_packing_config, "Fused Ops Kernels Config": fusedops_kernels_config, "Fast MoE Config": fast_moe_config, + "MCP Config": mcp_config, "Tracker Config": tracker_configs, "Extra Metadata": exp_metadata, "Trainer Controller Config": trainer_controller_args, @@ -816,6 +827,7 @@ def main(): quantized_lora_config=quantized_lora_config, fusedops_kernels_config=fusedops_kernels_config, attention_and_distributed_packing_config=attention_and_distributed_packing_config, + mcp_config=mcp_config, fast_moe_config=fast_moe_config, ) except (MemoryError, OutOfMemoryError) as e: From 17e01b29e3eabe9a9995d66568c9e57fd169d484 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 14 Nov 2025 17:27:36 +0530 Subject: [PATCH 2/7] feat: add support for mamba cp Signed-off-by: Mehant Kammakomati --- tuning/config/acceleration_configs/mcp.py | 2 +- tuning/sft_trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tuning/config/acceleration_configs/mcp.py b/tuning/config/acceleration_configs/mcp.py index 558a72809..0e2ec5f2a 100644 --- a/tuning/config/acceleration_configs/mcp.py +++ b/tuning/config/acceleration_configs/mcp.py @@ -32,7 +32,7 @@ class MCP: @dataclass class MCPConfig: - cp: MCP = None + mcp: MCP = None def __post_init__(self): # ensure nested dataclasses initialized diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 776a08541..6beb90791 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -200,7 +200,7 @@ def train( ) if fast_moe_config is not None and fast_moe_config.fast_moe is None: fast_moe_config = None - if mcp_config is not None and mcp_config.cp is None: + if mcp_config is not None and mcp_config.mcp is None: mcp_config = None if fast_moe_config is not None: # If LoRA with ScatterMoE detected, raise warning From a66053e6fe1334075b4ec99685a198184dc5337f Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 14 Nov 2025 17:55:02 +0530 Subject: [PATCH 3/7] feat: add support for mamba cp Signed-off-by: Mehant Kammakomati --- tuning/config/acceleration_configs/mcp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tuning/config/acceleration_configs/mcp.py b/tuning/config/acceleration_configs/mcp.py index 0e2ec5f2a..7b6ee8172 100644 --- a/tuning/config/acceleration_configs/mcp.py +++ b/tuning/config/acceleration_configs/mcp.py @@ -25,7 +25,6 @@ class MCP: degree: int = None mamba_impl: str = None - attn_impl: str = None mamba_recompute: bool = None From 2b9c679bc49255d5d9fdba8dd04e270288f76a0d Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Mon, 24 Nov 2025 12:47:58 +0530 Subject: [PATCH 4/7] feat: add support for mamba cp Signed-off-by: Mehant Kammakomati --- tuning/config/acceleration_configs/mcp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tuning/config/acceleration_configs/mcp.py b/tuning/config/acceleration_configs/mcp.py index 7b6ee8172..bcdc182a7 100644 --- a/tuning/config/acceleration_configs/mcp.py +++ b/tuning/config/acceleration_configs/mcp.py @@ -14,7 +14,6 @@ # Standard from dataclasses import dataclass -from typing import Union # Local from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass From 75d1cf1e445e61a1293b1a4096bc73124cf3e866 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Mon, 24 Nov 2025 13:28:14 +0530 Subject: [PATCH 5/7] feat: add support for mamba cp Signed-off-by: Mehant Kammakomati --- tests/test_sft_trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 598152317..38c996158 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -412,6 +412,7 @@ def test_parse_arguments(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_copy) assert str(model_args.torch_dtype) == "torch.bfloat16" assert data_args.dataset_text_field == "output" @@ -438,6 +439,7 @@ def test_parse_arguments_defaults(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_defaults) assert str(model_args.torch_dtype) == "torch.bfloat16" assert model_args.use_flash_attn is False @@ -461,6 +463,7 @@ def test_parse_arguments_peft_method(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_pt) assert isinstance(tune_config, peft_config.PromptTuningConfig) @@ -480,6 +483,7 @@ def test_parse_arguments_peft_method(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_lora) assert isinstance(tune_config, peft_config.LoraConfig) assert not tune_config.target_modules From c9fcdfedd06976802be8078e055d40ebac5080e2 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Sat, 29 Nov 2025 23:07:54 +0530 Subject: [PATCH 6/7] docs: add documentation Signed-off-by: Mehant Kammakomati --- docs/training.md | 178 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 177 insertions(+), 1 deletion(-) diff --git a/docs/training.md b/docs/training.md index 40f3c90b9..539acfdf7 100644 --- a/docs/training.md +++ b/docs/training.md @@ -254,4 +254,180 @@ datasets: remove_columns: all fn_kwargs: conversation_column_name: "messages" -``` \ No newline at end of file +``` + +## Long Context Training + +Long context training for instance to train on 128k sequence length can be performed using context parallel. + +### Model Architectures Supported + +1. Hybrid attention dense models. e.g. granite-4.0-h-1b +1. Hybrid attention moe models. e.g. ibm-granite/granite-4.0-h-small +1. SDPA attention dense models e.g. granite-4.0-1b +1. SDPA attention moe models e.g. ibm-research/moe-7b-1b-active-shared-experts, mixtral etc + +### Parallelisms Supported with Context Parallel + +1. Context Parallel + FSDP sharding +1. Context Parallel + FSDP sharding + Expert Parallel +1. Context Parallel + FSDP sharding + DP +1. Context Parallel + FSDP sharding + DP + Expert Parallel + +### Usage + +#### Enabling Context Parallel + +FSDPv2 is compulsory to use context parallel. FSDPv2 can be activated using the following accelerate config + +``` +compute_environment: LOCAL_MACHINE +distributed_type: FSDP +fsdp_config: + fsdp_version: "2" # turn on v2 of FSDP + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_backward_prefetch_policy: BACKWARD_PRE + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_sharding_strategy: FULL_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_cpu_ram_efficient_loading: true + fsdp_sync_module_states: true + fsdp_use_orig_params: true +``` + +Then, context parallel can be activated using the below accelerate config + +``` +compute_environment: LOCAL_MACHINE +distributed_type: FSDP +fsdp_config: + fsdp_version: "2" # turn on v2 of FSDP + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_backward_prefetch_policy: BACKWARD_PRE + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_sharding_strategy: FULL_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_cpu_ram_efficient_loading: true + fsdp_sync_module_states: true + fsdp_use_orig_params: true +use_parallelism_config: "true" # required to turn on parallelism feature +parallelism_config_cp_size: 2 # context parallel degree +machine_rank: 0 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +``` + +When using any model with mamba attention, its required to set the flag `--mcp` with context parallel degree. Further, for hybrid models that use combination of mamba and SDPA attention should use both `--mcp` and `parallelism_config_cp_size` options both having the same cp degree value. + +#### Enabling Context Parallel with Data Parallel + +Context parallel can be combined with data parallel using the `parallelism_config_dp_shard_size` parameter. + +``` +compute_environment: LOCAL_MACHINE +distributed_type: FSDP +fsdp_config: + fsdp_version: "2" # turn on v2 of FSDP + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_backward_prefetch_policy: BACKWARD_PRE + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_sharding_strategy: FULL_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_cpu_ram_efficient_loading: true + fsdp_sync_module_states: true + fsdp_use_orig_params: true +use_parallelism_config: "true" # required to turn on parallelism feature +parallelism_config_cp_size: 2 # context parallel degree +parallelism_config_dp_shard_size: 8 # data parallel degree +machine_rank: 0 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +``` + +To be noted that, context parallel degree multiplied by data parallel degree should be equal to the total number of GPUs being used. + +#### Enabling Mixed Precision + +Mixed precision has to be provided using `fsdp_mixed_precision_policy` parameter only. Do not use direct flags like `--bf16` or `mixed_precision` accelerate config parameter. + +``` +compute_environment: LOCAL_MACHINE +distributed_type: FSDP +fsdp_config: + fsdp_version: "2" # turn on v2 of FSDP + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_backward_prefetch_policy: BACKWARD_PRE + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_sharding_strategy: FULL_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_cpu_ram_efficient_loading: true + fsdp_sync_module_states: true + fsdp_use_orig_params: true + fsdp_mixed_precision_policy: "bf16" # mixed precision policy +use_parallelism_config: "true" # required to turn on parallelism feature +parallelism_config_cp_size: 2 # context parallel degree +parallelism_config_dp_shard_size: 8 # data parallel degree +machine_rank: 0 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +``` + +#### Gradient Checkpointing + +Optimal way to enable gradient checkpointing is using the accelerate config parameter `fsdp_activation_checkpointing` as shown below: + +``` +compute_environment: LOCAL_MACHINE +distributed_type: FSDP +fsdp_config: + fsdp_version: "2" # turn on v2 of FSDP + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_backward_prefetch_policy: BACKWARD_PRE + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_sharding_strategy: FULL_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_cpu_ram_efficient_loading: true + fsdp_sync_module_states: true + fsdp_use_orig_params: true + fsdp_mixed_precision_policy: "bf16" # mixed precision policy + fsdp_activation_checkpointing: true +use_parallelism_config: "true" # required to turn on parallelism feature +parallelism_config_cp_size: 2 # context parallel degree +parallelism_config_dp_shard_size: 8 # data parallel degree +machine_rank: 0 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +``` + +#### Enabling Context Parallel with Data Parallel and Expert Parallel + +For MoE models, expert parallel with MoE kernels can be enabled using the `--fast_moe` flag along with context and data parallelisms. The expert parallel degree is agnostic of context parallel degree. Therefore it can be used like described [here](./tuning-techniques.md#fms-acceleration). + +### Recommendations + +1. Keeping context parallelism within a node is usually optimal unless there is need for extremely long sequences like 256k. Given that, its optimal to choose the right cp degree in the multiple of 2 starting from 2 and upto 8. +2. Data parallel degree multiplied by context parallel degree should be equal to total number of GPUs being used. +3. Context parallel degree determinies number of chunks sequence has to be divided and distributed across GPUs, therefore it has to be choosen as minimium as needed to accommodate a sequence length. + +### Known Limitations + +1. load balancing is removed given limited support on mamba cp implementation. This could lead to potential throughput drops for trainings using causal mask. +2. Padding free and flash attention are not supported. From 80d25b9374ab237891db15b2fdbf174275bc8131 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Sat, 29 Nov 2025 23:15:26 +0530 Subject: [PATCH 7/7] docs: add documentation Signed-off-by: Mehant Kammakomati --- docs/training.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/docs/training.md b/docs/training.md index 539acfdf7..03b0f3998 100644 --- a/docs/training.md +++ b/docs/training.md @@ -427,6 +427,27 @@ For MoE models, expert parallel with MoE kernels can be enabled using the `--fas 2. Data parallel degree multiplied by context parallel degree should be equal to total number of GPUs being used. 3. Context parallel degree determinies number of chunks sequence has to be divided and distributed across GPUs, therefore it has to be choosen as minimium as needed to accommodate a sequence length. +Further, below ablations can be used as reference configurations. + +#### Ablations + +##### Parity Experiments + +| model | experiment setting | loss | tps per gpu | +| -------- | -------- | ------- | ------- | +| ibm-granite/granite-4.0-h-tiny | cp8-ebs4-s8192-gas1 | 0.8059140625 | 973.6 | +| ibm-granite/granite-4.0-h-tiny | cp8-ebs4-s8192-gas1-ep8 | 0.80224609375 | 2367.6 | +| ibm-granite/granite-4.0-h-tiny | cp8-ebs4-s8192-gas2 | 0.8059765625 | NA | +| ibm-granite/granite-4.0-h-tiny | cp4-dp2-ebs4-s8192-gas1 | 0.802953125 | 953.4 | +| ibm-granite/granite-4.0-h-tiny | cp1-dp4-ep4-ebs4-s8192-gas1 | 0.7967056884765625 | 2576 | + +##### Long Context (sequence length is 131072 (128k)) + +| model | experiment setting | tps per gpu | GPU memory util ratio | +| -------- | -------- | ------- | ------- | +| ibm-granite/granite-4.0-h-tiny | cp8-ebs1-s131072-gas1-ep8 | 1462.8 | 0.5140136719 | +| ibm-granite/granite-4.0-h-small | cp8-ebs1-s131072-gas1-ep8 | 682.7 | 0.9887207031 | + ### Known Limitations 1. load balancing is removed given limited support on mamba cp implementation. This could lead to potential throughput drops for trainings using causal mask.