Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 60 additions & 16 deletions inf_cl/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,28 @@
import triton
import triton.language as tl

# utils_tf32.py
import torch, os, warnings

def tf32_allowed() -> bool:
"""
True → we *can* and *want* to use TF32 (fast path, NVIDIA Ampere/Hopper).
False → fall back to full-FP32 while **disabling** Triton’s TF32 emitter so
kernels still compile on AMD / pre-Ampere.
"""
# 1) Explicit env-var always wins
env = os.getenv("INFCL_FORCE_TF32")
if env is not None:
return env.lower() in ("1", "true", "yes", "y")

# 2) If we’re on CUDA check the usual PyTorch flag + arch ≥ 80
if torch.cuda.is_available():
major, _ = torch.cuda.get_device_capability()
return major >= 8 and torch.backends.cuda.matmul.allow_tf32

# 3) ROCm / CPU → never TF32
return False


@triton.jit
def _prob_fwd_kernel(
Expand Down Expand Up @@ -79,6 +101,7 @@ def _dq_prob_bwd_kernel(
BLOCK_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
USE_TF32: tl.constexpr,
):
ASM: tl.constexpr = "cvt.rna.tf32.f32 $0, $1;"
# start index of sequence length
Expand Down Expand Up @@ -115,23 +138,31 @@ def _dq_prob_bwd_kernel(
qk_grad = tl.exp(qk - lse[:, None])
qk_grad = tl.where((start_n + offs_n)[None, :] < seqlen_k, qk_grad, 0.0)
qk_grad = qk_grad * dlse[:, None]
qk_grad = tl.inline_asm_elementwise(ASM, "=r, r", [qk_grad], dtype=tl.float32, is_pure=True, pack=1)
if tl.constexpr(USE_TF32):
# Fast path for Ampere/Hopper: keep NVIDIA’s TF32 inline PTX
qk_grad = tl.inline_asm_elementwise(
ASM, "=r, r", [qk_grad],
dtype=tl.float32, is_pure=True, pack=1
)
else:
# Portable path for ROCm / older CUDA: regular cast, no TF32
qk_grad = qk_grad.to(tl.float32)

for off_h in range(nheads):
offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]
# -- fetch q and k of a single head ----
q = tl.load(q_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0)
k = tl.load(k_ptrs + offs_hd + start_n * ndims, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
# -- compute q grad ----
# NOTE: tl.float32 adopt tf32, which causes precision inconsistency with torch
# A solution for this problem
# Refer to issue: https://github.com/triton-lang/triton/issues/4574
# if allow_tf32:
k = tl.inline_asm_elementwise(ASM, "=r, r", [k], dtype=tl.float32, is_pure=True, pack=1)
q_grad = tl.dot(qk_grad, k)
# Another solution for this problem
# Refer to https://github.com/triton-lang/triton/issues/376
# q_grad = tl.dot(qk_grad, k.to(tl.float32), allow_tf32=False)
# -- store dq ----

if tl.constexpr(USE_TF32):
# Fast path: keep NVIDIA TF32 w/ inline PTX
k_cast = tl.inline_asm_elementwise(
ASM, "=r, r", [k], dtype=tl.float32, is_pure=True, pack=1
)
q_grad = tl.dot(qk_grad, k_cast)
else:
# Portable path: cast in software, disable TF32 math
q_grad = tl.dot(qk_grad, k.to(tl.float32), allow_tf32=False)
dq_h = tl.load(dq_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0)
tl.store(dq_ptrs + offs_hd, dq_h + q_grad, mask=offs_m[:, None] < seqlen_q)

Expand All @@ -149,6 +180,7 @@ def _dk_prob_bwd_kernel(
BLOCK_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
USE_TF32: tl.constexpr,
):
ASM: tl.constexpr = "cvt.rna.tf32.f32 $0, $1;"
# start index of sequence length
Expand Down Expand Up @@ -184,16 +216,26 @@ def _dk_prob_bwd_kernel(
qk += tl.dot(q, tl.trans(k))

qk_grad = tl.exp(qk - lse[:, None])
qk_grad = tl.where((start_m + offs_m)[:, None] < seqlen_q, qk_grad, 0.0)
qk_grad = tl.where((start_n + offs_n)[None, :] < seqlen_k,
qk_grad, 0.0)
qk_grad = qk_grad * dlse[:, None]
qk_grad = tl.inline_asm_elementwise(ASM, "=r, r", [qk_grad], dtype=tl.float32, is_pure=True, pack=1)

if tl.constexpr(USE_TF32):
# Fast path for Ampere/Hopper: keep NVIDIA’s TF32 inline PTX
qk_grad = tl.inline_asm_elementwise(
ASM, "=r, r", [qk_grad],
dtype=tl.float32, is_pure=True, pack=1
)
else:
# Portable path for ROCm / older CUDA: regular cast, no TF32
qk_grad = qk_grad.to(tl.float32)

for off_h in range(nheads):
offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]
# -- fetch q and k of a single head ----
q = tl.load(q_ptrs + offs_hd + start_m * ndims, mask=(start_m + offs_m)[:, None] < seqlen_q, other=0.0)
k = tl.load(k_ptrs + offs_hd, mask=(offs_n)[:, None] < seqlen_k, other=0.0)
# -- compute k grad ----
q = tl.inline_asm_elementwise(ASM, "=r, r", [q], dtype=tl.float32, is_pure=True, pack=1)

k_grad = tl.dot(tl.trans(qk_grad), q)
# k_grad = tl.dot(tl.trans(qk_grad), q.to(tl.float32))
# -- store dk ----
Expand Down Expand Up @@ -276,6 +318,7 @@ def _flash_prob_backward(q, k, lse, dlse):
BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=num_stages,
USE_TF32=tf32_allowed(),
)

BLOCK_N = BLOCK_M
Expand All @@ -295,6 +338,7 @@ def _flash_prob_backward(q, k, lse, dlse):
BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=num_stages,
USE_TF32=tf32_allowed(),
)

dq = dq[:seqlen_q]
Expand Down