Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 198 additions & 1 deletion docs/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,4 +254,201 @@ datasets:
remove_columns: all
fn_kwargs:
conversation_column_name: "messages"
```
```

## Long Context Training

Long context training for instance to train on 128k sequence length can be performed using context parallel.

### Model Architectures Supported

1. Hybrid attention dense models. e.g. granite-4.0-h-1b
1. Hybrid attention moe models. e.g. ibm-granite/granite-4.0-h-small
1. SDPA attention dense models e.g. granite-4.0-1b
1. SDPA attention moe models e.g. ibm-research/moe-7b-1b-active-shared-experts, mixtral etc

### Parallelisms Supported with Context Parallel

1. Context Parallel + FSDP sharding
1. Context Parallel + FSDP sharding + Expert Parallel
1. Context Parallel + FSDP sharding + DP
1. Context Parallel + FSDP sharding + DP + Expert Parallel

### Usage

#### Enabling Context Parallel

FSDPv2 is compulsory to use context parallel. FSDPv2 can be activated using the following accelerate config

```
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
fsdp_version: "2" # turn on v2 of FSDP
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_cpu_ram_efficient_loading: true
fsdp_sync_module_states: true
fsdp_use_orig_params: true
```

Then, context parallel can be activated using the below accelerate config

```
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
fsdp_version: "2" # turn on v2 of FSDP
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_cpu_ram_efficient_loading: true
fsdp_sync_module_states: true
fsdp_use_orig_params: true
use_parallelism_config: "true" # required to turn on parallelism feature
parallelism_config_cp_size: 2 # context parallel degree
machine_rank: 0
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
```

When using any model with mamba attention, its required to set the flag `--mcp` with context parallel degree. Further, for hybrid models that use combination of mamba and SDPA attention should use both `--mcp` and `parallelism_config_cp_size` options both having the same cp degree value.

#### Enabling Context Parallel with Data Parallel

Context parallel can be combined with data parallel using the `parallelism_config_dp_shard_size` parameter.

```
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
fsdp_version: "2" # turn on v2 of FSDP
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_cpu_ram_efficient_loading: true
fsdp_sync_module_states: true
fsdp_use_orig_params: true
use_parallelism_config: "true" # required to turn on parallelism feature
parallelism_config_cp_size: 2 # context parallel degree
parallelism_config_dp_shard_size: 8 # data parallel degree
machine_rank: 0
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
```

To be noted that, context parallel degree multiplied by data parallel degree should be equal to the total number of GPUs being used.

#### Enabling Mixed Precision

Mixed precision has to be provided using `fsdp_mixed_precision_policy` parameter only. Do not use direct flags like `--bf16` or `mixed_precision` accelerate config parameter.

```
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
fsdp_version: "2" # turn on v2 of FSDP
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_cpu_ram_efficient_loading: true
fsdp_sync_module_states: true
fsdp_use_orig_params: true
fsdp_mixed_precision_policy: "bf16" # mixed precision policy
use_parallelism_config: "true" # required to turn on parallelism feature
parallelism_config_cp_size: 2 # context parallel degree
parallelism_config_dp_shard_size: 8 # data parallel degree
machine_rank: 0
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
```

#### Gradient Checkpointing

Optimal way to enable gradient checkpointing is using the accelerate config parameter `fsdp_activation_checkpointing` as shown below:

```
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
fsdp_version: "2" # turn on v2 of FSDP
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_cpu_ram_efficient_loading: true
fsdp_sync_module_states: true
fsdp_use_orig_params: true
fsdp_mixed_precision_policy: "bf16" # mixed precision policy
fsdp_activation_checkpointing: true
use_parallelism_config: "true" # required to turn on parallelism feature
parallelism_config_cp_size: 2 # context parallel degree
parallelism_config_dp_shard_size: 8 # data parallel degree
machine_rank: 0
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
```

#### Enabling Context Parallel with Data Parallel and Expert Parallel

