Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
289 changes: 184 additions & 105 deletions benchmarks/attention/benchmark_attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
#
# Note: A "flash-attn v3" warning may appear from Transformer Engine. The script does not
# install flash-attn; it uses whatever is already installed (e.g. v2). The warning suggests
# installing v3 for Hopper+ for better support; timings are still from the active backend.

import os, sys, time
import subprocess
Expand All @@ -9,10 +13,14 @@
import torch
import nvtx
import transformer_engine
from tests.pytorch.utils import (
ModelConfig,
get_available_attention_backends,
)

# Add project root so "tests" can be imported when run from any directory
_script_dir = os.path.dirname(os.path.abspath(__file__))
_project_root = os.path.dirname(os.path.dirname(_script_dir)) # repo root (parent of benchmarks/)
if _project_root not in sys.path:
sys.path.insert(0, _project_root)

from tests.pytorch.utils import ModelConfig, get_available_attention_backends
from tests.pytorch.attention.test_attention import _run_dot_product_attention

pd.set_option("display.precision", 4)
Expand All @@ -32,16 +40,39 @@
# training mode
is_training = True

# Substrings to match kernel names in nsys cuda_gpu_trace CSV (case-insensitive).
# If profiling output changes, update these (e.g. cuDNN may use "cudnn" or "cuda", flash may use "flash" or "fmha").
KERNEL_NAME_CUDNN = "cudnn"
KERNEL_NAME_FLASH = "flash"

model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
# ModelConfig(batch_size, max_seqlen_q, num_heads, head_dim_qk, max_seqlen_kv, num_gqa_groups, ...)
"test_0": ModelConfig(
2, 512, 16, 64, 512, 16, dropout_p=0.0, attn_mask_type="no_mask", attn_bias_type="no_bias"
), # short seq
"test_1": ModelConfig(
2, 2048, 16, 128, 2048, 16, dropout_p=0.0, attn_mask_type="causal", attn_bias_type="no_bias"
), # longer seq, mask
"test_2": ModelConfig(
2,
2048,
16,
128,
2048,
16,
dropout_p=0.0,
attn_mask_type="causal",
attn_bias_type="post_scale_bias",
), # bias; FlashAttention does not support post_scale_bias, so only cuDNN runs
"test_3": ModelConfig(
2, 8192, 32, 128, 8192, 4, dropout_p=0.0, attn_mask_type="causal", attn_bias_type="no_bias"
), # GQA
}


def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supported):
def benchmark_dot_product_attention(
model, fused_attn_supported, flash_attn_supported, append_csv=True
):
config = model_configs[model]
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
Expand All @@ -53,7 +84,7 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
warmup_iters = 3
for i in range(warmup_iters):
if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
fused_attn_fwd, _, fused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
Expand All @@ -64,7 +95,7 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
is_training,
)
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
flash_attn_fwd, _, flash_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FlashAttention",
Expand All @@ -77,14 +108,15 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
if fused_attn_supported and flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_bwd[i] is not None and flash_attn_bwd[i] is not None:
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)

