diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index d82bc065ca..07b1a1ff02 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -741,12 +741,9 @@ def __call__( model_mode, ) if cfg.using_pipeline_parallelism: - if cfg.pipeline_fsdp_ag_once: - logical_partition_spec = self.pipeline_module.get_weight_sharding( - y, decoder_segment_ids, decoder_positions, deterministic, model_mode - ) - else: - logical_partition_spec = None # This partition spec is only used for the fsdp_ag_once feature. + logical_partition_spec = self.pipeline_module.get_weight_sharding( + y, decoder_segment_ids, decoder_positions, deterministic, model_mode + ) if cfg.decoder_block == DecoderBlockType.DEEPSEEK: assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." dense_layer = RemattedBlockLayers[0] @@ -946,6 +943,13 @@ def __call__( else: logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) + logits = sharding.maybe_shard_with_logical( + logits, + ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab"), + mesh=self.mesh, + shard_mode=self.config.shard_mode, + debug_sharding=self.config.debug_sharding, + ) # The API of the Decoder is now a tuple, providing both the main output # and the raw hidden state needed for auxiliary tasks. diff --git a/src/MaxText/layers/deepseek_batchsplit.py b/src/MaxText/layers/deepseek_batchsplit.py index 9d89f29e20..4af366d2fa 100644 --- a/src/MaxText/layers/deepseek_batchsplit.py +++ b/src/MaxText/layers/deepseek_batchsplit.py @@ -755,7 +755,7 @@ def gmm( input_buffer_count, combine_scopes, ): - if config.use_qwix_quantization: + if config.use_qwix_quantization or config.using_pipeline_parallelism: output = megablox.gmm( lhs=inputs, rhs=kernel, diff --git a/src/MaxText/layers/pipeline.py b/src/MaxText/layers/pipeline.py index 8e12df3bea..0d57b9846d 100644 --- a/src/MaxText/layers/pipeline.py +++ b/src/MaxText/layers/pipeline.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Pipeline layer wrapping a decoder layer(s). Supports circular pipelining """ +"""Pipeline layer wrapping a decoder layer(s). Supports circular pipelining""" -import functools +# import functools from typing import Any +import functools import numpy as np @@ -23,12 +24,14 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec as P import jax import jax.ad_checkpoint +from jax._src.lax.parallel import all_gather_invariant from flax.core import meta from flax import linen as nn from flax.linen.spmd import LogicallyPartitioned from MaxText.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT, ShardMode +# from MaxText import maxtext_utils from MaxText.sharding import ( maybe_shard_with_logical, maybe_shard_with_name, @@ -154,7 +157,9 @@ def init_states(self, inputs): state_io: reshaped inputs [num_stages, microbatches/stages, micro_size, sequence, embed] circ_storage: zeros [num_stages, microbatches, micro_size, sequence, embed] when needed, else None circ_storage_mover: zeros[num_stages, micro_size, sequence, embed] when needed, else None - loop_iteration: scalar set initially to 0. + loop_iteration: scalar set initially to 0 + bsw: pytree of identical structure as weights with leaf arrays leading dimension of num_repeats replaced by 2, e.g. + a leaf of shape [num_repeats, stages, mlp, embed] is mapped to [2, num_stages, mlp, embed]. """ # Shift is used to rotate the output of each pipeline into the input of the next @@ -199,6 +204,20 @@ def init_states(self, inputs): else: circ_storage_mover = None + def _init_bsw_from_weights(variables): + """Buffer space for two copies of weights.""" + # take idx 0 slice assuming num_layers_per_pipeline_stage=1 + return ( + jax.tree.map(lambda x: jnp.zeros_like(x[0]), variables), + jax.tree.map(lambda x: jnp.zeros_like(x[0]), variables), + ) + + if self.is_initializing(): + bsw = None + else: + variables = self._remove_logically_partition(self.layers.variables) + bsw = _init_bsw_from_weights(variables) + init_loop_state = { "state_io": state_io, "shift": shift, @@ -206,6 +225,8 @@ def init_states(self, inputs): "circ_storage_mover": circ_storage_mover, "loop_iteration": 0, "prev_outputs": prev_outputs, + "bsw": bsw, + "weights": self.layers.variables, } return init_loop_state @@ -256,30 +277,6 @@ def select_state_or_input(first_stage_in, shift): stages_in = self._maybe_shard_with_logical(stages_in, self.stages_in_logical) return stages_in - def shard_dim_by_stages(self, x, dim: int, physical_partition_spec: P | None, is_stage_weight: bool = False): - """Shards x using the provided partition_spec, but adds the "stage" mesh axis to the existing sharding at - the specified dimension.""" - placeholder = None if self.config.shard_mode == ShardMode.EXPLICIT else P.UNCONSTRAINED - if physical_partition_spec is None: - dims_mapping = [placeholder] * x.ndim - else: - physical_partition_spec = self._remove_fsdp_from_physical_partition_spec(physical_partition_spec) - dims_mapping = list(physical_partition_spec) - # If not a stage weight, we handle the repeat dimension offset - if not is_stage_weight: - dims_mapping = [placeholder] * (dim + 1) + dims_mapping[dim:] # inflat one dimension for num_repeats - dims_mapping[dim] = "stage" - dims_mapping = tuple(dims_mapping) - # We add reduced rule only when pspec is given for a stage weight - if physical_partition_spec and is_stage_weight and self.config.shard_mode == ShardMode.EXPLICIT: - batch_mesh_axis = ["data", "fsdp"] - reduced_mark = [mesh_axis for mesh_axis in batch_mesh_axis if self.mesh.shape[mesh_axis] > 1] - pspec = P(*dims_mapping, reduced=set(reduced_mark)) - else: - pspec = P(*dims_mapping) - sharding = jax.sharding.NamedSharding(self.mesh, pspec) - return self._maybe_shard_with_name(x, sharding) - def get_microbatch_and_repeat_ids(self, loop_iteration): """Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and non-circular""" @@ -311,18 +308,9 @@ def _gather_one(x, repeat_id): return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights) gathered_weights_stage_dim = 0 - repeat_ids = self.shard_dim_by_stages(repeat_ids, 0, physical_partition_spec=None) - # num_repeats x num_stages x *param_dim - weights = self.shard_dim_by_stages( - weights, stages_dim_in_weights, physical_partition_spec=physical_partition_spec, is_stage_weight=False - ) stage_weights = jax.vmap(_gather_one, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim)( weights, repeat_ids ) - # num_stages x *param_dim - stage_weights = self.shard_dim_by_stages( - stage_weights, gathered_weights_stage_dim, physical_partition_spec=physical_partition_spec, is_stage_weight=True - ) return stage_weights def vmap_gather(self, xs, ids, ids_dim): @@ -346,9 +334,7 @@ def _gather_one(x, i): replicated_sharding = NamedSharding(self.mesh, P()) return x.at[idx].get(out_sharding=replicated_sharding) - ids = self.shard_dim_by_stages(ids, 0, physical_partition_spec=None) - outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) - return self.shard_dim_by_stages(outs, 0, physical_partition_spec=None) + return jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) def get_new_loop_state(self, output, loop_state): """ @@ -452,6 +438,7 @@ def _shift_left(arr, stage_size, output): mesh=self.mesh, in_specs=(self.state_io_spec, self.stages_in_spec, self.stages_in_spec, P()), out_specs=self.state_io_spec, + check_vma=True, ) def _update_state_io(state_in, stream_slice, output, stream_buf_idx): # Shift the current slice to the left, then fill the last stage with the final output. @@ -469,6 +456,8 @@ def _update_state_io(state_in, stream_slice, output, stream_buf_idx): "circ_storage_mover": new_circ_storage_mover, "loop_iteration": loop_iteration + 1, "prev_outputs": new_prev_outputs, + "bsw": loop_state["bsw"], # bsw is updated outside of this inner loop, only once per outer loop iteration + "weights": loop_state["weights"], # Pass weights through } return new_loop_state @@ -483,28 +472,52 @@ def permute_output_micro_per_stage_dim(self, output): output = output[:, permutation] return output - def get_current_stage_weights(self, pipeline_weights, loop_iteration, physical_partition_spec=None): + def get_current_stage_weights( + self, pipeline_weights, bsw, loop_iteration, physical_partition_spec=None, is_initializing=None + ): """ Gets the current weights used for one iteration. Outputs a pytree whose arrays have leading dimension of stages, e.g. {'mlp': 'wo': [stages, mlp, embed]}. Stage 0 will use the 0th index of this pytree, Stage 1 the 1st index, etc. For non-circular pipelines, this simply returns all weights - every weight is used in every iteraiton. However for circular pipelines each stage grabs only the weights corresponding to the current repeat. """ + pipeline_weights = self._remove_logically_partition(pipeline_weights) if self.config.num_pipeline_repeats > 1: - return self.get_current_repeat_from_stages( - pipeline_weights, loop_iteration, physical_partition_spec=physical_partition_spec + pipeline_weights = self.get_current_weights_from_bsw( + bsw, loop_iteration, physical_partition_spec=physical_partition_spec, is_initializing=is_initializing ) - else: - return pipeline_weights + return pipeline_weights - def get_current_repeat_from_stages(self, weights, loop_iteration, physical_partition_spec=None): - """get current repeat from stages""" + def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_spec, is_initializing=None): + """Collect and gather weights from given bsw (buffer sliding window)""" + bsw_pps = jax.tree.map(self._remove_fsdp_from_physical_partition_spec, physical_partition_spec) _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + target_repeat_id = repeat_ids[0] + + @jax.shard_map( + mesh=self.mesh, + in_specs=((bsw_pps, bsw_pps), P("stage")), + out_specs=(bsw_pps), + check_vma=True, + ) + def select_weights_from_bsw(bsw, repeat_id): + weights = jax.tree.map( + lambda x, y: jax.lax.select(repeat_id[0] == target_repeat_id, y, x), + bsw[0], + bsw[1], + ) + + return weights + + weights = select_weights_from_bsw(bsw, repeat_ids) + + if is_initializing is None: + is_initializing = self.is_initializing() circular_metadata_params = { nn.PARTITION_NAME: "circular_repeats", "sub_weight_split_dims_mapping": (None,), - "is_initializing": self.is_initializing(), + "is_initializing": is_initializing, "x_times": self.config.num_pipeline_repeats, "optimizer_dims_mapping": None, } @@ -512,22 +525,91 @@ def get_current_repeat_from_stages(self, weights, loop_iteration, physical_parti weights, 0, circular_metadata_params ) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one circular # entry per stage. - weights = self._remove_logically_partition(weights) - def gather_weights_for_stages_in(w, spec=None): + return weights + + @staticmethod + def get_fsdp_index_pytree(physical_partition_spec): + """ + Finds the index of 'fsdp' within each PartitionSpec in a Pytree. + + Args: + physical_partition_spec: A Pytree where leaves are PartitionSpecs. + + Returns: + A Pytree of the same structure where leaves are the integer index + of 'fsdp' or -1 if not found. + """ + + def find_fsdp(pspec): + # Ensure we are handling a PartitionSpec or a tuple/list of strings + if pspec is None: + return -1 + + # PartitionSpecs are essentially tuples (e.g., PartitionSpec('data', 'fsdp')) + for i, axis in enumerate(pspec): + # Handle cases where an axis might be a tuple itself (e.g., ('fsdp', 'tensor')) + if isinstance(axis, (list, tuple)): + if "fsdp" in axis: + return i + elif axis == "fsdp": + return i + return -1 + + return jax.tree.map(find_fsdp, physical_partition_spec) + + def from_all_variables_to_repeat_weights(self, weights, loop_iteration, physical_partition_spec): + """Generate one single repeat weight from all variables.""" + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + + def gather_weights_for_stages_in(w, spec): return self.vmap_parallel_gather( - w, - repeat_ids=repeat_ids, - repeat_dim_in_weights=0, - stages_dim_in_weights=1, - physical_partition_spec=spec, + w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec ) + weights = self._remove_logically_partition(weights) if physical_partition_spec is None: weights = jax.tree.map(gather_weights_for_stages_in, weights) else: weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec) - return weights + circular_metadata_params = { + nn.PARTITION_NAME: "circular_repeats", + "sub_weight_split_dims_mapping": (None,), + "is_initializing": self.is_initializing(), + "x_times": self.config.num_pipeline_repeats, + "optimizer_dims_mapping": None, + } + repeat_weights = meta.remove_axis(weights, 0, circular_metadata_params) + return repeat_weights + + def from_all_variables_to_bsw(self, weights, loop_iteration, physical_partition_spec): + """All gather one branch of bsw using shardmap.""" + repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration, physical_partition_spec) + bsw_pps = self._generate_bsw_pps_from_pps(physical_partition_spec) + repeat_weights_pps = jax.tree.map(lambda p: P(*p[1:]), physical_partition_spec) + fsdp_idx = self.get_fsdp_index_pytree(physical_partition_spec) + + @jax.shard_map( + mesh=self.mesh, + in_specs=(repeat_weights_pps, None), + out_specs=bsw_pps, + check_vma=True, + ) + def _all_gather_inner(sharded_weights, fsdp_idx): + def _all_gather_invariant(x, i): + if i >= 0: + return all_gather_invariant(x, axis_name="fsdp", axis=i - 1, tiled=True) + return x + + return jax.tree.map(_all_gather_invariant, sharded_weights, fsdp_idx) + + return _all_gather_inner(repeat_weights, fsdp_idx) + + def bsw_all_gather_over_fsdp(self, weights, physical_partition_spec, loop_iteration): + """All gather all bsw over fsdp mesh axis using shardmap.""" + bsw_0 = self.from_all_variables_to_bsw(weights, loop_iteration, physical_partition_spec) + bsw_1 = self.from_all_variables_to_bsw(weights, loop_iteration + 1, physical_partition_spec) + return jax.ad_checkpoint.checkpoint_name((bsw_0, bsw_1), "bsw") def get_vmap_func_for_init(self): """This vmap func is used to initialize the weights only on init.""" @@ -592,13 +674,14 @@ def func_to_vmap( def run_one_iteration( self, loop_state, - pipeline_weights, positions, segment_ids, deterministic, model_mode, decoder_layer_instance, - logical_partition_spec=None, + logical_partition_spec, + vmap_func=None, + is_initializing=None, ): """Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel, and update the loop state.""" @@ -606,6 +689,7 @@ def run_one_iteration( shift = loop_state["shift"] circ_storage = loop_state["circ_storage"] loop_iteration = loop_state["loop_iteration"] + pipeline_weights = loop_state["weights"] microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration) @@ -619,49 +703,15 @@ def run_one_iteration( stages_positions = self.vmap_gather(positions, microbatch_ids, 0) if positions is not None else None stages_segment_ids = self.vmap_gather(segment_ids, microbatch_ids, 0) if segment_ids is not None else None - vmap_func = self.get_main_vmap_func_for_iterations() - - if self.config.num_pipeline_repeats > 1: - _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) - - def prepare_vars_for_main_vmap(weights, physical_partition_spec=None): - - circular_metadata_params = { - nn.PARTITION_NAME: "circular_repeats", - "sub_weight_split_dims_mapping": (None,), - "is_initializing": self.is_initializing(), - "x_times": self.config.num_pipeline_repeats, - "optimizer_dims_mapping": None, - } - weights = meta.remove_axis( - weights, 0, circular_metadata_params - ) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one - # circular entry per stage. - weights = self._remove_logically_partition(weights) - - def gather_weights_for_stages_in(w, spec=None): - return self.vmap_parallel_gather( - w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec - ) - - if physical_partition_spec is None: - weights = jax.tree.map(gather_weights_for_stages_in, weights) - else: - weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec) - return weights - - prepare_vars_for_main_vmap_partial = functools.partial( - prepare_vars_for_main_vmap, physical_partition_spec=physical_partition_spec - ) - vmap_func = nn.map_variables( - vmap_func, - mapped_collections=["params", "_overwrite_with_gradient", "non_trainable", "summaries", "intermediates"], - mutable=True, - trans_in_fn=prepare_vars_for_main_vmap_partial, - ) + if vmap_func is None: + vmap_func = self.get_main_vmap_func_for_iterations() stage_weights = self.get_current_stage_weights( - pipeline_weights, loop_iteration, physical_partition_spec=physical_partition_spec + pipeline_weights, + loop_state["bsw"], + loop_iteration, + physical_partition_spec=physical_partition_spec, + is_initializing=is_initializing, ) stages_output = vmap_func( @@ -761,17 +811,11 @@ def _remove_fsdp_from_physical_partition_spec(pps): return P(*new_spec) return pps - def all_gather_over_fsdp(self, variables, logical_partition_spec): - physical_partition_spec = logical_to_mesh( - logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules - ) - physical_partition_spec_no_fsdp = jax.tree.map( - self._remove_fsdp_from_physical_partition_spec, physical_partition_spec - ) + def _generate_bsw_pps_from_pps(self, physical_partition_spec): + """Create bsw physical partition spec from weight physical partition spec.""" return jax.tree.map( - lambda w, p: self._maybe_shard_with_name(w, NamedSharding(self.mesh, p)), - variables, - physical_partition_spec_no_fsdp, + lambda pps: P(*self._remove_fsdp_from_physical_partition_spec(pps)[1:]), + physical_partition_spec, ) @nn.compact @@ -825,6 +869,9 @@ def __call__( segment_idx = None loop_state = self.init_states(inputs) + physical_partition_spec = logical_to_mesh( + logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules + ) # Each microbatch should go through each stage (with repeats) - so there is num_micro * (num_stages * repeats) # compute to perform @@ -836,8 +883,8 @@ def __call__( # Thus the total iterations is num_micro * repeat + num_stages - 1, & we may consider the num_stages - 1 as bubble. # The bubble doubles when we use forwarding delay. bubble_iterations = self.forwarding_delay * (self.num_stages - 1) - real_iterations = self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats - total_iterations = real_iterations + bubble_iterations + # real_iterations = self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats + # total_iterations = real_iterations + bubble_iterations if self.is_initializing(): vmap_func = self.get_vmap_func_for_init() @@ -899,21 +946,14 @@ def __call__( out_sharding=self.output_sharding, ) - if self.config.pipeline_fsdp_ag_once: - variables = self._remove_logically_partition(self.layers.variables) - all_pipeline_weights = self.all_gather_over_fsdp(variables, logical_partition_spec) - else: - all_pipeline_weights = self.layers.variables - logical_partition_spec = self.get_logical_spec_repeats_removed(logical_partition_spec) - def run_iteration_scannable(model, loop_state, xs): + def run_iteration_scannable(model, loop_state): # flax transforms like nn.scan and nn.remat can only be applied to nn.module classes or nn.module instances, so we # explicitly wrap the run_one_iteration in this method - the 1st argument model (`self`) is a nn.module instance. return ( model.run_one_iteration( loop_state, - all_pipeline_weights, positions, segment_ids, deterministic, @@ -927,39 +967,261 @@ def run_iteration_scannable(model, loop_state, xs): if self.config.set_remat_policy_on_pipeline_iterations: run_iteration_scannable = nn.remat( run_iteration_scannable, - prevent_cse=not self.config.scan_pipeline_iterations, # prevent_cse not used with scan + prevent_cse=not self.config.scan_pipeline_iterations, policy=self.get_pipeline_remat_policy(), ) - # The scan cannot be used on init since it broadcasts the weights, which aren't yet initialized. - if self.config.scan_pipeline_iterations: - variable_carry = [] - variable_broadcast = [ - "params", - "_overwrite_with_gradient", - ] # All loop iterations need the weights for the full pipeline. - if self.is_mutable_collection("non_trainable"): - variable_carry.append("non_trainable") + def run_one_repeat_scannable(model, loop_state): + loop_state["bsw"] = model.bsw_all_gather_over_fsdp( + loop_state["weights"], physical_partition_spec, loop_state["loop_iteration"] + ) + + if model.config.scan_pipeline_iterations: + run_one_repeat_scanned = nn.scan( + run_iteration_scannable, + variable_axes={ + "summaries": 0, + "aux_loss": 0, + "intermediates": 0, + "hyper_params": 0, + }, + # Dropout/aqt keys will be split for each iteration. + split_rngs={"random": True}, + length=model.config.num_pipeline_microbatches, + ) + + @functools.partial(jax.custom_vjp) + def run_one_repeat_scanned_custom(loop_state, positions, segment_ids): + final_state, _ = run_one_repeat_scanned(model, loop_state) + return final_state + + def run_one_repeat_scanned_custom_fwd(loop_state, positions, segment_ids): + final_state, _ = run_one_repeat_scanned(model, loop_state) + # We return loop_state as residual. model is passed to bwd as arg. + return final_state, ( + loop_state, + positions, + segment_ids, + ) + + def run_one_repeat_scanned_custom_bwd(residuals, g_final_state): + init_loop_state, positions, segment_ids = residuals + + # Re-run forward pass to get saved states (checkpointing) + def scan_body_fwd(carry, _): + new_state = model.run_one_iteration( + carry, + positions, + segment_ids, + deterministic, + model_mode, + model.layers, + logical_partition_spec=logical_partition_spec, + ) + # Return lightweight state for saving (exclude bsw/weights) + saved = {k: v for k, v in carry.items() if k not in ["bsw", "weights"]} + return new_state, saved + + _, saved_states = jax.lax.scan( + scan_body_fwd, + init_loop_state, + None, + length=model.config.num_pipeline_microbatches, + ) + + # Backward scan to accumulate gradients + def scan_body_bwd(carry, saved_slice): + d_next_state = carry + + # Reconstruct current loop_state (input to step) + curr_loop_state = { + **saved_slice, + "bsw": init_loop_state["bsw"], + "weights": init_loop_state["weights"], + } + + # Define function to differentiate w.r.t loop_state + def step_fn(s): + out = model.run_one_iteration( + s, + positions, + segment_ids, + deterministic, + model_mode, + model.layers, + logical_partition_spec=logical_partition_spec, + ) + return out + + _, vjp_fun = jax.vjp(step_fn, curr_loop_state) + + # Backprop d_next_state + (d_curr_state,) = vjp_fun(d_next_state) + + return d_curr_state, None + + # Run backward scan + d_init_state, _ = jax.lax.scan(scan_body_bwd, g_final_state, saved_states, reverse=True) + + return (d_init_state, None, None) + + run_one_repeat_scanned_custom.defvjp(run_one_repeat_scanned_custom_fwd, run_one_repeat_scanned_custom_bwd) + + loop_state = run_one_repeat_scanned_custom(loop_state, positions, segment_ids) else: - variable_broadcast.append("non_trainable") - run_all_iterations_scanned = nn.scan( - run_iteration_scannable, - variable_axes={ - "summaries": 0, - "aux_loss": 0, - "intermediates": 0, - "hyper_params": 0, - }, - variable_broadcast=variable_broadcast, - variable_carry=variable_carry, - # Dropout/aqt keys will be split for each iteration. - split_rngs={"random": True}, - length=total_iterations, + for _ in range(model.config.num_pipeline_microbatches): + loop_state, _ = run_iteration_scannable(model, loop_state) + return loop_state, None + + run_one_repeat_scannable = nn.remat( + run_one_repeat_scannable, + prevent_cse=not self.config.scan_pipeline_iterations, + policy=self.get_pipeline_remat_policy(), + ) + + def run_bubbles_scannable(model, loop_state): + loop_state["bsw"] = model.bsw_all_gather_over_fsdp( + loop_state["weights"], physical_partition_spec, loop_state["loop_iteration"] ) - loop_state, _ = run_all_iterations_scanned(self, loop_state, None) - else: - for _ in range(total_iterations): - loop_state, _ = run_iteration_scannable(self, loop_state, None) + + if model.config.scan_pipeline_iterations: + run_one_repeat_scanned = nn.scan( + run_iteration_scannable, + variable_axes={ + "summaries": 0, + "aux_loss": 0, + "intermediates": 0, + "hyper_params": 0, + }, + # Dropout/aqt keys will be split for each iteration. + split_rngs={"random": True}, + length=bubble_iterations, + ) + + @functools.partial(jax.custom_vjp) + def run_one_repeat_scanned_custom(loop_state, positions, segment_ids): + final_state, _ = run_one_repeat_scanned(model, loop_state) + return final_state + + def run_one_repeat_scanned_custom_fwd(loop_state, positions, segment_ids): + final_state, _ = run_one_repeat_scanned(model, loop_state) + # We return loop_state as residual. model is passed to bwd as arg. + return final_state, ( + loop_state, + positions, + segment_ids, + ) + + def run_one_repeat_scanned_custom_bwd(residuals, g_final_state): + init_loop_state, positions, segment_ids = residuals + + # Re-run forward pass to get saved states (checkpointing) + def scan_body_fwd(carry, _): + new_state = model.run_one_iteration( + carry, + positions, + segment_ids, + deterministic, + model_mode, + model.layers, + logical_partition_spec=logical_partition_spec, + ) + # Return lightweight state for saving (exclude bsw/weights) + saved = {k: v for k, v in carry.items() if k not in ["bsw", "weights"]} + return new_state, saved + + _, saved_states = jax.lax.scan( + scan_body_fwd, + init_loop_state, + None, + length=bubble_iterations, + ) + + # Backward scan to accumulate gradients + def scan_body_bwd(carry, saved_slice): + d_next_state = carry + + # Reconstruct current loop_state (input to step) + curr_loop_state = { + **saved_slice, + "bsw": init_loop_state["bsw"], + "weights": init_loop_state["weights"], + } + + # Define function to differentiate w.r.t loop_state + def step_fn(s): + out = model.run_one_iteration( + s, + positions, + segment_ids, + deterministic, + model_mode, + model.layers, + logical_partition_spec=logical_partition_spec, + ) + return out + + _, vjp_fun = jax.vjp(step_fn, curr_loop_state) + + # Backprop d_next_state + (d_curr_state,) = vjp_fun(d_next_state) + + return d_curr_state, None + + # Run backward scan + d_init_state, _ = jax.lax.scan(scan_body_bwd, g_final_state, saved_states, reverse=True) + + return (d_init_state, None, None) + + run_one_repeat_scanned_custom.defvjp(run_one_repeat_scanned_custom_fwd, run_one_repeat_scanned_custom_bwd) + + loop_state = run_one_repeat_scanned_custom(loop_state, positions, segment_ids) + else: + for _ in range(model.config.num_pipeline_microbatches): + loop_state, _ = run_iteration_scannable(model, loop_state) + return loop_state, None + + run_bubbles_scannable = nn.remat( + run_bubbles_scannable, + prevent_cse=not self.config.scan_pipeline_iterations, + policy=self.get_pipeline_remat_policy(), + ) + + def run_all_iterations(model, loop_state): + if self.config.scan_pipeline_repeats: + run_repeats_scanned = nn.scan( + run_one_repeat_scannable, + variable_axes={ + "summaries": 0, + "aux_loss": 0, + "intermediates": 0, + "hyper_params": 0, + }, + split_rngs={"random": True}, + length=model.config.num_pipeline_repeats, + ) + + run_bubbles_scanned = nn.scan( + run_bubbles_scannable, + variable_axes={ + "summaries": 0, + "aux_loss": 0, + "intermediates": 0, + "hyper_params": 0, + }, + split_rngs={"random": True}, + length=1, + ) + loop_state, _ = run_repeats_scanned(model, loop_state) + loop_state, _ = run_bubbles_scanned(model, loop_state) + else: + for _ in range(model.config.num_pipeline_repeats): # remat and scan outer loop + loop_state, _ = run_one_repeat_scannable(model, loop_state) + for _ in range(bubble_iterations): + loop_state, _ = run_iteration_scannable(model, loop_state) + return loop_state + + loop_state = run_all_iterations(self, loop_state) # The final output is located in the input/output array, however the output microbatches may be permuted relative to # the input diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 632f3345b7..1f9a0a0854 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -290,6 +290,7 @@ pipeline_fsdp_ag_once: False # If set to true then all gather all of the weights # It may be useful to do the reverse when the layers_per_stage is very large. # The below settings only have effect when using pipeline parallelism. scan_pipeline_iterations: True +scan_pipeline_repeats: True scan_layers_per_stage: False set_remat_policy_on_pipeline_iterations: True set_remat_policy_on_layers_per_stage: False @@ -900,7 +901,7 @@ xprof_e2e_enable_fw_throttle_event: False xprof_e2e_enable_fw_power_level_event: False xprof_e2e_enable_fw_thermal_event: False -log_config: True # Prints the config (after defaults have been set by pyconfig logic) +log_config: False # Prints the config (after defaults have been set by pyconfig logic) debug_sharding: False # Prints model weights sharding info # Checkpoint Structured logging diff --git a/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml b/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml index c68c813b01..c137d94c98 100644 --- a/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml +++ b/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml @@ -56,19 +56,21 @@ rope_truncate: True rope_attention_scaling: False override_logical_axis_rules: True -mesh_axes: ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context'] -data_sharding: [['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']] +mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context'] +data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], - ['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], + ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], ['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_norm_length', ['context']], ['activation_heads', []], + ['activation_stage', 'stage'], ['embed', ['fsdp']], ['embed_no_exp', ['fsdp']], ['q_lora', ['fsdp']], ['kv_lora', ['fsdp']], + ['layers', 'stage'], ['q_lora_up_proj', ['fsdp_transpose', 'expert']], ['kv_lora_up_proj', ['fsdp_transpose', 'expert']], ['q_heads', ['fsdp_transpose', 'expert']], diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index e651293a19..0924742f6c 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -842,6 +842,7 @@ class PipelineParallelism(BaseModel): ) pipeline_fsdp_ag_once: bool = Field(False, description="If True, all-gather FSDP weights once per pipeline repeat.") scan_pipeline_iterations: bool = Field(True, description="Use jax.lax.scan over pipeline iterations.") + scan_pipeline_repeats: bool = Field(True, description="Use jax.lax.scan over pipeline repeats.") scan_layers_per_stage: bool = Field(False, description="Use jax.lax.scan over layers within a stage.") set_remat_policy_on_pipeline_iterations: bool = Field(True, description="Set remat policy on the pipeline scan.") set_remat_policy_on_layers_per_stage: bool = Field(False, description="Set remat policy on the inner layer scan.") @@ -2399,75 +2400,75 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de # I. FINAL TYPE CONVERSIONS AND DERIVED LISTS # Create the ici_parallelism and dcn_parallelism lists for legacy compatibility. - if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage": - self.ici_parallelism = [ - self.ici_diloco_parallelism, - self.ici_pipeline_parallelism, - self.ici_data_parallelism, - self.ici_fsdp_parallelism, - self.ici_fsdp_transpose_parallelism, - self.ici_sequence_parallelism, - self.ici_context_parallelism, - self.ici_context_autoregressive_parallelism, - self.ici_tensor_parallelism, - self.ici_tensor_transpose_parallelism, - self.ici_tensor_sequence_parallelism, - self.ici_expert_parallelism, - self.ici_autoregressive_parallelism, - ] - self.dcn_parallelism = [ - self.dcn_diloco_parallelism, - self.dcn_pipeline_parallelism, - self.dcn_data_parallelism, - self.dcn_fsdp_parallelism, - self.dcn_fsdp_transpose_parallelism, - self.dcn_sequence_parallelism, - self.dcn_context_parallelism, - self.dcn_context_autoregressive_parallelism, - self.dcn_tensor_parallelism, - self.dcn_tensor_transpose_parallelism, - self.dcn_tensor_sequence_parallelism, - self.dcn_expert_parallelism, - self.dcn_autoregressive_parallelism, - ] - else: - ici_map = { - "diloco": self.ici_diloco_parallelism, - "data": self.ici_data_parallelism, - "stage": self.ici_pipeline_parallelism, - "fsdp": self.ici_fsdp_parallelism, - "fsdp_transpose": self.ici_fsdp_transpose_parallelism, - "sequence": self.ici_sequence_parallelism, - "context": self.ici_context_parallelism, - "context_autoregressive": self.ici_context_autoregressive_parallelism, - "tensor": self.ici_tensor_parallelism, - "tensor_transpose": self.ici_tensor_transpose_parallelism, - "tensor_sequence": self.ici_tensor_sequence_parallelism, - "model": self.ici_tensor_parallelism, - "expert": self.ici_expert_parallelism, - "autoregressive": self.ici_autoregressive_parallelism, - "attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads - } - self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes] - - dcn_map = { - "diloco": self.dcn_diloco_parallelism, - "data": self.dcn_data_parallelism, - "stage": self.dcn_pipeline_parallelism, - "fsdp": self.dcn_fsdp_parallelism, - "fsdp_transpose": self.dcn_fsdp_transpose_parallelism, - "sequence": self.dcn_sequence_parallelism, - "context": self.dcn_context_parallelism, - "context_autoregressive": self.dcn_context_autoregressive_parallelism, - "tensor": self.dcn_tensor_parallelism, - "tensor_transpose": self.dcn_tensor_transpose_parallelism, - "tensor_sequence": self.dcn_tensor_sequence_parallelism, - "model": self.dcn_tensor_parallelism, - "expert": self.dcn_expert_parallelism, - "autoregressive": self.dcn_autoregressive_parallelism, - "attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads - } - self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes] + # if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage": + # self.ici_parallelism = [ + # self.ici_diloco_parallelism, + # self.ici_pipeline_parallelism, + # self.ici_data_parallelism, + # self.ici_fsdp_parallelism, + # self.ici_fsdp_transpose_parallelism, + # self.ici_sequence_parallelism, + # self.ici_context_parallelism, + # self.ici_context_autoregressive_parallelism, + # self.ici_tensor_parallelism, + # self.ici_tensor_transpose_parallelism, + # self.ici_tensor_sequence_parallelism, + # self.ici_expert_parallelism, + # self.ici_autoregressive_parallelism, + # ] + # self.dcn_parallelism = [ + # self.dcn_diloco_parallelism, + # self.dcn_pipeline_parallelism, + # self.dcn_data_parallelism, + # self.dcn_fsdp_parallelism, + # self.dcn_fsdp_transpose_parallelism, + # self.dcn_sequence_parallelism, + # self.dcn_context_parallelism, + # self.dcn_context_autoregressive_parallelism, + # self.dcn_tensor_parallelism, + # self.dcn_tensor_transpose_parallelism, + # self.dcn_tensor_sequence_parallelism, + # self.dcn_expert_parallelism, + # self.dcn_autoregressive_parallelism, + # ] + # else: + ici_map = { + "diloco": self.ici_diloco_parallelism, + "data": self.ici_data_parallelism, + "stage": self.ici_pipeline_parallelism, + "fsdp": self.ici_fsdp_parallelism, + "fsdp_transpose": self.ici_fsdp_transpose_parallelism, + "sequence": self.ici_sequence_parallelism, + "context": self.ici_context_parallelism, + "context_autoregressive": self.ici_context_autoregressive_parallelism, + "tensor": self.ici_tensor_parallelism, + "tensor_transpose": self.ici_tensor_transpose_parallelism, + "tensor_sequence": self.ici_tensor_sequence_parallelism, + "model": self.ici_tensor_parallelism, + "expert": self.ici_expert_parallelism, + "autoregressive": self.ici_autoregressive_parallelism, + "attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads + } + self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes] + + dcn_map = { + "diloco": self.dcn_diloco_parallelism, + "data": self.dcn_data_parallelism, + "stage": self.dcn_pipeline_parallelism, + "fsdp": self.dcn_fsdp_parallelism, + "fsdp_transpose": self.dcn_fsdp_transpose_parallelism, + "sequence": self.dcn_sequence_parallelism, + "context": self.dcn_context_parallelism, + "context_autoregressive": self.dcn_context_autoregressive_parallelism, + "tensor": self.dcn_tensor_parallelism, + "tensor_transpose": self.dcn_tensor_transpose_parallelism, + "tensor_sequence": self.dcn_tensor_sequence_parallelism, + "model": self.dcn_tensor_parallelism, + "expert": self.dcn_expert_parallelism, + "autoregressive": self.dcn_autoregressive_parallelism, + "attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads + } + self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes] # Diloco params self.num_diloco_replicas = int(self.ici_diloco_parallelism * self.dcn_diloco_parallelism) diff --git a/tests/unit/pipeline_parallelism_test.py b/tests/unit/pipeline_parallelism_test.py index 98e49d5050..816eeb83e3 100644 --- a/tests/unit/pipeline_parallelism_test.py +++ b/tests/unit/pipeline_parallelism_test.py @@ -213,7 +213,7 @@ def test_circular_minimum_microbatches_same_output_and_grad(self): run_name="circular_minimum_microbatches", max_target_length=128, base_emb_dim=28, - ici_pipeline_parallelism=4, + ici_pipeline_parallelism=2, base_num_decoder_layers=8, num_pipeline_microbatches=4, per_device_batch_size=4, @@ -230,7 +230,7 @@ def test_circular_extra_microbatches_same_output_and_grad(self): run_name="circular_extra_microbatches", max_target_length=128, base_emb_dim=28, - ici_pipeline_parallelism=4, + ici_pipeline_parallelism=2, base_num_decoder_layers=8, num_pipeline_microbatches=8, per_device_batch_size=4, @@ -247,7 +247,7 @@ def test_circular_deepseek_megablox_same_output_and_grad(self): run_name="circular_moe", max_target_length=128, base_emb_dim=28, - ici_pipeline_parallelism=4, + ici_pipeline_parallelism=2, base_num_decoder_layers=8, num_pipeline_microbatches=8, per_device_batch_size=4, @@ -287,7 +287,7 @@ def test_non_circular_same_output_and_grad(self): run_name="non_circular", max_target_length=128, base_emb_dim=28, - ici_pipeline_parallelism=4, + ici_pipeline_parallelism=2, base_num_decoder_layers=4, num_pipeline_microbatches=4, per_device_batch_size=4, @@ -336,7 +336,7 @@ def test_delay_activation_forwarding_same_output_and_grad(self): run_name="activation_forwarding", max_target_length=128, base_emb_dim=28, - ici_pipeline_parallelism=4, + ici_pipeline_parallelism=2, base_num_decoder_layers=8, num_pipeline_microbatches=8, per_device_batch_size=4,