|
| 1 | +--- |
| 2 | +jupytext: |
| 3 | + text_representation: |
| 4 | + extension: .md |
| 5 | + format_name: myst |
| 6 | + format_version: 0.13 |
| 7 | +kernelspec: |
| 8 | + display_name: Python 3 (ipykernel) |
| 9 | + language: python |
| 10 | + name: python3 |
| 11 | +--- |
| 12 | + |
| 13 | +# JAX Performance Benchmark - Jupyter Book Execution |
| 14 | + |
| 15 | +This file tests JAX performance when executed through Jupyter Book's notebook execution. |
| 16 | +Compare results with direct script and nbconvert execution. |
| 17 | + |
| 18 | +```{code-cell} ipython3 |
| 19 | +import time |
| 20 | +import platform |
| 21 | +import os |
| 22 | +
|
| 23 | +print("=" * 60) |
| 24 | +print("JUPYTER BOOK EXECUTION BENCHMARK") |
| 25 | +print("=" * 60) |
| 26 | +print(f"Platform: {platform.platform()}") |
| 27 | +print(f"Python: {platform.python_version()}") |
| 28 | +print(f"CPU Count: {os.cpu_count()}") |
| 29 | +``` |
| 30 | + |
| 31 | +```{code-cell} ipython3 |
| 32 | +# Import JAX and check devices |
| 33 | +import jax |
| 34 | +import jax.numpy as jnp |
| 35 | +
|
| 36 | +devices = jax.devices() |
| 37 | +default_backend = jax.default_backend() |
| 38 | +has_gpu = any('cuda' in str(d).lower() or 'gpu' in str(d).lower() for d in devices) |
| 39 | +
|
| 40 | +print(f"JAX devices: {devices}") |
| 41 | +print(f"Default backend: {default_backend}") |
| 42 | +print(f"GPU Available: {has_gpu}") |
| 43 | +``` |
| 44 | + |
| 45 | +```{code-cell} ipython3 |
| 46 | +# Define JIT-compiled function |
| 47 | +@jax.jit |
| 48 | +def matmul(a, b): |
| 49 | + return jnp.dot(a, b) |
| 50 | +
|
| 51 | +print("matmul function defined with @jax.jit") |
| 52 | +``` |
| 53 | + |
| 54 | +```{code-cell} ipython3 |
| 55 | +# Benchmark 1: Small matrix (1000x1000) - includes JIT compilation |
| 56 | +print("\n" + "=" * 60) |
| 57 | +print("BENCHMARK 1: Small Matrix (1000x1000)") |
| 58 | +print("=" * 60) |
| 59 | +
|
| 60 | +n = 1000 |
| 61 | +key = jax.random.PRNGKey(0) |
| 62 | +A = jax.random.normal(key, (n, n)) |
| 63 | +B = jax.random.normal(key, (n, n)) |
| 64 | +
|
| 65 | +# Warm-up run (includes compilation) |
| 66 | +start = time.perf_counter() |
| 67 | +C = matmul(A, B).block_until_ready() |
| 68 | +warmup_time = time.perf_counter() - start |
| 69 | +print(f"Warm-up (includes JIT compile): {warmup_time:.3f} seconds") |
| 70 | +
|
| 71 | +# Compiled run |
| 72 | +start = time.perf_counter() |
| 73 | +C = matmul(A, B).block_until_ready() |
| 74 | +compiled_time = time.perf_counter() - start |
| 75 | +print(f"Compiled execution: {compiled_time:.3f} seconds") |
| 76 | +``` |
| 77 | + |
| 78 | +```{code-cell} ipython3 |
| 79 | +# Benchmark 2: Large matrix (3000x3000) - triggers recompilation |
| 80 | +print("\n" + "=" * 60) |
| 81 | +print("BENCHMARK 2: Large Matrix (3000x3000)") |
| 82 | +print("=" * 60) |
| 83 | +
|
| 84 | +n = 3000 |
| 85 | +A = jax.random.normal(key, (n, n)) |
| 86 | +B = jax.random.normal(key, (n, n)) |
| 87 | +
|
| 88 | +# Warm-up run (recompilation for new size) |
| 89 | +start = time.perf_counter() |
| 90 | +C = matmul(A, B).block_until_ready() |
| 91 | +warmup_time = time.perf_counter() - start |
| 92 | +print(f"Warm-up (recompile for new size): {warmup_time:.3f} seconds") |
| 93 | +
|
| 94 | +# Compiled run |
| 95 | +start = time.perf_counter() |
| 96 | +C = matmul(A, B).block_until_ready() |
| 97 | +compiled_time = time.perf_counter() - start |
| 98 | +print(f"Compiled execution: {compiled_time:.3f} seconds") |
| 99 | +``` |
| 100 | + |
| 101 | +```{code-cell} ipython3 |
| 102 | +# Benchmark 3: Element-wise operations (50M elements) |
| 103 | +print("\n" + "=" * 60) |
| 104 | +print("BENCHMARK 3: Element-wise Operations (50M elements)") |
| 105 | +print("=" * 60) |
| 106 | +
|
| 107 | +@jax.jit |
| 108 | +def elementwise_ops(x): |
| 109 | + return jnp.cos(x**2) + jnp.sin(x) |
| 110 | +
|
| 111 | +x = jax.random.normal(key, (50_000_000,)) |
| 112 | +
|
| 113 | +# Warm-up |
| 114 | +start = time.perf_counter() |
| 115 | +y = elementwise_ops(x).block_until_ready() |
| 116 | +warmup_time = time.perf_counter() - start |
| 117 | +print(f"Warm-up (includes JIT compile): {warmup_time:.3f} seconds") |
| 118 | +
|
| 119 | +# Compiled |
| 120 | +start = time.perf_counter() |
| 121 | +y = elementwise_ops(x).block_until_ready() |
| 122 | +compiled_time = time.perf_counter() - start |
| 123 | +print(f"Compiled execution: {compiled_time:.3f} seconds") |
| 124 | +``` |
| 125 | + |
| 126 | +```{code-cell} ipython3 |
| 127 | +# Benchmark 4: Multiple small operations (simulates lecture cells) |
| 128 | +print("\n" + "=" * 60) |
| 129 | +print("BENCHMARK 4: Multiple Small Operations (lecture simulation)") |
| 130 | +print("=" * 60) |
| 131 | +
|
| 132 | +total_start = time.perf_counter() |
| 133 | +
|
| 134 | +# Simulate multiple cell executions with different operations |
| 135 | +for i, size in enumerate([100, 500, 1000, 2000, 3000]): |
| 136 | + @jax.jit |
| 137 | + def compute(a, b): |
| 138 | + return jnp.dot(a, b) + jnp.sum(a) |
| 139 | + |
| 140 | + A = jax.random.normal(key, (size, size)) |
| 141 | + B = jax.random.normal(key, (size, size)) |
| 142 | + |
| 143 | + start = time.perf_counter() |
| 144 | + result = compute(A, B).block_until_ready() |
| 145 | + elapsed = time.perf_counter() - start |
| 146 | + print(f" Size {size}x{size}: {elapsed:.3f} seconds") |
| 147 | +
|
| 148 | +total_time = time.perf_counter() - total_start |
| 149 | +print(f"\nTotal time for all operations: {total_time:.3f} seconds") |
| 150 | +``` |
| 151 | + |
| 152 | +```{code-cell} ipython3 |
| 153 | +print("\n" + "=" * 60) |
| 154 | +print("JUPYTER BOOK EXECUTION BENCHMARK COMPLETE") |
| 155 | +print("=" * 60) |
| 156 | +``` |
0 commit comments