For MoE models, expert parallel with MoE kernels can be enabled using the `--fast_moe` flag along with context and data parallelisms. The expert parallel degree is agnostic of context parallel degree. Therefore it can be used like described [here](./tuning-techniques.md#fms-acceleration).

### Recommendations

1. Keeping context parallelism within a node is usually optimal unless there is need for extremely long sequences like 256k. Given that, its optimal to choose the right cp degree in the multiple of 2 starting from 2 and upto 8.
2. Data parallel degree multiplied by context parallel degree should be equal to total number of GPUs being used.
3. Context parallel degree determinies number of chunks sequence has to be divided and distributed across GPUs, therefore it has to be choosen as minimium as needed to accommodate a sequence length.

Further, below ablations can be used as reference configurations.

#### Ablations

##### Parity Experiments

| model | experiment setting | loss | tps per gpu |
| -------- | -------- | ------- | ------- |
| ibm-granite/granite-4.0-h-tiny | cp8-ebs4-s8192-gas1 | 0.8059140625 | 973.6 |
| ibm-granite/granite-4.0-h-tiny | cp8-ebs4-s8192-gas1-ep8 | 0.80224609375 | 2367.6 |
| ibm-granite/granite-4.0-h-tiny | cp8-ebs4-s8192-gas2 | 0.8059765625 | NA |
| ibm-granite/granite-4.0-h-tiny | cp4-dp2-ebs4-s8192-gas1 | 0.802953125 | 953.4 |
| ibm-granite/granite-4.0-h-tiny | cp1-dp4-ep4-ebs4-s8192-gas1 | 0.7967056884765625 | 2576 |

##### Long Context (sequence length is 131072 (128k))

| model | experiment setting | tps per gpu | GPU memory util ratio |
| -------- | -------- | ------- | ------- |
| ibm-granite/granite-4.0-h-tiny | cp8-ebs1-s131072-gas1-ep8 | 1462.8 | 0.5140136719 |
| ibm-granite/granite-4.0-h-small | cp8-ebs1-s131072-gas1-ep8 | 682.7 | 0.9887207031 |

### Known Limitations

1. load balancing is removed given limited support on mamba cp implementation. This could lead to potential throughput drops for trainings using causal mask.
2. Padding free and flash attention are not supported.
4 changes: 4 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ def test_parse_arguments(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_copy)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert data_args.dataset_text_field == "output"
Expand All @@ -438,6 +439,7 @@ def test_parse_arguments_defaults(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_defaults)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert model_args.use_flash_attn is False
Expand All @@ -461,6 +463,7 @@ def test_parse_arguments_peft_method(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_pt)

assert isinstance(tune_config, peft_config.PromptTuningConfig)
Expand All @@ -480,6 +483,7 @@ def test_parse_arguments_peft_method(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_lora)
assert isinstance(tune_config, peft_config.LoraConfig)
assert not tune_config.target_modules
Expand Down
1 change: 1 addition & 0 deletions tuning/config/acceleration_configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
from .callbacks import get_additional_accel_framework_callbacks
from .fast_moe import FastMoeConfig
from .fused_ops_and_kernels import FusedOpsAndKernelsConfig
from .mcp import MCP, MCPConfig
from .odm import ODM, ODMConfig
from .quantized_lora_config import QuantizedLoraConfig
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .attention_and_distributed_packing import MultiPack, PaddingFree
from .fast_moe import FastMoe
from .fused_ops_and_kernels import FastKernelsConfig, FusedLoraConfig
from .mcp import MCP
from .odm import ODM
from .quantized_lora_config import AutoGPTQLoraConfig, BNBQLoraConfig
from tuning.utils.import_utils import is_fms_accelerate_available
Expand Down Expand Up @@ -133,6 +134,17 @@ class AccelerationFrameworkConfig:
),
] = None

mcp: Annotated[
MCP,
ConfigAnnotation(
path="training.mamba",
key="cp",
standalone=True,
experimental=True,
required_packages=["mcp"],
),
] = None

