Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 128 additions & 5 deletions src/MaxText/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import jax
from jax.ad_checkpoint import checkpoint_name
from jax.experimental import layout
from jax.sharding import PartitionSpec as P
from jax.experimental import shard_map
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding

Expand Down Expand Up @@ -619,7 +621,11 @@ def __init__(
)

# Module attribute names must match names previously passed to Linen for checkpointing
self.MlaKVCache_0 = self.init_mla_kv_caches(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None
self.MlaKVCache_0 = (
self.init_mla_kv_caches(inputs_kv_shape)
if model_mode != MODEL_MODE_TRAIN and config.attention != "vllm_rpa"
else None
)

def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None:
"""Initializes the MLA-specific projections."""
Expand Down Expand Up @@ -937,15 +943,118 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm

key, value = self.mla_get_key_value(low_rank_main, key_rope, model_mode)
cached_values = [None, None]
if self.config.attention != "paged" and model_mode != MODEL_MODE_TRAIN:
if self.config.attention != "paged" and self.config.attention != "vllm_rpa" and model_mode != MODEL_MODE_TRAIN:
if self.config.mla_naive_kvcache:
cached_values = self.update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk)
else:
cached_values = self.update_mla_kv_caches(
low_rank_main, key_rope, decoder_segment_ids, model_mode, previous_chunk
)

return key, value, cached_values
return key, value, cached_values, low_rank_main, key_rope

def mla_rpa_vllm(self, q_nope, q_rope, k_latent, k_rope, mla_kv_cache, mla_metadata):
"""Forward function for vLLM serving with MLA attention.

Args:
q_nope: Query nope part [T, N, qk_nope_head_dim]
q_rope: Query rope part [T, N, qk_rope_head_dim]
k_latent: Latent KV representation [S, kv_lora_rank] (NOT expanded k_nope)
k_rope: Key rope part [S, qk_rope_head_dim] (NO head dimension)
mla_kv_cache: The KV cache
mla_metadata: Attention metadata
"""
md = mla_metadata
try:
# pylint: disable=import-outside-toplevel
# pytype: disable=import-error
from tpu_inference.kernels.mla.v1.kernel import mla_ragged_paged_attention
from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import get_tuned_block_sizes
except ImportError as e:
raise ImportError(
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
) from e

if mla_kv_cache is None or mla_metadata is None:
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")

wkv_b_kernel = self.wkv_b.kernel.value
wk_b_kernel = wkv_b_kernel[..., : self.qk_nope_head_dim]
wv_b_kernel = wkv_b_kernel[..., self.qk_nope_head_dim :]
q_absorbed = jnp.einsum("TNH,ANH->TNA", q_nope, wk_b_kernel)

def _mla_ragged_paged_attention(q, q_rope, k, k_rope, kv_cache, *args):
def _initialize_block_sizes():
# Set reasonable starting estimates for block sizes. (TODO(gpolovets): update this to use tuned sizes)
max_num_tokens = q_absorbed.shape[0]
max_num_seqs = md.seq_lens.shape[0]
num_page_indices = md.block_tables.shape[0]
assert num_page_indices % max_num_seqs == 0
pages_per_seq = num_page_indices // max_num_seqs
# num_kv_pages_per_block = min(pages_per_seq, 16)
bkv_p, bq_sz = get_tuned_block_sizes(
q_nope.dtype,
q_nope.dtype, # changed to q_nope dtype from mla_kv_cache.dtype
self.num_query_heads,
1, # num_kv_heads for MLA kernel
self.qk_nope_head_dim,
q_nope.shape[1], # page size ?? kv_cache.shape[1]
max_num_tokens,
pages_per_seq,
)
num_kv_pages_per_block = min(pages_per_seq, bkv_p, 4)
num_queries_per_block = min(max_num_tokens, bq_sz, 4) # OOMS at 8
return num_kv_pages_per_block, num_queries_per_block

num_kv_pages_per_block, num_queries_per_block = _initialize_block_sizes()
output, kv_cache = mla_ragged_paged_attention(
q,
q_rope,
k,
k_rope,
kv_cache,
*args,
sm_scale=1.0,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
)
return kv_cache, output

in_specs = (
P(("attn_dp", "model", "expert"), None, None), # q
P(("attn_dp", "model", "expert"), None, None), # q_rope
P(("attn_dp", "model", "expert"), None), # k
P(("attn_dp", "model", "expert"), None), # k_rope
P(("attn_dp", "model", "expert")), # kv_cache
P(("data", "attn_dp")), # md.seq_lens: Replicated
P(("data", "attn_dp")), # page_indices_flat: Replicated
P(("data", "attn_dp")), # query_start_loc: Replicated
P(("data", "attn_dp")), # distribution: Replicated
)

out_specs = (P(("attn_dp", "model", "expert"), None, None), P(("attn_dp", "model", "expert")))