torch.cuda.cudart().cudaProfilerStart()
torch.cuda.synchronize()
fused_attn_start = time.time()
if fused_attn_supported:
for i in range(num_iters):
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
_run_dot_product_attention(
dtype,
config,
"FusedAttention",
Expand All @@ -101,7 +133,7 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
flash_attn_start = time.time()
if flash_attn_supported:
for i in range(num_iters):
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
_run_dot_product_attention(
dtype,
config,
"FlashAttention",
Expand All @@ -114,61 +146,75 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
torch.cuda.synchronize()
flash_attn_time = time.time() - flash_attn_start if flash_attn_supported else 0

df = pd.read_csv("times.csv")
df = pd.concat(
[
df,
pd.DataFrame(
[
if append_csv:
df = pd.read_csv("times.csv")
df = pd.concat(
[
df,
pd.DataFrame(
[
fused_attn_time * 1e3 / num_iters,
0,
0,
0,
flash_attn_time * 1e3 / num_iters,
0,
0,
0,
0,
]
],
columns=df.columns,
),
],
ignore_index=True,
)
df.to_csv("times.csv", index=False)
[
fused_attn_time * 1e3 / num_iters,
0,
0,
0,
flash_attn_time * 1e3 / num_iters,
0,
0,
0,
0,
]
],
columns=df.columns,
),
],
ignore_index=True,
)
df.to_csv("times.csv", index=False)
torch.cuda.cudart().cudaProfilerStop()


def parse_results(per_cudnn, per_flash, model):
bench_dir = os.path.dirname(os.path.abspath(__file__))
filename = f"prof_{model}_cuda_gpu_trace.csv"
df = pd.read_csv(os.path.join("./", filename))
df_times = pd.read_csv("times.csv")
filepath = os.path.join(bench_dir, filename)
if not os.path.isfile(filepath):
return
df = pd.read_csv(filepath)
df_times = pd.read_csv(os.path.join(bench_dir, "times.csv"))
row = len(df_times.index) - 1

# Match kernel names case-insensitively; column may be "Name" or "Kernel Name" in nsys output
name_col = "Name" if "Name" in df.columns else "Kernel Name"
names = df[name_col].astype(str).str.lower()

if per_cudnn > 0:
t_cudnn_all = df[df["Name"].str.contains("cudnn")]["Duration (ns)"].to_numpy()
t_cudnn_all = t_cudnn_all.reshape(-1, per_cudnn)
t_cudnn_avg = np.average(t_cudnn_all, axis=0)
df_times.loc[row, "FusedAttention Kernels (fwd)"] = t_cudnn_avg[0] / 1e6
df_times.loc[row, "FusedAttention Kernels (bwd)"] = t_cudnn_avg[1:4].sum() / 1e6
df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6
cudnn_mask = names.str.contains(KERNEL_NAME_CUDNN.lower(), regex=False)
if cudnn_mask.any():
t_cudnn_all = df.loc[cudnn_mask, "Duration (ns)"].to_numpy()
t_cudnn_all = t_cudnn_all.reshape(-1, per_cudnn)
t_cudnn_avg = np.average(t_cudnn_all, axis=0)
df_times.loc[row, "FusedAttention Kernels (fwd)"] = t_cudnn_avg[0] / 1e6
df_times.loc[row, "FusedAttention Kernels (bwd)"] = t_cudnn_avg[1:4].sum() / 1e6
df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6

if per_flash > 0:
t_flash_all = df[df["Name"].str.contains("flash")]["Duration (ns)"].to_numpy()
t_flash_all = t_flash_all.reshape(-1, per_flash)
t_flash_avg = np.average(t_flash_all, axis=0)
df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6
df_times.loc[row, "FlashAttention Kernels (bwd)"] = t_flash_avg[1:4].sum() / 1e6
df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"] = t_flash_avg.sum() / 1e6
flash_mask = names.str.contains(KERNEL_NAME_FLASH.lower(), regex=False)
if flash_mask.any():
t_flash_all = df.loc[flash_mask, "Duration (ns)"].to_numpy()
t_flash_all = t_flash_all.reshape(-1, per_flash)
t_flash_avg = np.average(t_flash_all, axis=0)
df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6
df_times.loc[row, "FlashAttention Kernels (bwd)"] = t_flash_avg[1:4].sum() / 1e6
df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"] = t_flash_avg.sum() / 1e6

if per_cudnn > 0 and per_flash > 0:
df_times.loc[row, "Fused vs Flash Kernels Speedup (fwd+bwd)"] = (
df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"]
/ df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"]
)
df_times.to_csv("times.csv", index=False)
fwd_bwd = df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"]
if fwd_bwd and fwd_bwd > 0:
df_times.loc[row, "Fused vs Flash Kernels Speedup (fwd+bwd)"] = (
df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"] / fwd_bwd
)
df_times.to_csv(os.path.join(bench_dir, "times.csv"), index=False)


def main():
Expand Down Expand Up @@ -201,7 +247,7 @@ def main():
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
# window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
Expand All @@ -211,6 +257,17 @@ def main():
f'{" and flash-attention" if flash_attn_supported else ""}...'
)

# Run benchmark in main process so times.csv always gets a row (works without nsys)
benchmark_dot_product_attention(
model, fused_attn_supported, flash_attn_supported, append_csv=True
)

