Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions src/MaxText/layers/multi_token_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@

from typing import Type

from flax import linen as nn
from flax import nnx
import jax
import jax.numpy as jnp
from jax.sharding import Mesh

from flax import linen as nn
from flax import nnx

from MaxText import sharding
from MaxText.common_types import Config, MODEL_MODE_TRAIN
from MaxText.layers.linears import DenseGeneral
from MaxText.layers.normalizations import RMSNorm
from MaxText.layers.decoders import DecoderLayer
from MaxText.layers import nnx_wrappers
from MaxText.globals import EPS
from MaxText.layers import nnx_wrappers
from MaxText.layers.decoders import DecoderLayer
from MaxText.layers.initializers import variable_to_logically_partitioned
from MaxText.layers.linears import DenseGeneral
from MaxText.layers.normalizations import RMSNorm

from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils

Expand Down Expand Up @@ -84,24 +85,24 @@ def __init__(
cfg = self.config

self.embedding_norm = RMSNorm(
num_features=cfg.base_emb_dim,
num_features=cfg.emb_dim,
epsilon=cfg.normalization_layer_epsilon,
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
kernel_axes=("norm",),
rngs=rngs,
)
self.hidden_state_norm = RMSNorm(
num_features=cfg.base_emb_dim,
num_features=cfg.emb_dim,
epsilon=cfg.normalization_layer_epsilon,
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
kernel_axes=("norm",),
rngs=rngs,
)
self.projection_layer = DenseGeneral(
in_features_shape=2 * cfg.base_emb_dim,
out_features_shape=cfg.base_emb_dim,
in_features_shape=2 * cfg.emb_dim,
out_features_shape=cfg.emb_dim,
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
use_bias=False,
Expand All @@ -118,10 +119,11 @@ def __init__(
self.transformer_layer = nnx_wrappers.ToNNX(mtp_transformer_layer, rngs=rngs)

# ToNNX requires explicit initialization with sample inputs for proper parameter setup.
batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=MODEL_MODE_TRAIN)
self.transformer_layer.lazy_init(
inputs=jnp.zeros((1, 1, cfg.base_emb_dim), dtype=cfg.dtype),
inputs=jnp.zeros((batch_size, seq_len, self.config.emb_dim), dtype=self.config.dtype),
decoder_segment_ids=None,
decoder_positions=jnp.zeros((1, 1), dtype=jnp.int32),
decoder_positions=jnp.zeros((batch_size, seq_len), dtype=jnp.int32),
deterministic=True,
model_mode=MODEL_MODE_TRAIN,
)
Expand Down Expand Up @@ -149,6 +151,14 @@ def __call__(
Returns:
Processed hidden state. Shape [batch, seq_len, hidden_size].
"""
target_token_embedding = sharding.maybe_shard_with_logical(
target_token_embedding,
("activation_batch", "activation_length", "activation_embed"),
self.mesh,
self.config.shard_mode,
self.config.logical_axis_rules,
)

embedding_norm = self.embedding_norm(target_token_embedding)
hidden_state_norm = self.hidden_state_norm(prev_hidden_state)
concatenated_features = jnp.concatenate([embedding_norm, hidden_state_norm], axis=-1)
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/multi_token_prediction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ class MultiTokenPredictionBlockTest(unittest.TestCase):
def setUp(self):
super().setUp()
# Conditionally set ici_fsdp_parallelism to match device count in decoupled mode
num_devices = jax.device_count()
extra_args = {"ici_fsdp_parallelism": jax.device_count()} if is_decoupled() else {}
self.cfg = pyconfig.initialize(
[None, get_test_config_path()],
Expand All @@ -215,7 +216,7 @@ def setUp(self):
self.mesh = Mesh(devices_array, self.cfg.mesh_axes)
data_rng, self.init_rng = jax.random.split(self.rng)

self.batch_size, self.seq_len, self.embed_dim = 2, 8, self.cfg.base_emb_dim
self.batch_size, self.seq_len, self.embed_dim = num_devices, 8, self.cfg.base_emb_dim
key1, key2, key3 = jax.random.split(data_rng, 3)
self.main_hidden_state = jax.random.normal(key1, (self.batch_size, self.seq_len, self.embed_dim))
self.input_ids = jax.random.randint(key2, (self.batch_size, self.seq_len), 0, self.cfg.vocab_size)
Expand Down
Loading