Skip to content

Commit 71cbfe5

Browse files
authored
Fix attention fp8 output fusion for split attention path in v1 (ROCm#569)
1 parent 8cde510 commit 71cbfe5

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

vllm/attention/ops/prefix_prefill.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
# To check compatibility
1616
IS_TURING = current_platform.get_device_capability() == (7, 5)
17+
float8_info = torch.finfo(current_platform.fp8_dtype())
1718

1819

1920
# Here's an example autotuner config for this kernel. This config does provide
@@ -82,7 +83,9 @@ def _fwd_kernel(Q,
8283
SKIP_DECODE: tl.constexpr,
8384
USE_FP8: tl.constexpr,
8485
MAX_Q_LEN: tl.constexpr = 0,
85-
MAX_CTX_LEN: tl.constexpr = 0):
86+
MAX_CTX_LEN: tl.constexpr = 0,
87+
FP8_MIN: tl.constexpr = float8_info.min,
88+
FP8_MAX: tl.constexpr = float8_info.max):
8689

8790
cur_batch = tl.program_id(0)
8891
cur_head = tl.program_id(1)
@@ -277,6 +280,7 @@ def _fwd_kernel(Q,
277280
out_ptrs = Out + off_o
278281
if USE_FP8:
279282
acc = acc / tl.load(out_scale)
283+
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
280284
tl.store(out_ptrs,
281285
acc,
282286
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len))

0 commit comments

Comments
 (0)