From 96721687cb74f4091b6c1ea754051bbd882be4b3 Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Wed, 7 Jan 2026 20:41:53 +0000 Subject: [PATCH 1/7] add all gather insertion per repeat --- src/MaxText/layers/decoders.py | 9 +- src/MaxText/layers/pipeline.py | 339 +++++++++++++++++++++++---------- src/maxtext/configs/base.yml | 1 + src/maxtext/configs/types.py | 1 + 4 files changed, 248 insertions(+), 102 deletions(-) diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index d82bc065ca..b27cef3d21 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -741,12 +741,9 @@ def __call__( model_mode, ) if cfg.using_pipeline_parallelism: - if cfg.pipeline_fsdp_ag_once: - logical_partition_spec = self.pipeline_module.get_weight_sharding( - y, decoder_segment_ids, decoder_positions, deterministic, model_mode - ) - else: - logical_partition_spec = None # This partition spec is only used for the fsdp_ag_once feature. + logical_partition_spec = self.pipeline_module.get_weight_sharding( + y, decoder_segment_ids, decoder_positions, deterministic, model_mode + ) if cfg.decoder_block == DecoderBlockType.DEEPSEEK: assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." dense_layer = RemattedBlockLayers[0] diff --git a/src/MaxText/layers/pipeline.py b/src/MaxText/layers/pipeline.py index 8e12df3bea..094fd9625b 100644 --- a/src/MaxText/layers/pipeline.py +++ b/src/MaxText/layers/pipeline.py @@ -23,6 +23,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec as P import jax import jax.ad_checkpoint +from jax._src.lax.parallel import all_gather_invariant from flax.core import meta from flax import linen as nn @@ -154,7 +155,9 @@ def init_states(self, inputs): state_io: reshaped inputs [num_stages, microbatches/stages, micro_size, sequence, embed] circ_storage: zeros [num_stages, microbatches, micro_size, sequence, embed] when needed, else None circ_storage_mover: zeros[num_stages, micro_size, sequence, embed] when needed, else None - loop_iteration: scalar set initially to 0. + loop_iteration: scalar set initially to 0 + bsw: pytree of identical structure as weights with leaf arrays leading dimension of num_repeats replaced by 2, e.g. + a leaf of shape [num_repeats, stages, mlp, embed] is mapped to [2, num_stages, mlp, embed]. """ # Shift is used to rotate the output of each pipeline into the input of the next @@ -199,6 +202,15 @@ def init_states(self, inputs): else: circ_storage_mover = None + def _init_bsw_from_weights(variables): + """Buffer space for two copies of weights.""" + return jax.tree.map(lambda x: jnp.zeros_like(x[:2]), variables) + + if self.is_initializing(): + bsw = None + else: + bsw = _init_bsw_from_weights(self.layers.variables) + init_loop_state = { "state_io": state_io, "shift": shift, @@ -206,6 +218,7 @@ def init_states(self, inputs): "circ_storage_mover": circ_storage_mover, "loop_iteration": 0, "prev_outputs": prev_outputs, + "bsw": bsw, } return init_loop_state @@ -256,30 +269,6 @@ def select_state_or_input(first_stage_in, shift): stages_in = self._maybe_shard_with_logical(stages_in, self.stages_in_logical) return stages_in - def shard_dim_by_stages(self, x, dim: int, physical_partition_spec: P | None, is_stage_weight: bool = False): - """Shards x using the provided partition_spec, but adds the "stage" mesh axis to the existing sharding at - the specified dimension.""" - placeholder = None if self.config.shard_mode == ShardMode.EXPLICIT else P.UNCONSTRAINED - if physical_partition_spec is None: - dims_mapping = [placeholder] * x.ndim - else: - physical_partition_spec = self._remove_fsdp_from_physical_partition_spec(physical_partition_spec) - dims_mapping = list(physical_partition_spec) - # If not a stage weight, we handle the repeat dimension offset - if not is_stage_weight: - dims_mapping = [placeholder] * (dim + 1) + dims_mapping[dim:] # inflat one dimension for num_repeats - dims_mapping[dim] = "stage" - dims_mapping = tuple(dims_mapping) - # We add reduced rule only when pspec is given for a stage weight - if physical_partition_spec and is_stage_weight and self.config.shard_mode == ShardMode.EXPLICIT: - batch_mesh_axis = ["data", "fsdp"] - reduced_mark = [mesh_axis for mesh_axis in batch_mesh_axis if self.mesh.shape[mesh_axis] > 1] - pspec = P(*dims_mapping, reduced=set(reduced_mark)) - else: - pspec = P(*dims_mapping) - sharding = jax.sharding.NamedSharding(self.mesh, pspec) - return self._maybe_shard_with_name(x, sharding) - def get_microbatch_and_repeat_ids(self, loop_iteration): """Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and non-circular""" @@ -311,18 +300,9 @@ def _gather_one(x, repeat_id): return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights) gathered_weights_stage_dim = 0 - repeat_ids = self.shard_dim_by_stages(repeat_ids, 0, physical_partition_spec=None) - # num_repeats x num_stages x *param_dim - weights = self.shard_dim_by_stages( - weights, stages_dim_in_weights, physical_partition_spec=physical_partition_spec, is_stage_weight=False - ) stage_weights = jax.vmap(_gather_one, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim)( weights, repeat_ids ) - # num_stages x *param_dim - stage_weights = self.shard_dim_by_stages( - stage_weights, gathered_weights_stage_dim, physical_partition_spec=physical_partition_spec, is_stage_weight=True - ) return stage_weights def vmap_gather(self, xs, ids, ids_dim): @@ -346,9 +326,8 @@ def _gather_one(x, i): replicated_sharding = NamedSharding(self.mesh, P()) return x.at[idx].get(out_sharding=replicated_sharding) - ids = self.shard_dim_by_stages(ids, 0, physical_partition_spec=None) outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) - return self.shard_dim_by_stages(outs, 0, physical_partition_spec=None) + return outs def get_new_loop_state(self, output, loop_state): """ @@ -452,6 +431,7 @@ def _shift_left(arr, stage_size, output): mesh=self.mesh, in_specs=(self.state_io_spec, self.stages_in_spec, self.stages_in_spec, P()), out_specs=self.state_io_spec, + check_vma=True, ) def _update_state_io(state_in, stream_slice, output, stream_buf_idx): # Shift the current slice to the left, then fill the last stage with the final output. @@ -469,6 +449,7 @@ def _update_state_io(state_in, stream_slice, output, stream_buf_idx): "circ_storage_mover": new_circ_storage_mover, "loop_iteration": loop_iteration + 1, "prev_outputs": new_prev_outputs, + "bsw": loop_state["bsw"], # bsw is updated outside of this inner loop, only once per outer loop iteration } return new_loop_state @@ -483,7 +464,7 @@ def permute_output_micro_per_stage_dim(self, output): output = output[:, permutation] return output - def get_current_stage_weights(self, pipeline_weights, loop_iteration, physical_partition_spec=None): + def get_current_stage_weights(self, pipeline_weights, bsw, loop_iteration, physical_partition_spec=None): """ Gets the current weights used for one iteration. Outputs a pytree whose arrays have leading dimension of stages, e.g. {'mlp': 'wo': [stages, mlp, embed]}. Stage 0 will use the 0th index of this pytree, Stage 1 the 1st index, etc. @@ -491,15 +472,19 @@ def get_current_stage_weights(self, pipeline_weights, loop_iteration, physical_p for circular pipelines each stage grabs only the weights corresponding to the current repeat. """ if self.config.num_pipeline_repeats > 1: - return self.get_current_repeat_from_stages( - pipeline_weights, loop_iteration, physical_partition_spec=physical_partition_spec - ) + return self.get_current_weights_from_bsw(bsw, loop_iteration, physical_partition_spec=physical_partition_spec) else: return pipeline_weights - def get_current_repeat_from_stages(self, weights, loop_iteration, physical_partition_spec=None): - """get current repeat from stages""" - _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_spec=None): + """Collect and gather weights from given bsw (buffer sliding window)""" + + def _get_bsw_idx(loop_iteration): + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + bsw_ids = (repeat_ids == repeat_ids[0]).astype( + jnp.int32 + ) # For early repeats this might return true when it should be false + return bsw_ids circular_metadata_params = { nn.PARTITION_NAME: "circular_repeats", @@ -509,7 +494,7 @@ def get_current_repeat_from_stages(self, weights, loop_iteration, physical_parti "optimizer_dims_mapping": None, } weights = meta.remove_axis( - weights, 0, circular_metadata_params + bsw, 0, circular_metadata_params ) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one circular # entry per stage. weights = self._remove_logically_partition(weights) @@ -517,7 +502,7 @@ def get_current_repeat_from_stages(self, weights, loop_iteration, physical_parti def gather_weights_for_stages_in(w, spec=None): return self.vmap_parallel_gather( w, - repeat_ids=repeat_ids, + repeat_ids=_get_bsw_idx(loop_iteration), repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec, @@ -529,6 +514,71 @@ def gather_weights_for_stages_in(w, spec=None): weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec) return weights + @staticmethod + def get_fsdp_index_pytree(physical_partition_spec): + """ + Finds the index of 'fsdp' within each PartitionSpec in a Pytree. + + Args: + physical_partition_spec: A Pytree where leaves are PartitionSpecs. + + Returns: + A Pytree of the same structure where leaves are the integer index + of 'fsdp' or -1 if not found. + """ + + def find_fsdp(pspec): + # Ensure we are handling a PartitionSpec or a tuple/list of strings + if pspec is None: + return -1 + + # PartitionSpecs are essentially tuples (e.g., PartitionSpec('data', 'fsdp')) + for i, axis in enumerate(pspec): + # Handle cases where an axis might be a tuple itself (e.g., ('fsdp', 'tensor')) + if isinstance(axis, (list, tuple)): + if "fsdp" in axis: + return i + elif axis == "fsdp": + return i + return -1 + + return jax.tree.map(find_fsdp, physical_partition_spec) + + def bsw_all_gather_over_fsdp(self, bsw, physical_partition_spec, loop_iteration): + """All gather bsw over fsdp mesh axis using shardmap.""" + pps_no_fsdp = jax.tree.map(self._remove_fsdp_from_physical_partition_spec, physical_partition_spec) + fsdp_idx = self.get_fsdp_index_pytree(physical_partition_spec) + + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration + 1) + + @jax.shard_map( + mesh=self.mesh, + in_specs=(physical_partition_spec, pps_no_fsdp, None, None), + out_specs=pps_no_fsdp, + check_vma=True, + ) + def _all_gather_inner(variables, cur_bsw, repeat_idx, fsdp_idx): + new_variables = jax.tree.map( + lambda x: jax.lax.dynamic_slice_in_dim(x, repeat_idx, 1), + variables, + ) + + def _all_gather_invariant(x, i): + if i >= 0: + return all_gather_invariant(x, axis_name="fsdp", axis=i, tiled=True) + return x + + new_variables = jax.tree.map(_all_gather_invariant, new_variables, fsdp_idx) + + def shift_and_insert(bsw_leaf, new_leaf): + updated_bsw = bsw_leaf.at[0].set(bsw_leaf[1]) + updated_bsw = updated_bsw.at[1].set(jnp.squeeze(new_leaf, axis=0)) + return updated_bsw + + return jax.tree.map(shift_and_insert, cur_bsw, new_variables) + + return _all_gather_inner(self.layers.variables, bsw, repeat_ids[0], fsdp_idx) + def get_vmap_func_for_init(self): """This vmap func is used to initialize the weights only on init.""" @@ -661,7 +711,7 @@ def gather_weights_for_stages_in(w, spec=None): ) stage_weights = self.get_current_stage_weights( - pipeline_weights, loop_iteration, physical_partition_spec=physical_partition_spec + pipeline_weights, loop_state["bsw"], loop_iteration, physical_partition_spec=physical_partition_spec ) stages_output = vmap_func( @@ -761,19 +811,6 @@ def _remove_fsdp_from_physical_partition_spec(pps): return P(*new_spec) return pps - def all_gather_over_fsdp(self, variables, logical_partition_spec): - physical_partition_spec = logical_to_mesh( - logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules - ) - physical_partition_spec_no_fsdp = jax.tree.map( - self._remove_fsdp_from_physical_partition_spec, physical_partition_spec - ) - return jax.tree.map( - lambda w, p: self._maybe_shard_with_name(w, NamedSharding(self.mesh, p)), - variables, - physical_partition_spec_no_fsdp, - ) - @nn.compact def __call__( self, @@ -825,6 +862,9 @@ def __call__( segment_idx = None loop_state = self.init_states(inputs) + physical_partition_spec = logical_to_mesh( + logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules + ) # Each microbatch should go through each stage (with repeats) - so there is num_micro * (num_stages * repeats) # compute to perform @@ -836,8 +876,8 @@ def __call__( # Thus the total iterations is num_micro * repeat + num_stages - 1, & we may consider the num_stages - 1 as bubble. # The bubble doubles when we use forwarding delay. bubble_iterations = self.forwarding_delay * (self.num_stages - 1) - real_iterations = self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats - total_iterations = real_iterations + bubble_iterations + # real_iterations = self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats + # total_iterations = real_iterations + bubble_iterations if self.is_initializing(): vmap_func = self.get_vmap_func_for_init() @@ -899,21 +939,15 @@ def __call__( out_sharding=self.output_sharding, ) - if self.config.pipeline_fsdp_ag_once: - variables = self._remove_logically_partition(self.layers.variables) - all_pipeline_weights = self.all_gather_over_fsdp(variables, logical_partition_spec) - else: - all_pipeline_weights = self.layers.variables - logical_partition_spec = self.get_logical_spec_repeats_removed(logical_partition_spec) - def run_iteration_scannable(model, loop_state, xs): + def run_iteration_scannable(model, loop_state): # flax transforms like nn.scan and nn.remat can only be applied to nn.module classes or nn.module instances, so we # explicitly wrap the run_one_iteration in this method - the 1st argument model (`self`) is a nn.module instance. return ( model.run_one_iteration( loop_state, - all_pipeline_weights, + model.layers.variables, positions, segment_ids, deterministic, @@ -927,39 +961,152 @@ def run_iteration_scannable(model, loop_state, xs): if self.config.set_remat_policy_on_pipeline_iterations: run_iteration_scannable = nn.remat( run_iteration_scannable, - prevent_cse=not self.config.scan_pipeline_iterations, # prevent_cse not used with scan + prevent_cse=not self.config.scan_pipeline_iterations, policy=self.get_pipeline_remat_policy(), ) - # The scan cannot be used on init since it broadcasts the weights, which aren't yet initialized. - if self.config.scan_pipeline_iterations: - variable_carry = [] - variable_broadcast = [ - "params", - "_overwrite_with_gradient", - ] # All loop iterations need the weights for the full pipeline. - if self.is_mutable_collection("non_trainable"): - variable_carry.append("non_trainable") + def run_one_repeat_scannable(model, loop_state): + loop_state["bsw"] = model.bsw_all_gather_over_fsdp( + loop_state["bsw"], physical_partition_spec, loop_state["loop_iteration"] + ) + + if model.config.scan_pipeline_iterations: + run_one_repeat_scanned = nn.scan( + run_iteration_scannable, + variable_axes={ + "summaries": 0, + "aux_loss": 0, + "intermediates": 0, + "hyper_params": 0, + }, + variable_broadcast=variable_broadcast, + variable_carry=variable_carry, + # Dropout/aqt keys will be split for each iteration. + split_rngs={"random": True}, + length=model.config.num_pipeline_microbatches, + ) + loop_state, _ = run_one_repeat_scanned(model, loop_state) else: - variable_broadcast.append("non_trainable") - run_all_iterations_scanned = nn.scan( - run_iteration_scannable, - variable_axes={ - "summaries": 0, - "aux_loss": 0, - "intermediates": 0, - "hyper_params": 0, - }, - variable_broadcast=variable_broadcast, - variable_carry=variable_carry, - # Dropout/aqt keys will be split for each iteration. - split_rngs={"random": True}, - length=total_iterations, + for _ in range(model.config.num_pipeline_microbatches): + loop_state, _ = run_iteration_scannable(model, loop_state) + return loop_state, None + + run_one_repeat_scannable = nn.remat( + run_one_repeat_scannable, + prevent_cse=not self.config.scan_pipeline_iterations, + policy=self.get_pipeline_remat_policy(), + ) + + def run_real_repeats(model, loop_state): + if self.config.scan_pipeline_repeats: + run_repeats_scanned = nn.scan( + run_one_repeat_scannable, + variable_axes={ + "summaries": 0, + "aux_loss": 0, + "intermediates": 0, + "hyper_params": 0, + }, + variable_broadcast=variable_broadcast, + variable_carry=variable_carry, + split_rngs={"random": True}, + length=model.config.num_pipeline_repeats, + ) + loop_state, _ = run_repeats_scanned(model, loop_state) + else: + for _ in range(model.config.num_pipeline_repeats): # remat and scan outer loop + loop_state, _ = run_one_repeat_scannable(model, loop_state) + return loop_state + + run_real_repeats = nn.remat( + run_real_repeats, + prevent_cse=not self.config.scan_pipeline_iterations, + policy=self.get_pipeline_remat_policy(), + ) + + def run_bubble_iterations_scannable(model, loop_state): + loop_state["bsw"] = model.bsw_all_gather_over_fsdp( + loop_state["bsw"], physical_partition_spec, loop_state["loop_iteration"] ) - loop_state, _ = run_all_iterations_scanned(self, loop_state, None) + + if model.config.scan_pipeline_iterations: + run_one_repeat_scanned = nn.scan( + run_iteration_scannable, + variable_axes={ + "summaries": 0, + "aux_loss": 0, + "intermediates": 0, + "hyper_params": 0, + }, + variable_broadcast=variable_broadcast, + variable_carry=variable_carry, + # Dropout/aqt keys will be split for each iteration. + split_rngs={"random": True}, + length=bubble_iterations, + ) + loop_state, _ = run_one_repeat_scanned(model, loop_state) + else: + for _ in range(model.config.num_pipeline_microbatches): + loop_state, _ = run_iteration_scannable(model, loop_state) + return loop_state, None + + run_bubble_iterations_scannable = nn.remat( + run_bubble_iterations_scannable, + prevent_cse=not self.config.scan_pipeline_iterations, + policy=self.get_pipeline_remat_policy(), + ) + + def run_all_iterations(model, loop_state): + if self.config.scan_pipeline_repeats: + run_repeats_scanned = nn.scan( + run_one_repeat_scannable, + variable_axes={ + "summaries": 0, + "aux_loss": 0, + "intermediates": 0, + "hyper_params": 0, + }, + variable_broadcast=variable_broadcast, + variable_carry=variable_carry, + split_rngs={"random": True}, + length=model.config.num_pipeline_repeats, + ) + + run_bubbles_scanned = nn.scan( + run_bubble_iterations_scannable, + variable_axes={ + "summaries": 0, + "aux_loss": 0, + "intermediates": 0, + "hyper_params": 0, + }, + variable_broadcast=variable_broadcast, + variable_carry=variable_carry, + split_rngs={"random": True}, + length=model.config.num_pipeline_repeats, + ) + loop_state, _ = run_repeats_scanned(model, loop_state) + loop_state, _ = run_bubbles_scanned(model, loop_state) + else: + for _ in range(model.config.num_pipeline_repeats): # remat and scan outer loop + loop_state, _ = run_one_repeat_scannable(model, loop_state) + for _ in range(bubble_iterations): + loop_state, _ = run_iteration_scannable(model, loop_state) + return loop_state + + # The scan cannot be used on init since it broadcasts the weights, which aren't yet initialized. + # if self.config.scan_pipeline_iterations: + variable_carry = [] + variable_broadcast = [ + "params", + "_overwrite_with_gradient", + ] # All loop iterations need the weights for the full pipeline. + if self.is_mutable_collection("non_trainable"): + variable_carry.append("non_trainable") else: - for _ in range(total_iterations): - loop_state, _ = run_iteration_scannable(self, loop_state, None) + variable_broadcast.append("non_trainable") + + loop_state = run_all_iterations(self, loop_state) # The final output is located in the input/output array, however the output microbatches may be permuted relative to # the input diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 632f3345b7..9e3f081bba 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -290,6 +290,7 @@ pipeline_fsdp_ag_once: False # If set to true then all gather all of the weights # It may be useful to do the reverse when the layers_per_stage is very large. # The below settings only have effect when using pipeline parallelism. scan_pipeline_iterations: True +scan_pipeline_repeats: True scan_layers_per_stage: False set_remat_policy_on_pipeline_iterations: True set_remat_policy_on_layers_per_stage: False diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index e651293a19..43d3eded20 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -842,6 +842,7 @@ class PipelineParallelism(BaseModel): ) pipeline_fsdp_ag_once: bool = Field(False, description="If True, all-gather FSDP weights once per pipeline repeat.") scan_pipeline_iterations: bool = Field(True, description="Use jax.lax.scan over pipeline iterations.") + scan_pipeline_repeats: bool = Field(True, description="Use jax.lax.scan over pipeline repeats.") scan_layers_per_stage: bool = Field(False, description="Use jax.lax.scan over layers within a stage.") set_remat_policy_on_pipeline_iterations: bool = Field(True, description="Set remat policy on the pipeline scan.") set_remat_policy_on_layers_per_stage: bool = Field(False, description="Set remat policy on the inner layer scan.") From f12cd8956b03f55dcd27edb90a1b5405f9743bf8 Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Sat, 10 Jan 2026 01:25:44 +0000 Subject: [PATCH 2/7] working all gather insertion --- src/MaxText/layers/pipeline.py | 248 +++++++++++++----------- src/maxtext/configs/base.yml | 8 +- tests/unit/pipeline_parallelism_test.py | 10 +- 3 files changed, 150 insertions(+), 116 deletions(-) diff --git a/src/MaxText/layers/pipeline.py b/src/MaxText/layers/pipeline.py index 094fd9625b..2f3c22f98b 100644 --- a/src/MaxText/layers/pipeline.py +++ b/src/MaxText/layers/pipeline.py @@ -30,6 +30,7 @@ from flax.linen.spmd import LogicallyPartitioned from MaxText.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT, ShardMode +# from MaxText import maxtext_utils from MaxText.sharding import ( maybe_shard_with_logical, maybe_shard_with_name, @@ -204,12 +205,17 @@ def init_states(self, inputs): def _init_bsw_from_weights(variables): """Buffer space for two copies of weights.""" - return jax.tree.map(lambda x: jnp.zeros_like(x[:2]), variables) + # take idx 0 slice assuming num_layers_per_pipeline_stage=1 + return ( + jax.tree.map(lambda x: jnp.zeros_like(x[0]), variables), + jax.tree.map(lambda x: jnp.zeros_like(x[0]), variables), + ) if self.is_initializing(): bsw = None else: - bsw = _init_bsw_from_weights(self.layers.variables) + variables = self._remove_logically_partition(self.layers.variables) + bsw = _init_bsw_from_weights(variables) init_loop_state = { "state_io": state_io, @@ -269,6 +275,31 @@ def select_state_or_input(first_stage_in, shift): stages_in = self._maybe_shard_with_logical(stages_in, self.stages_in_logical) return stages_in + def shard_dim_by_stages(self, x, dim: int, physical_partition_spec: P | None, is_stage_weight: bool = False): + """Shards x using the provided partition_spec, but adds the "stage" mesh axis to the existing sharding at + the specified dimension.""" + # placeholder = None if self.config.shard_mode == ShardMode.EXPLICIT else P.UNCONSTRAINED + # if physical_partition_spec is None: + # dims_mapping = [placeholder] * x.ndim + # else: + # physical_partition_spec = self._remove_fsdp_from_physical_partition_spec(physical_partition_spec) + # dims_mapping = list(physical_partition_spec) + # # If not a stage weight, we handle the repeat dimension offset + # if not is_stage_weight: + # dims_mapping = [placeholder] * (dim + 1) + dims_mapping[dim:] # inflat one dimension for num_repeats + # dims_mapping[dim] = "stage" + # dims_mapping = tuple(dims_mapping) + # # We add reduced rule only when pspec is given for a stage weight + # if physical_partition_spec and is_stage_weight and self.config.shard_mode == ShardMode.EXPLICIT: + # batch_mesh_axis = ["data", "fsdp"] + # reduced_mark = [mesh_axis for mesh_axis in batch_mesh_axis if self.mesh.shape[mesh_axis] > 1] + # pspec = P(*dims_mapping, reduced=set(reduced_mark)) + # else: + # pspec = P(*dims_mapping) + # sharding = jax.sharding.NamedSharding(self.mesh, pspec) + # return self._maybe_shard_with_name(x, sharding) + return x + def get_microbatch_and_repeat_ids(self, loop_iteration): """Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and non-circular""" @@ -278,6 +309,14 @@ def get_microbatch_and_repeat_ids(self, loop_iteration): repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches return microbatch_ids, repeat_ids + def get_microbatch_and_repeat_ids_for_bsw(self, loop_iteration): + """Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and + non-circular""" + raw_processed = loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages) + repeat_ids = raw_processed // self.config.num_pipeline_microbatches + microbatch_ids = jnp.maximum(raw_processed, 0) % self.config.num_pipeline_microbatches + return microbatch_ids, repeat_ids + def vmap_parallel_gather( self, weights, physical_partition_spec, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights ): @@ -300,9 +339,18 @@ def _gather_one(x, repeat_id): return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights) gathered_weights_stage_dim = 0 + repeat_ids = self.shard_dim_by_stages(repeat_ids, 0, physical_partition_spec=None) + # num_repeats x num_stages x *param_dim + weights = self.shard_dim_by_stages( + weights, stages_dim_in_weights, physical_partition_spec=physical_partition_spec, is_stage_weight=False + ) stage_weights = jax.vmap(_gather_one, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim)( weights, repeat_ids ) + # num_stages x *param_dim + stage_weights = self.shard_dim_by_stages( + stage_weights, gathered_weights_stage_dim, physical_partition_spec=physical_partition_spec, is_stage_weight=True + ) return stage_weights def vmap_gather(self, xs, ids, ids_dim): @@ -326,8 +374,9 @@ def _gather_one(x, i): replicated_sharding = NamedSharding(self.mesh, P()) return x.at[idx].get(out_sharding=replicated_sharding) + ids = self.shard_dim_by_stages(ids, 0, physical_partition_spec=None) outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) - return outs + return self.shard_dim_by_stages(outs, 0, physical_partition_spec=None) def get_new_loop_state(self, output, loop_state): """ @@ -471,20 +520,53 @@ def get_current_stage_weights(self, pipeline_weights, bsw, loop_iteration, physi For non-circular pipelines, this simply returns all weights - every weight is used in every iteraiton. However for circular pipelines each stage grabs only the weights corresponding to the current repeat. """ + pipeline_weights = self._remove_logically_partition(pipeline_weights) if self.config.num_pipeline_repeats > 1: - return self.get_current_weights_from_bsw(bsw, loop_iteration, physical_partition_spec=physical_partition_spec) - else: - return pipeline_weights + pipeline_weights = self.get_current_weights_from_bsw( + bsw, loop_iteration, physical_partition_spec=physical_partition_spec + ) + return pipeline_weights - def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_spec=None): + def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_spec): """Collect and gather weights from given bsw (buffer sliding window)""" + bsw_pps = jax.tree.map(self._remove_fsdp_from_physical_partition_spec, physical_partition_spec) + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + target_repeat_id = repeat_ids[0] - def _get_bsw_idx(loop_iteration): - _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) - bsw_ids = (repeat_ids == repeat_ids[0]).astype( - jnp.int32 - ) # For early repeats this might return true when it should be false - return bsw_ids + # path = ("params", "mlp", "wi_0", "kernel") + # path = ("params", "weights") + + # jax.debug.print( + # "Iteration: {iter} | Global Target Repeat ID: {target} | Repeat_ids: {rids} | " + # "BSW[0] per-stage means: {bsw0} | BSW[1] per-stage means: {bsw1}", + # iter=loop_iteration, target=target_repeat_id, rids=repeat_ids, + # bsw0=maxtext_utils.get_nested_value(bsw[0], path).mean(axis=(1, 2)), + # bsw1=maxtext_utils.get_nested_value(bsw[1], path).mean(axis=(1, 2)), + # ) + + @jax.shard_map( + mesh=self.mesh, + in_specs=((bsw_pps, bsw_pps), P("stage")), + out_specs=(bsw_pps), + check_vma=True, + ) + def select_weights_from_bsw(bsw, repeat_id): + weights = jax.tree.map( + lambda x, y: jax.lax.select(repeat_id[0] == target_repeat_id, y, x), + bsw[0], + bsw[1], + ) + # jax.debug.print( + # "Iteration: {iter} | " + # "Selected weights mean for Stage {s} with repeat id {i}: {m}", + # iter=loop_iteration, + # s=jax.lax.axis_index("stage"), + # m=maxtext_utils.get_nested_value(weights, path).mean(), + # i=repeat_id[0], + # ) + return weights + + weights = select_weights_from_bsw(bsw, repeat_ids) circular_metadata_params = { nn.PARTITION_NAME: "circular_repeats", @@ -494,24 +576,10 @@ def _get_bsw_idx(loop_iteration): "optimizer_dims_mapping": None, } weights = meta.remove_axis( - bsw, 0, circular_metadata_params + weights, 0, circular_metadata_params ) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one circular # entry per stage. - weights = self._remove_logically_partition(weights) - def gather_weights_for_stages_in(w, spec=None): - return self.vmap_parallel_gather( - w, - repeat_ids=_get_bsw_idx(loop_iteration), - repeat_dim_in_weights=0, - stages_dim_in_weights=1, - physical_partition_spec=spec, - ) - - if physical_partition_spec is None: - weights = jax.tree.map(gather_weights_for_stages_in, weights) - else: - weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec) return weights @staticmethod @@ -544,40 +612,50 @@ def find_fsdp(pspec): return jax.tree.map(find_fsdp, physical_partition_spec) - def bsw_all_gather_over_fsdp(self, bsw, physical_partition_spec, loop_iteration): + def bsw_all_gather_over_fsdp(self, weights, bsw, physical_partition_spec, loop_iteration): """All gather bsw over fsdp mesh axis using shardmap.""" - pps_no_fsdp = jax.tree.map(self._remove_fsdp_from_physical_partition_spec, physical_partition_spec) + bsw_pps = self._generate_bsw_pps_from_pps(physical_partition_spec) + repeat_weights_pps = jax.tree.map(lambda p: P(*p[1:]), physical_partition_spec) fsdp_idx = self.get_fsdp_index_pytree(physical_partition_spec) _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration + 1) + def gather_weights_for_stages_in(w, spec): + return self.vmap_parallel_gather( + w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec + ) + + if physical_partition_spec is None: + repeat_weights = jax.tree.map(gather_weights_for_stages_in, weights) + else: + repeat_weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec) + + circular_metadata_params = { + nn.PARTITION_NAME: "circular_repeats", + "sub_weight_split_dims_mapping": (None,), + "is_initializing": self.is_initializing(), + "x_times": self.config.num_pipeline_repeats, + "optimizer_dims_mapping": None, + } + repeat_weights = meta.remove_axis(repeat_weights, 0, circular_metadata_params) + @jax.shard_map( mesh=self.mesh, - in_specs=(physical_partition_spec, pps_no_fsdp, None, None), - out_specs=pps_no_fsdp, + in_specs=(repeat_weights_pps, (bsw_pps, bsw_pps), None), + out_specs=(bsw_pps, bsw_pps), check_vma=True, ) - def _all_gather_inner(variables, cur_bsw, repeat_idx, fsdp_idx): - new_variables = jax.tree.map( - lambda x: jax.lax.dynamic_slice_in_dim(x, repeat_idx, 1), - variables, - ) - + def _all_gather_inner(sharded_weights, cur_bsw, fsdp_idx): def _all_gather_invariant(x, i): if i >= 0: - return all_gather_invariant(x, axis_name="fsdp", axis=i, tiled=True) + return all_gather_invariant(x, axis_name="fsdp", axis=i - 1, tiled=True) return x - new_variables = jax.tree.map(_all_gather_invariant, new_variables, fsdp_idx) - - def shift_and_insert(bsw_leaf, new_leaf): - updated_bsw = bsw_leaf.at[0].set(bsw_leaf[1]) - updated_bsw = updated_bsw.at[1].set(jnp.squeeze(new_leaf, axis=0)) - return updated_bsw + new_variables = jax.tree.map(_all_gather_invariant, sharded_weights, fsdp_idx) - return jax.tree.map(shift_and_insert, cur_bsw, new_variables) + return (cur_bsw[1], new_variables) - return _all_gather_inner(self.layers.variables, bsw, repeat_ids[0], fsdp_idx) + return _all_gather_inner(repeat_weights, bsw, fsdp_idx) def get_vmap_func_for_init(self): """This vmap func is used to initialize the weights only on init.""" @@ -648,7 +726,7 @@ def run_one_iteration( deterministic, model_mode, decoder_layer_instance, - logical_partition_spec=None, + logical_partition_spec, ): """Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel, and update the loop state.""" @@ -811,6 +889,13 @@ def _remove_fsdp_from_physical_partition_spec(pps): return P(*new_spec) return pps + def _generate_bsw_pps_from_pps(self, physical_partition_spec): + """Create bsw physical partition spec from weight physical partition spec.""" + return jax.tree.map( + lambda pps: P(*self._remove_fsdp_from_physical_partition_spec(pps)[1:]), + physical_partition_spec, + ) + @nn.compact def __call__( self, @@ -966,8 +1051,9 @@ def run_iteration_scannable(model, loop_state): ) def run_one_repeat_scannable(model, loop_state): + weights = model._remove_logically_partition(model.layers.variables) # pylint: disable=protected-access loop_state["bsw"] = model.bsw_all_gather_over_fsdp( - loop_state["bsw"], physical_partition_spec, loop_state["loop_iteration"] + weights, loop_state["bsw"], physical_partition_spec, loop_state["loop_iteration"] ) if model.config.scan_pipeline_iterations: @@ -997,65 +1083,6 @@ def run_one_repeat_scannable(model, loop_state): policy=self.get_pipeline_remat_policy(), ) - def run_real_repeats(model, loop_state): - if self.config.scan_pipeline_repeats: - run_repeats_scanned = nn.scan( - run_one_repeat_scannable, - variable_axes={ - "summaries": 0, - "aux_loss": 0, - "intermediates": 0, - "hyper_params": 0, - }, - variable_broadcast=variable_broadcast, - variable_carry=variable_carry, - split_rngs={"random": True}, - length=model.config.num_pipeline_repeats, - ) - loop_state, _ = run_repeats_scanned(model, loop_state) - else: - for _ in range(model.config.num_pipeline_repeats): # remat and scan outer loop - loop_state, _ = run_one_repeat_scannable(model, loop_state) - return loop_state - - run_real_repeats = nn.remat( - run_real_repeats, - prevent_cse=not self.config.scan_pipeline_iterations, - policy=self.get_pipeline_remat_policy(), - ) - - def run_bubble_iterations_scannable(model, loop_state): - loop_state["bsw"] = model.bsw_all_gather_over_fsdp( - loop_state["bsw"], physical_partition_spec, loop_state["loop_iteration"] - ) - - if model.config.scan_pipeline_iterations: - run_one_repeat_scanned = nn.scan( - run_iteration_scannable, - variable_axes={ - "summaries": 0, - "aux_loss": 0, - "intermediates": 0, - "hyper_params": 0, - }, - variable_broadcast=variable_broadcast, - variable_carry=variable_carry, - # Dropout/aqt keys will be split for each iteration. - split_rngs={"random": True}, - length=bubble_iterations, - ) - loop_state, _ = run_one_repeat_scanned(model, loop_state) - else: - for _ in range(model.config.num_pipeline_microbatches): - loop_state, _ = run_iteration_scannable(model, loop_state) - return loop_state, None - - run_bubble_iterations_scannable = nn.remat( - run_bubble_iterations_scannable, - prevent_cse=not self.config.scan_pipeline_iterations, - policy=self.get_pipeline_remat_policy(), - ) - def run_all_iterations(model, loop_state): if self.config.scan_pipeline_repeats: run_repeats_scanned = nn.scan( @@ -1073,7 +1100,7 @@ def run_all_iterations(model, loop_state): ) run_bubbles_scanned = nn.scan( - run_bubble_iterations_scannable, + run_iteration_scannable, variable_axes={ "summaries": 0, "aux_loss": 0, @@ -1083,9 +1110,10 @@ def run_all_iterations(model, loop_state): variable_broadcast=variable_broadcast, variable_carry=variable_carry, split_rngs={"random": True}, - length=model.config.num_pipeline_repeats, + length=bubble_iterations, ) loop_state, _ = run_repeats_scanned(model, loop_state) + loop_state["bsw"] = (loop_state["bsw"][1], jax.tree.map(jnp.zeros_like, loop_state["bsw"][1])) loop_state, _ = run_bubbles_scanned(model, loop_state) else: for _ in range(model.config.num_pipeline_repeats): # remat and scan outer loop diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 9e3f081bba..0674041620 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -901,7 +901,13 @@ xprof_e2e_enable_fw_throttle_event: False xprof_e2e_enable_fw_power_level_event: False xprof_e2e_enable_fw_thermal_event: False -log_config: True # Prints the config (after defaults have been set by pyconfig logic) +# TPU power trace level for xprof. 0:POWER_TRACE_NONE, 1:POWER_TRACE_NORMAL, or 2:POWER_TRACE_SPI +xprof_tpu_power_trace_level: 0 +xprof_e2e_enable_fw_throttle_event: False +xprof_e2e_enable_fw_power_level_event: False +xprof_e2e_enable_fw_thermal_event: False + +log_config: False # Prints the config (after defaults have been set by pyconfig logic) debug_sharding: False # Prints model weights sharding info # Checkpoint Structured logging diff --git a/tests/unit/pipeline_parallelism_test.py b/tests/unit/pipeline_parallelism_test.py index 98e49d5050..816eeb83e3 100644 --- a/tests/unit/pipeline_parallelism_test.py +++ b/tests/unit/pipeline_parallelism_test.py @@ -213,7 +213,7 @@ def test_circular_minimum_microbatches_same_output_and_grad(self): run_name="circular_minimum_microbatches", max_target_length=128, base_emb_dim=28, - ici_pipeline_parallelism=4, + ici_pipeline_parallelism=2, base_num_decoder_layers=8, num_pipeline_microbatches=4, per_device_batch_size=4, @@ -230,7 +230,7 @@ def test_circular_extra_microbatches_same_output_and_grad(self): run_name="circular_extra_microbatches", max_target_length=128, base_emb_dim=28, - ici_pipeline_parallelism=4, + ici_pipeline_parallelism=2, base_num_decoder_layers=8, num_pipeline_microbatches=8, per_device_batch_size=4, @@ -247,7 +247,7 @@ def test_circular_deepseek_megablox_same_output_and_grad(self): run_name="circular_moe", max_target_length=128, base_emb_dim=28, - ici_pipeline_parallelism=4, + ici_pipeline_parallelism=2, base_num_decoder_layers=8, num_pipeline_microbatches=8, per_device_batch_size=4, @@ -287,7 +287,7 @@ def test_non_circular_same_output_and_grad(self): run_name="non_circular", max_target_length=128, base_emb_dim=28, - ici_pipeline_parallelism=4, + ici_pipeline_parallelism=2, base_num_decoder_layers=4, num_pipeline_microbatches=4, per_device_batch_size=4, @@ -336,7 +336,7 @@ def test_delay_activation_forwarding_same_output_and_grad(self): run_name="activation_forwarding", max_target_length=128, base_emb_dim=28, - ici_pipeline_parallelism=4, + ici_pipeline_parallelism=2, base_num_decoder_layers=8, num_pipeline_microbatches=8, per_device_batch_size=4, From ceb7a948316a7e1c7cde0a806d668714c92c792f Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Tue, 13 Jan 2026 22:43:00 +0000 Subject: [PATCH 3/7] clean version fsdp+pp bug free --- src/MaxText/layers/pipeline.py | 66 ++-------------------------------- 1 file changed, 2 insertions(+), 64 deletions(-) diff --git a/src/MaxText/layers/pipeline.py b/src/MaxText/layers/pipeline.py index 2f3c22f98b..18ec18925b 100644 --- a/src/MaxText/layers/pipeline.py +++ b/src/MaxText/layers/pipeline.py @@ -275,31 +275,6 @@ def select_state_or_input(first_stage_in, shift): stages_in = self._maybe_shard_with_logical(stages_in, self.stages_in_logical) return stages_in - def shard_dim_by_stages(self, x, dim: int, physical_partition_spec: P | None, is_stage_weight: bool = False): - """Shards x using the provided partition_spec, but adds the "stage" mesh axis to the existing sharding at - the specified dimension.""" - # placeholder = None if self.config.shard_mode == ShardMode.EXPLICIT else P.UNCONSTRAINED - # if physical_partition_spec is None: - # dims_mapping = [placeholder] * x.ndim - # else: - # physical_partition_spec = self._remove_fsdp_from_physical_partition_spec(physical_partition_spec) - # dims_mapping = list(physical_partition_spec) - # # If not a stage weight, we handle the repeat dimension offset - # if not is_stage_weight: - # dims_mapping = [placeholder] * (dim + 1) + dims_mapping[dim:] # inflat one dimension for num_repeats - # dims_mapping[dim] = "stage" - # dims_mapping = tuple(dims_mapping) - # # We add reduced rule only when pspec is given for a stage weight - # if physical_partition_spec and is_stage_weight and self.config.shard_mode == ShardMode.EXPLICIT: - # batch_mesh_axis = ["data", "fsdp"] - # reduced_mark = [mesh_axis for mesh_axis in batch_mesh_axis if self.mesh.shape[mesh_axis] > 1] - # pspec = P(*dims_mapping, reduced=set(reduced_mark)) - # else: - # pspec = P(*dims_mapping) - # sharding = jax.sharding.NamedSharding(self.mesh, pspec) - # return self._maybe_shard_with_name(x, sharding) - return x - def get_microbatch_and_repeat_ids(self, loop_iteration): """Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and non-circular""" @@ -309,14 +284,6 @@ def get_microbatch_and_repeat_ids(self, loop_iteration): repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches return microbatch_ids, repeat_ids - def get_microbatch_and_repeat_ids_for_bsw(self, loop_iteration): - """Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and - non-circular""" - raw_processed = loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages) - repeat_ids = raw_processed // self.config.num_pipeline_microbatches - microbatch_ids = jnp.maximum(raw_processed, 0) % self.config.num_pipeline_microbatches - return microbatch_ids, repeat_ids - def vmap_parallel_gather( self, weights, physical_partition_spec, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights ): @@ -339,18 +306,9 @@ def _gather_one(x, repeat_id): return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights) gathered_weights_stage_dim = 0 - repeat_ids = self.shard_dim_by_stages(repeat_ids, 0, physical_partition_spec=None) - # num_repeats x num_stages x *param_dim - weights = self.shard_dim_by_stages( - weights, stages_dim_in_weights, physical_partition_spec=physical_partition_spec, is_stage_weight=False - ) stage_weights = jax.vmap(_gather_one, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim)( weights, repeat_ids ) - # num_stages x *param_dim - stage_weights = self.shard_dim_by_stages( - stage_weights, gathered_weights_stage_dim, physical_partition_spec=physical_partition_spec, is_stage_weight=True - ) return stage_weights def vmap_gather(self, xs, ids, ids_dim): @@ -374,9 +332,7 @@ def _gather_one(x, i): replicated_sharding = NamedSharding(self.mesh, P()) return x.at[idx].get(out_sharding=replicated_sharding) - ids = self.shard_dim_by_stages(ids, 0, physical_partition_spec=None) - outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) - return self.shard_dim_by_stages(outs, 0, physical_partition_spec=None) + return jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) def get_new_loop_state(self, output, loop_state): """ @@ -533,17 +489,6 @@ def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_s _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) target_repeat_id = repeat_ids[0] - # path = ("params", "mlp", "wi_0", "kernel") - # path = ("params", "weights") - - # jax.debug.print( - # "Iteration: {iter} | Global Target Repeat ID: {target} | Repeat_ids: {rids} | " - # "BSW[0] per-stage means: {bsw0} | BSW[1] per-stage means: {bsw1}", - # iter=loop_iteration, target=target_repeat_id, rids=repeat_ids, - # bsw0=maxtext_utils.get_nested_value(bsw[0], path).mean(axis=(1, 2)), - # bsw1=maxtext_utils.get_nested_value(bsw[1], path).mean(axis=(1, 2)), - # ) - @jax.shard_map( mesh=self.mesh, in_specs=((bsw_pps, bsw_pps), P("stage")), @@ -556,14 +501,7 @@ def select_weights_from_bsw(bsw, repeat_id): bsw[0], bsw[1], ) - # jax.debug.print( - # "Iteration: {iter} | " - # "Selected weights mean for Stage {s} with repeat id {i}: {m}", - # iter=loop_iteration, - # s=jax.lax.axis_index("stage"), - # m=maxtext_utils.get_nested_value(weights, path).mean(), - # i=repeat_id[0], - # ) + return weights weights = select_weights_from_bsw(bsw, repeat_ids) From b64f02fcd54204cf3af071c80d7f70d5958ef2ca Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Fri, 16 Jan 2026 18:16:25 +0000 Subject: [PATCH 4/7] add bsw checkpoint --- src/MaxText/layers/pipeline.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/MaxText/layers/pipeline.py b/src/MaxText/layers/pipeline.py index 18ec18925b..f22d213cbf 100644 --- a/src/MaxText/layers/pipeline.py +++ b/src/MaxText/layers/pipeline.py @@ -590,8 +590,9 @@ def _all_gather_invariant(x, i): return x new_variables = jax.tree.map(_all_gather_invariant, sharded_weights, fsdp_idx) + new_variables = jax.ad_checkpoint.checkpoint_name(new_variables, "bsw_gathered_weights") - return (cur_bsw[1], new_variables) + return jax.ad_checkpoint.checkpoint_name((cur_bsw[1], new_variables), "bsw") return _all_gather_inner(repeat_weights, bsw, fsdp_idx) @@ -984,7 +985,7 @@ def run_iteration_scannable(model, loop_state): if self.config.set_remat_policy_on_pipeline_iterations: run_iteration_scannable = nn.remat( run_iteration_scannable, - prevent_cse=not self.config.scan_pipeline_iterations, + prevent_cse=True, # not self.config.scan_pipeline_iterations, policy=self.get_pipeline_remat_policy(), ) @@ -1017,7 +1018,7 @@ def run_one_repeat_scannable(model, loop_state): run_one_repeat_scannable = nn.remat( run_one_repeat_scannable, - prevent_cse=not self.config.scan_pipeline_iterations, + prevent_cse=True, policy=self.get_pipeline_remat_policy(), ) From 33454f644bbb6a0af180c520e4937a2b0dc16d66 Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Fri, 23 Jan 2026 17:20:21 +0000 Subject: [PATCH 5/7] split bsw all gather into two --- src/MaxText/layers/pipeline.py | 48 ++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/src/MaxText/layers/pipeline.py b/src/MaxText/layers/pipeline.py index f22d213cbf..dccabcdbe5 100644 --- a/src/MaxText/layers/pipeline.py +++ b/src/MaxText/layers/pipeline.py @@ -550,24 +550,20 @@ def find_fsdp(pspec): return jax.tree.map(find_fsdp, physical_partition_spec) - def bsw_all_gather_over_fsdp(self, weights, bsw, physical_partition_spec, loop_iteration): - """All gather bsw over fsdp mesh axis using shardmap.""" - bsw_pps = self._generate_bsw_pps_from_pps(physical_partition_spec) - repeat_weights_pps = jax.tree.map(lambda p: P(*p[1:]), physical_partition_spec) - fsdp_idx = self.get_fsdp_index_pytree(physical_partition_spec) - - _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration + 1) + def from_all_variables_to_repeat_weights(self, loop_iteration, physical_partition_spec): + """Generate one single repeat weight from all variables.""" + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) def gather_weights_for_stages_in(w, spec): return self.vmap_parallel_gather( w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec ) + weights = self._remove_logically_partition(self.layers.variables) if physical_partition_spec is None: repeat_weights = jax.tree.map(gather_weights_for_stages_in, weights) else: repeat_weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec) - circular_metadata_params = { nn.PARTITION_NAME: "circular_repeats", "sub_weight_split_dims_mapping": (None,), @@ -576,25 +572,36 @@ def gather_weights_for_stages_in(w, spec): "optimizer_dims_mapping": None, } repeat_weights = meta.remove_axis(repeat_weights, 0, circular_metadata_params) + return repeat_weights + + def from_all_variables_to_bsw(self, loop_iteration, physical_partition_spec): + """All gather one branch of bsw using shardmap.""" + repeat_weights = self.from_all_variables_to_repeat_weights(loop_iteration, physical_partition_spec) + bsw_pps = self._generate_bsw_pps_from_pps(physical_partition_spec) + repeat_weights_pps = jax.tree.map(lambda p: P(*p[1:]), physical_partition_spec) + fsdp_idx = self.get_fsdp_index_pytree(physical_partition_spec) @jax.shard_map( mesh=self.mesh, - in_specs=(repeat_weights_pps, (bsw_pps, bsw_pps), None), - out_specs=(bsw_pps, bsw_pps), + in_specs=(repeat_weights_pps, None), + out_specs=bsw_pps, check_vma=True, ) - def _all_gather_inner(sharded_weights, cur_bsw, fsdp_idx): + def _all_gather_inner(sharded_weights, fsdp_idx): def _all_gather_invariant(x, i): if i >= 0: return all_gather_invariant(x, axis_name="fsdp", axis=i - 1, tiled=True) return x - new_variables = jax.tree.map(_all_gather_invariant, sharded_weights, fsdp_idx) - new_variables = jax.ad_checkpoint.checkpoint_name(new_variables, "bsw_gathered_weights") + return jax.tree.map(_all_gather_invariant, sharded_weights, fsdp_idx) - return jax.ad_checkpoint.checkpoint_name((cur_bsw[1], new_variables), "bsw") + return _all_gather_inner(repeat_weights, fsdp_idx) - return _all_gather_inner(repeat_weights, bsw, fsdp_idx) + def bsw_all_gather_over_fsdp(self, physical_partition_spec, loop_iteration): + """All gather all bsw over fsdp mesh axis using shardmap.""" + bsw_0 = self.from_all_variables_to_bsw(loop_iteration, physical_partition_spec) + bsw_1 = self.from_all_variables_to_bsw(loop_iteration + 1, physical_partition_spec) + return jax.ad_checkpoint.checkpoint_name((bsw_0, bsw_1), "bsw") def get_vmap_func_for_init(self): """This vmap func is used to initialize the weights only on init.""" @@ -985,15 +992,12 @@ def run_iteration_scannable(model, loop_state): if self.config.set_remat_policy_on_pipeline_iterations: run_iteration_scannable = nn.remat( run_iteration_scannable, - prevent_cse=True, # not self.config.scan_pipeline_iterations, + prevent_cse=not self.config.scan_pipeline_iterations, policy=self.get_pipeline_remat_policy(), ) def run_one_repeat_scannable(model, loop_state): - weights = model._remove_logically_partition(model.layers.variables) # pylint: disable=protected-access - loop_state["bsw"] = model.bsw_all_gather_over_fsdp( - weights, loop_state["bsw"], physical_partition_spec, loop_state["loop_iteration"] - ) + loop_state["bsw"] = model.bsw_all_gather_over_fsdp(physical_partition_spec, loop_state["loop_iteration"]) if model.config.scan_pipeline_iterations: run_one_repeat_scanned = nn.scan( @@ -1018,7 +1022,7 @@ def run_one_repeat_scannable(model, loop_state): run_one_repeat_scannable = nn.remat( run_one_repeat_scannable, - prevent_cse=True, + prevent_cse=not self.config.scan_pipeline_iterations, policy=self.get_pipeline_remat_policy(), ) @@ -1052,7 +1056,7 @@ def run_all_iterations(model, loop_state): length=bubble_iterations, ) loop_state, _ = run_repeats_scanned(model, loop_state) - loop_state["bsw"] = (loop_state["bsw"][1], jax.tree.map(jnp.zeros_like, loop_state["bsw"][1])) + loop_state["bsw"] = model.bsw_all_gather_over_fsdp(physical_partition_spec, loop_state["loop_iteration"]) loop_state, _ = run_bubbles_scanned(model, loop_state) else: for _ in range(model.config.num_pipeline_repeats): # remat and scan outer loop From 43aeedff83bb6dd2545f3ceee3cb502a4a5eadff Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Tue, 3 Feb 2026 18:04:53 +0000 Subject: [PATCH 6/7] add custom vjp --- src/MaxText/layers/decoders.py | 7 + src/MaxText/layers/pipeline.py | 308 ++++++++++++++++++++++++--------- src/maxtext/configs/base.yml | 6 - 3 files changed, 233 insertions(+), 88 deletions(-) diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index b27cef3d21..07b1a1ff02 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -943,6 +943,13 @@ def __call__( else: logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) + logits = sharding.maybe_shard_with_logical( + logits, + ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab"), + mesh=self.mesh, + shard_mode=self.config.shard_mode, + debug_sharding=self.config.debug_sharding, + ) # The API of the Decoder is now a tuple, providing both the main output # and the raw hidden state needed for auxiliary tasks. diff --git a/src/MaxText/layers/pipeline.py b/src/MaxText/layers/pipeline.py index dccabcdbe5..0d57b9846d 100644 --- a/src/MaxText/layers/pipeline.py +++ b/src/MaxText/layers/pipeline.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Pipeline layer wrapping a decoder layer(s). Supports circular pipelining """ +"""Pipeline layer wrapping a decoder layer(s). Supports circular pipelining""" -import functools +# import functools from typing import Any +import functools import numpy as np @@ -225,6 +226,7 @@ def _init_bsw_from_weights(variables): "loop_iteration": 0, "prev_outputs": prev_outputs, "bsw": bsw, + "weights": self.layers.variables, } return init_loop_state @@ -455,6 +457,7 @@ def _update_state_io(state_in, stream_slice, output, stream_buf_idx): "loop_iteration": loop_iteration + 1, "prev_outputs": new_prev_outputs, "bsw": loop_state["bsw"], # bsw is updated outside of this inner loop, only once per outer loop iteration + "weights": loop_state["weights"], # Pass weights through } return new_loop_state @@ -469,7 +472,9 @@ def permute_output_micro_per_stage_dim(self, output): output = output[:, permutation] return output - def get_current_stage_weights(self, pipeline_weights, bsw, loop_iteration, physical_partition_spec=None): + def get_current_stage_weights( + self, pipeline_weights, bsw, loop_iteration, physical_partition_spec=None, is_initializing=None + ): """ Gets the current weights used for one iteration. Outputs a pytree whose arrays have leading dimension of stages, e.g. {'mlp': 'wo': [stages, mlp, embed]}. Stage 0 will use the 0th index of this pytree, Stage 1 the 1st index, etc. @@ -479,11 +484,11 @@ def get_current_stage_weights(self, pipeline_weights, bsw, loop_iteration, physi pipeline_weights = self._remove_logically_partition(pipeline_weights) if self.config.num_pipeline_repeats > 1: pipeline_weights = self.get_current_weights_from_bsw( - bsw, loop_iteration, physical_partition_spec=physical_partition_spec + bsw, loop_iteration, physical_partition_spec=physical_partition_spec, is_initializing=is_initializing ) return pipeline_weights - def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_spec): + def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_spec, is_initializing=None): """Collect and gather weights from given bsw (buffer sliding window)""" bsw_pps = jax.tree.map(self._remove_fsdp_from_physical_partition_spec, physical_partition_spec) _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) @@ -506,10 +511,13 @@ def select_weights_from_bsw(bsw, repeat_id): weights = select_weights_from_bsw(bsw, repeat_ids) + if is_initializing is None: + is_initializing = self.is_initializing() + circular_metadata_params = { nn.PARTITION_NAME: "circular_repeats", "sub_weight_split_dims_mapping": (None,), - "is_initializing": self.is_initializing(), + "is_initializing": is_initializing, "x_times": self.config.num_pipeline_repeats, "optimizer_dims_mapping": None, } @@ -550,7 +558,7 @@ def find_fsdp(pspec): return jax.tree.map(find_fsdp, physical_partition_spec) - def from_all_variables_to_repeat_weights(self, loop_iteration, physical_partition_spec): + def from_all_variables_to_repeat_weights(self, weights, loop_iteration, physical_partition_spec): """Generate one single repeat weight from all variables.""" _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) @@ -559,11 +567,11 @@ def gather_weights_for_stages_in(w, spec): w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec ) - weights = self._remove_logically_partition(self.layers.variables) + weights = self._remove_logically_partition(weights) if physical_partition_spec is None: - repeat_weights = jax.tree.map(gather_weights_for_stages_in, weights) + weights = jax.tree.map(gather_weights_for_stages_in, weights) else: - repeat_weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec) + weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec) circular_metadata_params = { nn.PARTITION_NAME: "circular_repeats", "sub_weight_split_dims_mapping": (None,), @@ -571,12 +579,12 @@ def gather_weights_for_stages_in(w, spec): "x_times": self.config.num_pipeline_repeats, "optimizer_dims_mapping": None, } - repeat_weights = meta.remove_axis(repeat_weights, 0, circular_metadata_params) + repeat_weights = meta.remove_axis(weights, 0, circular_metadata_params) return repeat_weights - def from_all_variables_to_bsw(self, loop_iteration, physical_partition_spec): + def from_all_variables_to_bsw(self, weights, loop_iteration, physical_partition_spec): """All gather one branch of bsw using shardmap.""" - repeat_weights = self.from_all_variables_to_repeat_weights(loop_iteration, physical_partition_spec) + repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration, physical_partition_spec) bsw_pps = self._generate_bsw_pps_from_pps(physical_partition_spec) repeat_weights_pps = jax.tree.map(lambda p: P(*p[1:]), physical_partition_spec) fsdp_idx = self.get_fsdp_index_pytree(physical_partition_spec) @@ -597,10 +605,10 @@ def _all_gather_invariant(x, i): return _all_gather_inner(repeat_weights, fsdp_idx) - def bsw_all_gather_over_fsdp(self, physical_partition_spec, loop_iteration): + def bsw_all_gather_over_fsdp(self, weights, physical_partition_spec, loop_iteration): """All gather all bsw over fsdp mesh axis using shardmap.""" - bsw_0 = self.from_all_variables_to_bsw(loop_iteration, physical_partition_spec) - bsw_1 = self.from_all_variables_to_bsw(loop_iteration + 1, physical_partition_spec) + bsw_0 = self.from_all_variables_to_bsw(weights, loop_iteration, physical_partition_spec) + bsw_1 = self.from_all_variables_to_bsw(weights, loop_iteration + 1, physical_partition_spec) return jax.ad_checkpoint.checkpoint_name((bsw_0, bsw_1), "bsw") def get_vmap_func_for_init(self): @@ -666,13 +674,14 @@ def func_to_vmap( def run_one_iteration( self, loop_state, - pipeline_weights, positions, segment_ids, deterministic, model_mode, decoder_layer_instance, logical_partition_spec, + vmap_func=None, + is_initializing=None, ): """Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel, and update the loop state.""" @@ -680,6 +689,7 @@ def run_one_iteration( shift = loop_state["shift"] circ_storage = loop_state["circ_storage"] loop_iteration = loop_state["loop_iteration"] + pipeline_weights = loop_state["weights"] microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration) @@ -693,49 +703,15 @@ def run_one_iteration( stages_positions = self.vmap_gather(positions, microbatch_ids, 0) if positions is not None else None stages_segment_ids = self.vmap_gather(segment_ids, microbatch_ids, 0) if segment_ids is not None else None - vmap_func = self.get_main_vmap_func_for_iterations() - - if self.config.num_pipeline_repeats > 1: - _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) - - def prepare_vars_for_main_vmap(weights, physical_partition_spec=None): - - circular_metadata_params = { - nn.PARTITION_NAME: "circular_repeats", - "sub_weight_split_dims_mapping": (None,), - "is_initializing": self.is_initializing(), - "x_times": self.config.num_pipeline_repeats, - "optimizer_dims_mapping": None, - } - weights = meta.remove_axis( - weights, 0, circular_metadata_params - ) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one - # circular entry per stage. - weights = self._remove_logically_partition(weights) - - def gather_weights_for_stages_in(w, spec=None): - return self.vmap_parallel_gather( - w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec - ) - - if physical_partition_spec is None: - weights = jax.tree.map(gather_weights_for_stages_in, weights) - else: - weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec) - return weights - - prepare_vars_for_main_vmap_partial = functools.partial( - prepare_vars_for_main_vmap, physical_partition_spec=physical_partition_spec - ) - vmap_func = nn.map_variables( - vmap_func, - mapped_collections=["params", "_overwrite_with_gradient", "non_trainable", "summaries", "intermediates"], - mutable=True, - trans_in_fn=prepare_vars_for_main_vmap_partial, - ) + if vmap_func is None: + vmap_func = self.get_main_vmap_func_for_iterations() stage_weights = self.get_current_stage_weights( - pipeline_weights, loop_state["bsw"], loop_iteration, physical_partition_spec=physical_partition_spec + pipeline_weights, + loop_state["bsw"], + loop_iteration, + physical_partition_spec=physical_partition_spec, + is_initializing=is_initializing, ) stages_output = vmap_func( @@ -978,7 +954,6 @@ def run_iteration_scannable(model, loop_state): return ( model.run_one_iteration( loop_state, - model.layers.variables, positions, segment_ids, deterministic, @@ -997,7 +972,9 @@ def run_iteration_scannable(model, loop_state): ) def run_one_repeat_scannable(model, loop_state): - loop_state["bsw"] = model.bsw_all_gather_over_fsdp(physical_partition_spec, loop_state["loop_iteration"]) + loop_state["bsw"] = model.bsw_all_gather_over_fsdp( + loop_state["weights"], physical_partition_spec, loop_state["loop_iteration"] + ) if model.config.scan_pipeline_iterations: run_one_repeat_scanned = nn.scan( @@ -1008,13 +985,89 @@ def run_one_repeat_scannable(model, loop_state): "intermediates": 0, "hyper_params": 0, }, - variable_broadcast=variable_broadcast, - variable_carry=variable_carry, # Dropout/aqt keys will be split for each iteration. split_rngs={"random": True}, length=model.config.num_pipeline_microbatches, ) - loop_state, _ = run_one_repeat_scanned(model, loop_state) + + @functools.partial(jax.custom_vjp) + def run_one_repeat_scanned_custom(loop_state, positions, segment_ids): + final_state, _ = run_one_repeat_scanned(model, loop_state) + return final_state + + def run_one_repeat_scanned_custom_fwd(loop_state, positions, segment_ids): + final_state, _ = run_one_repeat_scanned(model, loop_state) + # We return loop_state as residual. model is passed to bwd as arg. + return final_state, ( + loop_state, + positions, + segment_ids, + ) + + def run_one_repeat_scanned_custom_bwd(residuals, g_final_state): + init_loop_state, positions, segment_ids = residuals + + # Re-run forward pass to get saved states (checkpointing) + def scan_body_fwd(carry, _): + new_state = model.run_one_iteration( + carry, + positions, + segment_ids, + deterministic, + model_mode, + model.layers, + logical_partition_spec=logical_partition_spec, + ) + # Return lightweight state for saving (exclude bsw/weights) + saved = {k: v for k, v in carry.items() if k not in ["bsw", "weights"]} + return new_state, saved + + _, saved_states = jax.lax.scan( + scan_body_fwd, + init_loop_state, + None, + length=model.config.num_pipeline_microbatches, + ) + + # Backward scan to accumulate gradients + def scan_body_bwd(carry, saved_slice): + d_next_state = carry + + # Reconstruct current loop_state (input to step) + curr_loop_state = { + **saved_slice, + "bsw": init_loop_state["bsw"], + "weights": init_loop_state["weights"], + } + + # Define function to differentiate w.r.t loop_state + def step_fn(s): + out = model.run_one_iteration( + s, + positions, + segment_ids, + deterministic, + model_mode, + model.layers, + logical_partition_spec=logical_partition_spec, + ) + return out + + _, vjp_fun = jax.vjp(step_fn, curr_loop_state) + + # Backprop d_next_state + (d_curr_state,) = vjp_fun(d_next_state) + + return d_curr_state, None + + # Run backward scan + d_init_state, _ = jax.lax.scan(scan_body_bwd, g_final_state, saved_states, reverse=True) + + return (d_init_state, None, None) + + run_one_repeat_scanned_custom.defvjp(run_one_repeat_scanned_custom_fwd, run_one_repeat_scanned_custom_bwd) + + loop_state = run_one_repeat_scanned_custom(loop_state, positions, segment_ids) else: for _ in range(model.config.num_pipeline_microbatches): loop_state, _ = run_iteration_scannable(model, loop_state) @@ -1026,6 +1079,114 @@ def run_one_repeat_scannable(model, loop_state): policy=self.get_pipeline_remat_policy(), ) + def run_bubbles_scannable(model, loop_state): + loop_state["bsw"] = model.bsw_all_gather_over_fsdp( + loop_state["weights"], physical_partition_spec, loop_state["loop_iteration"] + ) + + if model.config.scan_pipeline_iterations: + run_one_repeat_scanned = nn.scan( + run_iteration_scannable, + variable_axes={ + "summaries": 0, + "aux_loss": 0, + "intermediates": 0, + "hyper_params": 0, + }, + # Dropout/aqt keys will be split for each iteration. + split_rngs={"random": True}, + length=bubble_iterations, + ) + + @functools.partial(jax.custom_vjp) + def run_one_repeat_scanned_custom(loop_state, positions, segment_ids): + final_state, _ = run_one_repeat_scanned(model, loop_state) + return final_state + + def run_one_repeat_scanned_custom_fwd(loop_state, positions, segment_ids): + final_state, _ = run_one_repeat_scanned(model, loop_state) + # We return loop_state as residual. model is passed to bwd as arg. + return final_state, ( + loop_state, + positions, + segment_ids, + ) + + def run_one_repeat_scanned_custom_bwd(residuals, g_final_state): + init_loop_state, positions, segment_ids = residuals + + # Re-run forward pass to get saved states (checkpointing) + def scan_body_fwd(carry, _): + new_state = model.run_one_iteration( + carry, + positions, + segment_ids, + deterministic, + model_mode, + model.layers, + logical_partition_spec=logical_partition_spec, + ) + # Return lightweight state for saving (exclude bsw/weights) + saved = {k: v for k, v in carry.items() if k not in ["bsw", "weights"]} + return new_state, saved + + _, saved_states = jax.lax.scan( + scan_body_fwd, + init_loop_state, + None, + length=bubble_iterations, + ) + + # Backward scan to accumulate gradients + def scan_body_bwd(carry, saved_slice): + d_next_state = carry + + # Reconstruct current loop_state (input to step) + curr_loop_state = { + **saved_slice, + "bsw": init_loop_state["bsw"], + "weights": init_loop_state["weights"], + } + + # Define function to differentiate w.r.t loop_state + def step_fn(s): + out = model.run_one_iteration( + s, + positions, + segment_ids, + deterministic, + model_mode, + model.layers, + logical_partition_spec=logical_partition_spec, + ) + return out + + _, vjp_fun = jax.vjp(step_fn, curr_loop_state) + + # Backprop d_next_state + (d_curr_state,) = vjp_fun(d_next_state) + + return d_curr_state, None + + # Run backward scan + d_init_state, _ = jax.lax.scan(scan_body_bwd, g_final_state, saved_states, reverse=True) + + return (d_init_state, None, None) + + run_one_repeat_scanned_custom.defvjp(run_one_repeat_scanned_custom_fwd, run_one_repeat_scanned_custom_bwd) + + loop_state = run_one_repeat_scanned_custom(loop_state, positions, segment_ids) + else: + for _ in range(model.config.num_pipeline_microbatches): + loop_state, _ = run_iteration_scannable(model, loop_state) + return loop_state, None + + run_bubbles_scannable = nn.remat( + run_bubbles_scannable, + prevent_cse=not self.config.scan_pipeline_iterations, + policy=self.get_pipeline_remat_policy(), + ) + def run_all_iterations(model, loop_state): if self.config.scan_pipeline_repeats: run_repeats_scanned = nn.scan( @@ -1036,27 +1197,22 @@ def run_all_iterations(model, loop_state): "intermediates": 0, "hyper_params": 0, }, - variable_broadcast=variable_broadcast, - variable_carry=variable_carry, split_rngs={"random": True}, length=model.config.num_pipeline_repeats, ) run_bubbles_scanned = nn.scan( - run_iteration_scannable, + run_bubbles_scannable, variable_axes={ "summaries": 0, "aux_loss": 0, "intermediates": 0, "hyper_params": 0, }, - variable_broadcast=variable_broadcast, - variable_carry=variable_carry, split_rngs={"random": True}, - length=bubble_iterations, + length=1, ) loop_state, _ = run_repeats_scanned(model, loop_state) - loop_state["bsw"] = model.bsw_all_gather_over_fsdp(physical_partition_spec, loop_state["loop_iteration"]) loop_state, _ = run_bubbles_scanned(model, loop_state) else: for _ in range(model.config.num_pipeline_repeats): # remat and scan outer loop @@ -1065,18 +1221,6 @@ def run_all_iterations(model, loop_state): loop_state, _ = run_iteration_scannable(model, loop_state) return loop_state - # The scan cannot be used on init since it broadcasts the weights, which aren't yet initialized. - # if self.config.scan_pipeline_iterations: - variable_carry = [] - variable_broadcast = [ - "params", - "_overwrite_with_gradient", - ] # All loop iterations need the weights for the full pipeline. - if self.is_mutable_collection("non_trainable"): - variable_carry.append("non_trainable") - else: - variable_broadcast.append("non_trainable") - loop_state = run_all_iterations(self, loop_state) # The final output is located in the input/output array, however the output microbatches may be permuted relative to diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 0674041620..1f9a0a0854 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -901,12 +901,6 @@ xprof_e2e_enable_fw_throttle_event: False xprof_e2e_enable_fw_power_level_event: False xprof_e2e_enable_fw_thermal_event: False -# TPU power trace level for xprof. 0:POWER_TRACE_NONE, 1:POWER_TRACE_NORMAL, or 2:POWER_TRACE_SPI -xprof_tpu_power_trace_level: 0 -xprof_e2e_enable_fw_throttle_event: False -xprof_e2e_enable_fw_power_level_event: False -xprof_e2e_enable_fw_thermal_event: False - log_config: False # Prints the config (after defaults have been set by pyconfig logic) debug_sharding: False # Prints model weights sharding info From 7e0155b034f0f0ae7b2f0c0ea81cab2b0adcd4d4 Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Fri, 13 Feb 2026 22:32:52 +0000 Subject: [PATCH 7/7] enable pp with batch split ds --- src/MaxText/layers/deepseek_batchsplit.py | 2 +- .../configs/models/deepseek3-671b-2dfsdp.yml | 8 +- src/maxtext/configs/types.py | 138 +++++++++--------- 3 files changed, 75 insertions(+), 73 deletions(-) diff --git a/src/MaxText/layers/deepseek_batchsplit.py b/src/MaxText/layers/deepseek_batchsplit.py index 9d89f29e20..4af366d2fa 100644 --- a/src/MaxText/layers/deepseek_batchsplit.py +++ b/src/MaxText/layers/deepseek_batchsplit.py @@ -755,7 +755,7 @@ def gmm( input_buffer_count, combine_scopes, ): - if config.use_qwix_quantization: + if config.use_qwix_quantization or config.using_pipeline_parallelism: output = megablox.gmm( lhs=inputs, rhs=kernel, diff --git a/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml b/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml index c68c813b01..c137d94c98 100644 --- a/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml +++ b/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml @@ -56,19 +56,21 @@ rope_truncate: True rope_attention_scaling: False override_logical_axis_rules: True -mesh_axes: ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context'] -data_sharding: [['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']] +mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context'] +data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], - ['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], + ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], ['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_norm_length', ['context']], ['activation_heads', []], + ['activation_stage', 'stage'], ['embed', ['fsdp']], ['embed_no_exp', ['fsdp']], ['q_lora', ['fsdp']], ['kv_lora', ['fsdp']], + ['layers', 'stage'], ['q_lora_up_proj', ['fsdp_transpose', 'expert']], ['kv_lora_up_proj', ['fsdp_transpose', 'expert']], ['q_heads', ['fsdp_transpose', 'expert']], diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 43d3eded20..0924742f6c 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2400,75 +2400,75 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de # I. FINAL TYPE CONVERSIONS AND DERIVED LISTS # Create the ici_parallelism and dcn_parallelism lists for legacy compatibility. - if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage": - self.ici_parallelism = [ - self.ici_diloco_parallelism, - self.ici_pipeline_parallelism, - self.ici_data_parallelism, - self.ici_fsdp_parallelism, - self.ici_fsdp_transpose_parallelism, - self.ici_sequence_parallelism, - self.ici_context_parallelism, - self.ici_context_autoregressive_parallelism, - self.ici_tensor_parallelism, - self.ici_tensor_transpose_parallelism, - self.ici_tensor_sequence_parallelism, - self.ici_expert_parallelism, - self.ici_autoregressive_parallelism, - ] - self.dcn_parallelism = [ - self.dcn_diloco_parallelism, - self.dcn_pipeline_parallelism, - self.dcn_data_parallelism, - self.dcn_fsdp_parallelism, - self.dcn_fsdp_transpose_parallelism, - self.dcn_sequence_parallelism, - self.dcn_context_parallelism, - self.dcn_context_autoregressive_parallelism, - self.dcn_tensor_parallelism, - self.dcn_tensor_transpose_parallelism, - self.dcn_tensor_sequence_parallelism, - self.dcn_expert_parallelism, - self.dcn_autoregressive_parallelism, - ] - else: - ici_map = { - "diloco": self.ici_diloco_parallelism, - "data": self.ici_data_parallelism, - "stage": self.ici_pipeline_parallelism, - "fsdp": self.ici_fsdp_parallelism, - "fsdp_transpose": self.ici_fsdp_transpose_parallelism, - "sequence": self.ici_sequence_parallelism, - "context": self.ici_context_parallelism, - "context_autoregressive": self.ici_context_autoregressive_parallelism, - "tensor": self.ici_tensor_parallelism, - "tensor_transpose": self.ici_tensor_transpose_parallelism, - "tensor_sequence": self.ici_tensor_sequence_parallelism, - "model": self.ici_tensor_parallelism, - "expert": self.ici_expert_parallelism, - "autoregressive": self.ici_autoregressive_parallelism, - "attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads - } - self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes] - - dcn_map = { - "diloco": self.dcn_diloco_parallelism, - "data": self.dcn_data_parallelism, - "stage": self.dcn_pipeline_parallelism, - "fsdp": self.dcn_fsdp_parallelism, - "fsdp_transpose": self.dcn_fsdp_transpose_parallelism, - "sequence": self.dcn_sequence_parallelism, - "context": self.dcn_context_parallelism, - "context_autoregressive": self.dcn_context_autoregressive_parallelism, - "tensor": self.dcn_tensor_parallelism, - "tensor_transpose": self.dcn_tensor_transpose_parallelism, - "tensor_sequence": self.dcn_tensor_sequence_parallelism, - "model": self.dcn_tensor_parallelism, - "expert": self.dcn_expert_parallelism, - "autoregressive": self.dcn_autoregressive_parallelism, - "attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads - } - self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes] + # if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage": + # self.ici_parallelism = [ + # self.ici_diloco_parallelism, + # self.ici_pipeline_parallelism, + # self.ici_data_parallelism, + # self.ici_fsdp_parallelism, + # self.ici_fsdp_transpose_parallelism, + # self.ici_sequence_parallelism, + # self.ici_context_parallelism, + # self.ici_context_autoregressive_parallelism, + # self.ici_tensor_parallelism, + # self.ici_tensor_transpose_parallelism, + # self.ici_tensor_sequence_parallelism, + # self.ici_expert_parallelism, + # self.ici_autoregressive_parallelism, + # ] + # self.dcn_parallelism = [ + # self.dcn_diloco_parallelism, + # self.dcn_pipeline_parallelism, + # self.dcn_data_parallelism, + # self.dcn_fsdp_parallelism, + # self.dcn_fsdp_transpose_parallelism, + # self.dcn_sequence_parallelism, + # self.dcn_context_parallelism, + # self.dcn_context_autoregressive_parallelism, + # self.dcn_tensor_parallelism, + # self.dcn_tensor_transpose_parallelism, + # self.dcn_tensor_sequence_parallelism, + # self.dcn_expert_parallelism, + # self.dcn_autoregressive_parallelism, + # ] + # else: + ici_map = { + "diloco": self.ici_diloco_parallelism, + "data": self.ici_data_parallelism, + "stage": self.ici_pipeline_parallelism, + "fsdp": self.ici_fsdp_parallelism, + "fsdp_transpose": self.ici_fsdp_transpose_parallelism, + "sequence": self.ici_sequence_parallelism, + "context": self.ici_context_parallelism, + "context_autoregressive": self.ici_context_autoregressive_parallelism, + "tensor": self.ici_tensor_parallelism, + "tensor_transpose": self.ici_tensor_transpose_parallelism, + "tensor_sequence": self.ici_tensor_sequence_parallelism, + "model": self.ici_tensor_parallelism, + "expert": self.ici_expert_parallelism, + "autoregressive": self.ici_autoregressive_parallelism, + "attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads + } + self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes] + + dcn_map = { + "diloco": self.dcn_diloco_parallelism, + "data": self.dcn_data_parallelism, + "stage": self.dcn_pipeline_parallelism, + "fsdp": self.dcn_fsdp_parallelism, + "fsdp_transpose": self.dcn_fsdp_transpose_parallelism, + "sequence": self.dcn_sequence_parallelism, + "context": self.dcn_context_parallelism, + "context_autoregressive": self.dcn_context_autoregressive_parallelism, + "tensor": self.dcn_tensor_parallelism, + "tensor_transpose": self.dcn_tensor_transpose_parallelism, + "tensor_sequence": self.dcn_tensor_sequence_parallelism, + "model": self.dcn_tensor_parallelism, + "expert": self.dcn_expert_parallelism, + "autoregressive": self.dcn_autoregressive_parallelism, + "attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads + } + self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes] # Diloco params self.num_diloco_replicas = int(self.ici_diloco_parallelism * self.dcn_diloco_parallelism)