diff --git a/src/MaxText/layers/multi_token_prediction.py b/src/MaxText/layers/multi_token_prediction.py index 7a407721e4..f6c25a1e70 100644 --- a/src/MaxText/layers/multi_token_prediction.py +++ b/src/MaxText/layers/multi_token_prediction.py @@ -16,20 +16,21 @@ from typing import Type +from flax import linen as nn +from flax import nnx import jax import jax.numpy as jnp from jax.sharding import Mesh -from flax import linen as nn -from flax import nnx - +from MaxText import sharding from MaxText.common_types import Config, MODEL_MODE_TRAIN -from MaxText.layers.linears import DenseGeneral -from MaxText.layers.normalizations import RMSNorm -from MaxText.layers.decoders import DecoderLayer -from MaxText.layers import nnx_wrappers from MaxText.globals import EPS +from MaxText.layers import nnx_wrappers +from MaxText.layers.decoders import DecoderLayer from MaxText.layers.initializers import variable_to_logically_partitioned +from MaxText.layers.linears import DenseGeneral +from MaxText.layers.normalizations import RMSNorm + from maxtext.utils import max_utils from maxtext.utils import maxtext_utils @@ -84,7 +85,7 @@ def __init__( cfg = self.config self.embedding_norm = RMSNorm( - num_features=cfg.base_emb_dim, + num_features=cfg.emb_dim, epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, @@ -92,7 +93,7 @@ def __init__( rngs=rngs, ) self.hidden_state_norm = RMSNorm( - num_features=cfg.base_emb_dim, + num_features=cfg.emb_dim, epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, @@ -100,8 +101,8 @@ def __init__( rngs=rngs, ) self.projection_layer = DenseGeneral( - in_features_shape=2 * cfg.base_emb_dim, - out_features_shape=cfg.base_emb_dim, + in_features_shape=2 * cfg.emb_dim, + out_features_shape=cfg.emb_dim, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, use_bias=False, @@ -118,10 +119,11 @@ def __init__( self.transformer_layer = nnx_wrappers.ToNNX(mtp_transformer_layer, rngs=rngs) # ToNNX requires explicit initialization with sample inputs for proper parameter setup. + batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=MODEL_MODE_TRAIN) self.transformer_layer.lazy_init( - inputs=jnp.zeros((1, 1, cfg.base_emb_dim), dtype=cfg.dtype), + inputs=jnp.zeros((batch_size, seq_len, self.config.emb_dim), dtype=self.config.dtype), decoder_segment_ids=None, - decoder_positions=jnp.zeros((1, 1), dtype=jnp.int32), + decoder_positions=jnp.zeros((batch_size, seq_len), dtype=jnp.int32), deterministic=True, model_mode=MODEL_MODE_TRAIN, ) @@ -149,6 +151,14 @@ def __call__( Returns: Processed hidden state. Shape [batch, seq_len, hidden_size]. """ + target_token_embedding = sharding.maybe_shard_with_logical( + target_token_embedding, + ("activation_batch", "activation_length", "activation_embed"), + self.mesh, + self.config.shard_mode, + self.config.logical_axis_rules, + ) + embedding_norm = self.embedding_norm(target_token_embedding) hidden_state_norm = self.hidden_state_norm(prev_hidden_state) concatenated_features = jnp.concatenate([embedding_norm, hidden_state_norm], axis=-1) diff --git a/tests/unit/multi_token_prediction_test.py b/tests/unit/multi_token_prediction_test.py index 9f62504918..e8b5c0e1cf 100644 --- a/tests/unit/multi_token_prediction_test.py +++ b/tests/unit/multi_token_prediction_test.py @@ -199,6 +199,7 @@ class MultiTokenPredictionBlockTest(unittest.TestCase): def setUp(self): super().setUp() # Conditionally set ici_fsdp_parallelism to match device count in decoupled mode + num_devices = jax.device_count() extra_args = {"ici_fsdp_parallelism": jax.device_count()} if is_decoupled() else {} self.cfg = pyconfig.initialize( [None, get_test_config_path()], @@ -215,7 +216,7 @@ def setUp(self): self.mesh = Mesh(devices_array, self.cfg.mesh_axes) data_rng, self.init_rng = jax.random.split(self.rng) - self.batch_size, self.seq_len, self.embed_dim = 2, 8, self.cfg.base_emb_dim + self.batch_size, self.seq_len, self.embed_dim = num_devices, 8, self.cfg.base_emb_dim key1, key2, key3 = jax.random.split(data_rng, 3) self.main_hidden_state = jax.random.normal(key1, (self.batch_size, self.seq_len, self.embed_dim)) self.input_ids = jax.random.randint(key2, (self.batch_size, self.seq_len), 0, self.cfg.vocab_size)