Skip to content

Commit aed8469

Browse files
[Attention] Make split_decodes_and_prefills(..., require_uniform=True) support padding (#29644)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
1 parent e4605d2 commit aed8469

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

tests/v1/attention/test_attention_splitting.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,10 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
154154

155155

156156
def apply_split_decodes_and_prefills(
157-
query_lens: list[int], decode_threshold: int, require_uniform: bool
157+
query_lens: list[int],
158+
decode_threshold: int,
159+
require_uniform: bool,
160+
padded_num_tokens: int | None = None,
158161
):
159162
"""Helper function to apply split_decodes_and_prefills and return
160163
the results."""
@@ -165,6 +168,10 @@ def apply_split_decodes_and_prefills(
165168
block_size=16,
166169
device=device,
167170
)
171+
172+
if padded_num_tokens is not None:
173+
common_metadata.num_actual_tokens = padded_num_tokens
174+
168175
return split_decodes_and_prefills(
169176
common_metadata,
170177
decode_threshold=decode_threshold,
@@ -271,6 +278,22 @@ def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
271278
assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens
272279

273280

281+
def test_split_decodes_and_prefills_uniform_padded_batch_all_same():
282+
"""uniform batch where all query lengths are identical with 0 length padded reqs."""
283+
# All query lengths are 2, with decode_threshold=3 (so 2 <= 3)
284+
# This triggers the padded uniform path at line 891
285+
query_lens = [2, 2, 2, 0]
286+
padded_num_tokens = 8
287+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
288+
apply_split_decodes_and_prefills(query_lens, 3, True, padded_num_tokens)
289+
)
290+
# With uniform batch, all requests are treated as decodes
291+
assert num_decodes == 4
292+
assert num_prefills == 0
293+
assert num_decode_tokens == padded_num_tokens
294+
assert num_prefill_tokens == 0
295+
296+
274297
@pytest.mark.parametrize(
275298
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
276299
[

vllm/v1/attention/backends/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -883,11 +883,15 @@ def split_decodes_and_prefills(
883883
return 0, num_reqs, 0, num_tokens
884884

885885
if require_uniform:
886+
# check if we are in a padded uniform batch; this is used for full-CGs, some
887+
# requests may have a query length of 0 but since they are padding its fine
888+
# to treat them as decodes (ensures num_decodes matches the captured size)
889+
if torch.all((query_lens == query_lens[0]) | (query_lens == 0)):
890+
assert num_reqs * query_lens[0] == num_tokens, "tokens not padded correctly"
891+
return num_reqs, 0, num_tokens, 0 # all decodes
886892
is_prefill = query_lens != query_lens[0]
887893
else:
888-
# 0-query len indicates a padded request; leave this at the back
889-
# of the batch with the prefills
890-
is_prefill = (query_lens > decode_threshold) | (query_lens == 0)
894+
is_prefill = query_lens > decode_threshold
891895

892896
if not torch.any(is_prefill):
893897
return num_reqs, 0, num_tokens, 0

0 commit comments

Comments
 (0)