Skip to content

Commit ef69a1a

Browse files
committed
Fix: Add benchmark content to benchmark-jupyter.ipynb
1 parent f8829a4 commit ef69a1a

File tree

1 file changed

+207
-0
lines changed

1 file changed

+207
-0
lines changed

scripts/benchmark-jupyter.ipynb

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# JAX Performance Benchmark - Jupyter Kernel Execution\n",
8+
"\n",
9+
"This notebook tests JAX performance when executed through a Jupyter kernel.\n",
10+
"Compare results with direct script and jupyter-book execution."
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": null,
16+
"metadata": {},
17+
"outputs": [],
18+
"source": [
19+
"import time\n",
20+
"import platform\n",
21+
"import os\n",
22+
"\n",
23+
"print(\"=\" * 60)\n",
24+
"print(\"JUPYTER KERNEL EXECUTION BENCHMARK\")\n",
25+
"print(\"=\" * 60)\n",
26+
"print(f\"Platform: {platform.platform()}\")\n",
27+
"print(f\"Python: {platform.python_version()}\")\n",
28+
"print(f\"CPU Count: {os.cpu_count()}\")"
29+
]
30+
},
31+
{
32+
"cell_type": "code",
33+
"execution_count": null,
34+
"metadata": {},
35+
"outputs": [],
36+
"source": [
37+
"# Import JAX and check devices\n",
38+
"import jax\n",
39+
"import jax.numpy as jnp\n",
40+
"\n",
41+
"devices = jax.devices()\n",
42+
"default_backend = jax.default_backend()\n",
43+
"has_gpu = any('cuda' in str(d).lower() or 'gpu' in str(d).lower() for d in devices)\n",
44+
"\n",
45+
"print(f\"JAX devices: {devices}\")\n",
46+
"print(f\"Default backend: {default_backend}\")\n",
47+
"print(f\"GPU Available: {has_gpu}\")"
48+
]
49+
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": null,
53+
"metadata": {},
54+
"outputs": [],
55+
"source": [
56+
"# Define JIT-compiled function\n",
57+
"@jax.jit\n",
58+
"def matmul(a, b):\n",
59+
" return jnp.dot(a, b)\n",
60+
"\n",
61+
"print(\"matmul function defined with @jax.jit\")"
62+
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": null,
67+
"metadata": {},
68+
"outputs": [],
69+
"source": [
70+
"# Benchmark 1: Small matrix (1000x1000) - includes JIT compilation\n",
71+
"print(\"\\n\" + \"=\" * 60)\n",
72+
"print(\"BENCHMARK 1: Small Matrix (1000x1000)\")\n",
73+
"print(\"=\" * 60)\n",
74+
"\n",
75+
"n = 1000\n",
76+
"key = jax.random.PRNGKey(0)\n",
77+
"A = jax.random.normal(key, (n, n))\n",
78+
"B = jax.random.normal(key, (n, n))\n",
79+
"\n",
80+
"# Warm-up run (includes compilation)\n",
81+
"start = time.perf_counter()\n",
82+
"C = matmul(A, B).block_until_ready()\n",
83+
"warmup_time = time.perf_counter() - start\n",
84+
"print(f\"Warm-up (includes JIT compile): {warmup_time:.3f} seconds\")\n",
85+
"\n",
86+
"# Compiled run\n",
87+
"start = time.perf_counter()\n",
88+
"C = matmul(A, B).block_until_ready()\n",
89+
"compiled_time = time.perf_counter() - start\n",
90+
"print(f\"Compiled execution: {compiled_time:.3f} seconds\")"
91+
]
92+
},
93+
{
94+
"cell_type": "code",
95+
"execution_count": null,
96+
"metadata": {},
97+
"outputs": [],
98+
"source": [
99+
"# Benchmark 2: Large matrix (3000x3000) - triggers recompilation\n",
100+
"print(\"\\n\" + \"=\" * 60)\n",
101+
"print(\"BENCHMARK 2: Large Matrix (3000x3000)\")\n",
102+
"print(\"=\" * 60)\n",
103+
"\n",
104+
"n = 3000\n",
105+
"A = jax.random.normal(key, (n, n))\n",
106+
"B = jax.random.normal(key, (n, n))\n",
107+
"\n",
108+
"# Warm-up run (recompilation for new size)\n",
109+
"start = time.perf_counter()\n",
110+
"C = matmul(A, B).block_until_ready()\n",
111+
"warmup_time = time.perf_counter() - start\n",
112+
"print(f\"Warm-up (recompile for new size): {warmup_time:.3f} seconds\")\n",
113+
"\n",
114+
"# Compiled run\n",
115+
"start = time.perf_counter()\n",
116+
"C = matmul(A, B).block_until_ready()\n",
117+
"compiled_time = time.perf_counter() - start\n",
118+
"print(f\"Compiled execution: {compiled_time:.3f} seconds\")"
119+
]
120+
},
121+
{
122+
"cell_type": "code",
123+
"execution_count": null,
124+
"metadata": {},
125+
"outputs": [],
126+
"source": [
127+
"# Benchmark 3: Element-wise operations (50M elements)\n",
128+
"print(\"\\n\" + \"=\" * 60)\n",
129+
"print(\"BENCHMARK 3: Element-wise Operations (50M elements)\")\n",
130+
"print(\"=\" * 60)\n",
131+
"\n",
132+
"@jax.jit\n",
133+
"def elementwise_ops(x):\n",
134+
" return jnp.cos(x**2) + jnp.sin(x)\n",
135+
"\n",
136+
"x = jax.random.normal(key, (50_000_000,))\n",
137+
"\n",
138+
"# Warm-up\n",
139+
"start = time.perf_counter()\n",
140+
"y = elementwise_ops(x).block_until_ready()\n",
141+
"warmup_time = time.perf_counter() - start\n",
142+
"print(f\"Warm-up (includes JIT compile): {warmup_time:.3f} seconds\")\n",
143+
"\n",
144+
"# Compiled\n",
145+
"start = time.perf_counter()\n",
146+
"y = elementwise_ops(x).block_until_ready()\n",
147+
"compiled_time = time.perf_counter() - start\n",
148+
"print(f\"Compiled execution: {compiled_time:.3f} seconds\")"
149+
]
150+
},
151+
{
152+
"cell_type": "code",
153+
"execution_count": null,
154+
"metadata": {},
155+
"outputs": [],
156+
"source": [
157+
"# Benchmark 4: Multiple small operations (simulates lecture cells)\n",
158+
"print(\"\\n\" + \"=\" * 60)\n",
159+
"print(\"BENCHMARK 4: Multiple Small Operations (lecture simulation)\")\n",
160+
"print(\"=\" * 60)\n",
161+
"\n",
162+
"total_start = time.perf_counter()\n",
163+
"\n",
164+
"# Simulate multiple cell executions with different operations\n",
165+
"for i, size in enumerate([100, 500, 1000, 2000, 3000]):\n",
166+
" @jax.jit\n",
167+
" def compute(a, b):\n",
168+
" return jnp.dot(a, b) + jnp.sum(a)\n",
169+
" \n",
170+
" A = jax.random.normal(key, (size, size))\n",
171+
" B = jax.random.normal(key, (size, size))\n",
172+
" \n",
173+
" start = time.perf_counter()\n",
174+
" result = compute(A, B).block_until_ready()\n",
175+
" elapsed = time.perf_counter() - start\n",
176+
" print(f\" Size {size}x{size}: {elapsed:.3f} seconds\")\n",
177+
"\n",
178+
"total_time = time.perf_counter() - total_start\n",
179+
"print(f\"\\nTotal time for all operations: {total_time:.3f} seconds\")"
180+
]
181+
},
182+
{
183+
"cell_type": "code",
184+
"execution_count": null,
185+
"metadata": {},
186+
"outputs": [],
187+
"source": [
188+
"print(\"\\n\" + \"=\" * 60)\n",
189+
"print(\"JUPYTER KERNEL EXECUTION BENCHMARK COMPLETE\")\n",
190+
"print(\"=\" * 60)"
191+
]
192+
}
193+
],
194+
"metadata": {
195+
"kernelspec": {
196+
"display_name": "Python 3",
197+
"language": "python",
198+
"name": "python3"
199+
},
200+
"language_info": {
201+
"name": "python",
202+
"version": "3.13.0"
203+
}
204+
},
205+
"nbformat": 4,
206+
"nbformat_minor": 4
207+
}

0 commit comments

Comments
 (0)