Skip to content

Commit f8829a4

Browse files
committed
Add multi-pathway benchmark tests (bare metal, Jupyter, jupyter-book)
1 parent 534142c commit f8829a4

File tree

3 files changed

+181
-1
lines changed

3 files changed

+181
-1
lines changed

.github/workflows/ci.yml

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,35 @@ jobs:
3535
- name: Display Pip Versions
3636
shell: bash -l {0}
3737
run: pip list
38-
- name: Run Hardware Benchmarks
38+
- name: Run Hardware Benchmarks (Bare Metal)
3939
shell: bash -l {0}
4040
run: |
4141
pip install jax # Install JAX for CPU
42+
echo "=== Bare Metal Python Script Execution ==="
4243
python scripts/benchmark-hardware.py
44+
- name: Run Jupyter Notebook Benchmark (via nbconvert)
45+
shell: bash -l {0}
46+
run: |
47+
echo "=== Jupyter Kernel Execution ==="
48+
jupyter nbconvert --to notebook --execute scripts/benchmark-jupyter.ipynb --output benchmark-jupyter-executed.ipynb
49+
echo "Notebook executed successfully"
50+
- name: Run Jupyter-Book Benchmark
51+
shell: bash -l {0}
52+
run: |
53+
echo "=== Jupyter-Book Execution ==="
54+
# Build just the benchmark file using jupyter-book
55+
mkdir -p benchmark_test
56+
cp scripts/benchmark-jupyterbook.md benchmark_test/
57+
# Create minimal _config.yml
58+
echo "title: Benchmark Test" > benchmark_test/_config.yml
59+
echo "execute:" >> benchmark_test/_config.yml
60+
echo " execute_notebooks: force" >> benchmark_test/_config.yml
61+
# Create minimal _toc.yml
62+
echo "format: jb-book" > benchmark_test/_toc.yml
63+
echo "root: benchmark-jupyterbook" >> benchmark_test/_toc.yml
64+
# Build
65+
jb build benchmark_test --path-output benchmark_build/
66+
echo "Jupyter-Book build completed successfully"
4367
- name: Download "build" folder (cache)
4468
uses: dawidd6/action-download-artifact@v11
4569
with:

scripts/benchmark-jupyter.ipynb

Whitespace-only changes.