kv_cache, output = jax.jit(
shard_map.shard_map(
_mla_ragged_paged_attention,
mesh=self.mesh,
in_specs=in_specs,
out_specs=out_specs,
check_rep=False,
),
)(
q_absorbed,
q_rope,
k_latent,
k_rope,
mla_kv_cache,
md.seq_lens,
md.block_tables,
md.query_start_loc,
md.request_distribution,
)
output = jnp.einsum("TNA,ANH->TNH", output, wv_b_kernel)
return kv_cache, output

def __call__(
self,
Expand Down Expand Up @@ -1001,7 +1110,7 @@ def __call__(
query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode)
if self.config.force_q_layout:
query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1)))
key, value, cached_values = self.mla_kv_projection(
key, value, cached_values, low_rank_main, key_rope = self.mla_kv_projection(
inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk
)
query = checkpoint_name(query, "query_proj")
Expand Down Expand Up @@ -1032,8 +1141,22 @@ def __call__(
)
unnormalized_out = unnormalized_out[..., : self.v_head_dim]
out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out
elif self.config.attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN and kv_cache is not None:
batch, seq_len, num_heads, _ = query.shape
query = query.reshape(-1, query.shape[2], query.shape[3])
q_nope, q_rope = jnp.split(query, [self.qk_nope_head_dim], axis=-1)

k_latent = low_rank_main.reshape(-1, self.kv_lora_rank)
k_rope_squeezed = key_rope.reshape(-1, self.qk_rope_head_dim)

updated_kv, attn_out = self.mla_rpa_vllm(
q_nope, q_rope, k_latent, k_rope_squeezed, mla_kv_cache=kv_cache, mla_metadata=attention_metadata
)
out = attn_out.reshape(batch, seq_len, num_heads, self.v_head_dim)
kv_cache = updated_kv
else:
# Pass the index_mask to the Attention Op
if self.config.attention == "vllm_rpa" and kv_cache is None and model_mode != MODEL_MODE_TRAIN:
model_mode = MODEL_MODE_TRAIN
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values, index_mask=index_mask)

out = jax.ad_checkpoint.checkpoint_name(out, "attention_out")
Expand Down
26 changes: 19 additions & 7 deletions src/MaxText/layers/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# pylint: disable=arguments-differ
# pylint: disable=no-name-in-module

from typing import Optional
from typing import Optional, Any

from flax import nnx
from jax.ad_checkpoint import checkpoint_name
Expand Down Expand Up @@ -154,9 +154,11 @@ def attention_op(
previous_chunk=None,
page_state: None | page_manager.PageState = None,
slot: None | int = None,
kv_cache: None | jnp.ndarray = None,
attention_metadata: None | dict[str, Any] = None,
):
"""Executes the attention layer."""
attention_result, _ = self.self_attention(
attention_result, kv_cache = self.self_attention(
x,
x,
decoder_positions,
Expand All @@ -167,8 +169,10 @@ def attention_op(
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)
return self.with_logical_constraint(attention_result)
return self.with_logical_constraint(attention_result), kv_cache

@property
def logical_axis_names(self):
Expand Down Expand Up @@ -229,23 +233,27 @@ def self_attention_with_norm_op(
previous_chunk=None,
page_state: None | page_manager.PageState = None,
slot: None | int = None,
kv_cache: None | jnp.ndarray = None,
attention_metadata: None | dict[str, Any] = None,
):
"""self-attention with normalization"""
lnx = self.pre_attention_norm_op(inputs)

attention_lnx = self.attention_op(
attention_lnx, kv_cache = self.attention_op(
lnx,
decoder_segment_ids,
decoder_positions,
deterministic,
previous_chunk,
page_state,
slot,
kv_cache,
attention_metadata,
)
intermediate_inputs = inputs + attention_lnx
# Normalization
hidden_states = self.post_attention_norm_op(intermediate_inputs)
return hidden_states, intermediate_inputs
return hidden_states, intermediate_inputs, kv_cache


class DeepSeekDenseLayer(DeepSeekGenericLayer):
Expand Down Expand Up @@ -298,14 +306,16 @@ def __call__(
x = self.with_logical_constraint(inputs)
x = checkpoint_name(x, "decoder_layer_input")

hidden_states, intermediate_inputs = self.self_attention_with_norm_op(
hidden_states, intermediate_inputs, kv_cache = self.self_attention_with_norm_op(
x,
decoder_segment_ids,
decoder_positions,
deterministic,
previous_chunk,
page_state,
slot,
kv_cache,
attention_metadata,
)

mlp_lnx = self.mlp_op(hidden_states, deterministic)
Expand Down Expand Up @@ -384,14 +394,16 @@ def __call__(
x = self.with_logical_constraint(inputs)
x = checkpoint_name(x, "decoder_layer_input")

hidden_states, intermediate_inputs = self.self_attention_with_norm_op(
hidden_states, intermediate_inputs, kv_cache = self.self_attention_with_norm_op(
x,
decoder_segment_ids,
decoder_positions,
deterministic,
previous_chunk,
page_state,
slot,
kv_cache,
attention_metadata,
)

mlp_lnx, load_balance_loss, moe_bias_updates = self.mlp_op(hidden_states, deterministic)
Expand Down
20 changes: 10 additions & 10 deletions src/MaxText/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,16 +351,16 @@ def __init__(

if self.config.shard_exp_on_fsdp:
# special sharding for dsv3
self.wi_kernel_axes = ("embed_no_exp", None, "mlp")
self.wo_kernel_axes = ("embed_no_exp", "mlp", None)
self.wi_kernel_axes = ("embed_no_exp", None, "moe_mlp")
self.wo_kernel_axes = ("embed_no_exp", "moe_mlp", None)
elif self.config.use_2d_fsdp_sharding:
self.wi_kernel_axes = ("embed_no_exp", "mlp", None)
self.wo_kernel_axes = ("embed_no_exp", "mlp", None)
elif self.config.use_batch_split_schedule:
self.wi_kernel_axes, self.wo_kernel_axes = get_batchsplit_init_kernel_axes()
else:
self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp")
self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp")
self.wi_kernel_axes = ("exp", "embed_no_exp", "moe_mlp")
self.wo_kernel_axes = ("exp", "moe_mlp", "embed_no_exp")

if self.config.attention == "vllm_rpa":
# vLLM uses 'model' as the tensor parallelism axis name
Expand Down Expand Up @@ -1377,11 +1377,11 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):

if self.config.moe_fsdp_use_two_stage_all_gather:
# Unshard on fsdp axis
w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp"))
w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp"))
w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "moe_mlp"))
w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "moe_mlp"))

# Unshard on fsdp_transpose axis
wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "mlp", "embed_tensor_transpose"))
wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "moe_mlp", "embed_tensor_transpose"))

