Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ array-record
cloud-accelerator-diagnostics
cloud-tpu-diagnostics
datasets
drjax
flax
gcsfs
google-api-python-client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dill>=0.4.0
distlib>=0.4.0
dm-tree>=0.1.9
docstring-parser>=0.17.0
drjax>=0.1.4
editdistance>=0.8.1
einops>=0.8.1
einshape>=1.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dill>=0.4.0
distlib>=0.4.0
dm-tree>=0.1.9
docstring-parser>=0.17.0
drjax>=0.1.4
editdistance>=0.8.1
einops>=0.8.1
einshape>=1.0
Expand Down
1 change: 1 addition & 0 deletions dependencies/requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ array-record
cloud-accelerator-diagnostics
cloud-tpu-diagnostics
datasets
drjax>=0.1.4
flax
gcsfs
google-api-python-client
Expand Down
58 changes: 58 additions & 0 deletions src/MaxText/integration/vllm/maxtext_vllm_adapter/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
{
"architectures": [
"MaxTextForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"auto_map": {
"AutoConfig": "configuration_deepseek.DeepseekV3Config",
"AutoModel": "modeling_deepseek.DeepseekV3Model",
"AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM"
},
"bos_token_id": 0,
"eos_token_id": 1,
"ep_size": 1,
"first_k_dense_replace": 3,
"hidden_act": "silu",
"hidden_size": 7168,
"initializer_range": 0.02,
"intermediate_size": 18432,
"kv_lora_rank": 512,
"max_position_embeddings": 163840,
"model_type": "deepseek_v3",
"moe_intermediate_size": 2048,
"moe_layer_freq": 1,
"n_group": 8,
"n_routed_experts": 256,
"n_shared_experts": 1,
"norm_topk_prob": true,
"num_attention_heads": 128,
"num_experts_per_tok": 8,
"num_hidden_layers": 61,
"num_key_value_heads": 128,
"num_nextn_predict_layers": 1,
"q_lora_rank": 1536,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"rms_norm_eps": 1e-06,
"rope_scaling": {
"beta_fast": 32,
"beta_slow": 1,
"factor": 40,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn"
},
"rope_theta": 10000,
"routed_scaling_factor": 2.5,
"scoring_func": "sigmoid",
"tie_word_embeddings": false,
"topk_group": 4,
"topk_method": "noaux_tc",
"torch_dtype": "bfloat16",
"transformers_version": "4.33.1",
"use_cache": true,
"v_head_dim": 128,
"vocab_size": 129280
}
8 changes: 7 additions & 1 deletion src/MaxText/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@

def get_input_data_sharding(config, mesh):
"""Get the input data sharding for the model"""
return create_sharding(mesh, config.input_data_sharding_logical_axes, rules=config.logical_axis_rules)
if config.enable_diloco:
data_sharding = create_sharding(
mesh, ["diloco"] + config.input_data_sharding_logical_axes, rules=config.logical_axis_rules
)
else:
data_sharding = create_sharding(mesh, config.input_data_sharding_logical_axes, rules=config.logical_axis_rules)
return data_sharding


def maybe_shard_with_name(inputs, named_sharding, shard_mode, debug_sharding=False, extra_stack_level=0):
Expand Down
35 changes: 28 additions & 7 deletions src/MaxText/train_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Sequence
import os
import pickle
import functools

from absl import app

Expand All @@ -45,6 +46,7 @@
from maxtext.utils import gcs_utils
from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils
from maxtext.trainers.diloco import diloco

# pylint: disable=too-many-positional-arguments

Expand Down Expand Up @@ -235,13 +237,32 @@ def main(argv: Sequence[str]) -> None:

# Get data sharding
data_sharding = sharding.get_input_data_sharding(config, topology_mesh)

# Get function to compile and shardings
func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = (
maxtext_utils.get_functional_train_with_signature(
train.train_step, data_sharding, state_mesh_shardings, model, config
)
)
if config.enable_diloco:
# Build abstract DiLoCo state and shardings for AOT compilation
abstract_state = shaped_train_args[0]
diloco_state, state_mesh_shardings, inner_state_shardings = diloco.build_abstract_diloco_state(
config, abstract_state, state_mesh_shardings, topology_mesh
)
shaped_train_args = (diloco_state, shaped_train_args[1], shaped_train_args[2])

# Wrap train_step with diloco
train_step_partial = functools.partial(train.train_step, model, config, inner_state_shardings, None)
train_step_fn = diloco.build_diloco_train_step(config, train_step_partial)

# For DiLoCo, the train_step_fn is already fully wrapped and takes (state, batch, prng)
func_to_compile = train_step_fn
func_to_compile.__name__ = "train_step"
in_shard = (state_mesh_shardings, data_sharding, None) # State, batch, rng
out_shard = (state_mesh_shardings, None) # State, metrics
static_argnums = ()
donate_argnums = 0
else:
# Get function to compile and shardings
func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = (
maxtext_utils.get_functional_train_with_signature(
train.train_step, data_sharding, state_mesh_shardings, model, config
)
)

# print weights sharding info under debug sharding mode
if config.debug_sharding:
Expand Down
6 changes: 5 additions & 1 deletion src/maxtext/common/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
maybe_record_goodput,
)
from maxtext.utils import exceptions
from maxtext.trainers.diloco import diloco


class DataLoader:
Expand Down Expand Up @@ -70,10 +71,13 @@ def load_next_batch_pre_sharding(self):

