From 333f2515880de1ab34d7c997c4ba011f36edc431 Mon Sep 17 00:00:00 2001 From: Mohit Khatwani Date: Tue, 3 Feb 2026 19:10:01 +0000 Subject: [PATCH] deepseek sharding and mla attention plumbing --- src/MaxText/layers/attention_mla.py | 133 ++++++++++++++++++++++++- src/MaxText/layers/deepseek.py | 26 +++-- src/MaxText/layers/moe.py | 20 ++-- src/maxtext/configs/base.yml | 3 + src/maxtext/configs/inference/vllm.yml | 1 + src/maxtext/configs/post_train/rl.yml | 2 + src/maxtext/configs/types.py | 1 + src/maxtext/configs/vllm_deepseek.yml | 69 +++++++++++++ src/maxtext/vllm_decode.py | 18 +++- 9 files changed, 250 insertions(+), 23 deletions(-) create mode 100644 src/maxtext/configs/vllm_deepseek.yml diff --git a/src/MaxText/layers/attention_mla.py b/src/MaxText/layers/attention_mla.py index 7ab6d241c2..799db9bc9f 100644 --- a/src/MaxText/layers/attention_mla.py +++ b/src/MaxText/layers/attention_mla.py @@ -21,6 +21,8 @@ import jax from jax.ad_checkpoint import checkpoint_name from jax.experimental import layout +from jax.sharding import PartitionSpec as P +from jax.experimental import shard_map import jax.numpy as jnp from jax.sharding import Mesh, NamedSharding @@ -619,7 +621,11 @@ def __init__( ) # Module attribute names must match names previously passed to Linen for checkpointing - self.MlaKVCache_0 = self.init_mla_kv_caches(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None + self.MlaKVCache_0 = ( + self.init_mla_kv_caches(inputs_kv_shape) + if model_mode != MODEL_MODE_TRAIN and config.attention != "vllm_rpa" + else None + ) def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None: """Initializes the MLA-specific projections.""" @@ -937,7 +943,7 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm key, value = self.mla_get_key_value(low_rank_main, key_rope, model_mode) cached_values = [None, None] - if self.config.attention != "paged" and model_mode != MODEL_MODE_TRAIN: + if self.config.attention != "paged" and self.config.attention != "vllm_rpa" and model_mode != MODEL_MODE_TRAIN: if self.config.mla_naive_kvcache: cached_values = self.update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk) else: @@ -945,7 +951,110 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm low_rank_main, key_rope, decoder_segment_ids, model_mode, previous_chunk ) - return key, value, cached_values + return key, value, cached_values, low_rank_main, key_rope + + def mla_rpa_vllm(self, q_nope, q_rope, k_latent, k_rope, mla_kv_cache, mla_metadata): + """Forward function for vLLM serving with MLA attention. + + Args: + q_nope: Query nope part [T, N, qk_nope_head_dim] + q_rope: Query rope part [T, N, qk_rope_head_dim] + k_latent: Latent KV representation [S, kv_lora_rank] (NOT expanded k_nope) + k_rope: Key rope part [S, qk_rope_head_dim] (NO head dimension) + mla_kv_cache: The KV cache + mla_metadata: Attention metadata + """ + md = mla_metadata + try: + # pylint: disable=import-outside-toplevel + # pytype: disable=import-error + from tpu_inference.kernels.mla.v1.kernel import mla_ragged_paged_attention + from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import get_tuned_block_sizes + except ImportError as e: + raise ImportError( + "vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`." + ) from e + + if mla_kv_cache is None or mla_metadata is None: + raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.") + + wkv_b_kernel = self.wkv_b.kernel.value + wk_b_kernel = wkv_b_kernel[..., : self.qk_nope_head_dim] + wv_b_kernel = wkv_b_kernel[..., self.qk_nope_head_dim :] + q_absorbed = jnp.einsum("TNH,ANH->TNA", q_nope, wk_b_kernel) + + def _mla_ragged_paged_attention(q, q_rope, k, k_rope, kv_cache, *args): + def _initialize_block_sizes(): + # Set reasonable starting estimates for block sizes. (TODO(gpolovets): update this to use tuned sizes) + max_num_tokens = q_absorbed.shape[0] + max_num_seqs = md.seq_lens.shape[0] + num_page_indices = md.block_tables.shape[0] + assert num_page_indices % max_num_seqs == 0 + pages_per_seq = num_page_indices // max_num_seqs + # num_kv_pages_per_block = min(pages_per_seq, 16) + bkv_p, bq_sz = get_tuned_block_sizes( + q_nope.dtype, + q_nope.dtype, # changed to q_nope dtype from mla_kv_cache.dtype + self.num_query_heads, + 1, # num_kv_heads for MLA kernel + self.qk_nope_head_dim, + q_nope.shape[1], # page size ?? kv_cache.shape[1] + max_num_tokens, + pages_per_seq, + ) + num_kv_pages_per_block = min(pages_per_seq, bkv_p, 4) + num_queries_per_block = min(max_num_tokens, bq_sz, 4) # OOMS at 8 + return num_kv_pages_per_block, num_queries_per_block + + num_kv_pages_per_block, num_queries_per_block = _initialize_block_sizes() + output, kv_cache = mla_ragged_paged_attention( + q, + q_rope, + k, + k_rope, + kv_cache, + *args, + sm_scale=1.0, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + ) + return kv_cache, output + + in_specs = ( + P(("attn_dp", "model", "expert"), None, None), # q + P(("attn_dp", "model", "expert"), None, None), # q_rope + P(("attn_dp", "model", "expert"), None), # k + P(("attn_dp", "model", "expert"), None), # k_rope + P(("attn_dp", "model", "expert")), # kv_cache + P(("data", "attn_dp")), # md.seq_lens: Replicated + P(("data", "attn_dp")), # page_indices_flat: Replicated + P(("data", "attn_dp")), # query_start_loc: Replicated + P(("data", "attn_dp")), # distribution: Replicated + ) + + out_specs = (P(("attn_dp", "model", "expert"), None, None), P(("attn_dp", "model", "expert"))) + + kv_cache, output = jax.jit( + shard_map.shard_map( + _mla_ragged_paged_attention, + mesh=self.mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False, + ), + )( + q_absorbed, + q_rope, + k_latent, + k_rope, + mla_kv_cache, + md.seq_lens, + md.block_tables, + md.query_start_loc, + md.request_distribution, + ) + output = jnp.einsum("TNA,ANH->TNH", output, wv_b_kernel) + return kv_cache, output def __call__( self, @@ -1001,7 +1110,7 @@ def __call__( query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode) if self.config.force_q_layout: query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1))) - key, value, cached_values = self.mla_kv_projection( + key, value, cached_values, low_rank_main, key_rope = self.mla_kv_projection( inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk ) query = checkpoint_name(query, "query_proj") @@ -1032,8 +1141,22 @@ def __call__( ) unnormalized_out = unnormalized_out[..., : self.v_head_dim] out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out + elif self.config.attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN and kv_cache is not None: + batch, seq_len, num_heads, _ = query.shape + query = query.reshape(-1, query.shape[2], query.shape[3]) + q_nope, q_rope = jnp.split(query, [self.qk_nope_head_dim], axis=-1) + + k_latent = low_rank_main.reshape(-1, self.kv_lora_rank) + k_rope_squeezed = key_rope.reshape(-1, self.qk_rope_head_dim) + + updated_kv, attn_out = self.mla_rpa_vllm( + q_nope, q_rope, k_latent, k_rope_squeezed, mla_kv_cache=kv_cache, mla_metadata=attention_metadata + ) + out = attn_out.reshape(batch, seq_len, num_heads, self.v_head_dim) + kv_cache = updated_kv else: - # Pass the index_mask to the Attention Op + if self.config.attention == "vllm_rpa" and kv_cache is None and model_mode != MODEL_MODE_TRAIN: + model_mode = MODEL_MODE_TRAIN out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values, index_mask=index_mask) out = jax.ad_checkpoint.checkpoint_name(out, "attention_out") diff --git a/src/MaxText/layers/deepseek.py b/src/MaxText/layers/deepseek.py index cb473e445e..910aa507ad 100644 --- a/src/MaxText/layers/deepseek.py +++ b/src/MaxText/layers/deepseek.py @@ -16,7 +16,7 @@ # pylint: disable=arguments-differ # pylint: disable=no-name-in-module -from typing import Optional +from typing import Optional, Any from flax import nnx from jax.ad_checkpoint import checkpoint_name @@ -154,9 +154,11 @@ def attention_op( previous_chunk=None, page_state: None | page_manager.PageState = None, slot: None | int = None, + kv_cache: None | jnp.ndarray = None, + attention_metadata: None | dict[str, Any] = None, ): """Executes the attention layer.""" - attention_result, _ = self.self_attention( + attention_result, kv_cache = self.self_attention( x, x, decoder_positions, @@ -167,8 +169,10 @@ def attention_op( previous_chunk=previous_chunk, page_state=page_state, slot=slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) - return self.with_logical_constraint(attention_result) + return self.with_logical_constraint(attention_result), kv_cache @property def logical_axis_names(self): @@ -229,11 +233,13 @@ def self_attention_with_norm_op( previous_chunk=None, page_state: None | page_manager.PageState = None, slot: None | int = None, + kv_cache: None | jnp.ndarray = None, + attention_metadata: None | dict[str, Any] = None, ): """self-attention with normalization""" lnx = self.pre_attention_norm_op(inputs) - attention_lnx = self.attention_op( + attention_lnx, kv_cache = self.attention_op( lnx, decoder_segment_ids, decoder_positions, @@ -241,11 +247,13 @@ def self_attention_with_norm_op( previous_chunk, page_state, slot, + kv_cache, + attention_metadata, ) intermediate_inputs = inputs + attention_lnx # Normalization hidden_states = self.post_attention_norm_op(intermediate_inputs) - return hidden_states, intermediate_inputs + return hidden_states, intermediate_inputs, kv_cache class DeepSeekDenseLayer(DeepSeekGenericLayer): @@ -298,7 +306,7 @@ def __call__( x = self.with_logical_constraint(inputs) x = checkpoint_name(x, "decoder_layer_input") - hidden_states, intermediate_inputs = self.self_attention_with_norm_op( + hidden_states, intermediate_inputs, kv_cache = self.self_attention_with_norm_op( x, decoder_segment_ids, decoder_positions, @@ -306,6 +314,8 @@ def __call__( previous_chunk, page_state, slot, + kv_cache, + attention_metadata, ) mlp_lnx = self.mlp_op(hidden_states, deterministic) @@ -384,7 +394,7 @@ def __call__( x = self.with_logical_constraint(inputs) x = checkpoint_name(x, "decoder_layer_input") - hidden_states, intermediate_inputs = self.self_attention_with_norm_op( + hidden_states, intermediate_inputs, kv_cache = self.self_attention_with_norm_op( x, decoder_segment_ids, decoder_positions, @@ -392,6 +402,8 @@ def __call__( previous_chunk, page_state, slot, + kv_cache, + attention_metadata, ) mlp_lnx, load_balance_loss, moe_bias_updates = self.mlp_op(hidden_states, deterministic) diff --git a/src/MaxText/layers/moe.py b/src/MaxText/layers/moe.py index 7c54faf5a0..1c1d49cbcd 100644 --- a/src/MaxText/layers/moe.py +++ b/src/MaxText/layers/moe.py @@ -351,16 +351,16 @@ def __init__( if self.config.shard_exp_on_fsdp: # special sharding for dsv3 - self.wi_kernel_axes = ("embed_no_exp", None, "mlp") - self.wo_kernel_axes = ("embed_no_exp", "mlp", None) + self.wi_kernel_axes = ("embed_no_exp", None, "moe_mlp") + self.wo_kernel_axes = ("embed_no_exp", "moe_mlp", None) elif self.config.use_2d_fsdp_sharding: self.wi_kernel_axes = ("embed_no_exp", "mlp", None) self.wo_kernel_axes = ("embed_no_exp", "mlp", None) elif self.config.use_batch_split_schedule: self.wi_kernel_axes, self.wo_kernel_axes = get_batchsplit_init_kernel_axes() else: - self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp") - self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp") + self.wi_kernel_axes = ("exp", "embed_no_exp", "moe_mlp") + self.wo_kernel_axes = ("exp", "moe_mlp", "embed_no_exp") if self.config.attention == "vllm_rpa": # vLLM uses 'model' as the tensor parallelism axis name @@ -1377,11 +1377,11 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): if self.config.moe_fsdp_use_two_stage_all_gather: # Unshard on fsdp axis - w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp")) - w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp")) + w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "moe_mlp")) + w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "moe_mlp")) # Unshard on fsdp_transpose axis - wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "mlp", "embed_tensor_transpose")) + wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "moe_mlp", "embed_tensor_transpose")) # Make sure XLA does not optimize by combining above All-Gather to unshard # on FSDP axis and the subsequent unshard on fsdp_transpose axis @@ -1829,7 +1829,7 @@ def dense_matmul( dispatch_axis, ) with jax.named_scope("wi_0"): - w0_kernel_axes = ("exp", None, "mlp") + w0_kernel_axes = ("exp", None, "moe_mlp") w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes) layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)( mlp_up_einsum, dispatch, w0_kernel, precision=matmul_precision @@ -1846,7 +1846,7 @@ def dense_matmul( ) layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0") with jax.named_scope("wi_1"): - w1_kernel_axes = ("exp", None, "mlp") + w1_kernel_axes = ("exp", None, "moe_mlp") w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes) layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)( mlp_up_einsum, dispatch, w1_kernel, precision=matmul_precision @@ -1863,7 +1863,7 @@ def dense_matmul( layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1") layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1) with jax.named_scope("wo"): - wo_kernel_axes = ("exp", "mlp", None) + wo_kernel_axes = ("exp", "moe_mlp", None) wo_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(wo_kernel, wo_kernel_axes) intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)( mlp_down_einsum, diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 632f3345b7..9d3dd1bbb2 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -443,6 +443,7 @@ logical_axis_rules: [ ['decode_length', ['sequence']], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], + ['moe_mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], @@ -1083,6 +1084,8 @@ use_jax_splash: false # vLLM Adapter Configurations # Path to the HuggingFace-style config directory for the adapter (e.g. src/MaxText/integration/vllm/maxtext_vllm_adapter) vllm_hf_config_path: "" +# Path to yaml file for loading vLLM config +vllm_config_path: "" # JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}') vllm_additional_config: {} # When use_jax_splash=True, force the layout of the query tensor to be [..., NUM_HEADS, HEAD_DIM, SEQ_LENGTH] diff --git a/src/maxtext/configs/inference/vllm.yml b/src/maxtext/configs/inference/vllm.yml index 21ca47410e..67e32f89ee 100644 --- a/src/maxtext/configs/inference/vllm.yml +++ b/src/maxtext/configs/inference/vllm.yml @@ -53,6 +53,7 @@ logical_axis_rules: [ ['decode_length', []], ['mlp', ['model', 'attn_dp']], ['mlp_no_fsdp', ['model', 'attn_dp']], + ['moe_mlp', ['model', 'attn_dp']], ['vocab', ['model', 'attn_dp']], ['heads', ['model']], ['q_heads', ['model']], diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml index 9d741e7a8c..410d1a13e7 100644 --- a/src/maxtext/configs/post_train/rl.yml +++ b/src/maxtext/configs/post_train/rl.yml @@ -149,6 +149,8 @@ enable_dp_attention: False # Performance tuning for samplers max_num_batched_tokens: null max_num_seqs: null +# path to initialize vllm config +vllm_config_path: 'src/MaxText/configs/vllm.yml' # ====== Checkpoint Configuration ====== enable_checkpointing: True diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index e651293a19..4a7c8035c9 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1557,6 +1557,7 @@ class VLLM(BaseModel): max_num_seqs: Optional[int] = Field(None, description="Max number of sequences in vLLM.") vllm_additional_config: dict[str, Any] = Field(default_factory=dict, description="Additional vLLM config options.") vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.") + vllm_config_path: str = Field("src/MaxText/configs/vllm.yml", description="path to yaml file for loading vLLM config.") class RL(BaseModel): diff --git a/src/maxtext/configs/vllm_deepseek.yml b/src/maxtext/configs/vllm_deepseek.yml new file mode 100644 index 0000000000..1b6f7f2a54 --- /dev/null +++ b/src/maxtext/configs/vllm_deepseek.yml @@ -0,0 +1,69 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +base_config: "vllm.yml" + +logical_axis_rules: [ + ['activation_batch', ['']], + ['activation_batch_no_exp', []], + ['activation_embed_and_logits_batch', ['expert']], + ['activation_embed_and_logits_batch_sequence', ['expert']], + ['activation_heads', ['model']], + ['activation_kv_heads', ['model']], + ['activation_attn_length', ['expert']], + ['activation_attn_length_no_exp', []], + ['activation_length', ['data', 'expert']], + ['activation_length_no_exp', 'data'], + ['activation_q_length', ['expert']], + ['activation_attn_embed', 'model'], + ['activation_embed', ['model', 'attn_dp']], + ['activation_mlp', ['model', 'attn_dp', 'expert']], + ['activation_kv', ['model']], + ['activation_prefill_kv_batch', ['expert']], + ['activation_kv_batch', ['']], + ['activation_kv_batch_no_exp', []], + ['activation_kv_head_dim', ['model', 'attn_dp', 'expert']], + ['activation_vocab', ['model', 'attn_dp']], + ['activation_norm_length', []], + ['activation_exp', ['expert']], + ['decode_batch', ['expert']], + ['decode_length', []], + ['mlp_no_fsdp', ['model', 'attn_dp', 'expert']], + ['vocab', ['model', 'attn_dp', 'expert']], + ['heads', ['expert', 'attn_dp', 'model']], + ['q_heads', []], + ['kv_heads', []], + ['kv_head_dim', ['model', 'attn_dp', 'expert']], + ['kv', ['model', 'attn_dp', 'expert']], + ['kv', []], + ['embed', []], + ['mlp', ['model', 'attn_dp', 'expert']], + ['moe_mlp', []], + ['embed_tensor_transpose', ['attn_dp', 'model']], + ['embed_no_exp', []], + ['q_lora', []], + ['kv_lora', []], + ['norm', []], + ['cache_heads', ['model']], + ['exp', ['expert', 'attn_dp', 'model']], + ['paged_kv_heads', ['model']], + ['cache_batch_prefill', []], + ['cache_batch', []], + ['cache_sequence', []], + ['cache_heads_none', []], + ['cache_kv', []], + ['kv_lora_up_proj',['expert', 'attn_dp', 'model']], + ['q_lora_up_proj',['expert', 'attn_dp', 'model']], + ] \ No newline at end of file diff --git a/src/maxtext/vllm_decode.py b/src/maxtext/vllm_decode.py index 2e532d63af..ad0fef508e 100644 --- a/src/maxtext/vllm_decode.py +++ b/src/maxtext/vllm_decode.py @@ -80,15 +80,24 @@ flags.DEFINE_integer("max_prefill_length", 512, "Maximum prefill length.") flags.DEFINE_float("gpu_memory_utilization", 0.72, "Fraction of GPU memory to be used for the model executor.") +# vllm config variables +flags.DEFINE_integer("vllm_swap_space", 2, "per device swap space in GB") +flags.DEFINE_integer("vllm_async_scheduling", 1, "Async DP Scheduler for vLLM") + # Decoding flags.DEFINE_bool("use_tunix", False, "Whether to use Tunix for vLLM decoding.") flags.DEFINE_string("prompt", "Suggest some famous landmarks in London.", "The prompt to decode.") flags.DEFINE_integer("decode_sampling_temperature", 0, "Temperature for sampling.") flags.DEFINE_integer("decode_sampling_nucleus_p", 1, "Nucleus sampling probability.") flags.DEFINE_integer("decode_sampling_top_k", 1, "Top-k sampling probability.") +flags.DEFINE_string( + "vllm_config_path", + "src/MaxText/configs/vllm.yml", + "Path to vLLM config file. Defaults to MAXTEXT_PKG_DIR/configs/vllm.yml.", +) # Mark required flags -flags.mark_flag_as_required("hf_config_path") +# flags.mark_flag_as_required("hf_config_path") def decode_with_vllm( @@ -103,6 +112,8 @@ def decode_with_vllm( max_prefill_length: int, max_target_length: int, gpu_memory_utilization: float, + vllm_swap_space: int, + vllm_async_scheduling: int, enable_expert_parallel: bool, prompt: str, decode_sampling_temperature: float, @@ -145,6 +156,8 @@ def decode_with_vllm( vllm_args["enable_expert_parallel"] = enable_expert_parallel vllm_args["hf_config_path"] = hf_config_path vllm_args["gpu_memory_utilization"] = gpu_memory_utilization + vllm_args["swap_space"] = vllm_swap_space + vllm_args["async_scheduling"] = vllm_async_scheduling # Prepare MaxText and sharding configs (Parallelism is dynamic) vllm_args["additional_config"]["maxtext_config"] = { @@ -291,12 +304,15 @@ def main(argv: Sequence[str]) -> None: max_target_length=FLAGS.max_target_length, max_prefill_length=FLAGS.max_prefill_length, gpu_memory_utilization=FLAGS.gpu_memory_utilization, + vllm_swap_space=FLAGS.vllm_swap_space, + vllm_async_scheduling=FLAGS.vllm_async_scheduling, enable_expert_parallel=FLAGS.enable_expert_parallel, prompt=FLAGS.prompt, decode_sampling_temperature=FLAGS.decode_sampling_temperature, decode_sampling_nucleus_p=FLAGS.decode_sampling_nucleus_p, decode_sampling_top_k=FLAGS.decode_sampling_top_k, debug_sharding=FLAGS.debug_sharding, + vllm_config_path=FLAGS.vllm_config_path, )