Skip to content

Commit d08b78b

Browse files
authored
Properly initializing the new field in the attn metadata (ROCm#337)
1 parent 399016d commit d08b78b

17 files changed

+38
-9
lines changed

tests/kernels/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,7 @@ def make_test_metadata(
914914
num_prefills=num_prefills,
915915
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
916916
multi_modal_placeholder_index_maps=None,
917+
enable_kv_scales_calculation=True,
917918
num_prefill_tokens=num_prefill_tokens,
918919
num_decode_tokens=num_decode_tokens,
919920
seq_lens=seq_lens,
@@ -963,6 +964,7 @@ def make_test_metadata(
963964
num_prefills=num_prefills,
964965
slot_mapping=kv_mmap.slot_mapping,
965966
multi_modal_placeholder_index_maps=None,
967+
enable_kv_scales_calculation=True,
966968
num_prefill_tokens=num_prefill_tokens,
967969
num_decode_tokens=num_decode_tokens,
968970
seq_lens=seq_lens,

tests/worker/test_model_input.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def test_model_runner_input():
7474
num_decode_tokens=3,
7575
slot_mapping=torch.zeros(1),
7676
multi_modal_placeholder_index_maps=None,
77+
enable_kv_scales_calculation=True,
7778
)
7879
model_input = ModelInputForGPUWithSamplingMetadata(
7980
input_tokens=torch.ones(10),
@@ -126,6 +127,7 @@ def test_embedding_model_runner_input():
126127
num_decode_tokens=3,
127128
slot_mapping=torch.zeros(1),
128129
multi_modal_placeholder_index_maps=None,
130+
enable_kv_scales_calculation=True,
129131
)
130132
model_input = ModelInputForGPUWithPoolingMetadata(
131133
input_tokens=torch.ones(10),
@@ -177,6 +179,7 @@ def test_multi_step_model_runner_input():
177179
num_decode_tokens=3,
178180
slot_mapping=torch.zeros(1),
179181
multi_modal_placeholder_index_maps=None,
182+
enable_kv_scales_calculation=True,
180183
)
181184
frozen_model_input = ModelInputForGPUWithSamplingMetadata(
182185
input_tokens=torch.ones(10),

vllm/attention/backends/abstract.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
from contextlib import contextmanager
3-
from dataclasses import dataclass, field, fields
3+
from dataclasses import dataclass, fields
44
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
55
Tuple, Type, TypeVar)
66

@@ -126,8 +126,7 @@ class AttentionMetadata:
126126

127127
# Enable/disable KV scales calculation. This is so that we can disable the
128128
# calculation until after prefill and cuda graph capture.
129-
enable_kv_scales_calculation: bool = field(init=False,
130-
default_factory=lambda: True)
129+
enable_kv_scales_calculation: bool
131130

132131
@property
133132
@abstractmethod

vllm/attention/backends/blocksparse_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def prefill_metadata(
222222
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
223223
multi_modal_placeholder_index_maps=self.
224224
multi_modal_placeholder_index_maps,
225+
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
225226
seq_lens=self.seq_lens[:self.num_prefills],
226227
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
227228
max_query_len=self.max_query_len,
@@ -251,6 +252,7 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
251252
num_decode_tokens=self.num_decode_tokens,
252253
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
253254
multi_modal_placeholder_index_maps=None,
255+
enable_kv_scales_calculation=False,
254256
seq_lens=None,
255257
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
256258
max_query_len=None,

vllm/attention/backends/flash_attn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
224224
slot_mapping=slot_mapping,
225225
multi_modal_placeholder_index_maps=self.
226226
multi_modal_placeholder_index_maps,
227+
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
227228
seq_lens=seq_lens,
228229
seq_lens_tensor=seq_lens_tensor,
229230
max_query_len=self.max_query_len,
@@ -268,6 +269,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
268269
num_decode_tokens=self.num_decode_tokens,
269270
slot_mapping=slot_mapping,
270271
multi_modal_placeholder_index_maps=None,
272+
enable_kv_scales_calculation=True,
271273
seq_lens=None,
272274
seq_lens_tensor=seq_lens_tensor,
273275
max_decode_query_len=self.max_decode_query_len,
@@ -550,6 +552,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
550552
num_decode_tokens=num_decode_tokens,
551553
seq_lens=seq_lens,
552554
multi_modal_placeholder_index_maps=placeholder_index_maps,
555+
enable_kv_scales_calculation=True,
553556
seq_lens_tensor=seq_lens_tensor,
554557
max_query_len=max_query_len,
555558
max_decode_query_len=max_decode_query_len,

vllm/attention/backends/flashinfer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def graph_capture_get_metadata_for_batch(
218218
num_prefills=0,
219219
slot_mapping=self._graph_slot_mapping[:batch_size],
220220
multi_modal_placeholder_index_maps=None,
221+
enable_kv_scales_calculation=False,
221222
num_prefill_tokens=0,
222223
num_decode_tokens=batch_size,
223224
max_prefill_seq_len=0,
@@ -711,6 +712,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
711712
num_prefills=self.num_prefills,
712713
slot_mapping=slot_mapping_tensor,
713714
multi_modal_placeholder_index_maps=placeholder_index_maps,
715+
enable_kv_scales_calculation=False,
714716
num_prefill_tokens=self.num_prefill_tokens,
715717
num_decode_tokens=num_decode_tokens,
716718
max_prefill_seq_len=max_prefill_seq_len,

vllm/attention/backends/placeholder_attn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
140140
slot_mapping=slot_mapping,
141141
multi_modal_placeholder_index_maps=self.
142142
multi_modal_placeholder_index_maps,
143+
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
143144
seq_lens=self.seq_lens[:self.num_prefills],
144145
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
145146
max_decode_query_len=0,
@@ -173,6 +174,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
173174
num_decode_tokens=self.num_decode_tokens,
174175
slot_mapping=slot_mapping,
175176
multi_modal_placeholder_index_maps=None,
177+
enable_kv_scales_calculation=True,
176178
seq_lens=None,
177179
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
178180
max_decode_query_len=self.max_decode_query_len,
@@ -378,6 +380,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
378380
num_prefills=self.num_prefills,
379381
slot_mapping=slot_mapping,
380382
multi_modal_placeholder_index_maps=placeholder_index_maps,
383+
enable_kv_scales_calculation=True,
381384
num_prefill_tokens=self.num_prefill_tokens,
382385
num_decode_tokens=num_decode_tokens,
383386
seq_lens=seq_lens,

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
165165
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
166166
multi_modal_placeholder_index_maps=self.
167167
multi_modal_placeholder_index_maps,
168+
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
168169
seq_lens=self.seq_lens[:self.num_prefills],
169170
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
170171
max_query_len=self.max_query_len,
@@ -202,6 +203,7 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
202203
num_decode_tokens=self.num_decode_tokens,
203204
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
204205
multi_modal_placeholder_index_maps=None,
206+
enable_kv_scales_calculation=True,
205207
seq_lens=None,
206208
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
207209
max_query_len=None,

vllm/attention/backends/torch_sdpa.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
372372
prefill_block_tables=prefill_block_tables,
373373
slot_mapping=slot_mapping,
374374
multi_modal_placeholder_index_maps=placeholder_index_maps,
375+
enable_kv_scales_calculation=False,
375376
)
376377

377378
return attn_metadata

vllm/attention/backends/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
274274
num_prefills=self.num_prefills,
275275
slot_mapping=slot_mapping_tensor,
276276
multi_modal_placeholder_index_maps=placeholder_index_maps,
277+
enable_kv_scales_calculation=True,
277278
num_prefill_tokens=self.num_prefill_tokens,
278279
num_decode_tokens=num_decode_tokens,
279280
seq_lens=seq_lens,
@@ -326,6 +327,7 @@ def graph_capture_get_metadata_for_batch(
326327
num_decode_tokens=batch_size,
327328
slot_mapping=self._graph_slot_mapping[:batch_size],
328329
multi_modal_placeholder_index_maps=None,
330+
enable_kv_scales_calculation=True,
329331
seq_lens=None,
330332
seq_lens_tensor=self._graph_seq_lens[:batch_size],
331333
max_query_len=1,

0 commit comments

Comments
 (0)