Skip to content

Commit 1bfbaf9

Browse files
committed
Add diagnostic mode to lax.scan profiler
- Add --diagnose flag that tests time scaling across iteration counts - If time scales linearly with iterations (not compute), it proves constant per-iteration overhead (CPU-GPU synchronization) - Also add --verbose flag for CUDA/XLA logging - Update CI to run with --diagnose flag
1 parent 8fbb9a7 commit 1bfbaf9

File tree

2 files changed

+70
-2
lines changed

2 files changed

+70
-2
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ jobs:
3232
echo "=== lax.scan Performance Profiling ==="
3333
echo "This profiles the known issue with lax.scan on GPU (JAX Issue #2491)"
3434
echo ""
35-
python scripts/profile_lax_scan.py --iterations 100000
35+
python scripts/profile_lax_scan.py --iterations 100000 --diagnose
3636
echo ""
37-
echo "Note: GPU is expected to be much slower due to CPU-GPU sync per iteration"
37+
echo "The diagnostic shows if time scales linearly with iterations,"
38+
echo "which indicates constant per-iteration CPU-GPU sync overhead."
3839
# === Benchmark Tests (Bare Metal, Jupyter, Jupyter-Book) ===
3940
- name: Run Hardware Benchmarks (Bare Metal)
4041
shell: bash -l {0}

scripts/profile_lax_scan.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ def setup_xla_dump(dump_dir="/tmp/xla_dump"):
3434
os.environ["XLA_FLAGS"] = f"--xla_dump_to={dump_dir} --xla_dump_hlo_as_text"
3535
print(f"XLA dumps will be written to: {dump_dir}")
3636

37+
def setup_cuda_logging():
38+
"""Enable CUDA/XLA logging to see sync patterns."""
39+
# These may help reveal synchronization behavior
40+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" # Show all TF/XLA logs
41+
os.environ["XLA_FLAGS"] = os.environ.get("XLA_FLAGS", "") + " --xla_gpu_cuda_data_dir=/usr/local/cuda"
42+
print("CUDA/XLA logging enabled")
43+
3744
def main():
3845
parser = argparse.ArgumentParser(description="Profile lax.scan GPU performance")
3946
parser.add_argument("--nsys", action="store_true",
@@ -42,6 +49,10 @@ def main():
4249
help="Enable JAX profiler (view with TensorBoard)")
4350
parser.add_argument("--xla-dump", action="store_true",
4451
help="Dump XLA HLO for analysis")
52+
parser.add_argument("--verbose", action="store_true",
53+
help="Enable verbose CUDA/XLA logging")
54+
parser.add_argument("--diagnose", action="store_true",
55+
help="Run diagnostic to demonstrate sync overhead")
4556
parser.add_argument("-n", "--iterations", type=int, default=10_000_000,
4657
dest="n", help="Number of iterations (default: 10M)")
4758
parser.add_argument("--profile-dir", type=str, default="/tmp/jax-trace",
@@ -51,6 +62,9 @@ def main():
5162
# Setup XLA dump before importing JAX
5263
if args.xla_dump:
5364
setup_xla_dump()
65+
66+
if args.verbose:
67+
setup_cuda_logging()
5468

5569
# Now import JAX
5670
import jax
@@ -170,5 +184,58 @@ def update(x, t):
170184
print("\nNsight Systems trace will be saved as lax_scan_profile.nsys-rep")
171185
print("View with: nsys-ui lax_scan_profile.nsys-rep")
172186

187+
# Diagnostic: demonstrate sync overhead by showing time scaling
188+
if args.diagnose:
189+
print("\n" + "=" * 60)
190+
print("DIAGNOSTIC: Per-iteration Sync Overhead Analysis")
191+
print("=" * 60)
192+
print("\nIf there's a CPU-GPU sync per iteration, time should scale")
193+
print("linearly with iteration count (not with compute work).\n")
194+
195+
# Test different iteration counts
196+
test_ns = [1000, 5000, 10000, 50000, 100000]
197+
198+
print("Iteration Count | GPU Time (s) | Time/Iter (µs) | Expected if O(n)")
199+
print("-" * 70)
200+
201+
gpu_times = []
202+
for test_n in test_ns:
203+
# Define fresh function for this n
204+
@partial(jax.jit, static_argnums=(1,))
205+
def qm_test(x0, n, α=4.0):
206+
def update(x, t):
207+
return α * x * (1 - x), α * x * (1 - x)
208+
_, x = lax.scan(update, x0, jnp.arange(n))
209+
return jnp.concatenate([jnp.array([x0]), x])
210+
211+
# Compile
212+
_ = qm_test(0.1, test_n).block_until_ready()
213+
214+
# Time
215+
t0 = time.perf_counter()
216+
_ = qm_test(0.1, test_n).block_until_ready()
217+
elapsed = time.perf_counter() - t0
218+
gpu_times.append(elapsed)
219+
220+
time_per_iter = (elapsed / test_n) * 1_000_000 # microseconds
221+
expected = gpu_times[0] * (test_n / test_ns[0]) if gpu_times else elapsed
222+
223+
print(f"{test_n:>15,} | {elapsed:>12.6f} | {time_per_iter:>14.2f} | {expected:.6f}")
224+
225+
# Calculate if time scales linearly (indicating per-iteration overhead)
226+
ratio_1k_to_100k = gpu_times[-1] / gpu_times[0]
227+
expected_ratio = test_ns[-1] / test_ns[0] # 100x if linear
228+
229+
print(f"\nScaling analysis:")
230+
print(f" Time ratio (100k/1k iterations): {ratio_1k_to_100k:.1f}x")
231+
print(f" Expected if linear O(n): {expected_ratio:.1f}x")
232+
233+
if 0.5 * expected_ratio < ratio_1k_to_100k < 2.0 * expected_ratio:
234+
print("\n✓ Time scales ~linearly with iterations!")
235+
print(" This indicates constant per-iteration overhead (CPU-GPU sync).")
236+
print(f" Estimated sync overhead: ~{(gpu_times[0]/test_ns[0])*1e6:.1f} µs per iteration")
237+
else:
238+
print("\n? Scaling is not linear - may be other factors involved")
239+
173240
if __name__ == "__main__":
174241
main()

0 commit comments

Comments
 (0)