|
| 1 | +""" |
| 2 | +Profile lax.scan performance on GPU vs CPU to investigate synchronization overhead. |
| 3 | +
|
| 4 | +This script helps diagnose why lax.scan with many lightweight iterations |
| 5 | +performs poorly on GPU (81s) compared to CPU (0.06s). |
| 6 | +
|
| 7 | +Usage: |
| 8 | + # Basic timing comparison |
| 9 | + python profile_lax_scan.py |
| 10 | +
|
| 11 | + # With NVIDIA Nsight Systems (requires nsys installed) |
| 12 | + nsys profile -o lax_scan_profile --trace=cuda,nvtx python profile_lax_scan.py --nsys |
| 13 | +
|
| 14 | + # With JAX profiler (view with TensorBoard) |
| 15 | + python profile_lax_scan.py --jax-profile |
| 16 | +
|
| 17 | + # With XLA debug dumps |
| 18 | + python profile_lax_scan.py --xla-dump |
| 19 | +
|
| 20 | +Requirements: |
| 21 | + - JAX with CUDA support: pip install jax[cuda12] |
| 22 | + - For Nsight: NVIDIA Nsight Systems (https://developer.nvidia.com/nsight-systems) |
| 23 | + - For TensorBoard: pip install tensorboard tensorboard-plugin-profile |
| 24 | +""" |
| 25 | + |
| 26 | +import argparse |
| 27 | +import os |
| 28 | +import time |
| 29 | +from functools import partial |
| 30 | + |
| 31 | +def setup_xla_dump(dump_dir="/tmp/xla_dump"): |
| 32 | + """Enable XLA debug dumps before importing JAX.""" |
| 33 | + os.makedirs(dump_dir, exist_ok=True) |
| 34 | + os.environ["XLA_FLAGS"] = f"--xla_dump_to={dump_dir} --xla_dump_hlo_as_text" |
| 35 | + print(f"XLA dumps will be written to: {dump_dir}") |
| 36 | + |
| 37 | +def main(): |
| 38 | + parser = argparse.ArgumentParser(description="Profile lax.scan GPU performance") |
| 39 | + parser.add_argument("--nsys", action="store_true", |
| 40 | + help="Run in Nsight Systems compatible mode (smaller n)") |
| 41 | + parser.add_argument("--jax-profile", action="store_true", |
| 42 | + help="Enable JAX profiler (view with TensorBoard)") |
| 43 | + parser.add_argument("--xla-dump", action="store_true", |
| 44 | + help="Dump XLA HLO for analysis") |
| 45 | + parser.add_argument("-n", "--iterations", type=int, default=10_000_000, |
| 46 | + dest="n", help="Number of iterations (default: 10M)") |
| 47 | + parser.add_argument("--profile-dir", type=str, default="/tmp/jax-trace", |
| 48 | + help="Directory for JAX profile output") |
| 49 | + args = parser.parse_args() |
| 50 | + |
| 51 | + # Setup XLA dump before importing JAX |
| 52 | + if args.xla_dump: |
| 53 | + setup_xla_dump() |
| 54 | + |
| 55 | + # Now import JAX |
| 56 | + import jax |
| 57 | + import jax.numpy as jnp |
| 58 | + from jax import lax |
| 59 | + |
| 60 | + print("=" * 60) |
| 61 | + print("lax.scan GPU Performance Profiling") |
| 62 | + print("=" * 60) |
| 63 | + |
| 64 | + # Show device info |
| 65 | + print(f"\nJAX version: {jax.__version__}") |
| 66 | + print(f"Available devices: {jax.devices()}") |
| 67 | + print(f"Default device: {jax.devices()[0]}") |
| 68 | + |
| 69 | + # Reduce n for Nsight profiling to keep trace manageable |
| 70 | + n = 100_000 if args.nsys else args.n |
| 71 | + print(f"\nIterations (n): {n:,}") |
| 72 | + |
| 73 | + # Define the functions |
| 74 | + @partial(jax.jit, static_argnums=(1,)) |
| 75 | + def qm_jax_default(x0, n, α=4.0): |
| 76 | + """lax.scan on default device (GPU if available).""" |
| 77 | + def update(x, t): |
| 78 | + x_new = α * x * (1 - x) |
| 79 | + return x_new, x_new |
| 80 | + _, x = lax.scan(update, x0, jnp.arange(n)) |
| 81 | + return jnp.concatenate([jnp.array([x0]), x]) |
| 82 | + |
| 83 | + cpu = jax.devices("cpu")[0] |
| 84 | + |
| 85 | + @partial(jax.jit, static_argnums=(1,), device=cpu) |
| 86 | + def qm_jax_cpu(x0, n, α=4.0): |
| 87 | + """lax.scan forced to CPU.""" |
| 88 | + def update(x, t): |
| 89 | + x_new = α * x * (1 - x) |
| 90 | + return x_new, x_new |
| 91 | + _, x = lax.scan(update, x0, jnp.arange(n)) |
| 92 | + return jnp.concatenate([jnp.array([x0]), x]) |
| 93 | + |
| 94 | + # Warm up (compilation) |
| 95 | + print("\n--- Compilation (warm-up) ---") |
| 96 | + print("Compiling default device version...", end=" ", flush=True) |
| 97 | + t0 = time.perf_counter() |
| 98 | + _ = qm_jax_default(0.1, n).block_until_ready() |
| 99 | + print(f"done ({time.perf_counter() - t0:.2f}s)") |
| 100 | + |
| 101 | + print("Compiling CPU version...", end=" ", flush=True) |
| 102 | + t0 = time.perf_counter() |
| 103 | + _ = qm_jax_cpu(0.1, n).block_until_ready() |
| 104 | + print(f"done ({time.perf_counter() - t0:.2f}s)") |
| 105 | + |
| 106 | + # Profile with JAX profiler if requested |
| 107 | + if args.jax_profile: |
| 108 | + print(f"\n--- JAX Profiler (output: {args.profile_dir}) ---") |
| 109 | + os.makedirs(args.profile_dir, exist_ok=True) |
| 110 | + |
| 111 | + jax.profiler.start_trace(args.profile_dir) |
| 112 | + |
| 113 | + # Run both versions while profiling |
| 114 | + print("Profiling default device version...") |
| 115 | + result_default = qm_jax_default(0.1, n).block_until_ready() |
| 116 | + |
| 117 | + print("Profiling CPU version...") |
| 118 | + result_cpu = qm_jax_cpu(0.1, n).block_until_ready() |
| 119 | + |
| 120 | + jax.profiler.stop_trace() |
| 121 | + print(f"\nProfile saved. View with:") |
| 122 | + print(f" tensorboard --logdir={args.profile_dir}") |
| 123 | + |
| 124 | + # Timing runs |
| 125 | + print("\n--- Timing Runs (post-compilation) ---") |
| 126 | + |
| 127 | + # Default device (GPU if available) |
| 128 | + print(f"\nDefault device ({jax.devices()[0]}):") |
| 129 | + times_default = [] |
| 130 | + for i in range(3): |
| 131 | + t0 = time.perf_counter() |
| 132 | + result = qm_jax_default(0.1, n).block_until_ready() |
| 133 | + elapsed = time.perf_counter() - t0 |
| 134 | + times_default.append(elapsed) |
| 135 | + print(f" Run {i+1}: {elapsed:.6f}s") |
| 136 | + |
| 137 | + # CPU |
| 138 | + print(f"\nCPU (forced with device=cpu):") |
| 139 | + times_cpu = [] |
| 140 | + for i in range(3): |
| 141 | + t0 = time.perf_counter() |
| 142 | + result = qm_jax_cpu(0.1, n).block_until_ready() |
| 143 | + elapsed = time.perf_counter() - t0 |
| 144 | + times_cpu.append(elapsed) |
| 145 | + print(f" Run {i+1}: {elapsed:.6f}s") |
| 146 | + |
| 147 | + # Summary |
| 148 | + print("\n" + "=" * 60) |
| 149 | + print("SUMMARY") |
| 150 | + print("=" * 60) |
| 151 | + avg_default = sum(times_default) / len(times_default) |
| 152 | + avg_cpu = sum(times_cpu) / len(times_cpu) |
| 153 | + print(f"Default device avg: {avg_default:.6f}s") |
| 154 | + print(f"CPU avg: {avg_cpu:.6f}s") |
| 155 | + print(f"Ratio (default/cpu): {avg_default/avg_cpu:.1f}x") |
| 156 | + |
| 157 | + if avg_default > avg_cpu * 10: |
| 158 | + print("\n⚠️ GPU is significantly slower than CPU!") |
| 159 | + print(" This confirms the lax.scan synchronization overhead issue.") |
| 160 | + elif avg_default < avg_cpu: |
| 161 | + print("\n✓ GPU is faster (unexpected for this workload)") |
| 162 | + else: |
| 163 | + print("\n~ Performance is similar") |
| 164 | + |
| 165 | + if args.xla_dump: |
| 166 | + print(f"\nXLA dumps written to /tmp/xla_dump/") |
| 167 | + print("Look for .txt files with HLO representation") |
| 168 | + |
| 169 | + if args.nsys: |
| 170 | + print("\nNsight Systems trace will be saved as lax_scan_profile.nsys-rep") |
| 171 | + print("View with: nsys-ui lax_scan_profile.nsys-rep") |
| 172 | + |
| 173 | +if __name__ == "__main__": |
| 174 | + main() |
0 commit comments