File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed
Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change 1414
1515# To check compatibility
1616IS_TURING = current_platform .get_device_capability () == (7 , 5 )
17+ float8_info = torch .finfo (current_platform .fp8_dtype ())
1718
1819
1920# Here's an example autotuner config for this kernel. This config does provide
@@ -82,7 +83,9 @@ def _fwd_kernel(Q,
8283 SKIP_DECODE : tl .constexpr ,
8384 USE_FP8 : tl .constexpr ,
8485 MAX_Q_LEN : tl .constexpr = 0 ,
85- MAX_CTX_LEN : tl .constexpr = 0 ):
86+ MAX_CTX_LEN : tl .constexpr = 0 ,
87+ FP8_MIN : tl .constexpr = float8_info .min ,
88+ FP8_MAX : tl .constexpr = float8_info .max ):
8689
8790 cur_batch = tl .program_id (0 )
8891 cur_head = tl .program_id (1 )
@@ -277,6 +280,7 @@ def _fwd_kernel(Q,
277280 out_ptrs = Out + off_o
278281 if USE_FP8 :
279282 acc = acc / tl .load (out_scale )
283+ acc = tl .clamp (acc , FP8_MIN , FP8_MAX )
280284 tl .store (out_ptrs ,
281285 acc ,
282286 mask = dim_mask [None , :] & (offs_m [:, None ] < cur_batch_query_len ))
You can’t perform that action at this time.
0 commit comments