Skip to content

Commit 1dcd9fe

Browse files
mawong-amdgshtras
andauthored
Ingest FP8 attn scales and use them in ROCm FlashAttention (ROCm#338)
* Ingest FP8 attn scales and use them in Triton FA, if present * Disabling calc_kv_scales if the checkoint has them. Enabling fp8 attention for dynamic quantization * q_range as an env * format * Dedupe FA/PA attn toggles, set FA off by default * Lint again, to fixed point * Don't calculate KV scales dynamically if Q scale is included --------- Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
1 parent d08b78b commit 1dcd9fe

File tree

20 files changed

+157
-81
lines changed

20 files changed

+157
-81
lines changed

vllm/attention/backends/abstract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,6 @@ def forward(
252252
v_scale: torch.Tensor,
253253
attn_type: str = AttentionType.DECODER,
254254
output: Optional[torch.Tensor] = None,
255-
fp8_out_scale: Optional[torch.Tensor] = None,
255+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
256256
) -> torch.Tensor:
257257
raise NotImplementedError

vllm/attention/backends/blocksparse_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def forward(
363363
v_scale: torch.Tensor,
364364
attn_type: str = AttentionType.DECODER,
365365
output: Optional[torch.Tensor] = None,
366-
fp8_out_scale: Optional[torch.Tensor] = None,
366+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
367367
) -> torch.Tensor:
368368
"""Forward pass with FlashAttention and PagedAttention.
369369

vllm/attention/backends/flash_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ def forward(
642642
v_scale: float = 1.0,
643643
attn_type: str = AttentionType.DECODER,
644644
output: Optional[torch.Tensor] = None,
645-
fp8_out_scale: Optional[torch.Tensor] = None,
645+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
646646
) -> torch.Tensor:
647647
"""Forward pass with FlashAttention.
648648

vllm/attention/backends/flashinfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -777,7 +777,7 @@ def forward(
777777
v_scale: float = 1.0,
778778
attn_type: str = AttentionType.DECODER,
779779
output: Optional[torch.Tensor] = None,
780-
fp8_out_scale: Optional[torch.Tensor] = None,
780+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
781781
) -> torch.Tensor:
782782

783783
# TODO: directly write to output tensor

vllm/attention/backends/hpu_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def forward(
154154
v_scale: float = 1.0,
155155
attn_type: str = AttentionType.DECODER,
156156
output: Optional[torch.Tensor] = None,
157-
fp8_out_scale: Optional[torch.Tensor] = None,
157+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
158158
) -> torch.Tensor:
159159
"""Forward pass with xFormers and PagedAttention.
160160

vllm/attention/backends/ipex_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def forward(
174174
v_scale: float = 1.0,
175175
attn_type: str = AttentionType.DECODER,
176176
output: Optional[torch.Tensor] = None,
177-
fp8_out_scale: Optional[torch.Tensor] = None,
177+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
178178
) -> torch.Tensor:
179179
"""Forward pass with IPEX varlen_attention and PagedAttention.
180180

vllm/attention/backends/pallas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def forward(
152152
v_scale: float = 1.0,
153153
attn_type: str = AttentionType.DECODER,
154154
output: Optional[torch.Tensor] = None,
155-
fp8_out_scale: Optional[torch.Tensor] = None,
155+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
156156
) -> torch.Tensor:
157157
"""Forward pass with Pallas attention.
158158

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ def forward(
551551
v_scale: torch.Tensor,
552552
attn_type: str = AttentionType.DECODER,
553553
output: Optional[torch.Tensor] = None,
554-
fp8_out_scale: torch.Tensor = None,
554+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
555555
) -> torch.Tensor:
556556
"""Forward pass with FlashAttention and PagedAttention.
557557
@@ -601,6 +601,8 @@ def forward(
601601
Returns:
602602
shape = [num_tokens, num_heads * head_size]
603603
"""
604+
q_scale, prob_scale, fp8_out_scale = fp8_comp_scales or (None, None,
605+
None)
604606

605607
query = query.view(-1, self.num_heads, self.head_size)
606608
if key is not None:
@@ -681,6 +683,12 @@ def forward(
681683
query.dtype,
682684
seq_lens,
683685
make_attn_mask=False) # type: ignore
686+
full_scales = (
687+
1.0 / q_scale.item(), 1.0 / k_scale.item(),
688+
1.0 / v_scale.item(), 1.0 / prob_scale.item(),
689+
fp8_out_scale.item()) if (
690+
fp8_out_scale
691+
and envs.VLLM_USE_ROCM_FP8_FLASH_ATTN) else None
684692
out, _ = self.attn_func(
685693
query,
686694
key,
@@ -694,7 +702,7 @@ def forward(
694702
self.scale,
695703
attn_masks[0][None]
696704
if attn_masks is not None else None,
697-
None,
705+
full_scales,
698706
)
699707
elif self.use_naive_attn:
700708
if self.num_kv_heads != self.num_heads:

vllm/attention/backends/torch_sdpa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ def forward(
434434
v_scale: float = 1.0,
435435
attn_type: str = AttentionType.DECODER,
436436
output: Optional[torch.Tensor] = None,
437-
fp8_out_scale: Optional[torch.Tensor] = None,
437+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
438438
) -> torch.Tensor:
439439
"""Forward pass with torch SDPA and PagedAttention.
440440

vllm/attention/backends/xformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def forward(
420420
v_scale: float = 1.0,
421421
attn_type: str = AttentionType.DECODER,
422422
output: Optional[torch.Tensor] = None,
423-
fp8_out_scale: Optional[torch.Tensor] = None,
423+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
424424
) -> torch.Tensor:
425425
"""Forward pass with xFormers and PagedAttention.
426426

0 commit comments

Comments
 (0)