# Make sure XLA does not optimize by combining above All-Gather to unshard
# on FSDP axis and the subsequent unshard on fsdp_transpose axis
Expand Down Expand Up @@ -1829,7 +1829,7 @@ def dense_matmul(
dispatch_axis,
)
with jax.named_scope("wi_0"):
w0_kernel_axes = ("exp", None, "mlp")
w0_kernel_axes = ("exp", None, "moe_mlp")
w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes)
layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)(
mlp_up_einsum, dispatch, w0_kernel, precision=matmul_precision
Expand All @@ -1846,7 +1846,7 @@ def dense_matmul(
)
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
with jax.named_scope("wi_1"):
w1_kernel_axes = ("exp", None, "mlp")
w1_kernel_axes = ("exp", None, "moe_mlp")
w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes)
layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)(
mlp_up_einsum, dispatch, w1_kernel, precision=matmul_precision
Expand All @@ -1863,7 +1863,7 @@ def dense_matmul(
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1)
with jax.named_scope("wo"):
wo_kernel_axes = ("exp", "mlp", None)
wo_kernel_axes = ("exp", "moe_mlp", None)
wo_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(wo_kernel, wo_kernel_axes)
intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)(
mlp_down_einsum,
Expand Down
3 changes: 3 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ logical_axis_rules: [
['decode_length', ['sequence']],
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
['moe_mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
Expand Down Expand Up @@ -1083,6 +1084,8 @@ use_jax_splash: false
# vLLM Adapter Configurations
# Path to the HuggingFace-style config directory for the adapter (e.g. src/MaxText/integration/vllm/maxtext_vllm_adapter)
vllm_hf_config_path: ""
# Path to yaml file for loading vLLM config
vllm_config_path: ""
# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}')
vllm_additional_config: {}
# When use_jax_splash=True, force the layout of the query tensor to be [..., NUM_HEADS, HEAD_DIM, SEQ_LENGTH]
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/configs/inference/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ logical_axis_rules: [
['decode_length', []],
['mlp', ['model', 'attn_dp']],
['mlp_no_fsdp', ['model', 'attn_dp']],
['moe_mlp', ['model', 'attn_dp']],
['vocab', ['model', 'attn_dp']],
['heads', ['model']],
['q_heads', ['model']],
Expand Down
2 changes: 2 additions & 0 deletions src/maxtext/configs/post_train/rl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ enable_dp_attention: False
# Performance tuning for samplers
max_num_batched_tokens: null
max_num_seqs: null
# path to initialize vllm config
vllm_config_path: 'src/MaxText/configs/vllm.yml'

# ====== Checkpoint Configuration ======
enable_checkpointing: True
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1557,6 +1557,7 @@ class VLLM(BaseModel):
max_num_seqs: Optional[int] = Field(None, description="Max number of sequences in vLLM.")
vllm_additional_config: dict[str, Any] = Field(default_factory=dict, description="Additional vLLM config options.")
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")
vllm_config_path: str = Field("src/MaxText/configs/vllm.yml", description="path to yaml file for loading vLLM config.")


class RL(BaseModel):
Expand Down
Loading
Loading