@@ -140,6 +140,7 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
140140 slot_mapping = slot_mapping ,
141141 multi_modal_placeholder_index_maps = self .
142142 multi_modal_placeholder_index_maps ,
143+ enable_kv_scales_calculation = self .enable_kv_scales_calculation ,
143144 seq_lens = self .seq_lens [:self .num_prefills ],
144145 seq_lens_tensor = self .seq_lens_tensor [:self .num_prefills ],
145146 max_decode_query_len = 0 ,
@@ -173,6 +174,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
173174 num_decode_tokens = self .num_decode_tokens ,
174175 slot_mapping = slot_mapping ,
175176 multi_modal_placeholder_index_maps = None ,
177+ enable_kv_scales_calculation = True ,
176178 seq_lens = None ,
177179 seq_lens_tensor = self .seq_lens_tensor [self .num_prefills :],
178180 max_decode_query_len = self .max_decode_query_len ,
@@ -378,6 +380,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
378380 num_prefills = self .num_prefills ,
379381 slot_mapping = slot_mapping ,
380382 multi_modal_placeholder_index_maps = placeholder_index_maps ,
383+ enable_kv_scales_calculation = True ,
381384 num_prefill_tokens = self .num_prefill_tokens ,
382385 num_decode_tokens = num_decode_tokens ,
383386 seq_lens = seq_lens ,
0 commit comments