scripts/benchmark-jupyterbook.md

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
---
2+
jupytext:
3+
text_representation:
4+
extension: .md
5+
format_name: myst
6+
format_version: 0.13
7+
kernelspec:
8+
display_name: Python 3 (ipykernel)
9+
language: python
10+
name: python3
11+
---
12+
13+
# JAX Performance Benchmark - Jupyter Book Execution
14+
15+
This file tests JAX performance when executed through Jupyter Book's notebook execution.
16+
Compare results with direct script and nbconvert execution.
17+
18+
```{code-cell} ipython3
19+
import time
20+
import platform
21+
import os
22+
23+
print("=" * 60)
24+
print("JUPYTER BOOK EXECUTION BENCHMARK")
25+
print("=" * 60)
26+
print(f"Platform: {platform.platform()}")
27+
print(f"Python: {platform.python_version()}")
28+
print(f"CPU Count: {os.cpu_count()}")
29+
```
30+
31+
```{code-cell} ipython3
32+
# Import JAX and check devices
33+
import jax
34+
import jax.numpy as jnp
35+
36+
devices = jax.devices()
37+
default_backend = jax.default_backend()
38+
has_gpu = any('cuda' in str(d).lower() or 'gpu' in str(d).lower() for d in devices)
39+
40+
print(f"JAX devices: {devices}")
41+
print(f"Default backend: {default_backend}")
42+
print(f"GPU Available: {has_gpu}")
43+
```
44+
45+
```{code-cell} ipython3
46+
# Define JIT-compiled function
47+
@jax.jit
48+
def matmul(a, b):
49+
return jnp.dot(a, b)
50+
51+
print("matmul function defined with @jax.jit")
52+
```
53+
54+
```{code-cell} ipython3
55+
# Benchmark 1: Small matrix (1000x1000) - includes JIT compilation
56+
print("\n" + "=" * 60)
57+
print("BENCHMARK 1: Small Matrix (1000x1000)")
58+
print("=" * 60)
59+
60+
n = 1000
61+
key = jax.random.PRNGKey(0)
62+
A = jax.random.normal(key, (n, n))
63+
B = jax.random.normal(key, (n, n))
64+
65+
# Warm-up run (includes compilation)
66+
start = time.perf_counter()
67+
C = matmul(A, B).block_until_ready()
68+
warmup_time = time.perf_counter() - start
69+
print(f"Warm-up (includes JIT compile): {warmup_time:.3f} seconds")
70+
71+
# Compiled run
72+
start = time.perf_counter()
73+
C = matmul(A, B).block_until_ready()
74+
compiled_time = time.perf_counter() - start
75+
print(f"Compiled execution: {compiled_time:.3f} seconds")
76+
```
77+
78+
```{code-cell} ipython3
79+
# Benchmark 2: Large matrix (3000x3000) - triggers recompilation
80+
print("\n" + "=" * 60)
81+
print("BENCHMARK 2: Large Matrix (3000x3000)")
82+
print("=" * 60)
83+
84+
n = 3000
85+
A = jax.random.normal(key, (n, n))
86+
B = jax.random.normal(key, (n, n))
87+
88+
# Warm-up run (recompilation for new size)
89+
start = time.perf_counter()
90+
C = matmul(A, B).block_until_ready()
91+
warmup_time = time.perf_counter() - start
92+
print(f"Warm-up (recompile for new size): {warmup_time:.3f} seconds")
93+
94+
# Compiled run
95+
start = time.perf_counter()
96+
C = matmul(A, B).block_until_ready()
97+
compiled_time = time.perf_counter() - start
98+
print(f"Compiled execution: {compiled_time:.3f} seconds")
99+
```
100+
101+
```{code-cell} ipython3
102+
# Benchmark 3: Element-wise operations (50M elements)
103+
print("\n" + "=" * 60)
104+
print("BENCHMARK 3: Element-wise Operations (50M elements)")
105+
print("=" * 60)
106+
107+
@jax.jit
108+
def elementwise_ops(x):
109+
return jnp.cos(x**2) + jnp.sin(x)
110+
111+
x = jax.random.normal(key, (50_000_000,))
112+
113+
# Warm-up
114+
start = time.perf_counter()
115+
y = elementwise_ops(x).block_until_ready()
116+
warmup_time = time.perf_counter() - start
117+
print(f"Warm-up (includes JIT compile): {warmup_time:.3f} seconds")
118+
119+
# Compiled
120+
start = time.perf_counter()
121+
y = elementwise_ops(x).block_until_ready()
122+
compiled_time = time.perf_counter() - start
123+
print(f"Compiled execution: {compiled_time:.3f} seconds")
124+
```
125+
126+
```{code-cell} ipython3
127+
# Benchmark 4: Multiple small operations (simulates lecture cells)
128+
print("\n" + "=" * 60)
129+
print("BENCHMARK 4: Multiple Small Operations (lecture simulation)")
130+
print("=" * 60)
131+
132+
total_start = time.perf_counter()
133+
134+
# Simulate multiple cell executions with different operations
135+
for i, size in enumerate([100, 500, 1000, 2000, 3000]):
136+
@jax.jit
137+
def compute(a, b):
138+
return jnp.dot(a, b) + jnp.sum(a)
139+
140+
A = jax.random.normal(key, (size, size))
141+
B = jax.random.normal(key, (size, size))
142+
143+
start = time.perf_counter()
144+
result = compute(A, B).block_until_ready()
145+
elapsed = time.perf_counter() - start
146+
print(f" Size {size}x{size}: {elapsed:.3f} seconds")
147+
148+
total_time = time.perf_counter() - total_start
149+
print(f"\nTotal time for all operations: {total_time:.3f} seconds")
150+
```
151+
152+
```{code-cell} ipython3
153+
print("\n" + "=" * 60)
154+
print("JUPYTER BOOK EXECUTION BENCHMARK COMPLETE")
155+
print("=" * 60)
156+
```

0 commit comments

Comments
 (0)