Skip to content

Commit 922b24c

Browse files
committed
Fix syntax errors in benchmark-hardware.py
- Remove extra triple quote at start of file - Remove stray parentheses in result assignments
1 parent 2bda114 commit 922b24c

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

scripts/benchmark-hardware.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
"""
21
"""Hardware benchmark script for CI runners.
32
Compares CPU and GPU performance to diagnose slowdowns.
43
Works on both CPU-only (GitHub Actions) and GPU (RunsOn) runners.
@@ -204,7 +203,7 @@ def matmul(a, b):
204203
C = matmul(A, B).block_until_ready()
205204
elapsed = time.perf_counter() - start
206205
print(f"Matrix multiply compiled ({n}x{n}): {elapsed:.3f} seconds")
207-
results["matmul_3000x3000_compiled"] = elapsed)
206+
results["matmul_3000x3000_compiled"] = elapsed
208207

209208
# Element-wise GPU benchmark
210209
x = jax.random.normal(key, (50_000_000,))
@@ -235,7 +234,7 @@ def elementwise_ops(x):
235234
RESULTS["benchmarks"]["jax"] = {"error": str(e)}
236235
except Exception as e:
237236
print(f"JAX benchmark failed: {e}")
238-
RESULTS["benchmarks"]["jax"] = {"error": str(e)})
237+
RESULTS["benchmarks"]["jax"] = {"error": str(e)}
239238

240239
def benchmark_numba():
241240
"""Numba CPU benchmark."""
@@ -268,7 +267,7 @@ def numba_sum(n):
268267
result = numba_sum(10_000_000)
269268
elapsed = time.perf_counter() - start
270269
print(f"Integer sum compiled (10M): {elapsed:.3f} seconds")
271-
results["integer_sum_10M_compiled"] = elapsed)
270+
results["integer_sum_10M_compiled"] = elapsed
272271

273272
@numba.jit(nopython=True, parallel=True)
274273
def numba_parallel_sum(arr):
@@ -301,7 +300,7 @@ def numba_parallel_sum(arr):
301300
RESULTS["benchmarks"]["numba"] = {"error": str(e)}
302301
except Exception as e:
303302
print(f"Numba benchmark failed: {e}")
304-
RESULTS["benchmarks"]["numba"] = {"error": str(e)})
303+
RESULTS["benchmarks"]["numba"] = {"error": str(e)}
305304

306305
if __name__ == "__main__":
307306
print("\n" + "=" * 60)

0 commit comments

Comments
 (0)