@@ -34,6 +34,13 @@ def setup_xla_dump(dump_dir="/tmp/xla_dump"):
3434 os .environ ["XLA_FLAGS" ] = f"--xla_dump_to={ dump_dir } --xla_dump_hlo_as_text"
3535 print (f"XLA dumps will be written to: { dump_dir } " )
3636
37+ def setup_cuda_logging ():
38+ """Enable CUDA/XLA logging to see sync patterns."""
39+ # These may help reveal synchronization behavior
40+ os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "0" # Show all TF/XLA logs
41+ os .environ ["XLA_FLAGS" ] = os .environ .get ("XLA_FLAGS" , "" ) + " --xla_gpu_cuda_data_dir=/usr/local/cuda"
42+ print ("CUDA/XLA logging enabled" )
43+
3744def main ():
3845 parser = argparse .ArgumentParser (description = "Profile lax.scan GPU performance" )
3946 parser .add_argument ("--nsys" , action = "store_true" ,
@@ -42,6 +49,10 @@ def main():
4249 help = "Enable JAX profiler (view with TensorBoard)" )
4350 parser .add_argument ("--xla-dump" , action = "store_true" ,
4451 help = "Dump XLA HLO for analysis" )
52+ parser .add_argument ("--verbose" , action = "store_true" ,
53+ help = "Enable verbose CUDA/XLA logging" )
54+ parser .add_argument ("--diagnose" , action = "store_true" ,
55+ help = "Run diagnostic to demonstrate sync overhead" )
4556 parser .add_argument ("-n" , "--iterations" , type = int , default = 10_000_000 ,
4657 dest = "n" , help = "Number of iterations (default: 10M)" )
4758 parser .add_argument ("--profile-dir" , type = str , default = "/tmp/jax-trace" ,
@@ -51,6 +62,9 @@ def main():
5162 # Setup XLA dump before importing JAX
5263 if args .xla_dump :
5364 setup_xla_dump ()
65+
66+ if args .verbose :
67+ setup_cuda_logging ()
5468
5569 # Now import JAX
5670 import jax
@@ -170,5 +184,58 @@ def update(x, t):
170184 print ("\n Nsight Systems trace will be saved as lax_scan_profile.nsys-rep" )
171185 print ("View with: nsys-ui lax_scan_profile.nsys-rep" )
172186
187+ # Diagnostic: demonstrate sync overhead by showing time scaling
188+ if args .diagnose :
189+ print ("\n " + "=" * 60 )
190+ print ("DIAGNOSTIC: Per-iteration Sync Overhead Analysis" )
191+ print ("=" * 60 )
192+ print ("\n If there's a CPU-GPU sync per iteration, time should scale" )
193+ print ("linearly with iteration count (not with compute work).\n " )
194+
195+ # Test different iteration counts
196+ test_ns = [1000 , 5000 , 10000 , 50000 , 100000 ]
197+
198+ print ("Iteration Count | GPU Time (s) | Time/Iter (µs) | Expected if O(n)" )
199+ print ("-" * 70 )
200+
201+ gpu_times = []
202+ for test_n in test_ns :
203+ # Define fresh function for this n
204+ @partial (jax .jit , static_argnums = (1 ,))
205+ def qm_test (x0 , n , α = 4.0 ):
206+ def update (x , t ):
207+ return α * x * (1 - x ), α * x * (1 - x )
208+ _ , x = lax .scan (update , x0 , jnp .arange (n ))
209+ return jnp .concatenate ([jnp .array ([x0 ]), x ])
210+
211+ # Compile
212+ _ = qm_test (0.1 , test_n ).block_until_ready ()
213+
214+ # Time
215+ t0 = time .perf_counter ()
216+ _ = qm_test (0.1 , test_n ).block_until_ready ()
217+ elapsed = time .perf_counter () - t0
218+ gpu_times .append (elapsed )
219+
220+ time_per_iter = (elapsed / test_n ) * 1_000_000 # microseconds
221+ expected = gpu_times [0 ] * (test_n / test_ns [0 ]) if gpu_times else elapsed
222+
223+ print (f"{ test_n :>15,} | { elapsed :>12.6f} | { time_per_iter :>14.2f} | { expected :.6f} " )
224+
225+ # Calculate if time scales linearly (indicating per-iteration overhead)
226+ ratio_1k_to_100k = gpu_times [- 1 ] / gpu_times [0 ]
227+ expected_ratio = test_ns [- 1 ] / test_ns [0 ] # 100x if linear
228+
229+ print (f"\n Scaling analysis:" )
230+ print (f" Time ratio (100k/1k iterations): { ratio_1k_to_100k :.1f} x" )
231+ print (f" Expected if linear O(n): { expected_ratio :.1f} x" )
232+
233+ if 0.5 * expected_ratio < ratio_1k_to_100k < 2.0 * expected_ratio :
234+ print ("\n ✓ Time scales ~linearly with iterations!" )
235+ print (" This indicates constant per-iteration overhead (CPU-GPU sync)." )
236+ print (f" Estimated sync overhead: ~{ (gpu_times [0 ]/ test_ns [0 ])* 1e6 :.1f} µs per iteration" )
237+ else :
238+ print ("\n ? Scaling is not linear - may be other factors involved" )
239+
173240if __name__ == "__main__" :
174241 main ()
0 commit comments