diff --git a/Ironwood/src/benchmark_gemm.py b/Ironwood/src/benchmark_gemm.py index 99564490..b802ddc0 100644 --- a/Ironwood/src/benchmark_gemm.py +++ b/Ironwood/src/benchmark_gemm.py @@ -213,6 +213,8 @@ def data_generator(): return (lhs_device, rhs_device) # Run the benchmark + num_runs = 1 + ## Need to fix gemm timing logic to handle num_runs > 1 time_ms_list = iteration_timeit( jit_sharded_f, @@ -300,6 +302,9 @@ def data_generator(): return (lhs_device, rhs_device) + num_runs = 1 + ## Need to fix gemm timing logic to handle num_runs > 1 + # Run the benchmark time_ms_list = iteration_timeit( jit_sharded_f, @@ -402,6 +407,9 @@ def data_generator(): return (lhs_device, rhs_device, sf0_device, sf1_device) + num_runs = 1 + ## Need to fix gemm timing logic to handle num_runs > 1 + time_ms_list = iteration_timeit( jit_sharded_f, data_generator, @@ -513,6 +521,10 @@ def data_generator(): return (out_buffer_device, lhs_device, rhs_device, sf0_device, sf1_device) + + num_runs = 1 + ## Need to fix gemm timing logic to handle num_runs > 1 + time_ms_list = iteration_timeit( jit_sharded_f, data_generator,