Skip to content

Commit 0b6a567

Browse files
committed
DEBUG: Add hardware benchmark script to diagnose performance
- Add benchmark-hardware.py with CPU, NumPy, Numba, and JAX benchmarks - Works on both GPU (RunsOn) and CPU-only (GitHub Actions) runners - Include warm-up vs compiled timing to isolate JIT overhead - Add system info collection (CPU model, frequency, GPU detection)
1 parent 849cf16 commit 0b6a567

File tree

2 files changed

+267
-0
lines changed

2 files changed

+267
-0
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ jobs:
2525
pip install -U "jax[cuda13]"
2626
pip install numpyro
2727
python scripts/test-jax-install.py
28+
- name: Run Hardware Benchmarks
29+
shell: bash -l {0}
30+
run: python scripts/benchmark-hardware.py
2831
- name: Install latex dependencies
2932
run: |
3033
sudo apt-get -qq update

scripts/benchmark-hardware.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
"""
2+
Hardware benchmark script for CI runners.
3+
Compares CPU and GPU performance to diagnose slowdowns.
4+
Works on both CPU-only (GitHub Actions) and GPU (RunsOn) runners.
5+
"""
6+
import time
7+
import platform
8+
import os
9+
10+
def get_cpu_info():
11+
"""Get CPU information."""
12+
print("=" * 60)
13+
print("SYSTEM INFORMATION")
14+
print("=" * 60)
15+
print(f"Platform: {platform.platform()}")
16+
print(f"Processor: {platform.processor()}")
17+
print(f"Python: {platform.python_version()}")
18+
19+
# Try to get CPU frequency
20+
try:
21+
with open('/proc/cpuinfo', 'r') as f:
22+
for line in f:
23+
if 'model name' in line:
24+
print(f"CPU Model: {line.split(':')[1].strip()}")
25+
break
26+
except:
27+
pass
28+
29+
# Try to get CPU frequency
30+
try:
31+
with open('/proc/cpuinfo', 'r') as f:
32+
for line in f:
33+
if 'cpu MHz' in line:
34+
print(f"CPU MHz: {line.split(':')[1].strip()}")
35+
break
36+
except:
37+
pass
38+
39+
# CPU count
40+
print(f"CPU Count: {os.cpu_count()}")
41+
42+
# Check for GPU
43+
try:
44+
import subprocess
45+
result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv,noheader'],
46+
capture_output=True, text=True, timeout=5)
47+
if result.returncode == 0:
48+
print(f"GPU: {result.stdout.strip()}")
49+
else:
50+
print("GPU: None detected")
51+
except:
52+
print("GPU: None detected (nvidia-smi not available)")
53+
54+
print()
55+
56+
def benchmark_cpu_pure_python():
57+
"""Pure Python CPU benchmark."""
58+
print("=" * 60)
59+
print("CPU BENCHMARK: Pure Python")
60+
print("=" * 60)
61+
62+
# Integer computation
63+
start = time.perf_counter()
64+
total = sum(i * i for i in range(10_000_000))
65+
elapsed = time.perf_counter() - start
66+
print(f"Integer sum (10M iterations): {elapsed:.3f} seconds")
67+
68+
# Float computation
69+
start = time.perf_counter()
70+
total = 0.0
71+
for i in range(1_000_000):
72+
total += (i * 0.1) ** 0.5
73+
elapsed = time.perf_counter() - start
74+
print(f"Float sqrt (1M iterations): {elapsed:.3f} seconds")
75+
print()
76+
77+
def benchmark_cpu_numpy():
78+
"""NumPy CPU benchmark."""
79+
import numpy as np
80+
81+
print("=" * 60)
82+
print("CPU BENCHMARK: NumPy")
83+
print("=" * 60)
84+
85+
# Matrix multiplication
86+
n = 3000
87+
A = np.random.randn(n, n)
88+
B = np.random.randn(n, n)
89+
90+
start = time.perf_counter()
91+
C = A @ B
92+
elapsed = time.perf_counter() - start
93+
print(f"Matrix multiply ({n}x{n}): {elapsed:.3f} seconds")
94+
95+
# Element-wise operations
96+
x = np.random.randn(50_000_000)
97+
98+
start = time.perf_counter()
99+
y = np.cos(x**2) + np.sin(x)
100+
elapsed = time.perf_counter() - start
101+
print(f"Element-wise ops (50M elements): {elapsed:.3f} seconds")
102+
print()
103+
104+
def benchmark_gpu_jax():
105+
"""JAX benchmark (GPU if available, otherwise CPU)."""
106+
try:
107+
import jax
108+
import jax.numpy as jnp
109+
110+
devices = jax.devices()
111+
default_backend = jax.default_backend()
112+
113+
# Check if GPU is available
114+
has_gpu = any('cuda' in str(d).lower() or 'gpu' in str(d).lower() for d in devices)
115+
116+
print("=" * 60)
117+
if has_gpu:
118+
print("JAX BENCHMARK: GPU")
119+
else:
120+
print("JAX BENCHMARK: CPU (no GPU detected)")
121+
print("=" * 60)
122+
123+
print(f"JAX devices: {devices}")
124+
print(f"Default backend: {default_backend}")
125+
print(f"GPU Available: {has_gpu}")
126+
print()
127+
128+
# Warm-up JIT compilation
129+
print("Warming up JIT compilation...")
130+
n = 1000
131+
key = jax.random.PRNGKey(0)
132+
A = jax.random.normal(key, (n, n))
133+
B = jax.random.normal(key, (n, n))
134+
135+
@jax.jit
136+
def matmul(a, b):
137+
return jnp.dot(a, b)
138+
139+
# Warm-up run (includes compilation)
140+
start = time.perf_counter()
141+
C = matmul(A, B).block_until_ready()
142+
warmup_time = time.perf_counter() - start
143+
print(f"Warm-up (includes JIT compile, {n}x{n}): {warmup_time:.3f} seconds")
144+
145+
# Actual benchmark (compiled)
146+
start = time.perf_counter()
147+
C = matmul(A, B).block_until_ready()
148+
elapsed = time.perf_counter() - start
149+
print(f"Matrix multiply compiled ({n}x{n}): {elapsed:.3f} seconds")
150+
151+
# Larger matrix
152+
n = 3000
153+
A = jax.random.normal(key, (n, n))
154+
B = jax.random.normal(key, (n, n))
155+
156+
# Warm-up for new size
157+
start = time.perf_counter()
158+
C = matmul(A, B).block_until_ready()
159+
warmup_time = time.perf_counter() - start
160+
print(f"Warm-up (recompile for {n}x{n}): {warmup_time:.3f} seconds")
161+
162+
# Benchmark compiled
163+
start = time.perf_counter()
164+
C = matmul(A, B).block_until_ready()
165+
elapsed = time.perf_counter() - start
166+
print(f"Matrix multiply compiled ({n}x{n}): {elapsed:.3f} seconds")
167+
168+
# Element-wise GPU benchmark
169+
x = jax.random.normal(key, (50_000_000,))
170+
171+
@jax.jit
172+
def elementwise_ops(x):
173+
return jnp.cos(x**2) + jnp.sin(x)
174+
175+
# Warm-up
176+
start = time.perf_counter()
177+
y = elementwise_ops(x).block_until_ready()
178+
warmup_time = time.perf_counter() - start
179+
print(f"Element-wise warm-up (50M): {warmup_time:.3f} seconds")
180+
181+
# Compiled
182+
start = time.perf_counter()
183+
y = elementwise_ops(x).block_until_ready()
184+
elapsed = time.perf_counter() - start
185+
print(f"Element-wise compiled (50M): {elapsed:.3f} seconds")
186+
187+
print()
188+
189+
except ImportError as e:
190+
print(f"JAX not available: {e}")
191+
except Exception as e:
192+
print(f"JAX benchmark failed: {e}")
193+
194+
def benchmark_numba():
195+
"""Numba CPU benchmark."""
196+
try:
197+
import numba
198+
import numpy as np
199+
200+
print("=" * 60)
201+
print("CPU BENCHMARK: Numba")
202+
print("=" * 60)
203+
204+
@numba.jit(nopython=True)
205+
def numba_sum(n):
206+
total = 0
207+
for i in range(n):
208+
total += i * i
209+
return total
210+
211+
# Warm-up (compilation)
212+
start = time.perf_counter()
213+
result = numba_sum(10_000_000)
214+
warmup_time = time.perf_counter() - start
215+
print(f"Integer sum warm-up (includes compile): {warmup_time:.3f} seconds")
216+
217+
# Compiled run
218+
start = time.perf_counter()
219+
result = numba_sum(10_000_000)
220+
elapsed = time.perf_counter() - start
221+
print(f"Integer sum compiled (10M): {elapsed:.3f} seconds")
222+
223+
@numba.jit(nopython=True, parallel=True)
224+
def numba_parallel_sum(arr):
225+
total = 0.0
226+
for i in numba.prange(len(arr)):
227+
total += arr[i] ** 2
228+
return total
229+
230+
arr = np.random.randn(50_000_000)
231+
232+
# Warm-up
233+
start = time.perf_counter()
234+
result = numba_parallel_sum(arr)
235+
warmup_time = time.perf_counter() - start
236+
print(f"Parallel sum warm-up (50M): {warmup_time:.3f} seconds")
237+
238+
# Compiled
239+
start = time.perf_counter()
240+
result = numba_parallel_sum(arr)
241+
elapsed = time.perf_counter() - start
242+
print(f"Parallel sum compiled (50M): {elapsed:.3f} seconds")
243+
244+
print()
245+
246+
except ImportError as e:
247+
print(f"Numba not available: {e}")
248+
except Exception as e:
249+
print(f"Numba benchmark failed: {e}")
250+
251+
if __name__ == "__main__":
252+
print("\n" + "=" * 60)
253+
print("HARDWARE BENCHMARK FOR CI RUNNER")
254+
print("=" * 60 + "\n")
255+
256+
get_cpu_info()
257+
benchmark_cpu_pure_python()
258+
benchmark_cpu_numpy()
259+
benchmark_numba()
260+
benchmark_gpu_jax()
261+
262+
print("=" * 60)
263+
print("BENCHMARK COMPLETE")
264+
print("=" * 60)

0 commit comments

Comments
 (0)