diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 51755ea9..18d09454 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,6 +35,68 @@ jobs: - name: Display Pip Versions shell: bash -l {0} run: pip list + - name: Run Hardware Benchmarks (Bare Metal) + shell: bash -l {0} + run: | + pip install jax # Install JAX for CPU + echo "=== Bare Metal Python Script Execution ===" + python scripts/benchmark-hardware.py + - name: Run Jupyter Notebook Benchmark (via nbconvert) + shell: bash -l {0} + run: | + echo "=== Jupyter Kernel Execution ===" + cd scripts + jupyter nbconvert --to notebook --execute benchmark-jupyter.ipynb --output benchmark-jupyter-executed.ipynb + echo "Notebook executed successfully" + cd .. + - name: Run Jupyter-Book Benchmark + shell: bash -l {0} + run: | + echo "=== Jupyter-Book Execution ===" + # Build just the benchmark file using jupyter-book + mkdir -p benchmark_test + cp scripts/benchmark-jupyterbook.md benchmark_test/ + # Create minimal _config.yml + echo "title: Benchmark Test" > benchmark_test/_config.yml + echo "execute:" >> benchmark_test/_config.yml + echo " execute_notebooks: force" >> benchmark_test/_config.yml + # Create minimal _toc.yml + echo "format: jb-book" > benchmark_test/_toc.yml + echo "root: benchmark-jupyterbook" >> benchmark_test/_toc.yml + # Build (run from benchmark_test so JSON is written there) + cd benchmark_test + jb build . --path-output ../benchmark_build/ + cd .. + echo "Jupyter-Book build completed successfully" + - name: Collect and Display Benchmark Results + shell: bash -l {0} + run: | + echo "=== Collecting Benchmark Results ===" + mkdir -p benchmark_results + + # Copy results from each pathway + cp benchmark_results_bare_metal.json benchmark_results/ 2>/dev/null || echo "No bare metal results" + cp scripts/benchmark_results_jupyter.json benchmark_results/ 2>/dev/null || echo "No jupyter results" + cp benchmark_test/benchmark_results_jupyterbook.json benchmark_results/ 2>/dev/null || echo "No jupyterbook results" + + # Display summary + echo "" + echo "============================================================" + echo "BENCHMARK RESULTS SUMMARY" + echo "============================================================" + for f in benchmark_results/*.json; do + if [ -f "$f" ]; then + echo "" + echo "--- $(basename $f) ---" + cat "$f" + fi + done + - name: Upload Benchmark Results + uses: actions/upload-artifact@v5 + with: + name: benchmark-results + path: benchmark_results/ + if-no-files-found: warn - name: Download "build" folder (cache) uses: dawidd6/action-download-artifact@v11 with: diff --git a/scripts/benchmark-hardware.py b/scripts/benchmark-hardware.py new file mode 100644 index 00000000..12443855 --- /dev/null +++ b/scripts/benchmark-hardware.py @@ -0,0 +1,336 @@ +""" +Hardware benchmark script for CI runners. +Compares CPU and GPU performance to diagnose slowdowns. +Works on both CPU-only (GitHub Actions) and GPU (RunsOn) runners. +""" +import time +import platform +import os +import json +from datetime import datetime + +# Global results dictionary +RESULTS = { + "pathway": "bare_metal", + "timestamp": datetime.now().isoformat(), + "system": {}, + "benchmarks": {} +} + +def get_cpu_info(): + """Get CPU information.""" + print("=" * 60) + print("SYSTEM INFORMATION") + print("=" * 60) + print(f"Platform: {platform.platform()}") + print(f"Processor: {platform.processor()}") + print(f"Python: {platform.python_version()}") + + RESULTS["system"]["platform"] = platform.platform() + RESULTS["system"]["processor"] = platform.processor() + RESULTS["system"]["python"] = platform.python_version() + RESULTS["system"]["cpu_count"] = os.cpu_count() + + # Try to get CPU model + cpu_model = None + cpu_mhz = None + try: + with open('/proc/cpuinfo', 'r') as f: + for line in f: + if 'model name' in line: + cpu_model = line.split(':')[1].strip() + print(f"CPU Model: {cpu_model}") + break + except: + pass + + # Try to get CPU frequency + try: + with open('/proc/cpuinfo', 'r') as f: + for line in f: + if 'cpu MHz' in line: + cpu_mhz = line.split(':')[1].strip() + print(f"CPU MHz: {cpu_mhz}") + break + except: + pass + + RESULTS["system"]["cpu_model"] = cpu_model + RESULTS["system"]["cpu_mhz"] = cpu_mhz + + # CPU count + print(f"CPU Count: {os.cpu_count()}") + + # Check for GPU + gpu_info = None + try: + import subprocess + result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv,noheader'], + capture_output=True, text=True, timeout=5) + if result.returncode == 0: + gpu_info = result.stdout.strip() + print(f"GPU: {gpu_info}") + else: + print("GPU: None detected") + except: + print("GPU: None detected (nvidia-smi not available)") + + RESULTS["system"]["gpu"] = gpu_info + print() + +def benchmark_cpu_pure_python(): + """Pure Python CPU benchmark.""" + print("=" * 60) + print("CPU BENCHMARK: Pure Python") + print("=" * 60) + + results = {} + + # Integer computation + start = time.perf_counter() + total = sum(i * i for i in range(10_000_000)) + elapsed = time.perf_counter() - start + print(f"Integer sum (10M iterations): {elapsed:.3f} seconds") + results["integer_sum_10m"] = elapsed + + # Float computation + start = time.perf_counter() + total = 0.0 + for i in range(1_000_000): + total += (i * 0.1) ** 0.5 + elapsed = time.perf_counter() - start + print(f"Float sqrt (1M iterations): {elapsed:.3f} seconds") + results["float_sqrt_1m"] = elapsed + print() + + RESULTS["benchmarks"]["pure_python"] = results + +def benchmark_cpu_numpy(): + """NumPy CPU benchmark.""" + import numpy as np + + print("=" * 60) + print("CPU BENCHMARK: NumPy") + print("=" * 60) + + results = {} + + # Matrix multiplication + n = 3000 + A = np.random.randn(n, n) + B = np.random.randn(n, n) + + start = time.perf_counter() + C = A @ B + elapsed = time.perf_counter() - start + print(f"Matrix multiply ({n}x{n}): {elapsed:.3f} seconds") + results["matmul_3000x3000"] = elapsed + + # Element-wise operations + x = np.random.randn(50_000_000) + + start = time.perf_counter() + y = np.cos(x**2) + np.sin(x) + elapsed = time.perf_counter() - start + print(f"Element-wise ops (50M elements): {elapsed:.3f} seconds") + results["elementwise_50m"] = elapsed + print() + + RESULTS["benchmarks"]["numpy"] = results + +def benchmark_gpu_jax(): + """JAX benchmark (GPU if available, otherwise CPU).""" + try: + import jax + import jax.numpy as jnp + + devices = jax.devices() + default_backend = jax.default_backend() + + # Check if GPU is available + has_gpu = any('cuda' in str(d).lower() or 'gpu' in str(d).lower() for d in devices) + + print("=" * 60) + if has_gpu: + print("JAX BENCHMARK: GPU") + else: + print("JAX BENCHMARK: CPU (no GPU detected)") + print("=" * 60) + + print(f"JAX devices: {devices}") + print(f"Default backend: {default_backend}") + print(f"GPU Available: {has_gpu}") + print() + + results = { + "backend": default_backend, + "has_gpu": has_gpu, + "devices": str(devices) + } + + # Warm-up JIT compilation + print("Warming up JIT compilation...") + n = 1000 + key = jax.random.PRNGKey(0) + A = jax.random.normal(key, (n, n)) + B = jax.random.normal(key, (n, n)) + + @jax.jit + def matmul(a, b): + return jnp.dot(a, b) + + # Warm-up run (includes compilation) + start = time.perf_counter() + C = matmul(A, B).block_until_ready() + warmup_time = time.perf_counter() - start + print(f"Warm-up (includes JIT compile, {n}x{n}): {warmup_time:.3f} seconds") + results["matmul_1000x1000_warmup"] = warmup_time + + # Actual benchmark (compiled) + start = time.perf_counter() + C = matmul(A, B).block_until_ready() + elapsed = time.perf_counter() - start + print(f"Matrix multiply compiled ({n}x{n}): {elapsed:.3f} seconds") + results["matmul_1000x1000_compiled"] = elapsed + + # Larger matrix + n = 3000 + A = jax.random.normal(key, (n, n)) + B = jax.random.normal(key, (n, n)) + + # Warm-up for new size + start = time.perf_counter() + C = matmul(A, B).block_until_ready() + warmup_time = time.perf_counter() - start + print(f"Warm-up (recompile for {n}x{n}): {warmup_time:.3f} seconds") + results["matmul_3000x3000_warmup"] = warmup_time + + # Benchmark compiled + start = time.perf_counter() + C = matmul(A, B).block_until_ready() + elapsed = time.perf_counter() - start + print(f"Matrix multiply compiled ({n}x{n}): {elapsed:.3f} seconds") + results["matmul_3000x3000_compiled"] = elapsed + + # Element-wise GPU benchmark + x = jax.random.normal(key, (50_000_000,)) + + @jax.jit + def elementwise_ops(x): + return jnp.cos(x**2) + jnp.sin(x) + + # Warm-up + start = time.perf_counter() + y = elementwise_ops(x).block_until_ready() + warmup_time = time.perf_counter() - start + print(f"Element-wise warm-up (50M): {warmup_time:.3f} seconds") + results["elementwise_50m_warmup"] = warmup_time + + # Compiled + start = time.perf_counter() + y = elementwise_ops(x).block_until_ready() + elapsed = time.perf_counter() - start + print(f"Element-wise compiled (50M): {elapsed:.3f} seconds") + results["elementwise_50m_compiled"] = elapsed + + print() + RESULTS["benchmarks"]["jax"] = results + + except ImportError as e: + print(f"JAX not available: {e}") + RESULTS["benchmarks"]["jax"] = {"error": str(e)} + except Exception as e: + print(f"JAX benchmark failed: {e}") + RESULTS["benchmarks"]["jax"] = {"error": str(e)} + +def benchmark_numba(): + """Numba CPU benchmark.""" + try: + import numba + import numpy as np + + print("=" * 60) + print("CPU BENCHMARK: Numba") + print("=" * 60) + + results = {} + + @numba.jit(nopython=True) + def numba_sum(n): + total = 0 + for i in range(n): + total += i * i + return total + + # Warm-up (compilation) + start = time.perf_counter() + result = numba_sum(10_000_000) + warmup_time = time.perf_counter() - start + print(f"Integer sum warm-up (includes compile): {warmup_time:.3f} seconds") + results["integer_sum_10m_warmup"] = warmup_time + + # Compiled run + start = time.perf_counter() + result = numba_sum(10_000_000) + elapsed = time.perf_counter() - start + print(f"Integer sum compiled (10M): {elapsed:.3f} seconds") + results["integer_sum_10m_compiled"] = elapsed + + @numba.jit(nopython=True, parallel=True) + def numba_parallel_sum(arr): + total = 0.0 + for i in numba.prange(len(arr)): + total += arr[i] ** 2 + return total + + arr = np.random.randn(50_000_000) + + # Warm-up + start = time.perf_counter() + result = numba_parallel_sum(arr) + warmup_time = time.perf_counter() - start + print(f"Parallel sum warm-up (50M): {warmup_time:.3f} seconds") + results["parallel_sum_50m_warmup"] = warmup_time + + # Compiled + start = time.perf_counter() + result = numba_parallel_sum(arr) + elapsed = time.perf_counter() - start + print(f"Parallel sum compiled (50M): {elapsed:.3f} seconds") + results["parallel_sum_50m_compiled"] = elapsed + + print() + RESULTS["benchmarks"]["numba"] = results + + except ImportError as e: + print(f"Numba not available: {e}") + RESULTS["benchmarks"]["numba"] = {"error": str(e)} + except Exception as e: + print(f"Numba benchmark failed: {e}") + RESULTS["benchmarks"]["numba"] = {"error": str(e)} + + +def save_results(output_path="benchmark_results_bare_metal.json"): + """Save benchmark results to JSON file.""" + with open(output_path, 'w') as f: + json.dump(RESULTS, f, indent=2) + print(f"\nResults saved to: {output_path}") + + +if __name__ == "__main__": + print("\n" + "=" * 60) + print("HARDWARE BENCHMARK FOR CI RUNNER") + print("=" * 60 + "\n") + + get_cpu_info() + benchmark_cpu_pure_python() + benchmark_cpu_numpy() + benchmark_numba() + benchmark_gpu_jax() + + # Save results to JSON + save_results("benchmark_results_bare_metal.json") + + print("=" * 60) + print("BENCHMARK COMPLETE") + print("=" * 60) diff --git a/scripts/benchmark-jupyter.ipynb b/scripts/benchmark-jupyter.ipynb new file mode 100644 index 00000000..909b8fe5 --- /dev/null +++ b/scripts/benchmark-jupyter.ipynb @@ -0,0 +1,247 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# JAX Performance Benchmark - Jupyter Kernel Execution\n", + "\n", + "This notebook tests JAX performance when executed through a Jupyter kernel.\n", + "Compare results with direct script and jupyter-book execution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "import platform\n", + "import os\n", + "import json\n", + "from datetime import datetime\n", + "\n", + "# Initialize results dictionary\n", + "RESULTS = {\n", + " \"pathway\": \"jupyter_kernel\",\n", + " \"timestamp\": datetime.now().isoformat(),\n", + " \"system\": {\n", + " \"platform\": platform.platform(),\n", + " \"python\": platform.python_version(),\n", + " \"cpu_count\": os.cpu_count()\n", + " },\n", + " \"benchmarks\": {}\n", + "}\n", + "\n", + "print(\"=\" * 60)\n", + "print(\"JUPYTER KERNEL EXECUTION BENCHMARK\")\n", + "print(\"=\" * 60)\n", + "print(f\"Platform: {platform.platform()}\")\n", + "print(f\"Python: {platform.python_version()}\")\n", + "print(f\"CPU Count: {os.cpu_count()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import JAX and check devices\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "devices = jax.devices()\n", + "default_backend = jax.default_backend()\n", + "has_gpu = any('cuda' in str(d).lower() or 'gpu' in str(d).lower() for d in devices)\n", + "\n", + "print(f\"JAX devices: {devices}\")\n", + "print(f\"Default backend: {default_backend}\")\n", + "print(f\"GPU Available: {has_gpu}\")\n", + "\n", + "RESULTS[\"system\"][\"jax_backend\"] = default_backend\n", + "RESULTS[\"system\"][\"has_gpu\"] = has_gpu\n", + "RESULTS[\"system\"][\"jax_devices\"] = str(devices)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define JIT-compiled function\n", + "@jax.jit\n", + "def matmul(a, b):\n", + " return jnp.dot(a, b)\n", + "\n", + "print(\"matmul function defined with @jax.jit\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Benchmark 1: Small matrix (1000x1000) - includes JIT compilation\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"BENCHMARK 1: Small Matrix (1000x1000)\")\n", + "print(\"=\" * 60)\n", + "\n", + "n = 1000\n", + "key = jax.random.PRNGKey(0)\n", + "A = jax.random.normal(key, (n, n))\n", + "B = jax.random.normal(key, (n, n))\n", + "\n", + "# Warm-up run (includes compilation)\n", + "start = time.perf_counter()\n", + "C = matmul(A, B).block_until_ready()\n", + "warmup_time = time.perf_counter() - start\n", + "print(f\"Warm-up (includes JIT compile): {warmup_time:.3f} seconds\")\n", + "\n", + "# Compiled run\n", + "start = time.perf_counter()\n", + "C = matmul(A, B).block_until_ready()\n", + "compiled_time = time.perf_counter() - start\n", + "print(f\"Compiled execution: {compiled_time:.3f} seconds\")\n", + "\n", + "RESULTS[\"benchmarks\"][\"matmul_1000x1000_warmup\"] = warmup_time\n", + "RESULTS[\"benchmarks\"][\"matmul_1000x1000_compiled\"] = compiled_time" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Benchmark 2: Large matrix (3000x3000) - triggers recompilation\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"BENCHMARK 2: Large Matrix (3000x3000)\")\n", + "print(\"=\" * 60)\n", + "\n", + "n = 3000\n", + "A = jax.random.normal(key, (n, n))\n", + "B = jax.random.normal(key, (n, n))\n", + "\n", + "# Warm-up run (recompilation for new size)\n", + "start = time.perf_counter()\n", + "C = matmul(A, B).block_until_ready()\n", + "warmup_time = time.perf_counter() - start\n", + "print(f\"Warm-up (recompile for new size): {warmup_time:.3f} seconds\")\n", + "\n", + "# Compiled run\n", + "start = time.perf_counter()\n", + "C = matmul(A, B).block_until_ready()\n", + "compiled_time = time.perf_counter() - start\n", + "print(f\"Compiled execution: {compiled_time:.3f} seconds\")\n", + "\n", + "RESULTS[\"benchmarks\"][\"matmul_3000x3000_warmup\"] = warmup_time\n", + "RESULTS[\"benchmarks\"][\"matmul_3000x3000_compiled\"] = compiled_time" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Benchmark 3: Element-wise operations (50M elements)\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"BENCHMARK 3: Element-wise Operations (50M elements)\")\n", + "print(\"=\" * 60)\n", + "\n", + "@jax.jit\n", + "def elementwise_ops(x):\n", + " return jnp.cos(x**2) + jnp.sin(x)\n", + "\n", + "x = jax.random.normal(key, (50_000_000,))\n", + "\n", + "# Warm-up\n", + "start = time.perf_counter()\n", + "y = elementwise_ops(x).block_until_ready()\n", + "warmup_time = time.perf_counter() - start\n", + "print(f\"Warm-up (includes JIT compile): {warmup_time:.3f} seconds\")\n", + "\n", + "# Compiled\n", + "start = time.perf_counter()\n", + "y = elementwise_ops(x).block_until_ready()\n", + "compiled_time = time.perf_counter() - start\n", + "print(f\"Compiled execution: {compiled_time:.3f} seconds\")\n", + "\n", + "RESULTS[\"benchmarks\"][\"elementwise_50m_warmup\"] = warmup_time\n", + "RESULTS[\"benchmarks\"][\"elementwise_50m_compiled\"] = compiled_time" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Benchmark 4: Multiple small operations (simulates lecture cells)\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"BENCHMARK 4: Multiple Small Operations (lecture simulation)\")\n", + "print(\"=\" * 60)\n", + "\n", + "total_start = time.perf_counter()\n", + "multi_results = {}\n", + "\n", + "# Simulate multiple cell executions with different operations\n", + "for i, size in enumerate([100, 500, 1000, 2000, 3000]):\n", + " @jax.jit\n", + " def compute(a, b):\n", + " return jnp.dot(a, b) + jnp.sum(a)\n", + " \n", + " A = jax.random.normal(key, (size, size))\n", + " B = jax.random.normal(key, (size, size))\n", + " \n", + " start = time.perf_counter()\n", + " result = compute(A, B).block_until_ready()\n", + " elapsed = time.perf_counter() - start\n", + " print(f\" Size {size}x{size}: {elapsed:.3f} seconds\")\n", + " multi_results[f\"size_{size}x{size}\"] = elapsed\n", + "\n", + "total_time = time.perf_counter() - total_start\n", + "print(f\"\\nTotal time for all operations: {total_time:.3f} seconds\")\n", + "\n", + "RESULTS[\"benchmarks\"][\"multi_ops\"] = multi_results\n", + "RESULTS[\"benchmarks\"][\"multi_ops_total\"] = total_time" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Save results to JSON file\n", + "output_path = \"benchmark_results_jupyter.json\"\n", + "with open(output_path, 'w') as f:\n", + " json.dump(RESULTS, f, indent=2)\n", + "\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"JUPYTER KERNEL EXECUTION BENCHMARK COMPLETE\")\n", + "print(\"=\" * 60)\n", + "print(f\"\\nResults saved to: {output_path}\")\n", + "print(\"\\nJSON Results:\")\n", + "print(json.dumps(RESULTS, indent=2))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/scripts/benchmark-jupyterbook.md b/scripts/benchmark-jupyterbook.md new file mode 100644 index 00000000..162613c8 --- /dev/null +++ b/scripts/benchmark-jupyterbook.md @@ -0,0 +1,196 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# JAX Performance Benchmark - Jupyter Book Execution + +This file tests JAX performance when executed through Jupyter Book's notebook execution. +Compare results with direct script and nbconvert execution. + +```{code-cell} ipython3 +import time +import platform +import os +import json +from datetime import datetime + +# Initialize results dictionary +RESULTS = { + "pathway": "jupyter_book", + "timestamp": datetime.now().isoformat(), + "system": { + "platform": platform.platform(), + "python": platform.python_version(), + "cpu_count": os.cpu_count() + }, + "benchmarks": {} +} + +print("=" * 60) +print("JUPYTER BOOK EXECUTION BENCHMARK") +print("=" * 60) +print(f"Platform: {platform.platform()}") +print(f"Python: {platform.python_version()}") +print(f"CPU Count: {os.cpu_count()}") +``` + +```{code-cell} ipython3 +# Import JAX and check devices +import jax +import jax.numpy as jnp + +devices = jax.devices() +default_backend = jax.default_backend() +has_gpu = any('cuda' in str(d).lower() or 'gpu' in str(d).lower() for d in devices) + +print(f"JAX devices: {devices}") +print(f"Default backend: {default_backend}") +print(f"GPU Available: {has_gpu}") + +RESULTS["system"]["jax_backend"] = default_backend +RESULTS["system"]["has_gpu"] = has_gpu +RESULTS["system"]["jax_devices"] = str(devices) +``` + +```{code-cell} ipython3 +# Define JIT-compiled function +@jax.jit +def matmul(a, b): + return jnp.dot(a, b) + +print("matmul function defined with @jax.jit") +``` + +```{code-cell} ipython3 +# Benchmark 1: Small matrix (1000x1000) - includes JIT compilation +print("\n" + "=" * 60) +print("BENCHMARK 1: Small Matrix (1000x1000)") +print("=" * 60) + +n = 1000 +key = jax.random.PRNGKey(0) +A = jax.random.normal(key, (n, n)) +B = jax.random.normal(key, (n, n)) + +# Warm-up run (includes compilation) +start = time.perf_counter() +C = matmul(A, B).block_until_ready() +warmup_time = time.perf_counter() - start +print(f"Warm-up (includes JIT compile): {warmup_time:.3f} seconds") + +# Compiled run +start = time.perf_counter() +C = matmul(A, B).block_until_ready() +compiled_time = time.perf_counter() - start +print(f"Compiled execution: {compiled_time:.3f} seconds") + +RESULTS["benchmarks"]["matmul_1000x1000_warmup"] = warmup_time +RESULTS["benchmarks"]["matmul_1000x1000_compiled"] = compiled_time +``` + +```{code-cell} ipython3 +# Benchmark 2: Large matrix (3000x3000) - triggers recompilation +print("\n" + "=" * 60) +print("BENCHMARK 2: Large Matrix (3000x3000)") +print("=" * 60) + +n = 3000 +A = jax.random.normal(key, (n, n)) +B = jax.random.normal(key, (n, n)) + +# Warm-up run (recompilation for new size) +start = time.perf_counter() +C = matmul(A, B).block_until_ready() +warmup_time = time.perf_counter() - start +print(f"Warm-up (recompile for new size): {warmup_time:.3f} seconds") + +# Compiled run +start = time.perf_counter() +C = matmul(A, B).block_until_ready() +compiled_time = time.perf_counter() - start +print(f"Compiled execution: {compiled_time:.3f} seconds") + +RESULTS["benchmarks"]["matmul_3000x3000_warmup"] = warmup_time +RESULTS["benchmarks"]["matmul_3000x3000_compiled"] = compiled_time +``` + +```{code-cell} ipython3 +# Benchmark 3: Element-wise operations (50M elements) +print("\n" + "=" * 60) +print("BENCHMARK 3: Element-wise Operations (50M elements)") +print("=" * 60) + +@jax.jit +def elementwise_ops(x): + return jnp.cos(x**2) + jnp.sin(x) + +x = jax.random.normal(key, (50_000_000,)) + +# Warm-up +start = time.perf_counter() +y = elementwise_ops(x).block_until_ready() +warmup_time = time.perf_counter() - start +print(f"Warm-up (includes JIT compile): {warmup_time:.3f} seconds") + +# Compiled +start = time.perf_counter() +y = elementwise_ops(x).block_until_ready() +compiled_time = time.perf_counter() - start +print(f"Compiled execution: {compiled_time:.3f} seconds") + +RESULTS["benchmarks"]["elementwise_50m_warmup"] = warmup_time +RESULTS["benchmarks"]["elementwise_50m_compiled"] = compiled_time +``` + +```{code-cell} ipython3 +# Benchmark 4: Multiple small operations (simulates lecture cells) +print("\n" + "=" * 60) +print("BENCHMARK 4: Multiple Small Operations (lecture simulation)") +print("=" * 60) + +total_start = time.perf_counter() +multi_results = {} + +# Simulate multiple cell executions with different operations +for i, size in enumerate([100, 500, 1000, 2000, 3000]): + @jax.jit + def compute(a, b): + return jnp.dot(a, b) + jnp.sum(a) + + A = jax.random.normal(key, (size, size)) + B = jax.random.normal(key, (size, size)) + + start = time.perf_counter() + result = compute(A, B).block_until_ready() + elapsed = time.perf_counter() - start + print(f" Size {size}x{size}: {elapsed:.3f} seconds") + multi_results[f"size_{size}x{size}"] = elapsed + +total_time = time.perf_counter() - total_start +print(f"\nTotal time for all operations: {total_time:.3f} seconds") + +RESULTS["benchmarks"]["multi_ops"] = multi_results +RESULTS["benchmarks"]["multi_ops_total"] = total_time +``` + +```{code-cell} ipython3 +# Save results to JSON file +output_path = "benchmark_results_jupyterbook.json" +with open(output_path, 'w') as f: + json.dump(RESULTS, f, indent=2) + +print("\n" + "=" * 60) +print("JUPYTER BOOK EXECUTION BENCHMARK COMPLETE") +print("=" * 60) +print(f"\nResults saved to: {output_path}") +print("\nJSON Results:") +print(json.dumps(RESULTS, indent=2)) +```