multipack: Annotated[
MultiPack,
ConfigAnnotation(
Expand Down
37 changes: 37 additions & 0 deletions tuning/config/acceleration_configs/mcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
from dataclasses import dataclass

# Local
from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass


@parsable_dataclass
@dataclass
class MCP:
degree: int = None
mamba_impl: str = None
mamba_recompute: bool = None


@dataclass
class MCPConfig:

mcp: MCP = None

def __post_init__(self):
# ensure nested dataclasses initialized
ensure_nested_dataclasses_initialized(self)
12 changes: 12 additions & 0 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
AttentionAndDistributedPackingConfig,
FastMoeConfig,
FusedOpsAndKernelsConfig,
MCPConfig,
ODMConfig,
QuantizedLoraConfig,
get_additional_accel_framework_callbacks,
Expand Down Expand Up @@ -87,6 +88,7 @@ def train(
AttentionAndDistributedPackingConfig
] = None,
fast_moe_config: Optional[FastMoeConfig] = None,
mcp_config: Optional[MCPConfig] = None,
additional_data_handlers: Optional[Dict[str, DataHandler]] = None,
) -> tuple[SFTTrainer, dict]:
"""Call the SFTTrainer
Expand Down Expand Up @@ -198,6 +200,8 @@ def train(
)
if fast_moe_config is not None and fast_moe_config.fast_moe is None:
fast_moe_config = None
if mcp_config is not None and mcp_config.mcp is None:
mcp_config = None
if fast_moe_config is not None:
# If LoRA with ScatterMoE detected, raise warning
accepted_layers = ["all-linear"]
Expand Down Expand Up @@ -261,6 +265,7 @@ def train(
quantized_lora_config,
fusedops_kernels_config,
odm_config,
mcp_config,
).get_framework()

# option to set multimodal var here
Expand Down Expand Up @@ -594,6 +599,7 @@ def get_parser():
FusedOpsAndKernelsConfig,
AttentionAndDistributedPackingConfig,
FastMoeConfig,
MCPConfig,
TrackerConfigs,
)
)
Expand Down Expand Up @@ -675,6 +681,7 @@ def parse_arguments(parser, json_config=None):
fusedops_kernels_config,
attention_and_distributed_packing_config,
fast_moe_config,
mcp_config,
tracker_configs,
) = parser.parse_dict(json_config, allow_extra_keys=True)
peft_method = json_config.get("peft_method")
Expand All @@ -694,6 +701,7 @@ def parse_arguments(parser, json_config=None):
fusedops_kernels_config,
attention_and_distributed_packing_config,
fast_moe_config,
mcp_config,
tracker_configs,
additional,
_,
Expand Down Expand Up @@ -730,6 +738,7 @@ def parse_arguments(parser, json_config=None):
fusedops_kernels_config,
attention_and_distributed_packing_config,
fast_moe_config,
mcp_config,
tracker_configs,
exp_metadata,
)
Expand All @@ -752,6 +761,7 @@ def main():
fusedops_kernels_config,
attention_and_distributed_packing_config,
fast_moe_config,
mcp_config,
tracker_configs,
exp_metadata,
) = parse_arguments(parser, job_config)
Expand All @@ -773,6 +783,7 @@ def main():
"AADP (fms-acceleration) Config": attention_and_distributed_packing_config,
"Fused Ops Kernels Config": fusedops_kernels_config,
"Fast MoE Config": fast_moe_config,
"MCP Config": mcp_config,
"Tracker Config": tracker_configs,
"Extra Metadata": exp_metadata,
"Trainer Controller Config": trainer_controller_args,
Expand Down Expand Up @@ -816,6 +827,7 @@ def main():
quantized_lora_config=quantized_lora_config,
fusedops_kernels_config=fusedops_kernels_config,
attention_and_distributed_packing_config=attention_and_distributed_packing_config,
mcp_config=mcp_config,
fast_moe_config=fast_moe_config,
)
except (MemoryError, OutOfMemoryError) as e:
Expand Down