Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions src/MaxText/generate_param_only_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
The output "parameter state" is output to the checkpoint directory. Additionally it is cast down to bf16.
"""

import functools
import os.path
from typing import Sequence

Expand All @@ -31,8 +32,8 @@
from jax.sharding import Mesh
from MaxText import optimizers
from MaxText import pyconfig
from maxtext.common import checkpointing
from MaxText.common_types import DecoderBlockType, MODEL_MODE_TRAIN
from maxtext.common import checkpointing
from maxtext.layers import quantizations
from maxtext.models import models
from maxtext.utils import gcs_utils
Expand All @@ -41,8 +42,6 @@
from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils

Transformer = models.transformer_as_linen


def _possibly_unroll_params(config, training_state, training_state_annotations, mesh):
"""Unroll scanned input layers when force_unroll is set."""
Expand Down Expand Up @@ -92,12 +91,20 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh):
"""Read training checkpoint at path defined by load_full_state_path."""
# Model and Optimizer definition
quant = quantizations.configure_quantization(config)
model = Transformer(config, mesh, quant, MODEL_MODE_TRAIN)
if config.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN)
rng = random.PRNGKey(0)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
tx = optimizers.get_optimizer(config, learning_rate_schedule)
if config.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng)
state, state_mesh_notations, _, _ = maxtext_utils.setup_training_state(
model, None, tx, config, rng, mesh, checkpoint_manager
None, config, mesh, checkpoint_manager, init_state_fn
)
num_params = max_utils.calculate_num_params_from_pytree(state.params)
max_logging.log(f"In input checkpoint Number of model params={num_params/1e9:.3f} billion")
Expand All @@ -108,7 +115,10 @@ def _generate_lora_decode_checkpoints(config, mesh):
"""Read lora checkpoints checkpoint at path defined by load_full_state_path."""
# Model and Optimizer definition
quant = quantizations.configure_quantization(config)
model = Transformer(config, mesh, quant, MODEL_MODE_TRAIN)
if config.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN)
rng = random.PRNGKey(0)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
tx = optimizers.get_optimizer(config, learning_rate_schedule)
Expand Down
36 changes: 33 additions & 3 deletions src/MaxText/gradient_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import jax.numpy as jnp
from jax.sharding import NamedSharding

from flax import nnx

from MaxText.common_types import ShardMode
from MaxText.sharding import maybe_shard_with_name

Expand Down Expand Up @@ -49,7 +51,8 @@ def gradient_accumulation_loss_and_grad(
config: Model and training configuration object. Must contain
`gradient_accumulation_steps` and `shard_optimizer_over_data`.
model: The model module.
params: The model parameters (PyTree).
params: The model parameters (PyTree). This is only used for Linen. For NNX,
we can get the params from the model.
params_shardings: The sharding constraints for the parameters (PyTree).
data: A PyTree of batched data. The leading dimension is assumed
to be the total batch size (microbatch_size * num_accumulations).
Expand All @@ -67,12 +70,18 @@ def _maybe_shard_with_name(inputs, sharding_names):
"""Wrapper of maybe_shard_with_name with fixed shard_mode"""
return maybe_shard_with_name(inputs, sharding_names, config.shard_mode, debug_sharding=config.debug_sharding)

is_nnx = isinstance(model, nnx.Module)

# For more efficient DP/ZeRO-1 + GA
if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1:
ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings)
grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings)
else:
ga_params_shardings = grad_shardings = params_shardings

if is_nnx:
graphdef, params, rest = nnx.split(model, nnx.Param, ...)

# When using Zero-1 optimizer sharding, cast params to lower precision and apply sharding constraints
# so that all-gather is done once in the lower precision before the gradient accumulation loop
if config.shard_optimizer_over_data:
Expand All @@ -87,11 +96,27 @@ def convert_to_bf16(param):
ga_params = params

ga_params = jax.tree.map(_maybe_shard_with_name, ga_params, ga_params_shardings)
grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True)
if is_nnx:
grad_func = nnx.value_and_grad(_loss_fn, argnums=0, has_aux=True)
else:
grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True)

def accumulate_gradient(acc_grad_and_loss, data):
ga_params = acc_grad_and_loss["ga_params"]
(_, aux), cur_batch_gradient = grad_func(model, config, data, dropout_rng, ga_params, *extra_dpo_args, is_train=True)
if is_nnx:
# Reconstruct the model using the fixed parameters (ga_params)
# and the advancing non-parameter state (RNGs) from the carry.
local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"])
(_, aux), cur_batch_gradient = grad_func(local_model, config, data, None, None, *extra_dpo_args, is_train=True)
_, _, next_rest_state = nnx.split(local_model, nnx.Param, ...)
acc_grad_and_loss["rest_state"] = next_rest_state
else:
rng = (
jax.random.fold_in(dropout_rng, acc_grad_and_loss["total_weights"].astype(jnp.int32))
if dropout_rng is not None
else None
)
(_, aux), cur_batch_gradient = grad_func(model, config, data, rng, ga_params, *extra_dpo_args, is_train=True)
acc_grad_and_loss["loss"] += aux["total_loss"]
acc_grad_and_loss["moe_lb_loss"] += aux["moe_lb_loss"]
acc_grad_and_loss["mtp_loss"] += aux["mtp_loss"]
Expand All @@ -117,6 +142,8 @@ def reshape_to_microbatch_accumulations(batch_arr):
"mtp_loss": 0.0,
"ga_params": ga_params,
}
if is_nnx:
init_grad_and_loss["rest_state"] = rest

grad_and_loss, aux = jax.lax.scan(
accumulate_gradient, init_grad_and_loss, data, length=config.gradient_accumulation_steps
Expand All @@ -131,6 +158,9 @@ def reshape_to_microbatch_accumulations(batch_arr):
raw_grads = jax.tree_util.tree_map(lambda arr: arr / grad_and_loss["total_weights"], raw_grads)
aux = jax.tree.map(lambda x: jnp.sum(x, axis=0), aux) # pytype: disable=module-attr

if is_nnx:
nnx.update(model, grad_and_loss["rest_state"])

return loss, aux, raw_grads


Expand Down
20 changes: 14 additions & 6 deletions src/MaxText/layerwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

"""

