Skip to content

Commit 8fbb9a7

Browse files
committed
Add lax.scan profiler to CI for GPU debugging
- Add scripts/profile_lax_scan.py: Profiles lax.scan performance on GPU vs CPU to investigate the synchronization overhead issue (JAX Issue #2491) - Add CI step to run profiler with 100K iterations on RunsOn GPU environment - Script supports multiple profiling modes: basic timing, Nsight, JAX profiler, XLA dumps
1 parent 6bf345a commit 8fbb9a7

File tree

2 files changed

+184
-0
lines changed

2 files changed

+184
-0
lines changed

.github/workflows/ci.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@ jobs:
2525
pip install -U "jax[cuda13]"
2626
pip install numpyro
2727
python scripts/test-jax-install.py
28+
# === lax.scan GPU Performance Profiling ===
29+
- name: Profile lax.scan (GPU vs CPU)
30+
shell: bash -l {0}
31+
run: |
32+
echo "=== lax.scan Performance Profiling ==="
33+
echo "This profiles the known issue with lax.scan on GPU (JAX Issue #2491)"
34+
echo ""
35+
python scripts/profile_lax_scan.py --iterations 100000
36+
echo ""
37+
echo "Note: GPU is expected to be much slower due to CPU-GPU sync per iteration"
2838
# === Benchmark Tests (Bare Metal, Jupyter, Jupyter-Book) ===
2939
- name: Run Hardware Benchmarks (Bare Metal)
3040
shell: bash -l {0}

scripts/profile_lax_scan.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""
2+
Profile lax.scan performance on GPU vs CPU to investigate synchronization overhead.
3+
4+
This script helps diagnose why lax.scan with many lightweight iterations
5+
performs poorly on GPU (81s) compared to CPU (0.06s).
6+
7+
Usage:
8+
# Basic timing comparison
9+
python profile_lax_scan.py
10+
11+
# With NVIDIA Nsight Systems (requires nsys installed)
12+
nsys profile -o lax_scan_profile --trace=cuda,nvtx python profile_lax_scan.py --nsys
13+
14+
# With JAX profiler (view with TensorBoard)
15+
python profile_lax_scan.py --jax-profile
16+
17+
# With XLA debug dumps
18+
python profile_lax_scan.py --xla-dump
19+
20+
Requirements:
21+
- JAX with CUDA support: pip install jax[cuda12]
22+
- For Nsight: NVIDIA Nsight Systems (https://developer.nvidia.com/nsight-systems)
23+
- For TensorBoard: pip install tensorboard tensorboard-plugin-profile
24+
"""
25+
26+
import argparse
27+
import os
28+
import time
29+
from functools import partial
30+
31+
def setup_xla_dump(dump_dir="/tmp/xla_dump"):
32+
"""Enable XLA debug dumps before importing JAX."""
33+
os.makedirs(dump_dir, exist_ok=True)
34+
os.environ["XLA_FLAGS"] = f"--xla_dump_to={dump_dir} --xla_dump_hlo_as_text"
35+
print(f"XLA dumps will be written to: {dump_dir}")
36+
37+
def main():
38+
parser = argparse.ArgumentParser(description="Profile lax.scan GPU performance")
39+
parser.add_argument("--nsys", action="store_true",
40+
help="Run in Nsight Systems compatible mode (smaller n)")
41+
parser.add_argument("--jax-profile", action="store_true",
42+
help="Enable JAX profiler (view with TensorBoard)")
43+
parser.add_argument("--xla-dump", action="store_true",
44+
help="Dump XLA HLO for analysis")
45+
parser.add_argument("-n", "--iterations", type=int, default=10_000_000,
46+
dest="n", help="Number of iterations (default: 10M)")
47+
parser.add_argument("--profile-dir", type=str, default="/tmp/jax-trace",
48+
help="Directory for JAX profile output")
49+
args = parser.parse_args()
50+
51+
# Setup XLA dump before importing JAX
52+
if args.xla_dump:
53+
setup_xla_dump()
54+
55+
# Now import JAX
56+
import jax
57+
import jax.numpy as jnp
58+
from jax import lax
59+
60+
print("=" * 60)
61+
print("lax.scan GPU Performance Profiling")
62+
print("=" * 60)
63+
64+
# Show device info
65+
print(f"\nJAX version: {jax.__version__}")
66+
print(f"Available devices: {jax.devices()}")
67+
print(f"Default device: {jax.devices()[0]}")
68+
69+
# Reduce n for Nsight profiling to keep trace manageable
70+
n = 100_000 if args.nsys else args.n
71+
print(f"\nIterations (n): {n:,}")
72+
73+
# Define the functions
74+
@partial(jax.jit, static_argnums=(1,))
75+
def qm_jax_default(x0, n, α=4.0):
76+
"""lax.scan on default device (GPU if available)."""
77+
def update(x, t):
78+
x_new = α * x * (1 - x)
79+
return x_new, x_new
80+
_, x = lax.scan(update, x0, jnp.arange(n))
81+
return jnp.concatenate([jnp.array([x0]), x])
82+
83+
cpu = jax.devices("cpu")[0]
84+
85+
@partial(jax.jit, static_argnums=(1,), device=cpu)
86+
def qm_jax_cpu(x0, n, α=4.0):
87+
"""lax.scan forced to CPU."""
88+
def update(x, t):
89+
x_new = α * x * (1 - x)
90+
return x_new, x_new
91+
_, x = lax.scan(update, x0, jnp.arange(n))
92+
return jnp.concatenate([jnp.array([x0]), x])
93+
94+
# Warm up (compilation)
95+
print("\n--- Compilation (warm-up) ---")
96+
print("Compiling default device version...", end=" ", flush=True)
97+
t0 = time.perf_counter()
98+
_ = qm_jax_default(0.1, n).block_until_ready()
99+
print(f"done ({time.perf_counter() - t0:.2f}s)")
100+
101+
print("Compiling CPU version...", end=" ", flush=True)
102+
t0 = time.perf_counter()
103+
_ = qm_jax_cpu(0.1, n).block_until_ready()
104+
print(f"done ({time.perf_counter() - t0:.2f}s)")
105+
106+
# Profile with JAX profiler if requested
107+
if args.jax_profile:
108+
print(f"\n--- JAX Profiler (output: {args.profile_dir}) ---")
109+
os.makedirs(args.profile_dir, exist_ok=True)
110+
111+
jax.profiler.start_trace(args.profile_dir)
112+
113+
# Run both versions while profiling
114+
print("Profiling default device version...")
115+
result_default = qm_jax_default(0.1, n).block_until_ready()
116+
117+
print("Profiling CPU version...")
118+
result_cpu = qm_jax_cpu(0.1, n).block_until_ready()
119+
120+
jax.profiler.stop_trace()
121+
print(f"\nProfile saved. View with:")
122+
print(f" tensorboard --logdir={args.profile_dir}")
123+
124+
# Timing runs
125+
print("\n--- Timing Runs (post-compilation) ---")
126+
127+
# Default device (GPU if available)
128+
print(f"\nDefault device ({jax.devices()[0]}):")
129+
times_default = []
130+
for i in range(3):
131+
t0 = time.perf_counter()
132+
result = qm_jax_default(0.1, n).block_until_ready()
133+
elapsed = time.perf_counter() - t0
134+
times_default.append(elapsed)
135+
print(f" Run {i+1}: {elapsed:.6f}s")
136+
137+
# CPU
138+
print(f"\nCPU (forced with device=cpu):")
139+
times_cpu = []
140+
for i in range(3):
141+
t0 = time.perf_counter()
142+
result = qm_jax_cpu(0.1, n).block_until_ready()
143+
elapsed = time.perf_counter() - t0
144+
times_cpu.append(elapsed)
145+
print(f" Run {i+1}: {elapsed:.6f}s")
146+
147+
# Summary
148+
print("\n" + "=" * 60)
149+
print("SUMMARY")
150+
print("=" * 60)
151+
avg_default = sum(times_default) / len(times_default)
152+
avg_cpu = sum(times_cpu) / len(times_cpu)
153+
print(f"Default device avg: {avg_default:.6f}s")
154+
print(f"CPU avg: {avg_cpu:.6f}s")
155+
print(f"Ratio (default/cpu): {avg_default/avg_cpu:.1f}x")
156+
157+
if avg_default > avg_cpu * 10:
158+
print("\n⚠️ GPU is significantly slower than CPU!")
159+
print(" This confirms the lax.scan synchronization overhead issue.")
160+
elif avg_default < avg_cpu:
161+
print("\n✓ GPU is faster (unexpected for this workload)")
162+
else:
163+
print("\n~ Performance is similar")
164+
165+
if args.xla_dump:
166+
print(f"\nXLA dumps written to /tmp/xla_dump/")
167+
print("Look for .txt files with HLO representation")
168+
169+
if args.nsys:
170+
print("\nNsight Systems trace will be saved as lax_scan_profile.nsys-rep")
171+
print("View with: nsys-ui lax_scan_profile.nsys-rep")
172+
173+
if __name__ == "__main__":
174+
main()

0 commit comments

Comments
 (0)