diff --git a/src/MaxText/maxengine.py b/src/MaxText/maxengine.py index 403f34885b..d725aef9fd 100644 --- a/src/MaxText/maxengine.py +++ b/src/MaxText/maxengine.py @@ -479,8 +479,14 @@ def _prefill_jit( full_true_length = start_position + true_length - input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE] - positions = jnp.expand_dims(jnp.arange(start_position, start_position + input_tokens.shape[1]), 0) + # [BHUMI PATCH] Support batched inputs + if padded_tokens.ndim == 1: + input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE] + else: + input_tokens = padded_tokens + + positions = jnp.arange(start_position, start_position + input_tokens.shape[1]) + positions = jnp.expand_dims(positions, 0) # [1, SEQUENCE] if self.config.use_multimodal and images is not None: if images.ndim == 3: @@ -494,10 +500,21 @@ def _prefill_jit( # sequence_indicator will be concatenated to existing_prefix decoder_segment_ids start_to_n = jnp.arange(start_position, start_position + input_tokens.shape[1]) - ones_to_keep = start_to_n < full_true_length + # Handle broadcasting for batched true_length + full_true_len_bc = full_true_length + start_to_n_bc = start_to_n + if hasattr(full_true_length, 'ndim') and full_true_length.ndim == 1: + full_true_len_bc = full_true_length[:, None] + start_to_n_bc = start_to_n[None, :] + + ones_to_keep = start_to_n_bc < full_true_len_bc one_d_output = ones_to_keep * DECODING_ACTIVE_SEQUENCE_INDICATOR - sequence_indicator = jnp.expand_dims(one_d_output, 0) - + + if one_d_output.ndim == 1: + sequence_indicator = jnp.expand_dims(one_d_output, 0) + else: + sequence_indicator = one_d_output + rng, new_rng = jax.random.split(rng) with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): flat_logits, new_vars = self.model.apply( @@ -522,12 +539,22 @@ def _prefill_jit( else: prompt_logp = None - generated_tokens = jnp.zeros((1, 1), dtype=jnp.int32) - selected_logits = jax.lax.dynamic_slice( - flat_logits, - (0, true_length - 1, 0), - (flat_logits.shape[0], 1, flat_logits.shape[2]), - ) + # [BHUMI PATCH] Dynamic batch size + batch_size = flat_logits.shape[0] + generated_tokens = jnp.zeros((batch_size, 1), dtype=jnp.int32) + + if hasattr(true_length, 'ndim') and true_length.ndim == 1: + # [BHUMI PATCH] Batch gather + batch_indices = jnp.arange(batch_size) + seq_indices = true_length - 1 + selected_logits = flat_logits[batch_indices, seq_indices, :] + selected_logits = selected_logits[:, None, :] # [Batch, 1, Vocab] + else: + selected_logits = jax.lax.dynamic_slice( + flat_logits, + (0, true_length - 1, 0), + (flat_logits.shape[0], 1, flat_logits.shape[2]), + ) selected_logits = jax.lax.with_sharding_constraint(selected_logits, self.replicated_sharding) # sampling first token @@ -562,7 +589,12 @@ def _prefill_jit( cache = new_vars["cache"] cache = self._maybe_stack_prefill_result_cache(cache) - next_pos = jnp.full((1, 1), full_true_length, dtype=jnp.int32) + + # [BHUMI PATCH] Handle batched next_pos + if hasattr(full_true_length, 'ndim') and full_true_length.ndim == 1: + next_pos = full_true_length[:, None].astype(jnp.int32) + else: + next_pos = jnp.full((1, 1), full_true_length, dtype=jnp.int32) return { "logits": selected_logits, "cache": cache,