import functools
import os
from typing import Any, Sequence

Expand Down Expand Up @@ -174,12 +175,19 @@ def __init__(self, config: Any, rng: PRNGKeyType):

# Model and quantization config
self.quant = quantizations.configure_quantization(config)
model = models.transformer_as_linen(
config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN
)
self.unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(
model, None, self.config, self.rng, self._mesh, False
)
if self.config.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
model = models.transformer_as_linen(
config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN
)
if self.config.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, self.config, False, self.rng)

self.unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(self.config, self._mesh, init_state_fn, False)

def load_and_quantize(self) -> None:
"""
Expand Down
79 changes: 76 additions & 3 deletions src/MaxText/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,20 @@
# pylint: disable=line-too-long, disable=bare-except, consider-using-generator
""" Utils that are only interesting to MaxText and sharding related. """

from flax import linen as nn

from collections.abc import Iterable

import jax
from jax.core import Tracer
from jax.sharding import PartitionSpec as P, NamedSharding, reshard

from flax import linen as nn, nnx
import optax

from MaxText.common_types import ShardMode
from MaxText import pyconfig
from maxtext.utils import max_logging
from maxtext.utils import max_utils


_LOGGED_ACTIVATION_SHARDINGS = set()
_LOGGED_LOGICAL_AXES = set()

Expand Down Expand Up @@ -419,6 +418,8 @@ def maybe_update_params_sharding_with_opt(config, state_mesh_shardings):
- updated_state_mesh_shardings: State mesh shardings with updated params field
(unchanged if shard_optimizer_over_data is False)
"""
if config.pure_nnx:
return maybe_update_params_sharding_with_opt_nnx(config, state_mesh_shardings)
prev_params_shardings = state_mesh_shardings.params
if config.shard_optimizer_over_data:
if isinstance(state_mesh_shardings.opt_state, optax.ScaleByAdamState):
Expand All @@ -437,6 +438,78 @@ def maybe_update_params_sharding_with_opt(config, state_mesh_shardings):
return prev_params_shardings, state_mesh_shardings


def maybe_update_params_sharding_with_opt_nnx(
config: pyconfig.HyperParameters, state_mesh_shardings: nnx.State
) -> tuple[nnx.State, nnx.State]:
"""
NNX version of parameter sharding update. Updates parameter sharding configuration
when optimizer state sharding is enabled.
When shard_optimizer_over_data is enabled (Zero-1 style sharding), this function
extracts the optimizer state shardings from the Adam optimizer's first moment (mu)
and merges them with the parameter shardings. This ensures parameter sharding is
consistent with how the optimizer state is distributed across the compute mesh.
Args:
config: Configuration with shard_optimizer_over_data flag.
state_mesh_shardings: The sharding state for a TrainStateNNX container.
Returns:
A tuple of (prev_params_shardings, updated_state_mesh_shardings):
- prev_params_shardings: Original parameter shardings before the update
- updated_state_mesh_shardings: State mesh shardings with updated params field
(unchanged if shard_optimizer_over_data is False)"""
# In TrainStateNNX, parameters are under 'model'
model_shardings = state_mesh_shardings.model
_, prev_params_shardings, _ = nnx.split(model_shardings, nnx.Param, ...)

if config.shard_optimizer_over_data:
sharded_fp32_params = None
# Check if the optimizer has any state at all (stateless optimizers like SGD omit this key)
if "opt_state" in state_mesh_shardings.optimizer:
# Access the optimizer branch to find the optax state
# state_mesh_shardings.optimizer contains the sharding for the nnx.Optimizer
opt_state = state_mesh_shardings.optimizer.opt_state

def find_adam_mu(obj):
# 1. Direct hit on ScaleByAdamState (Linen path or unflattened NNX)
if isinstance(obj, optax.ScaleByAdamState):
return obj.mu

# 2. Check for flattened ScaleByAdamState (nnx.State/dict)
# These nodes contain 'mu', 'nu', and 'count' as keys.
if hasattr(obj, "__getitem__") and "mu" in obj and "nu" in obj:
return obj["mu"]

# 3. Recursive search through containers (nnx.State, dict, list, tuple)
values = None
if hasattr(obj, "values"): # Handles nnx.State and dict
values = obj.values()
elif isinstance(obj, (list, tuple)):
values = obj

if values:
for v in values:
res = find_adam_mu(v)
if res is not None:
return res
return None

sharded_fp32_params = find_adam_mu(opt_state)
if sharded_fp32_params is None:
actual_type = type(state_mesh_shardings.optimizer.get("opt_state", "None"))
raise NotImplementedError(f"Could not find Adam optimizer state in: {actual_type}")

# In NNX, the mu structure matches the model state structure directly.
# We create a new model sharding branch by merging existing with the mu shardings.
updated_model_shardings = dict(prev_params_shardings) | dict(sharded_fp32_params)

# Update the top-level state container
state_mesh_shardings = nnx.State(dict(state_mesh_shardings) | {"model": updated_model_shardings})

return prev_params_shardings, state_mesh_shardings


def logical_axis_rules_pp_act_as_dp(logical_rules):
"""Add stage as a physical axes before data for each rule, so stage acts just like data instead of PP.
This is used when we want to pipeline only a subset of layers, and leave the rest like DP.
Expand Down
Loading
Loading