diff --git a/inf_cl/flash.py b/inf_cl/flash.py index 99e3fd8..dfc6fba 100644 --- a/inf_cl/flash.py +++ b/inf_cl/flash.py @@ -7,6 +7,28 @@ import triton import triton.language as tl +# utils_tf32.py +import torch, os, warnings + +def tf32_allowed() -> bool: + """ + True → we *can* and *want* to use TF32 (fast path, NVIDIA Ampere/Hopper). + False → fall back to full-FP32 while **disabling** Triton’s TF32 emitter so + kernels still compile on AMD / pre-Ampere. + """ + # 1) Explicit env-var always wins + env = os.getenv("INFCL_FORCE_TF32") + if env is not None: + return env.lower() in ("1", "true", "yes", "y") + + # 2) If we’re on CUDA check the usual PyTorch flag + arch ≥ 80 + if torch.cuda.is_available(): + major, _ = torch.cuda.get_device_capability() + return major >= 8 and torch.backends.cuda.matmul.allow_tf32 + + # 3) ROCm / CPU → never TF32 + return False + @triton.jit def _prob_fwd_kernel( @@ -79,6 +101,7 @@ def _dq_prob_bwd_kernel( BLOCK_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + USE_TF32: tl.constexpr, ): ASM: tl.constexpr = "cvt.rna.tf32.f32 $0, $1;" # start index of sequence length @@ -115,23 +138,31 @@ def _dq_prob_bwd_kernel( qk_grad = tl.exp(qk - lse[:, None]) qk_grad = tl.where((start_n + offs_n)[None, :] < seqlen_k, qk_grad, 0.0) qk_grad = qk_grad * dlse[:, None] - qk_grad = tl.inline_asm_elementwise(ASM, "=r, r", [qk_grad], dtype=tl.float32, is_pure=True, pack=1) + if tl.constexpr(USE_TF32): + # Fast path for Ampere/Hopper: keep NVIDIA’s TF32 inline PTX + qk_grad = tl.inline_asm_elementwise( + ASM, "=r, r", [qk_grad], + dtype=tl.float32, is_pure=True, pack=1 + ) + else: + # Portable path for ROCm / older CUDA: regular cast, no TF32 + qk_grad = qk_grad.to(tl.float32) + for off_h in range(nheads): offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :] # -- fetch q and k of a single head ---- q = tl.load(q_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0) k = tl.load(k_ptrs + offs_hd + start_n * ndims, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0) - # -- compute q grad ---- - # NOTE: tl.float32 adopt tf32, which causes precision inconsistency with torch - # A solution for this problem - # Refer to issue: https://github.com/triton-lang/triton/issues/4574 - # if allow_tf32: - k = tl.inline_asm_elementwise(ASM, "=r, r", [k], dtype=tl.float32, is_pure=True, pack=1) - q_grad = tl.dot(qk_grad, k) - # Another solution for this problem - # Refer to https://github.com/triton-lang/triton/issues/376 - # q_grad = tl.dot(qk_grad, k.to(tl.float32), allow_tf32=False) - # -- store dq ---- + + if tl.constexpr(USE_TF32): + # Fast path: keep NVIDIA TF32 w/ inline PTX + k_cast = tl.inline_asm_elementwise( + ASM, "=r, r", [k], dtype=tl.float32, is_pure=True, pack=1 + ) + q_grad = tl.dot(qk_grad, k_cast) + else: + # Portable path: cast in software, disable TF32 math + q_grad = tl.dot(qk_grad, k.to(tl.float32), allow_tf32=False) dq_h = tl.load(dq_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0) tl.store(dq_ptrs + offs_hd, dq_h + q_grad, mask=offs_m[:, None] < seqlen_q) @@ -149,6 +180,7 @@ def _dk_prob_bwd_kernel( BLOCK_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + USE_TF32: tl.constexpr, ): ASM: tl.constexpr = "cvt.rna.tf32.f32 $0, $1;" # start index of sequence length @@ -184,16 +216,26 @@ def _dk_prob_bwd_kernel( qk += tl.dot(q, tl.trans(k)) qk_grad = tl.exp(qk - lse[:, None]) - qk_grad = tl.where((start_m + offs_m)[:, None] < seqlen_q, qk_grad, 0.0) + qk_grad = tl.where((start_n + offs_n)[None, :] < seqlen_k, + qk_grad, 0.0) qk_grad = qk_grad * dlse[:, None] - qk_grad = tl.inline_asm_elementwise(ASM, "=r, r", [qk_grad], dtype=tl.float32, is_pure=True, pack=1) + + if tl.constexpr(USE_TF32): + # Fast path for Ampere/Hopper: keep NVIDIA’s TF32 inline PTX + qk_grad = tl.inline_asm_elementwise( + ASM, "=r, r", [qk_grad], + dtype=tl.float32, is_pure=True, pack=1 + ) + else: + # Portable path for ROCm / older CUDA: regular cast, no TF32 + qk_grad = qk_grad.to(tl.float32) + for off_h in range(nheads): offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :] # -- fetch q and k of a single head ---- q = tl.load(q_ptrs + offs_hd + start_m * ndims, mask=(start_m + offs_m)[:, None] < seqlen_q, other=0.0) k = tl.load(k_ptrs + offs_hd, mask=(offs_n)[:, None] < seqlen_k, other=0.0) - # -- compute k grad ---- - q = tl.inline_asm_elementwise(ASM, "=r, r", [q], dtype=tl.float32, is_pure=True, pack=1) + k_grad = tl.dot(tl.trans(qk_grad), q) # k_grad = tl.dot(tl.trans(qk_grad), q.to(tl.float32)) # -- store dk ---- @@ -276,6 +318,7 @@ def _flash_prob_backward(q, k, lse, dlse): BLOCK_N=BLOCK_N, num_warps=num_warps, num_stages=num_stages, + USE_TF32=tf32_allowed(), ) BLOCK_N = BLOCK_M @@ -295,6 +338,7 @@ def _flash_prob_backward(q, k, lse, dlse): BLOCK_N=BLOCK_N, num_warps=num_warps, num_stages=num_stages, + USE_TF32=tf32_allowed(), ) dq = dq[:seqlen_q]