Skip to content

Commit 68efcf1

Browse files
committed
deepseek sharding and mla attention plumbing
1 parent b43d692 commit 68efcf1

File tree

13 files changed

+281
-96
lines changed

13 files changed

+281
-96
lines changed

src/MaxText/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from maxtext.trainers.post_train.dpo import dpo_utils
3535
from maxtext.utils import maxtext_utils
3636
from maxtext.utils import model_creation_utils
37-
from maxtext.utils.model_creation_utils import from_config
3837

3938
Transformer = models.Transformer
4039
transformer_as_linen = models.transformer_as_linen

src/MaxText/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,7 @@ logical_axis_rules: [
430430
['decode_length', ['sequence']],
431431
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
432432
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
433+
['moe_mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
433434
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
434435
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
435436
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
@@ -1047,6 +1048,8 @@ use_jax_splash: false
10471048
# vLLM Adapter Configurations
10481049
# Path to the HuggingFace-style config directory for the adapter (e.g. src/MaxText/integration/vllm/maxtext_vllm_adapter)
10491050
vllm_hf_config_path: ""
1051+
# Path to yaml file for loading vLLM config
1052+
vllm_config_path: ""
10501053
# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}')
10511054
vllm_additional_config: {}
10521055
# When use_jax_splash=True, force the layout of the query tensor to be [..., NUM_HEADS, HEAD_DIM, SEQ_LENGTH]

src/MaxText/configs/rl.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ enable_dp_attention: False
149149
# Performance tuning for samplers
150150
max_num_batched_tokens: null
151151
max_num_seqs: null
152+
# path to initialize vllm config
153+
vllm_config_path: 'src/MaxText/configs/vllm.yml'
152154

153155
# ====== Checkpoint Configuration ======
154156
enable_checkpointing: True

src/MaxText/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,6 +1490,7 @@ class VLLM(BaseModel):
14901490
max_num_seqs: Optional[int] = Field(None, description="Max number of sequences in vLLM.")
14911491
vllm_additional_config: dict[str, Any] = Field(default_factory=dict, description="Additional vLLM config options.")
14921492
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")
1493+
vllm_config_path: str = Field("src/MaxText/configs/vllm.yml", description="path to yaml file for loading vLLM config.")
14931494

14941495

14951496
class RL(BaseModel):

src/MaxText/configs/vllm.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ logical_axis_rules: [
5353
['decode_length', []],
5454
['mlp', ['model', 'attn_dp']],
5555
['mlp_no_fsdp', ['model', 'attn_dp']],
56+
['moe_mlp', ['model', 'attn_dp']],
5657
['vocab', ['model', 'attn_dp']],
5758
['heads', ['model']],
5859
['q_heads', ['model']],
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
base_config: "vllm.yml"
17+
18+
logical_axis_rules: [
19+
['activation_batch', ['']],
20+
['activation_batch_no_exp', []],
21+
['activation_embed_and_logits_batch', ['expert']],
22+
['activation_embed_and_logits_batch_sequence', ['expert']],
23+
['activation_heads', ['model']],
24+
['activation_kv_heads', ['model']],
25+
['activation_attn_length', ['expert']],
26+
['activation_attn_length_no_exp', []],
27+
['activation_length', ['data', 'expert']],
28+
['activation_length_no_exp', 'data'],
29+
['activation_q_length', ['expert']],
30+
['activation_attn_embed', 'model'],
31+
['activation_embed', ['model', 'attn_dp']],
32+
['activation_mlp', ['model', 'attn_dp', 'expert']],
33+
['activation_kv', ['model']],
34+
['activation_prefill_kv_batch', ['expert']],
35+
['activation_kv_batch', ['']],
36+
['activation_kv_batch_no_exp', []],
37+
['activation_kv_head_dim', ['model', 'attn_dp', 'expert']],
38+
['activation_vocab', ['model', 'attn_dp']],
39+
['activation_norm_length', []],
40+
['activation_exp', ['expert']],
41+
['decode_batch', ['expert']],
42+
['decode_length', []],
43+
['mlp_no_fsdp', ['model', 'attn_dp', 'expert']],
44+
['vocab', ['model', 'attn_dp', 'expert']],
45+
['heads', ['expert', 'attn_dp', 'model']],
46+
['q_heads', []],
47+
['kv_heads', []],
48+
['kv_head_dim', ['model', 'attn_dp', 'expert']],
49+
['kv', ['model', 'attn_dp', 'expert']],
50+
['kv', []],
51+
['embed', []],
52+
['mlp', ['model', 'attn_dp', 'expert']],
53+
['moe_mlp', []],
54+
['embed_tensor_transpose', ['attn_dp', 'model']],
55+
['embed_no_exp', []],
56+
['q_lora', []],
57+
['kv_lora', []],
58+
['norm', []],
59+
['cache_heads', ['model']],
60+
['exp', ['expert', 'attn_dp', 'model']],
61+
['paged_kv_heads', ['model']],
62+
['cache_batch_prefill', []],
63+
['cache_batch', []],
64+
['cache_sequence', []],
65+
['cache_heads_none', []],
66+
['cache_kv', []],
67+
['kv_lora_up_proj',['expert', 'attn_dp', 'model']],
68+
['q_lora_up_proj',['expert', 'attn_dp', 'model']],
69+
]

src/MaxText/kernels/sort_activations.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ def _unroute_fwd(
7373
)
7474

7575

76-
def _unroute_bwd(
77-
use_custom_mosaic_kernel: bool, residuals: jax.Array, grads: jax.Array
78-
) -> tuple[jax.Array, None]:
76+
def _unroute_bwd(use_custom_mosaic_kernel: bool, residuals: jax.Array, grads: jax.Array) -> tuple[jax.Array, None]:
7977
selected_experts = residuals
8078
return _route_impl(grads, selected_experts, use_custom_mosaic_kernel), None
8179

@@ -90,8 +88,7 @@ def _route_impl(
9088
) -> jax.Array:
9189
"""Gather `tokens` according to `selected_experts`."""
9290
assert (
93-
tokens.shape[0] == selected_experts.shape[0]
94-
and selected_experts.ndim == 2
91+
tokens.shape[0] == selected_experts.shape[0] and selected_experts.ndim == 2
9592
), f"{tokens.shape=}, {selected_experts.shape=}"
9693
if use_custom_mosaic_kernel:
9794
raise NotImplementedError("Custom Mosaic kernel not implemented.")
@@ -104,10 +101,8 @@ def _unroute_impl(
104101
selected_experts: jax.Array,
105102
use_custom_mosaic_kernel: bool,
106103
) -> jax.Array:
107-
assert (
108-
tokens.shape[0] == selected_experts.shape[0] * selected_experts.shape[1]
109-
and selected_experts.ndim == 2
110-
)
104+
"""Reverse the routing operation, restoring tokens to their original order."""
105+
assert tokens.shape[0] == selected_experts.shape[0] * selected_experts.shape[1] and selected_experts.ndim == 2
111106
inds = jnp.argsort(jnp.argsort(jnp.ravel(selected_experts)))
112107
return jnp.sum(
113108
jnp.reshape(
@@ -118,9 +113,7 @@ def _unroute_impl(
118113
)
119114

120115

121-
def _sort_impl(
122-
tokens: jax.Array, inds: jax.Array, use_custom_mosaic_kernel: bool
123-
) -> jax.Array:
116+
def _sort_impl(tokens: jax.Array, inds: jax.Array, use_custom_mosaic_kernel: bool) -> jax.Array:
124117
if use_custom_mosaic_kernel:
125118
raise NotImplementedError("Custom Mosaic kernel not implemented.")
126119
else:

src/MaxText/layers/attention_mla.py

Lines changed: 128 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import jax
2222
from jax.ad_checkpoint import checkpoint_name
2323
from jax.experimental import layout
24+
from jax.sharding import PartitionSpec as P
25+
from jax.experimental import shard_map
2426
import jax.numpy as jnp
2527
from jax.sharding import Mesh, NamedSharding
2628

@@ -619,7 +621,11 @@ def __init__(
619621
)
620622

621623
# Module attribute names must match names previously passed to Linen for checkpointing
622-
self.MlaKVCache_0 = self.init_mla_kv_caches(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None
624+
self.MlaKVCache_0 = (
625+
self.init_mla_kv_caches(inputs_kv_shape)
626+
if model_mode != MODEL_MODE_TRAIN and config.attention != "vllm_rpa"
627+
else None
628+
)
623629

624630
def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None:
625631
"""Initializes the MLA-specific projections."""
@@ -937,15 +943,118 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
937943

938944
key, value = self.mla_get_key_value(low_rank_main, key_rope, model_mode)
939945
cached_values = [None, None]
940-
if self.config.attention != "paged" and model_mode != MODEL_MODE_TRAIN:
946+
if self.config.attention != "paged" and self.config.attention != "vllm_rpa" and model_mode != MODEL_MODE_TRAIN:
941947
if self.config.mla_naive_kvcache:
942948
cached_values = self.update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk)
943949
else:
944950
cached_values = self.update_mla_kv_caches(
945951
low_rank_main, key_rope, decoder_segment_ids, model_mode, previous_chunk
946952
)
947953

948-
return key, value, cached_values
954+
return key, value, cached_values, low_rank_main, key_rope
955+
956+
def mla_rpa_vllm(self, q_nope, q_rope, k_latent, k_rope, mla_kv_cache, mla_metadata):
957+
"""Forward function for vLLM serving with MLA attention.
958+
959+
Args:
960+
q_nope: Query nope part [T, N, qk_nope_head_dim]
961+
q_rope: Query rope part [T, N, qk_rope_head_dim]
962+
k_latent: Latent KV representation [S, kv_lora_rank] (NOT expanded k_nope)
963+
k_rope: Key rope part [S, qk_rope_head_dim] (NO head dimension)
964+
mla_kv_cache: The KV cache
965+
mla_metadata: Attention metadata
966+
"""
967+
md = mla_metadata
968+
try:
969+
# pylint: disable=import-outside-toplevel
970+
# pytype: disable=import-error
971+
from tpu_inference.kernels.mla.v1.kernel import mla_ragged_paged_attention
972+
from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import get_tuned_block_sizes
973+
except ImportError as e:
974+
raise ImportError(
975+
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
976+
) from e
977+
978+
if mla_kv_cache is None or mla_metadata is None:
979+
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")
980+
981+
wkv_b_kernel = self.wkv_b.kernel.value
982+
wk_b_kernel = wkv_b_kernel[..., : self.qk_nope_head_dim]
983+
wv_b_kernel = wkv_b_kernel[..., self.qk_nope_head_dim :]
984+
q_absorbed = jnp.einsum("TNH,ANH->TNA", q_nope, wk_b_kernel)
985+
986+
def _mla_ragged_paged_attention(q, q_rope, k, k_rope, kv_cache, *args):
987+
def _initialize_block_sizes():
988+
# Set reasonable starting estimates for block sizes. (TODO(gpolovets): update this to use tuned sizes)
989+
max_num_tokens = q_absorbed.shape[0]
990+
max_num_seqs = md.seq_lens.shape[0]
991+
num_page_indices = md.block_tables.shape[0]
992+
assert num_page_indices % max_num_seqs == 0
993+
pages_per_seq = num_page_indices // max_num_seqs
994+
# num_kv_pages_per_block = min(pages_per_seq, 16)
995+
bkv_p, bq_sz = get_tuned_block_sizes(
996+
q_nope.dtype,
997+
q_nope.dtype, # changed to q_nope dtype from mla_kv_cache.dtype
998+
self.num_query_heads,
999+
1, # num_kv_heads for MLA kernel
1000+
self.qk_nope_head_dim,
1001+
q_nope.shape[1], # page size ?? kv_cache.shape[1]
1002+
max_num_tokens,
1003+
pages_per_seq,
1004+
)
1005+
num_kv_pages_per_block = min(pages_per_seq, bkv_p, 4)
1006+
num_queries_per_block = min(max_num_tokens, bq_sz, 4) # OOMS at 8
1007+
return num_kv_pages_per_block, num_queries_per_block
1008+
1009+
num_kv_pages_per_block, num_queries_per_block = _initialize_block_sizes()
1010+
output, kv_cache = mla_ragged_paged_attention(
1011+
q,
1012+
q_rope,
1013+
k,
1014+
k_rope,
1015+
kv_cache,
1016+
*args,
1017+
sm_scale=1.0,
1018+
num_kv_pages_per_block=num_kv_pages_per_block,
1019+
num_queries_per_block=num_queries_per_block,
1020+
)
1021+
return kv_cache, output
1022+
1023+
in_specs = (
1024+
P(("attn_dp", "model", "expert"), None, None), # q
1025+
P(("attn_dp", "model", "expert"), None, None), # q_rope
1026+
P(("attn_dp", "model", "expert"), None), # k
1027+
P(("attn_dp", "model", "expert"), None), # k_rope
1028+
P(("attn_dp", "model", "expert")), # kv_cache
1029+
P(("data", "attn_dp")), # md.seq_lens: Replicated
1030+
P(("data", "attn_dp")), # page_indices_flat: Replicated
1031+
P(("data", "attn_dp")), # query_start_loc: Replicated
1032+
P(("data", "attn_dp")), # distribution: Replicated
1033+
)
1034+
1035+
out_specs = (P(("attn_dp", "model", "expert"), None, None), P(("attn_dp", "model", "expert")))
1036+
1037+
kv_cache, output = jax.jit(
1038+
shard_map.shard_map(
1039+
_mla_ragged_paged_attention,
1040+
mesh=self.mesh,
1041+
in_specs=in_specs,
1042+
out_specs=out_specs,
1043+
check_rep=False,
1044+
),
1045+
)(
1046+
q_absorbed,
1047+
q_rope,
1048+
k_latent,
1049+
k_rope,
1050+
mla_kv_cache,
1051+
md.seq_lens,
1052+
md.block_tables,
1053+
md.query_start_loc,
1054+
md.request_distribution,
1055+
)
1056+
output = jnp.einsum("TNA,ANH->TNH", output, wv_b_kernel)
1057+
return kv_cache, output
9491058

9501059
def __call__(
9511060
self,
@@ -1001,7 +1110,7 @@ def __call__(
10011110
query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode)
10021111
if self.config.force_q_layout:
10031112
query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1)))
1004-
key, value, cached_values = self.mla_kv_projection(
1113+
key, value, cached_values, low_rank_main, key_rope = self.mla_kv_projection(
10051114
inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk
10061115
)
10071116
query = checkpoint_name(query, "query_proj")
@@ -1034,8 +1143,22 @@ def __call__(
10341143
)
10351144
unnormalized_out = unnormalized_out[..., : self.v_head_dim]
10361145
out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out
1146+
elif self.config.attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN and kv_cache is not None:
1147+
batch, seq_len, num_heads, _ = query.shape
1148+
query = query.reshape(-1, query.shape[2], query.shape[3])
1149+
q_nope, q_rope = jnp.split(query, [self.qk_nope_head_dim], axis=-1)
1150+
1151+
k_latent = low_rank_main.reshape(-1, self.kv_lora_rank)
1152+
k_rope_squeezed = key_rope.reshape(-1, self.qk_rope_head_dim)
1153+
1154+
updated_kv, attn_out = self.mla_rpa_vllm(
1155+
q_nope, q_rope, k_latent, k_rope_squeezed, mla_kv_cache=kv_cache, mla_metadata=attention_metadata
1156+
)
1157+
out = attn_out.reshape(batch, seq_len, num_heads, self.v_head_dim)
1158+
kv_cache = updated_kv
10371159
else:
1038-
# Pass the index_mask to the Attention Op
1160+
if self.config.attention == "vllm_rpa" and kv_cache is None and model_mode != MODEL_MODE_TRAIN:
1161+
model_mode = MODEL_MODE_TRAIN
10391162
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values, index_mask=index_mask)
10401163

10411164
if model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:

0 commit comments

Comments
 (0)