# Optional: run under nsys to get kernel-level stats; subprocess must not append again
bench_code = (
"import benchmark_attention; "
"benchmark_attention.benchmark_dot_product_attention("
f"'{model}', {fused_attn_supported}, {flash_attn_supported}, append_csv=False)"
)
prof_cmd = [
"nsys",
"profile",
Expand All @@ -220,58 +277,80 @@ def main():
f"--output=prof_{model}",
"python",
"-c",
f""" "import benchmark_attention;""",
f"""benchmark_attention.benchmark_dot_product_attention("""
f"""'{model}', {fused_attn_supported}, {flash_attn_supported})" """,
]
prof_cmd = " ".join(prof_cmd)
subprocess.call(prof_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
stats_cmd = [
"nsys",
"stats",
"-q",
"-r",
"cuda_gpu_trace",
"--format",
"csv,column",
"--force-overwrite=true",
"--force-export=true",
f"--output=prof_{model}",
f"prof_{model}.nsys-rep",
]
if fused_attn_supported:
num_kernels_cudnn = 4
if config.attn_bias_type == "post_scale_bias":
num_kernels_cudnn = num_kernels_cudnn + 1
if config.num_heads != config.num_gqa_groups:
num_kernels_cudnn = num_kernels_cudnn + 2
else:
num_kernels_cudnn = 0
num_kernels_flash = 4 if flash_attn_supported else 0
stats_cmd = " ".join(stats_cmd)
subprocess.call(stats_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
parse_cmd = [
"python",
"-c",
f""" "import benchmark_attention;""",
f"""benchmark_attention.parse_results("""
f"""{num_kernels_cudnn}, {num_kernels_flash}, '{model}')" """,
bench_code,
]
parse_cmd = " ".join(parse_cmd)
subprocess.call(parse_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
bench_dir = os.path.dirname(os.path.abspath(__file__))
prof_ret = subprocess.call(
prof_cmd,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
cwd=bench_dir,
)
if prof_ret == 0:
stats_cmd = [
"nsys",
"stats",
"-q",
"-r",
"cuda_gpu_trace",
"--format",
"csv,column",
"--force-overwrite=true",
"--force-export=true",
f"--output=prof_{model}",
f"prof_{model}.nsys-rep",
]
if fused_attn_supported:
num_kernels_cudnn = 4
if config.attn_bias_type == "post_scale_bias":
num_kernels_cudnn = num_kernels_cudnn + 1
if config.num_heads != config.num_gqa_groups:
num_kernels_cudnn = num_kernels_cudnn + 2
else:
num_kernels_cudnn = 0
num_kernels_flash = 4 if flash_attn_supported else 0
subprocess.call(
stats_cmd,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
cwd=bench_dir,
)
parse_code = (
"import benchmark_attention; "
"benchmark_attention.parse_results("
f"{num_kernels_cudnn}, {num_kernels_flash}, '{model}')"
)
parse_cmd = ["python", "-c", parse_code]
subprocess.call(
parse_cmd,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
cwd=bench_dir,
)

df_times = pd.read_csv("times.csv")
n_models = len(model_configs)
if len(df_times) != n_models:
raise RuntimeError(
f"times.csv has {len(df_times)} rows but expected {n_models}. "
"Subprocess benchmarks may have failed (check nsys availability)."
)
df_times.index = list(model_configs.keys())
a = df_times[
[
"FusedAttention Kernels (fwd+bwd)",
"FlashAttention Kernels (fwd+bwd)",
"Fused vs Flash Kernels Speedup (fwd+bwd)",
]
]
a.columns = ["cuDNN fwd+bwd (ms)", "flash-attn fwd+bwd (ms)", "cuDNN vs flash speedup"]
# Prefer module timings (from time.time(), always populated); fall back to kernel timings (from nsys)
cudnn_col = "FusedAttention Module"
flash_col = "FlashAttention Module"
a = df_times[[cudnn_col, flash_col]].copy()
a.columns = ["cuDNN fwd+bwd (ms)", "flash-attn fwd+bwd (ms)"]
# Speedup: flash/cudnn ratio (>1 means cuDNN faster). N/A when only one backend ran (e.g. test_2 has bias, flash not used).
cudnn_ms = df_times[cudnn_col]
flash_ms = df_times[flash_col]
speedup = np.where((cudnn_ms > 0) & (flash_ms > 0), flash_ms / cudnn_ms, np.nan)
a["cuDNN vs flash speedup"] = speedup
# Show "N/A" instead of NaN when speedup not defined (only one backend ran)
a_display = a.copy()
a_display["cuDNN vs flash speedup"] = [f"{x:.4f}" if not pd.isna(x) else "N/A" for x in speedup]
print()
print(a)
print(a_display)


if __name__ == "__main__":
Expand Down