Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 44 additions & 12 deletions src/MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,14 @@ def _prefill_jit(

full_true_length = start_position + true_length

input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE]
positions = jnp.expand_dims(jnp.arange(start_position, start_position + input_tokens.shape[1]), 0)
# [BHUMI PATCH] Support batched inputs
if padded_tokens.ndim == 1:
input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE]
else:
input_tokens = padded_tokens

positions = jnp.arange(start_position, start_position + input_tokens.shape[1])
positions = jnp.expand_dims(positions, 0) # [1, SEQUENCE]

if self.config.use_multimodal and images is not None:
if images.ndim == 3:
Expand All @@ -494,10 +500,21 @@ def _prefill_jit(

# sequence_indicator will be concatenated to existing_prefix decoder_segment_ids
start_to_n = jnp.arange(start_position, start_position + input_tokens.shape[1])
ones_to_keep = start_to_n < full_true_length
# Handle broadcasting for batched true_length
full_true_len_bc = full_true_length
start_to_n_bc = start_to_n
if hasattr(full_true_length, 'ndim') and full_true_length.ndim == 1:
full_true_len_bc = full_true_length[:, None]
start_to_n_bc = start_to_n[None, :]

ones_to_keep = start_to_n_bc < full_true_len_bc
one_d_output = ones_to_keep * DECODING_ACTIVE_SEQUENCE_INDICATOR
sequence_indicator = jnp.expand_dims(one_d_output, 0)


if one_d_output.ndim == 1:
sequence_indicator = jnp.expand_dims(one_d_output, 0)
else:
sequence_indicator = one_d_output

rng, new_rng = jax.random.split(rng)
with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
flat_logits, new_vars = self.model.apply(
Expand All @@ -522,12 +539,22 @@ def _prefill_jit(
else:
prompt_logp = None

generated_tokens = jnp.zeros((1, 1), dtype=jnp.int32)
selected_logits = jax.lax.dynamic_slice(
flat_logits,
(0, true_length - 1, 0),
(flat_logits.shape[0], 1, flat_logits.shape[2]),
)
# [BHUMI PATCH] Dynamic batch size
batch_size = flat_logits.shape[0]
generated_tokens = jnp.zeros((batch_size, 1), dtype=jnp.int32)

if hasattr(true_length, 'ndim') and true_length.ndim == 1:
# [BHUMI PATCH] Batch gather
batch_indices = jnp.arange(batch_size)
seq_indices = true_length - 1
selected_logits = flat_logits[batch_indices, seq_indices, :]
selected_logits = selected_logits[:, None, :] # [Batch, 1, Vocab]
else:
selected_logits = jax.lax.dynamic_slice(
flat_logits,
(0, true_length - 1, 0),
(flat_logits.shape[0], 1, flat_logits.shape[2]),
)
selected_logits = jax.lax.with_sharding_constraint(selected_logits, self.replicated_sharding)

# sampling first token
Expand Down Expand Up @@ -562,7 +589,12 @@ def _prefill_jit(

cache = new_vars["cache"]
cache = self._maybe_stack_prefill_result_cache(cache)
next_pos = jnp.full((1, 1), full_true_length, dtype=jnp.int32)

# [BHUMI PATCH] Handle batched next_pos
if hasattr(full_true_length, 'ndim') and full_true_length.ndim == 1:
next_pos = full_true_length[:, None].astype(jnp.int32)
else:
next_pos = jnp.full((1, 1), full_true_length, dtype=jnp.int32)
return {
"logits": selected_logits,
"cache": cache,
Expand Down