def load_next_batch(self, *args, **kwargs):
"""Loads the next batch with sharding hint"""
return jax.device_put(
example_batch = jax.device_put(
self.load_next_batch_pre_sharding(),
self.input_data_shardings,
)
if self.config.enable_diloco:
example_batch = diloco.reshape_first_axis_with_diloco(self.config.num_diloco_replicas, example_batch)
return example_batch

def check_example_batch(self):
if self.config.max_checkify:
Expand Down
11 changes: 10 additions & 1 deletion src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess'

# Parallelism
shard_mode: "auto" # can be either auto or explicit
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
Expand Down Expand Up @@ -483,6 +483,7 @@ logical_axis_rules: [
['paged_kv_head_dim_size', []],
['dense_layers', []],
['moe_layers', []],
['diloco', 'diloco'],
]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
Expand All @@ -495,6 +496,7 @@ sharding_tolerance: 0.02
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_diloco_parallelism: 1
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: 1
dcn_fsdp_transpose_parallelism: 1
Expand All @@ -507,6 +509,7 @@ dcn_tensor_sequence_parallelism: 1 # never recommended
dcn_pipeline_parallelism: 1
dcn_expert_parallelism: 1
dcn_autoregressive_parallelism: 1 # never recommended
ici_diloco_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_fsdp_transpose_parallelism: 1
Expand Down Expand Up @@ -738,6 +741,12 @@ enable_data_shuffling: True
data_shuffle_seed: 0
init_weights_seed: 0

# DiLoCo params.
enable_diloco: False
diloco_sync_period: 36
diloco_outer_lr: 0.3
diloco_outer_momentum: 0.9

# You may disable clipping by setting gradient_clipping_threshold to zero.
gradient_clipping_threshold: 1.0

Expand Down
24 changes: 24 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,7 @@ class LayoutAndSharding(BaseModel):
class DcnParallelism(BaseModel):
"""Parallelism dimensions across the DCN (Data Center Network)."""

dcn_diloco_parallelism: int = Field(1, description="DCN axis for Diloco parallelism.")
dcn_data_parallelism: int = Field(-1, description="DCN axis for data parallelism.")
dcn_fsdp_parallelism: int = Field(1, description="DCN axis for FSDP.")
dcn_fsdp_transpose_parallelism: int = Field(1, description="DCN axis for FSDP transpose.")
Expand All @@ -803,6 +804,7 @@ class DcnParallelism(BaseModel):
class IciParallelism(BaseModel):
"""Parallelism dimensions within the ICI (Inter-Chip Interconnect)."""

ici_diloco_parallelism: int = Field(1, description="ICI axis for Diloco parallelism.")
ici_data_parallelism: int = Field(1, description="ICI axis for data parallelism.")
ici_fsdp_parallelism: int = Field(-1, description="ICI axis for FSDP.")
ici_fsdp_transpose_parallelism: int = Field(1, description="ICI axis for FSDP transpose.")
Expand Down Expand Up @@ -1082,6 +1084,15 @@ class ManifoldConstrainedHyperConnections(BaseModel):
sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.")


class DilocoParams(BaseModel):
"""Diloco Hyperparameters"""

enable_diloco: bool = Field(False, description="Enable Diloco parallelism")
diloco_sync_period: int = Field(36, description="Diloco sync period.")
diloco_outer_lr: float = Field(0.3, description="learning rate for outer optimizer.")
diloco_outer_momentum: float = Field(0.9, description="momentum for outer optimizer.")


class Optimizer(BaseModel):
"""Configuration for the optimizer and learning rate schedule."""

Expand Down Expand Up @@ -1632,6 +1643,11 @@ class DerivedValues(BaseModel):
description="Effective number of query heads, scaled by `global_parameter_scale`.",
)

num_diloco_replicas: None | int = Field(
None,
description="The number of diloco replicas, derived from ICI and DCN values.",
)

ici_parallelism: None | list[int] = Field(
None,
description="Aggregated list of all ICI parallelism values for legacy compatibility.",
Expand Down Expand Up @@ -1779,6 +1795,7 @@ class MaxTextConfig(
RematAndOffload,
TrainingLoop,
ManifoldConstrainedHyperConnections,
DilocoParams,
Optimizer,
AdamW,
Muon,
Expand Down Expand Up @@ -2375,6 +2392,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
# Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.
if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage":
self.ici_parallelism = [
self.ici_diloco_parallelism,
self.ici_pipeline_parallelism,
self.ici_data_parallelism,
self.ici_fsdp_parallelism,
Expand All @@ -2389,6 +2407,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
self.ici_autoregressive_parallelism,
]
self.dcn_parallelism = [
self.dcn_diloco_parallelism,
self.dcn_pipeline_parallelism,
self.dcn_data_parallelism,
self.dcn_fsdp_parallelism,
Expand All @@ -2404,6 +2423,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
]
else:
ici_map = {
"diloco": self.ici_diloco_parallelism,
"data": self.ici_data_parallelism,
"stage": self.ici_pipeline_parallelism,
"fsdp": self.ici_fsdp_parallelism,
Expand All @@ -2422,6 +2442,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]

dcn_map = {
"diloco": self.dcn_diloco_parallelism,
"data": self.dcn_data_parallelism,
"stage": self.dcn_pipeline_parallelism,
"fsdp": self.dcn_fsdp_parallelism,
Expand All @@ -2439,6 +2460,9 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
}
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]

# Diloco params
self.num_diloco_replicas = int(self.ici_diloco_parallelism * self.dcn_diloco_parallelism)

# Final string-to-enum conversions if they haven't been coerced by pydantic yet.
if isinstance(self.decoder_block, str):
self.decoder_block = DecoderBlockType(self.decoder_block.lower())
Expand Down
13 changes: 13 additions & 0 deletions src/maxtext/trainers/diloco/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023-2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading
Loading