|
21 | 21 | import jax |
22 | 22 | from jax.ad_checkpoint import checkpoint_name |
23 | 23 | from jax.experimental import layout |
| 24 | +from jax.sharding import PartitionSpec as P |
| 25 | +from jax.experimental import shard_map |
24 | 26 | import jax.numpy as jnp |
25 | 27 | from jax.sharding import Mesh, NamedSharding |
26 | 28 |
|
@@ -619,7 +621,11 @@ def __init__( |
619 | 621 | ) |
620 | 622 |
|
621 | 623 | # Module attribute names must match names previously passed to Linen for checkpointing |
622 | | - self.MlaKVCache_0 = self.init_mla_kv_caches(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None |
| 624 | + self.MlaKVCache_0 = ( |
| 625 | + self.init_mla_kv_caches(inputs_kv_shape) |
| 626 | + if model_mode != MODEL_MODE_TRAIN and config.attention != "vllm_rpa" |
| 627 | + else None |
| 628 | + ) |
623 | 629 |
|
624 | 630 | def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None: |
625 | 631 | """Initializes the MLA-specific projections.""" |
@@ -937,15 +943,118 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm |
937 | 943 |
|
938 | 944 | key, value = self.mla_get_key_value(low_rank_main, key_rope, model_mode) |
939 | 945 | cached_values = [None, None] |
940 | | - if self.config.attention != "paged" and model_mode != MODEL_MODE_TRAIN: |
| 946 | + if self.config.attention != "paged" and self.config.attention != "vllm_rpa" and model_mode != MODEL_MODE_TRAIN: |
941 | 947 | if self.config.mla_naive_kvcache: |
942 | 948 | cached_values = self.update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk) |
943 | 949 | else: |
944 | 950 | cached_values = self.update_mla_kv_caches( |
945 | 951 | low_rank_main, key_rope, decoder_segment_ids, model_mode, previous_chunk |
946 | 952 | ) |
947 | 953 |
|
948 | | - return key, value, cached_values |
| 954 | + return key, value, cached_values, low_rank_main, key_rope |
| 955 | + |
| 956 | + def mla_rpa_vllm(self, q_nope, q_rope, k_latent, k_rope, mla_kv_cache, mla_metadata): |
| 957 | + """Forward function for vLLM serving with MLA attention. |
| 958 | +
|
| 959 | + Args: |
| 960 | + q_nope: Query nope part [T, N, qk_nope_head_dim] |
| 961 | + q_rope: Query rope part [T, N, qk_rope_head_dim] |
| 962 | + k_latent: Latent KV representation [S, kv_lora_rank] (NOT expanded k_nope) |
| 963 | + k_rope: Key rope part [S, qk_rope_head_dim] (NO head dimension) |
| 964 | + mla_kv_cache: The KV cache |
| 965 | + mla_metadata: Attention metadata |
| 966 | + """ |
| 967 | + md = mla_metadata |
| 968 | + try: |
| 969 | + # pylint: disable=import-outside-toplevel |
| 970 | + # pytype: disable=import-error |
| 971 | + from tpu_inference.kernels.mla.v1.kernel import mla_ragged_paged_attention |
| 972 | + from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import get_tuned_block_sizes |
| 973 | + except ImportError as e: |
| 974 | + raise ImportError( |
| 975 | + "vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`." |
| 976 | + ) from e |
| 977 | + |
| 978 | + if mla_kv_cache is None or mla_metadata is None: |
| 979 | + raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.") |
| 980 | + |
| 981 | + wkv_b_kernel = self.wkv_b.kernel.value |
| 982 | + wk_b_kernel = wkv_b_kernel[..., : self.qk_nope_head_dim] |
| 983 | + wv_b_kernel = wkv_b_kernel[..., self.qk_nope_head_dim :] |
| 984 | + q_absorbed = jnp.einsum("TNH,ANH->TNA", q_nope, wk_b_kernel) |
| 985 | + |
| 986 | + def _mla_ragged_paged_attention(q, q_rope, k, k_rope, kv_cache, *args): |
| 987 | + def _initialize_block_sizes(): |
| 988 | + # Set reasonable starting estimates for block sizes. (TODO(gpolovets): update this to use tuned sizes) |
| 989 | + max_num_tokens = q_absorbed.shape[0] |
| 990 | + max_num_seqs = md.seq_lens.shape[0] |
| 991 | + num_page_indices = md.block_tables.shape[0] |
| 992 | + assert num_page_indices % max_num_seqs == 0 |
| 993 | + pages_per_seq = num_page_indices // max_num_seqs |
| 994 | + # num_kv_pages_per_block = min(pages_per_seq, 16) |
| 995 | + bkv_p, bq_sz = get_tuned_block_sizes( |
| 996 | + q_nope.dtype, |
| 997 | + q_nope.dtype, # changed to q_nope dtype from mla_kv_cache.dtype |
| 998 | + self.num_query_heads, |
| 999 | + 1, # num_kv_heads for MLA kernel |
| 1000 | + self.qk_nope_head_dim, |
| 1001 | + q_nope.shape[1], # page size ?? kv_cache.shape[1] |
| 1002 | + max_num_tokens, |
| 1003 | + pages_per_seq, |
| 1004 | + ) |
| 1005 | + num_kv_pages_per_block = min(pages_per_seq, bkv_p, 4) |
| 1006 | + num_queries_per_block = min(max_num_tokens, bq_sz, 4) # OOMS at 8 |
| 1007 | + return num_kv_pages_per_block, num_queries_per_block |
| 1008 | + |
| 1009 | + num_kv_pages_per_block, num_queries_per_block = _initialize_block_sizes() |
| 1010 | + output, kv_cache = mla_ragged_paged_attention( |
| 1011 | + q, |
| 1012 | + q_rope, |
| 1013 | + k, |
| 1014 | + k_rope, |
| 1015 | + kv_cache, |
| 1016 | + *args, |
| 1017 | + sm_scale=1.0, |
| 1018 | + num_kv_pages_per_block=num_kv_pages_per_block, |
| 1019 | + num_queries_per_block=num_queries_per_block, |
| 1020 | + ) |
| 1021 | + return kv_cache, output |
| 1022 | + |
| 1023 | + in_specs = ( |
| 1024 | + P(("attn_dp", "model", "expert"), None, None), # q |
| 1025 | + P(("attn_dp", "model", "expert"), None, None), # q_rope |
| 1026 | + P(("attn_dp", "model", "expert"), None), # k |
| 1027 | + P(("attn_dp", "model", "expert"), None), # k_rope |
| 1028 | + P(("attn_dp", "model", "expert")), # kv_cache |
| 1029 | + P(("data", "attn_dp")), # md.seq_lens: Replicated |
| 1030 | + P(("data", "attn_dp")), # page_indices_flat: Replicated |
| 1031 | + P(("data", "attn_dp")), # query_start_loc: Replicated |
| 1032 | + P(("data", "attn_dp")), # distribution: Replicated |
| 1033 | + ) |
| 1034 | + |
| 1035 | + out_specs = (P(("attn_dp", "model", "expert"), None, None), P(("attn_dp", "model", "expert"))) |
| 1036 | + |
| 1037 | + kv_cache, output = jax.jit( |
| 1038 | + shard_map.shard_map( |
| 1039 | + _mla_ragged_paged_attention, |
| 1040 | + mesh=self.mesh, |
| 1041 | + in_specs=in_specs, |
| 1042 | + out_specs=out_specs, |
| 1043 | + check_rep=False, |
| 1044 | + ), |
| 1045 | + )( |
| 1046 | + q_absorbed, |
| 1047 | + q_rope, |
| 1048 | + k_latent, |
| 1049 | + k_rope, |
| 1050 | + mla_kv_cache, |
| 1051 | + md.seq_lens, |
| 1052 | + md.block_tables, |
| 1053 | + md.query_start_loc, |
| 1054 | + md.request_distribution, |
| 1055 | + ) |
| 1056 | + output = jnp.einsum("TNA,ANH->TNH", output, wv_b_kernel) |
| 1057 | + return kv_cache, output |
949 | 1058 |
|
950 | 1059 | def __call__( |
951 | 1060 | self, |
@@ -1001,7 +1110,7 @@ def __call__( |
1001 | 1110 | query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode) |
1002 | 1111 | if self.config.force_q_layout: |
1003 | 1112 | query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1))) |
1004 | | - key, value, cached_values = self.mla_kv_projection( |
| 1113 | + key, value, cached_values, low_rank_main, key_rope = self.mla_kv_projection( |
1005 | 1114 | inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk |
1006 | 1115 | ) |
1007 | 1116 | query = checkpoint_name(query, "query_proj") |
@@ -1034,8 +1143,22 @@ def __call__( |
1034 | 1143 | ) |
1035 | 1144 | unnormalized_out = unnormalized_out[..., : self.v_head_dim] |
1036 | 1145 | out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out |
| 1146 | + elif self.config.attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN and kv_cache is not None: |
| 1147 | + batch, seq_len, num_heads, _ = query.shape |
| 1148 | + query = query.reshape(-1, query.shape[2], query.shape[3]) |
| 1149 | + q_nope, q_rope = jnp.split(query, [self.qk_nope_head_dim], axis=-1) |
| 1150 | + |
| 1151 | + k_latent = low_rank_main.reshape(-1, self.kv_lora_rank) |
| 1152 | + k_rope_squeezed = key_rope.reshape(-1, self.qk_rope_head_dim) |
| 1153 | + |
| 1154 | + updated_kv, attn_out = self.mla_rpa_vllm( |
| 1155 | + q_nope, q_rope, k_latent, k_rope_squeezed, mla_kv_cache=kv_cache, mla_metadata=attention_metadata |
| 1156 | + ) |
| 1157 | + out = attn_out.reshape(batch, seq_len, num_heads, self.v_head_dim) |
| 1158 | + kv_cache = updated_kv |
1037 | 1159 | else: |
1038 | | - # Pass the index_mask to the Attention Op |
| 1160 | + if self.config.attention == "vllm_rpa" and kv_cache is None and model_mode != MODEL_MODE_TRAIN: |
| 1161 | + model_mode = MODEL_MODE_TRAIN |
1039 | 1162 | out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values, index_mask=index_mask) |
1040 | 1163 |
|
1041 | 1164 | if model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: |
|
0 commit comments