From 2e1684b2bc6b5e674621163c0b4af77d8758f1c6 Mon Sep 17 00:00:00 2001 From: mmcky Date: Thu, 27 Nov 2025 14:15:35 +1100 Subject: [PATCH 01/19] Enable RunsOn GPU support for lecture builds - Add scripts/test-jax-install.py to verify JAX/GPU installation - Add .github/runs-on.yml with QuantEcon Ubuntu 24.04 AMI configuration - Update cache.yml to use RunsOn g4dn.2xlarge GPU runner - Update ci.yml to use RunsOn g4dn.2xlarge GPU runner - Update publish.yml to use RunsOn g4dn.2xlarge GPU runner - Install JAX with CUDA 13 support and Numpyro on all workflows - Add nvidia-smi check to verify GPU availability This mirrors the setup used in lecture-python.myst repository. --- .github/runs-on.yml | 6 ++++++ .github/workflows/cache.yml | 12 +++++++++++- .github/workflows/ci.yml | 11 ++++++++++- .github/workflows/publish.yml | 12 +++++++++++- scripts/test-jax-install.py | 21 +++++++++++++++++++++ 5 files changed, 59 insertions(+), 3 deletions(-) create mode 100644 .github/runs-on.yml create mode 100644 scripts/test-jax-install.py diff --git a/.github/runs-on.yml b/.github/runs-on.yml new file mode 100644 index 00000000..e7a18910 --- /dev/null +++ b/.github/runs-on.yml @@ -0,0 +1,6 @@ +images: + quantecon_ubuntu2404: + platform: "linux" + arch: "x64" + ami: "ami-0edec81935264b6d3" + region: "us-west-2" diff --git a/.github/workflows/cache.yml b/.github/workflows/cache.yml index 138ee3fb..c9325914 100644 --- a/.github/workflows/cache.yml +++ b/.github/workflows/cache.yml @@ -6,7 +6,7 @@ on: workflow_dispatch: jobs: cache: - runs-on: ubuntu-latest + runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large" steps: - uses: actions/checkout@v6 - name: Setup Anaconda @@ -18,6 +18,16 @@ jobs: python-version: "3.13" environment-file: environment.yml activate-environment: quantecon + - name: Install JAX and Numpyro + shell: bash -l {0} + run: | + pip install -U "jax[cuda13]" + pip install numpyro + python scripts/test-jax-install.py + - name: Check nvidia drivers + shell: bash -l {0} + run: | + nvidia-smi - name: Build HTML shell: bash -l {0} run: | diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 51755ea9..58f69dcc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,7 +2,7 @@ name: Build Project [using jupyter-book] on: [pull_request] jobs: preview: - runs-on: ubuntu-latest + runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large" steps: - uses: actions/checkout@v6 with: @@ -16,6 +16,15 @@ jobs: python-version: "3.13" environment-file: environment.yml activate-environment: quantecon + - name: Check nvidia Drivers + shell: bash -l {0} + run: nvidia-smi + - name: Install JAX and Numpyro + shell: bash -l {0} + run: | + pip install -U "jax[cuda13]" + pip install numpyro + python scripts/test-jax-install.py - name: Install latex dependencies run: | sudo apt-get -qq update diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 258fbe54..5622ef7a 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -6,7 +6,7 @@ on: jobs: publish: if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags') - runs-on: ubuntu-latest + runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large" steps: - name: Checkout uses: actions/checkout@v6 @@ -21,6 +21,16 @@ jobs: python-version: "3.13" environment-file: environment.yml activate-environment: quantecon + - name: Install JAX and Numpyro + shell: bash -l {0} + run: | + pip install -U "jax[cuda13]" + pip install numpyro + python scripts/test-jax-install.py + - name: Check nvidia drivers + shell: bash -l {0} + run: | + nvidia-smi - name: Install latex dependencies run: | sudo apt-get -qq update diff --git a/scripts/test-jax-install.py b/scripts/test-jax-install.py new file mode 100644 index 00000000..c2be1d3d --- /dev/null +++ b/scripts/test-jax-install.py @@ -0,0 +1,21 @@ +import jax +import jax.numpy as jnp + +devices = jax.devices() +print(f"The available devices are: {devices}") + +@jax.jit +def matrix_multiply(a, b): + return jnp.dot(a, b) + +# Example usage: +key = jax.random.PRNGKey(0) +x = jax.random.normal(key, (1000, 1000)) +y = jax.random.normal(key, (1000, 1000)) +z = matrix_multiply(x, y) + +# Now the function is JIT compiled and will likely run on GPU (if available) +print(z) + +devices = jax.devices() +print(f"The available devices are: {devices}") From 849cf16492f14ddc0ad4be5e7a08e4b4ff1d95d9 Mon Sep 17 00:00:00 2001 From: mmcky Date: Thu, 27 Nov 2025 15:33:50 +1100 Subject: [PATCH 02/19] DOC: Update JAX lectures with GPU admonition and narrative - Add standard GPU admonition to jax_intro.md and numpy_vs_numba_vs_jax.md - Update introduction in jax_intro.md to reflect GPU access - Update conditional GPU language to reflect lectures now run on GPU - Following QuantEcon style guide for JAX lectures --- lectures/jax_intro.md | 35 +++++++++++++------------------ lectures/numpy_vs_numba_vs_jax.md | 14 ++++++++++++- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 0d890d8f..7d4706e4 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -21,20 +21,23 @@ In addition to what's in Anaconda, this lecture will need the following librarie !pip install jax quantecon ``` -This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax). +```{admonition} GPU +:class: warning + +This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and target JAX for GPU programming. -Here we are focused on using JAX on the CPU, rather than on accelerators such as -GPUs or TPUs. +Free GPUs are available on Google Colab. +To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU. -This means we will only see a small amount of the possible benefits from using JAX. +Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support. +If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]` +``` -However, JAX seamlessly handles transitions across different hardware platforms. +This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax). -As a result, if you run this code on a machine with a GPU and a GPU-aware -version of JAX installed, your code will be automatically accelerated and you -will receive the full benefits. +JAX provides a NumPy-like interface that can leverage GPU acceleration for high-performance numerical computing. -For a discussion of JAX on GPUs, see [our JAX lecture series](https://jax.quantecon.org/intro.html). +For a more comprehensive discussion of JAX, see [our JAX lecture series](https://jax.quantecon.org/intro.html). ## JAX as a NumPy Replacement @@ -523,16 +526,9 @@ with qe.Timer(): jax.block_until_ready(y); ``` -If you are running this on a GPU the code will run much faster than its NumPy -equivalent, which ran on the CPU. - -Even if you are running on a machine with many CPUs, the second JAX run should -be substantially faster with JAX. - -Also, typically, the second run is faster than the first. +On a GPU, this code runs much faster than its NumPy equivalent. -(This might not be noticable on the CPU but it should definitely be noticable on -the GPU.) +Also, typically, the second run is faster than the first due to JIT compilation. This is because even built in functions like `jnp.cos` are JIT-compiled --- and the first run includes compile time. @@ -634,8 +630,7 @@ with qe.Timer(): jax.block_until_ready(y); ``` -The outcome is similar to the `cos` example --- JAX is faster, especially if you -use a GPU and especially on the second run. +The outcome is similar to the `cos` example --- JAX is faster, especially on the second run after JIT compilation. Moreover, with JAX, we have another trick up our sleeve: diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index 883f2d14..99393e2e 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -48,6 +48,18 @@ tags: [hide-output] !pip install quantecon jax ``` +```{admonition} GPU +:class: warning + +This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and target JAX for GPU programming. + +Free GPUs are available on Google Colab. +To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU. + +Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support. +If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]` +``` + We will use the following imports. ```{code-cell} ipython3 @@ -317,7 +329,7 @@ with qe.Timer(precision=8): z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready() ``` -Once compiled, JAX will be significantly faster than NumPy, especially if you are using a GPU. +Once compiled, JAX is significantly faster than NumPy due to GPU acceleration. The compilation overhead is a one-time cost that pays off when the function is called repeatedly. From 0b6a5673d287fada8d7b3b9ef4d5f42506044c68 Mon Sep 17 00:00:00 2001 From: mmcky Date: Thu, 27 Nov 2025 15:50:42 +1100 Subject: [PATCH 03/19] DEBUG: Add hardware benchmark script to diagnose performance - Add benchmark-hardware.py with CPU, NumPy, Numba, and JAX benchmarks - Works on both GPU (RunsOn) and CPU-only (GitHub Actions) runners - Include warm-up vs compiled timing to isolate JIT overhead - Add system info collection (CPU model, frequency, GPU detection) --- .github/workflows/ci.yml | 3 + scripts/benchmark-hardware.py | 264 ++++++++++++++++++++++++++++++++++ 2 files changed, 267 insertions(+) create mode 100644 scripts/benchmark-hardware.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 58f69dcc..a58e95d3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,6 +25,9 @@ jobs: pip install -U "jax[cuda13]" pip install numpyro python scripts/test-jax-install.py + - name: Run Hardware Benchmarks + shell: bash -l {0} + run: python scripts/benchmark-hardware.py - name: Install latex dependencies run: | sudo apt-get -qq update diff --git a/scripts/benchmark-hardware.py b/scripts/benchmark-hardware.py new file mode 100644 index 00000000..45a1604c --- /dev/null +++ b/scripts/benchmark-hardware.py @@ -0,0 +1,264 @@ +""" +Hardware benchmark script for CI runners. +Compares CPU and GPU performance to diagnose slowdowns. +Works on both CPU-only (GitHub Actions) and GPU (RunsOn) runners. +""" +import time +import platform +import os + +def get_cpu_info(): + """Get CPU information.""" + print("=" * 60) + print("SYSTEM INFORMATION") + print("=" * 60) + print(f"Platform: {platform.platform()}") + print(f"Processor: {platform.processor()}") + print(f"Python: {platform.python_version()}") + + # Try to get CPU frequency + try: + with open('/proc/cpuinfo', 'r') as f: + for line in f: + if 'model name' in line: + print(f"CPU Model: {line.split(':')[1].strip()}") + break + except: + pass + + # Try to get CPU frequency + try: + with open('/proc/cpuinfo', 'r') as f: + for line in f: + if 'cpu MHz' in line: + print(f"CPU MHz: {line.split(':')[1].strip()}") + break + except: + pass + + # CPU count + print(f"CPU Count: {os.cpu_count()}") + + # Check for GPU + try: + import subprocess + result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv,noheader'], + capture_output=True, text=True, timeout=5) + if result.returncode == 0: + print(f"GPU: {result.stdout.strip()}") + else: + print("GPU: None detected") + except: + print("GPU: None detected (nvidia-smi not available)") + + print() + +def benchmark_cpu_pure_python(): + """Pure Python CPU benchmark.""" + print("=" * 60) + print("CPU BENCHMARK: Pure Python") + print("=" * 60) + + # Integer computation + start = time.perf_counter() + total = sum(i * i for i in range(10_000_000)) + elapsed = time.perf_counter() - start + print(f"Integer sum (10M iterations): {elapsed:.3f} seconds") + + # Float computation + start = time.perf_counter() + total = 0.0 + for i in range(1_000_000): + total += (i * 0.1) ** 0.5 + elapsed = time.perf_counter() - start + print(f"Float sqrt (1M iterations): {elapsed:.3f} seconds") + print() + +def benchmark_cpu_numpy(): + """NumPy CPU benchmark.""" + import numpy as np + + print("=" * 60) + print("CPU BENCHMARK: NumPy") + print("=" * 60) + + # Matrix multiplication + n = 3000 + A = np.random.randn(n, n) + B = np.random.randn(n, n) + + start = time.perf_counter() + C = A @ B + elapsed = time.perf_counter() - start + print(f"Matrix multiply ({n}x{n}): {elapsed:.3f} seconds") + + # Element-wise operations + x = np.random.randn(50_000_000) + + start = time.perf_counter() + y = np.cos(x**2) + np.sin(x) + elapsed = time.perf_counter() - start + print(f"Element-wise ops (50M elements): {elapsed:.3f} seconds") + print() + +def benchmark_gpu_jax(): + """JAX benchmark (GPU if available, otherwise CPU).""" + try: + import jax + import jax.numpy as jnp + + devices = jax.devices() + default_backend = jax.default_backend() + + # Check if GPU is available + has_gpu = any('cuda' in str(d).lower() or 'gpu' in str(d).lower() for d in devices) + + print("=" * 60) + if has_gpu: + print("JAX BENCHMARK: GPU") + else: + print("JAX BENCHMARK: CPU (no GPU detected)") + print("=" * 60) + + print(f"JAX devices: {devices}") + print(f"Default backend: {default_backend}") + print(f"GPU Available: {has_gpu}") + print() + + # Warm-up JIT compilation + print("Warming up JIT compilation...") + n = 1000 + key = jax.random.PRNGKey(0) + A = jax.random.normal(key, (n, n)) + B = jax.random.normal(key, (n, n)) + + @jax.jit + def matmul(a, b): + return jnp.dot(a, b) + + # Warm-up run (includes compilation) + start = time.perf_counter() + C = matmul(A, B).block_until_ready() + warmup_time = time.perf_counter() - start + print(f"Warm-up (includes JIT compile, {n}x{n}): {warmup_time:.3f} seconds") + + # Actual benchmark (compiled) + start = time.perf_counter() + C = matmul(A, B).block_until_ready() + elapsed = time.perf_counter() - start + print(f"Matrix multiply compiled ({n}x{n}): {elapsed:.3f} seconds") + + # Larger matrix + n = 3000 + A = jax.random.normal(key, (n, n)) + B = jax.random.normal(key, (n, n)) + + # Warm-up for new size + start = time.perf_counter() + C = matmul(A, B).block_until_ready() + warmup_time = time.perf_counter() - start + print(f"Warm-up (recompile for {n}x{n}): {warmup_time:.3f} seconds") + + # Benchmark compiled + start = time.perf_counter() + C = matmul(A, B).block_until_ready() + elapsed = time.perf_counter() - start + print(f"Matrix multiply compiled ({n}x{n}): {elapsed:.3f} seconds") + + # Element-wise GPU benchmark + x = jax.random.normal(key, (50_000_000,)) + + @jax.jit + def elementwise_ops(x): + return jnp.cos(x**2) + jnp.sin(x) + + # Warm-up + start = time.perf_counter() + y = elementwise_ops(x).block_until_ready() + warmup_time = time.perf_counter() - start + print(f"Element-wise warm-up (50M): {warmup_time:.3f} seconds") + + # Compiled + start = time.perf_counter() + y = elementwise_ops(x).block_until_ready() + elapsed = time.perf_counter() - start + print(f"Element-wise compiled (50M): {elapsed:.3f} seconds") + + print() + + except ImportError as e: + print(f"JAX not available: {e}") + except Exception as e: + print(f"JAX benchmark failed: {e}") + +def benchmark_numba(): + """Numba CPU benchmark.""" + try: + import numba + import numpy as np + + print("=" * 60) + print("CPU BENCHMARK: Numba") + print("=" * 60) + + @numba.jit(nopython=True) + def numba_sum(n): + total = 0 + for i in range(n): + total += i * i + return total + + # Warm-up (compilation) + start = time.perf_counter() + result = numba_sum(10_000_000) + warmup_time = time.perf_counter() - start + print(f"Integer sum warm-up (includes compile): {warmup_time:.3f} seconds") + + # Compiled run + start = time.perf_counter() + result = numba_sum(10_000_000) + elapsed = time.perf_counter() - start + print(f"Integer sum compiled (10M): {elapsed:.3f} seconds") + + @numba.jit(nopython=True, parallel=True) + def numba_parallel_sum(arr): + total = 0.0 + for i in numba.prange(len(arr)): + total += arr[i] ** 2 + return total + + arr = np.random.randn(50_000_000) + + # Warm-up + start = time.perf_counter() + result = numba_parallel_sum(arr) + warmup_time = time.perf_counter() - start + print(f"Parallel sum warm-up (50M): {warmup_time:.3f} seconds") + + # Compiled + start = time.perf_counter() + result = numba_parallel_sum(arr) + elapsed = time.perf_counter() - start + print(f"Parallel sum compiled (50M): {elapsed:.3f} seconds") + + print() + + except ImportError as e: + print(f"Numba not available: {e}") + except Exception as e: + print(f"Numba benchmark failed: {e}") + +if __name__ == "__main__": + print("\n" + "=" * 60) + print("HARDWARE BENCHMARK FOR CI RUNNER") + print("=" * 60 + "\n") + + get_cpu_info() + benchmark_cpu_pure_python() + benchmark_cpu_numpy() + benchmark_numba() + benchmark_gpu_jax() + + print("=" * 60) + print("BENCHMARK COMPLETE") + print("=" * 60) From 6d1b9c351f82c3fe7091b684905588d10a53314c Mon Sep 17 00:00:00 2001 From: mmcky Date: Thu, 27 Nov 2025 16:21:31 +1100 Subject: [PATCH 04/19] Add multi-pathway benchmark tests (bare metal, Jupyter, jupyter-book) --- .github/workflows/ci.yml | 30 +++++- scripts/benchmark-jupyter.ipynb | 0 scripts/benchmark-jupyterbook.md | 156 +++++++++++++++++++++++++++++++ 3 files changed, 184 insertions(+), 2 deletions(-) create mode 100644 scripts/benchmark-jupyter.ipynb create mode 100644 scripts/benchmark-jupyterbook.md diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a58e95d3..da7f4f8a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,9 +25,35 @@ jobs: pip install -U "jax[cuda13]" pip install numpyro python scripts/test-jax-install.py - - name: Run Hardware Benchmarks + # === Benchmark Tests (Bare Metal, Jupyter, Jupyter-Book) === + - name: Run Hardware Benchmarks (Bare Metal) shell: bash -l {0} - run: python scripts/benchmark-hardware.py + run: | + echo "=== Bare Metal Python Script Execution ===" + python scripts/benchmark-hardware.py + - name: Run Jupyter Notebook Benchmark (via nbconvert) + shell: bash -l {0} + run: | + echo "=== Jupyter Kernel Execution ===" + jupyter nbconvert --to notebook --execute scripts/benchmark-jupyter.ipynb --output benchmark-jupyter-executed.ipynb + echo "Notebook executed successfully" + - name: Run Jupyter-Book Benchmark + shell: bash -l {0} + run: | + echo "=== Jupyter-Book Execution ===" + # Build just the benchmark file using jupyter-book + mkdir -p benchmark_test + cp scripts/benchmark-jupyterbook.md benchmark_test/ + # Create minimal _config.yml + echo "title: Benchmark Test" > benchmark_test/_config.yml + echo "execute:" >> benchmark_test/_config.yml + echo " execute_notebooks: force" >> benchmark_test/_config.yml + # Create minimal _toc.yml + echo "format: jb-book" > benchmark_test/_toc.yml + echo "root: benchmark-jupyterbook" >> benchmark_test/_toc.yml + # Build + jb build benchmark_test --path-output benchmark_build/ + echo "Jupyter-Book build completed successfully" - name: Install latex dependencies run: | sudo apt-get -qq update diff --git a/scripts/benchmark-jupyter.ipynb b/scripts/benchmark-jupyter.ipynb new file mode 100644 index 00000000..e69de29b diff --git a/scripts/benchmark-jupyterbook.md b/scripts/benchmark-jupyterbook.md new file mode 100644 index 00000000..23434e9e --- /dev/null +++ b/scripts/benchmark-jupyterbook.md @@ -0,0 +1,156 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# JAX Performance Benchmark - Jupyter Book Execution + +This file tests JAX performance when executed through Jupyter Book's notebook execution. +Compare results with direct script and nbconvert execution. + +```{code-cell} ipython3 +import time +import platform +import os + +print("=" * 60) +print("JUPYTER BOOK EXECUTION BENCHMARK") +print("=" * 60) +print(f"Platform: {platform.platform()}") +print(f"Python: {platform.python_version()}") +print(f"CPU Count: {os.cpu_count()}") +``` + +```{code-cell} ipython3 +# Import JAX and check devices +import jax +import jax.numpy as jnp + +devices = jax.devices() +default_backend = jax.default_backend() +has_gpu = any('cuda' in str(d).lower() or 'gpu' in str(d).lower() for d in devices) + +print(f"JAX devices: {devices}") +print(f"Default backend: {default_backend}") +print(f"GPU Available: {has_gpu}") +``` + +```{code-cell} ipython3 +# Define JIT-compiled function +@jax.jit +def matmul(a, b): + return jnp.dot(a, b) + +print("matmul function defined with @jax.jit") +``` + +```{code-cell} ipython3 +# Benchmark 1: Small matrix (1000x1000) - includes JIT compilation +print("\n" + "=" * 60) +print("BENCHMARK 1: Small Matrix (1000x1000)") +print("=" * 60) + +n = 1000 +key = jax.random.PRNGKey(0) +A = jax.random.normal(key, (n, n)) +B = jax.random.normal(key, (n, n)) + +# Warm-up run (includes compilation) +start = time.perf_counter() +C = matmul(A, B).block_until_ready() +warmup_time = time.perf_counter() - start +print(f"Warm-up (includes JIT compile): {warmup_time:.3f} seconds") + +# Compiled run +start = time.perf_counter() +C = matmul(A, B).block_until_ready() +compiled_time = time.perf_counter() - start +print(f"Compiled execution: {compiled_time:.3f} seconds") +``` + +```{code-cell} ipython3 +# Benchmark 2: Large matrix (3000x3000) - triggers recompilation +print("\n" + "=" * 60) +print("BENCHMARK 2: Large Matrix (3000x3000)") +print("=" * 60) + +n = 3000 +A = jax.random.normal(key, (n, n)) +B = jax.random.normal(key, (n, n)) + +# Warm-up run (recompilation for new size) +start = time.perf_counter() +C = matmul(A, B).block_until_ready() +warmup_time = time.perf_counter() - start +print(f"Warm-up (recompile for new size): {warmup_time:.3f} seconds") + +# Compiled run +start = time.perf_counter() +C = matmul(A, B).block_until_ready() +compiled_time = time.perf_counter() - start +print(f"Compiled execution: {compiled_time:.3f} seconds") +``` + +```{code-cell} ipython3 +# Benchmark 3: Element-wise operations (50M elements) +print("\n" + "=" * 60) +print("BENCHMARK 3: Element-wise Operations (50M elements)") +print("=" * 60) + +@jax.jit +def elementwise_ops(x): + return jnp.cos(x**2) + jnp.sin(x) + +x = jax.random.normal(key, (50_000_000,)) + +# Warm-up +start = time.perf_counter() +y = elementwise_ops(x).block_until_ready() +warmup_time = time.perf_counter() - start +print(f"Warm-up (includes JIT compile): {warmup_time:.3f} seconds") + +# Compiled +start = time.perf_counter() +y = elementwise_ops(x).block_until_ready() +compiled_time = time.perf_counter() - start +print(f"Compiled execution: {compiled_time:.3f} seconds") +``` + +```{code-cell} ipython3 +# Benchmark 4: Multiple small operations (simulates lecture cells) +print("\n" + "=" * 60) +print("BENCHMARK 4: Multiple Small Operations (lecture simulation)") +print("=" * 60) + +total_start = time.perf_counter() + +# Simulate multiple cell executions with different operations +for i, size in enumerate([100, 500, 1000, 2000, 3000]): + @jax.jit + def compute(a, b): + return jnp.dot(a, b) + jnp.sum(a) + + A = jax.random.normal(key, (size, size)) + B = jax.random.normal(key, (size, size)) + + start = time.perf_counter() + result = compute(A, B).block_until_ready() + elapsed = time.perf_counter() - start + print(f" Size {size}x{size}: {elapsed:.3f} seconds") + +total_time = time.perf_counter() - total_start +print(f"\nTotal time for all operations: {total_time:.3f} seconds") +``` + +```{code-cell} ipython3 +print("\n" + "=" * 60) +print("JUPYTER BOOK EXECUTION BENCHMARK COMPLETE") +print("=" * 60) +``` From d129f79e75c2d76b7921087af28acaa7918b9fb7 Mon Sep 17 00:00:00 2001 From: mmcky Date: Thu, 27 Nov 2025 16:32:05 +1100 Subject: [PATCH 05/19] Fix: Add content to benchmark-jupyter.ipynb (was empty) --- scripts/benchmark-jupyter.ipynb | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 scripts/benchmark-jupyter.ipynb diff --git a/scripts/benchmark-jupyter.ipynb b/scripts/benchmark-jupyter.ipynb deleted file mode 100644 index e69de29b..00000000 From 2da4e0c34e11de49f13856964ec8bb6d66da53d3 Mon Sep 17 00:00:00 2001 From: mmcky Date: Thu, 27 Nov 2025 16:33:29 +1100 Subject: [PATCH 06/19] Fix: Add benchmark content to benchmark-jupyter.ipynb --- scripts/benchmark-jupyter.ipynb | 207 ++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 scripts/benchmark-jupyter.ipynb diff --git a/scripts/benchmark-jupyter.ipynb b/scripts/benchmark-jupyter.ipynb new file mode 100644 index 00000000..e095f79b --- /dev/null +++ b/scripts/benchmark-jupyter.ipynb @@ -0,0 +1,207 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# JAX Performance Benchmark - Jupyter Kernel Execution\n", + "\n", + "This notebook tests JAX performance when executed through a Jupyter kernel.\n", + "Compare results with direct script and jupyter-book execution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "import platform\n", + "import os\n", + "\n", + "print(\"=\" * 60)\n", + "print(\"JUPYTER KERNEL EXECUTION BENCHMARK\")\n", + "print(\"=\" * 60)\n", + "print(f\"Platform: {platform.platform()}\")\n", + "print(f\"Python: {platform.python_version()}\")\n", + "print(f\"CPU Count: {os.cpu_count()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import JAX and check devices\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "devices = jax.devices()\n", + "default_backend = jax.default_backend()\n", + "has_gpu = any('cuda' in str(d).lower() or 'gpu' in str(d).lower() for d in devices)\n", + "\n", + "print(f\"JAX devices: {devices}\")\n", + "print(f\"Default backend: {default_backend}\")\n", + "print(f\"GPU Available: {has_gpu}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define JIT-compiled function\n", + "@jax.jit\n", + "def matmul(a, b):\n", + " return jnp.dot(a, b)\n", + "\n", + "print(\"matmul function defined with @jax.jit\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Benchmark 1: Small matrix (1000x1000) - includes JIT compilation\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"BENCHMARK 1: Small Matrix (1000x1000)\")\n", + "print(\"=\" * 60)\n", + "\n", + "n = 1000\n", + "key = jax.random.PRNGKey(0)\n", + "A = jax.random.normal(key, (n, n))\n", + "B = jax.random.normal(key, (n, n))\n", + "\n", + "# Warm-up run (includes compilation)\n", + "start = time.perf_counter()\n", + "C = matmul(A, B).block_until_ready()\n", + "warmup_time = time.perf_counter() - start\n", + "print(f\"Warm-up (includes JIT compile): {warmup_time:.3f} seconds\")\n", + "\n", + "# Compiled run\n", + "start = time.perf_counter()\n", + "C = matmul(A, B).block_until_ready()\n", + "compiled_time = time.perf_counter() - start\n", + "print(f\"Compiled execution: {compiled_time:.3f} seconds\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Benchmark 2: Large matrix (3000x3000) - triggers recompilation\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"BENCHMARK 2: Large Matrix (3000x3000)\")\n", + "print(\"=\" * 60)\n", + "\n", + "n = 3000\n", + "A = jax.random.normal(key, (n, n))\n", + "B = jax.random.normal(key, (n, n))\n", + "\n", + "# Warm-up run (recompilation for new size)\n", + "start = time.perf_counter()\n", + "C = matmul(A, B).block_until_ready()\n", + "warmup_time = time.perf_counter() - start\n", + "print(f\"Warm-up (recompile for new size): {warmup_time:.3f} seconds\")\n", + "\n", + "# Compiled run\n", + "start = time.perf_counter()\n", + "C = matmul(A, B).block_until_ready()\n", + "compiled_time = time.perf_counter() - start\n", + "print(f\"Compiled execution: {compiled_time:.3f} seconds\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Benchmark 3: Element-wise operations (50M elements)\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"BENCHMARK 3: Element-wise Operations (50M elements)\")\n", + "print(\"=\" * 60)\n", + "\n", + "@jax.jit\n", + "def elementwise_ops(x):\n", + " return jnp.cos(x**2) + jnp.sin(x)\n", + "\n", + "x = jax.random.normal(key, (50_000_000,))\n", + "\n", + "# Warm-up\n", + "start = time.perf_counter()\n", + "y = elementwise_ops(x).block_until_ready()\n", + "warmup_time = time.perf_counter() - start\n", + "print(f\"Warm-up (includes JIT compile): {warmup_time:.3f} seconds\")\n", + "\n", + "# Compiled\n", + "start = time.perf_counter()\n", + "y = elementwise_ops(x).block_until_ready()\n", + "compiled_time = time.perf_counter() - start\n", + "print(f\"Compiled execution: {compiled_time:.3f} seconds\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Benchmark 4: Multiple small operations (simulates lecture cells)\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"BENCHMARK 4: Multiple Small Operations (lecture simulation)\")\n", + "print(\"=\" * 60)\n", + "\n", + "total_start = time.perf_counter()\n", + "\n", + "# Simulate multiple cell executions with different operations\n", + "for i, size in enumerate([100, 500, 1000, 2000, 3000]):\n", + " @jax.jit\n", + " def compute(a, b):\n", + " return jnp.dot(a, b) + jnp.sum(a)\n", + " \n", + " A = jax.random.normal(key, (size, size))\n", + " B = jax.random.normal(key, (size, size))\n", + " \n", + " start = time.perf_counter()\n", + " result = compute(A, B).block_until_ready()\n", + " elapsed = time.perf_counter() - start\n", + " print(f\" Size {size}x{size}: {elapsed:.3f} seconds\")\n", + "\n", + "total_time = time.perf_counter() - total_start\n", + "print(f\"\\nTotal time for all operations: {total_time:.3f} seconds\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"JUPYTER KERNEL EXECUTION BENCHMARK COMPLETE\")\n", + "print(\"=\" * 60)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 2bda11475479d04d16554ee4bea152a4dc89c97f Mon Sep 17 00:00:00 2001 From: mmcky Date: Thu, 27 Nov 2025 16:53:26 +1100 Subject: [PATCH 07/19] Add JSON output to benchmarks and upload as artifacts - Update benchmark-hardware.py to save results to JSON - Update benchmark-jupyter.ipynb to save results to JSON - Update benchmark-jupyterbook.md to save results to JSON - Add CI step to collect and display benchmark results - Add CI step to upload benchmark results as artifact --- .github/workflows/ci.yml | 27 ++++++++++++- scripts/benchmark-hardware.py | 65 +++++++++++++++++++++++++++++++- scripts/benchmark-jupyter.ipynb | 58 ++++++++++++++++++++++++---- scripts/benchmark-jupyterbook.md | 44 +++++++++++++++++++++ 4 files changed, 184 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index da7f4f8a..07f58b2b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,12 +31,17 @@ jobs: run: | echo "=== Bare Metal Python Script Execution ===" python scripts/benchmark-hardware.py + mkdir -p benchmark_results + mv benchmark_results_baremetal.json benchmark_results/ - name: Run Jupyter Notebook Benchmark (via nbconvert) shell: bash -l {0} run: | echo "=== Jupyter Kernel Execution ===" - jupyter nbconvert --to notebook --execute scripts/benchmark-jupyter.ipynb --output benchmark-jupyter-executed.ipynb + cd scripts + jupyter nbconvert --to notebook --execute benchmark-jupyter.ipynb --output benchmark-jupyter-executed.ipynb echo "Notebook executed successfully" + cd .. + mv scripts/benchmark_results_jupyter.json benchmark_results/ - name: Run Jupyter-Book Benchmark shell: bash -l {0} run: | @@ -54,6 +59,26 @@ jobs: # Build jb build benchmark_test --path-output benchmark_build/ echo "Jupyter-Book build completed successfully" + # Move JSON results if generated + if [ -f benchmark_test/benchmark_results_jupyterbook.json ]; then + mv benchmark_test/benchmark_results_jupyterbook.json benchmark_results/ + elif [ -f benchmark_results_jupyterbook.json ]; then + mv benchmark_results_jupyterbook.json benchmark_results/ + fi + - name: Collect and Upload Benchmark Results + uses: actions/upload-artifact@v5 + with: + name: benchmark-results + path: benchmark_results/ + - name: Display Benchmark Results + shell: bash -l {0} + run: | + echo "=== Benchmark Results Summary ===" + for f in benchmark_results/*.json; do + echo "--- $f ---" + cat "$f" + echo "" + done - name: Install latex dependencies run: | sudo apt-get -qq update diff --git a/scripts/benchmark-hardware.py b/scripts/benchmark-hardware.py index 45a1604c..7eb13bb3 100644 --- a/scripts/benchmark-hardware.py +++ b/scripts/benchmark-hardware.py @@ -1,11 +1,21 @@ """ -Hardware benchmark script for CI runners. +"""Hardware benchmark script for CI runners. Compares CPU and GPU performance to diagnose slowdowns. Works on both CPU-only (GitHub Actions) and GPU (RunsOn) runners. """ import time import platform import os +import json +from datetime import datetime + +# Global results dictionary for JSON output +RESULTS = { + "timestamp": datetime.now().isoformat(), + "execution_method": "bare_metal", + "system": {}, + "benchmarks": {} +} def get_cpu_info(): """Get CPU information.""" @@ -16,6 +26,10 @@ def get_cpu_info(): print(f"Processor: {platform.processor()}") print(f"Python: {platform.python_version()}") + RESULTS["system"]["platform"] = platform.platform() + RESULTS["system"]["processor"] = platform.processor() + RESULTS["system"]["python_version"] = platform.python_version() + # Try to get CPU frequency try: with open('/proc/cpuinfo', 'r') as f: @@ -38,6 +52,7 @@ def get_cpu_info(): # CPU count print(f"CPU Count: {os.cpu_count()}") + RESULTS["system"]["cpu_count"] = os.cpu_count() # Check for GPU try: @@ -45,11 +60,15 @@ def get_cpu_info(): result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv,noheader'], capture_output=True, text=True, timeout=5) if result.returncode == 0: - print(f"GPU: {result.stdout.strip()}") + gpu_info = result.stdout.strip() + print(f"GPU: {gpu_info}") + RESULTS["system"]["gpu"] = gpu_info else: print("GPU: None detected") + RESULTS["system"]["gpu"] = None except: print("GPU: None detected (nvidia-smi not available)") + RESULTS["system"]["gpu"] = None print() @@ -59,11 +78,14 @@ def benchmark_cpu_pure_python(): print("CPU BENCHMARK: Pure Python") print("=" * 60) + results = {} + # Integer computation start = time.perf_counter() total = sum(i * i for i in range(10_000_000)) elapsed = time.perf_counter() - start print(f"Integer sum (10M iterations): {elapsed:.3f} seconds") + results["integer_sum_10M"] = elapsed # Float computation start = time.perf_counter() @@ -72,6 +94,9 @@ def benchmark_cpu_pure_python(): total += (i * 0.1) ** 0.5 elapsed = time.perf_counter() - start print(f"Float sqrt (1M iterations): {elapsed:.3f} seconds") + results["float_sqrt_1M"] = elapsed + + RESULTS["benchmarks"]["pure_python"] = results print() def benchmark_cpu_numpy(): @@ -82,6 +107,8 @@ def benchmark_cpu_numpy(): print("CPU BENCHMARK: NumPy") print("=" * 60) + results = {} + # Matrix multiplication n = 3000 A = np.random.randn(n, n) @@ -91,6 +118,7 @@ def benchmark_cpu_numpy(): C = A @ B elapsed = time.perf_counter() - start print(f"Matrix multiply ({n}x{n}): {elapsed:.3f} seconds") + results["matmul_3000x3000"] = elapsed # Element-wise operations x = np.random.randn(50_000_000) @@ -99,6 +127,9 @@ def benchmark_cpu_numpy(): y = np.cos(x**2) + np.sin(x) elapsed = time.perf_counter() - start print(f"Element-wise ops (50M elements): {elapsed:.3f} seconds") + results["elementwise_50M"] = elapsed + + RESULTS["benchmarks"]["numpy"] = results print() def benchmark_gpu_jax(): @@ -107,12 +138,18 @@ def benchmark_gpu_jax(): import jax import jax.numpy as jnp + results = {} + devices = jax.devices() default_backend = jax.default_backend() # Check if GPU is available has_gpu = any('cuda' in str(d).lower() or 'gpu' in str(d).lower() for d in devices) + results["has_gpu"] = has_gpu + results["default_backend"] = default_backend + results["devices"] = [str(d) for d in devices] + print("=" * 60) if has_gpu: print("JAX BENCHMARK: GPU") @@ -141,12 +178,14 @@ def matmul(a, b): C = matmul(A, B).block_until_ready() warmup_time = time.perf_counter() - start print(f"Warm-up (includes JIT compile, {n}x{n}): {warmup_time:.3f} seconds") + results["matmul_1000x1000_warmup"] = warmup_time # Actual benchmark (compiled) start = time.perf_counter() C = matmul(A, B).block_until_ready() elapsed = time.perf_counter() - start print(f"Matrix multiply compiled ({n}x{n}): {elapsed:.3f} seconds") + results["matmul_1000x1000_compiled"] = elapsed # Larger matrix n = 3000 @@ -158,12 +197,14 @@ def matmul(a, b): C = matmul(A, B).block_until_ready() warmup_time = time.perf_counter() - start print(f"Warm-up (recompile for {n}x{n}): {warmup_time:.3f} seconds") + results["matmul_3000x3000_warmup"] = warmup_time # Benchmark compiled start = time.perf_counter() C = matmul(A, B).block_until_ready() elapsed = time.perf_counter() - start print(f"Matrix multiply compiled ({n}x{n}): {elapsed:.3f} seconds") + results["matmul_3000x3000_compiled"] = elapsed) # Element-wise GPU benchmark x = jax.random.normal(key, (50_000_000,)) @@ -177,19 +218,24 @@ def elementwise_ops(x): y = elementwise_ops(x).block_until_ready() warmup_time = time.perf_counter() - start print(f"Element-wise warm-up (50M): {warmup_time:.3f} seconds") + results["elementwise_50M_warmup"] = warmup_time # Compiled start = time.perf_counter() y = elementwise_ops(x).block_until_ready() elapsed = time.perf_counter() - start print(f"Element-wise compiled (50M): {elapsed:.3f} seconds") + results["elementwise_50M_compiled"] = elapsed + RESULTS["benchmarks"]["jax"] = results print() except ImportError as e: print(f"JAX not available: {e}") + RESULTS["benchmarks"]["jax"] = {"error": str(e)} except Exception as e: print(f"JAX benchmark failed: {e}") + RESULTS["benchmarks"]["jax"] = {"error": str(e)}) def benchmark_numba(): """Numba CPU benchmark.""" @@ -197,6 +243,8 @@ def benchmark_numba(): import numba import numpy as np + results = {} + print("=" * 60) print("CPU BENCHMARK: Numba") print("=" * 60) @@ -213,12 +261,14 @@ def numba_sum(n): result = numba_sum(10_000_000) warmup_time = time.perf_counter() - start print(f"Integer sum warm-up (includes compile): {warmup_time:.3f} seconds") + results["integer_sum_10M_warmup"] = warmup_time # Compiled run start = time.perf_counter() result = numba_sum(10_000_000) elapsed = time.perf_counter() - start print(f"Integer sum compiled (10M): {elapsed:.3f} seconds") + results["integer_sum_10M_compiled"] = elapsed) @numba.jit(nopython=True, parallel=True) def numba_parallel_sum(arr): @@ -234,19 +284,24 @@ def numba_parallel_sum(arr): result = numba_parallel_sum(arr) warmup_time = time.perf_counter() - start print(f"Parallel sum warm-up (50M): {warmup_time:.3f} seconds") + results["parallel_sum_50M_warmup"] = warmup_time # Compiled start = time.perf_counter() result = numba_parallel_sum(arr) elapsed = time.perf_counter() - start print(f"Parallel sum compiled (50M): {elapsed:.3f} seconds") + results["parallel_sum_50M_compiled"] = elapsed + RESULTS["benchmarks"]["numba"] = results print() except ImportError as e: print(f"Numba not available: {e}") + RESULTS["benchmarks"]["numba"] = {"error": str(e)} except Exception as e: print(f"Numba benchmark failed: {e}") + RESULTS["benchmarks"]["numba"] = {"error": str(e)}) if __name__ == "__main__": print("\n" + "=" * 60) @@ -262,3 +317,9 @@ def numba_parallel_sum(arr): print("=" * 60) print("BENCHMARK COMPLETE") print("=" * 60) + + # Save results to JSON + output_file = "benchmark_results_baremetal.json" + with open(output_file, 'w') as f: + json.dump(RESULTS, f, indent=2) + print(f"\nResults saved to {output_file}") diff --git a/scripts/benchmark-jupyter.ipynb b/scripts/benchmark-jupyter.ipynb index e095f79b..3804af46 100644 --- a/scripts/benchmark-jupyter.ipynb +++ b/scripts/benchmark-jupyter.ipynb @@ -19,13 +19,27 @@ "import time\n", "import platform\n", "import os\n", + "import json\n", + "from datetime import datetime\n", + "\n", + "# Global results dictionary for JSON output\n", + "RESULTS = {\n", + " \"timestamp\": datetime.now().isoformat(),\n", + " \"execution_method\": \"jupyter_kernel\",\n", + " \"system\": {},\n", + " \"benchmarks\": {}\n", + "}\n", "\n", "print(\"=\" * 60)\n", "print(\"JUPYTER KERNEL EXECUTION BENCHMARK\")\n", "print(\"=\" * 60)\n", "print(f\"Platform: {platform.platform()}\")\n", "print(f\"Python: {platform.python_version()}\")\n", - "print(f\"CPU Count: {os.cpu_count()}\")" + "print(f\"CPU Count: {os.cpu_count()}\")\n", + "\n", + "RESULTS[\"system\"][\"platform\"] = platform.platform()\n", + "RESULTS[\"system\"][\"python_version\"] = platform.python_version()\n", + "RESULTS[\"system\"][\"cpu_count\"] = os.cpu_count()" ] }, { @@ -44,7 +58,11 @@ "\n", "print(f\"JAX devices: {devices}\")\n", "print(f\"Default backend: {default_backend}\")\n", - "print(f\"GPU Available: {has_gpu}\")" + "print(f\"GPU Available: {has_gpu}\")\n", + "\n", + "RESULTS[\"system\"][\"jax_devices\"] = [str(d) for d in devices]\n", + "RESULTS[\"system\"][\"jax_backend\"] = default_backend\n", + "RESULTS[\"system\"][\"has_gpu\"] = has_gpu" ] }, { @@ -87,7 +105,12 @@ "start = time.perf_counter()\n", "C = matmul(A, B).block_until_ready()\n", "compiled_time = time.perf_counter() - start\n", - "print(f\"Compiled execution: {compiled_time:.3f} seconds\")" + "print(f\"Compiled execution: {compiled_time:.3f} seconds\")\n", + "\n", + "RESULTS[\"benchmarks\"][\"matmul_1000x1000\"] = {\n", + " \"warmup\": warmup_time,\n", + " \"compiled\": compiled_time\n", + "}" ] }, { @@ -115,7 +138,12 @@ "start = time.perf_counter()\n", "C = matmul(A, B).block_until_ready()\n", "compiled_time = time.perf_counter() - start\n", - "print(f\"Compiled execution: {compiled_time:.3f} seconds\")" + "print(f\"Compiled execution: {compiled_time:.3f} seconds\")\n", + "\n", + "RESULTS[\"benchmarks\"][\"matmul_3000x3000\"] = {\n", + " \"warmup\": warmup_time,\n", + " \"compiled\": compiled_time\n", + "}" ] }, { @@ -145,7 +173,12 @@ "start = time.perf_counter()\n", "y = elementwise_ops(x).block_until_ready()\n", "compiled_time = time.perf_counter() - start\n", - "print(f\"Compiled execution: {compiled_time:.3f} seconds\")" + "print(f\"Compiled execution: {compiled_time:.3f} seconds\")\n", + "\n", + "RESULTS[\"benchmarks\"][\"elementwise_50M\"] = {\n", + " \"warmup\": warmup_time,\n", + " \"compiled\": compiled_time\n", + "}" ] }, { @@ -160,6 +193,7 @@ "print(\"=\" * 60)\n", "\n", "total_start = time.perf_counter()\n", + "multi_op_results = {}\n", "\n", "# Simulate multiple cell executions with different operations\n", "for i, size in enumerate([100, 500, 1000, 2000, 3000]):\n", @@ -174,9 +208,13 @@ " result = compute(A, B).block_until_ready()\n", " elapsed = time.perf_counter() - start\n", " print(f\" Size {size}x{size}: {elapsed:.3f} seconds\")\n", + " multi_op_results[f\"size_{size}x{size}\"] = elapsed\n", "\n", "total_time = time.perf_counter() - total_start\n", - "print(f\"\\nTotal time for all operations: {total_time:.3f} seconds\")" + "print(f\"\\nTotal time for all operations: {total_time:.3f} seconds\")\n", + "\n", + "multi_op_results[\"total_time\"] = total_time\n", + "RESULTS[\"benchmarks\"][\"multi_operations\"] = multi_op_results" ] }, { @@ -185,9 +223,15 @@ "metadata": {}, "outputs": [], "source": [ + "# Save results to JSON file\n", + "output_file = \"benchmark_results_jupyter.json\"\n", + "with open(output_file, 'w') as f:\n", + " json.dump(RESULTS, f, indent=2)\n", + "\n", "print(\"\\n\" + \"=\" * 60)\n", "print(\"JUPYTER KERNEL EXECUTION BENCHMARK COMPLETE\")\n", - "print(\"=\" * 60)" + "print(\"=\" * 60)\n", + "print(f\"\\nResults saved to {output_file}\")" ] } ], diff --git a/scripts/benchmark-jupyterbook.md b/scripts/benchmark-jupyterbook.md index 23434e9e..827c1beb 100644 --- a/scripts/benchmark-jupyterbook.md +++ b/scripts/benchmark-jupyterbook.md @@ -19,6 +19,16 @@ Compare results with direct script and nbconvert execution. import time import platform import os +import json +from datetime import datetime + +# Global results dictionary for JSON output +RESULTS = { + "timestamp": datetime.now().isoformat(), + "execution_method": "jupyter_book", + "system": {}, + "benchmarks": {} +} print("=" * 60) print("JUPYTER BOOK EXECUTION BENCHMARK") @@ -26,6 +36,10 @@ print("=" * 60) print(f"Platform: {platform.platform()}") print(f"Python: {platform.python_version()}") print(f"CPU Count: {os.cpu_count()}") + +RESULTS["system"]["platform"] = platform.platform() +RESULTS["system"]["python_version"] = platform.python_version() +RESULTS["system"]["cpu_count"] = os.cpu_count() ``` ```{code-cell} ipython3 @@ -40,6 +54,10 @@ has_gpu = any('cuda' in str(d).lower() or 'gpu' in str(d).lower() for d in devic print(f"JAX devices: {devices}") print(f"Default backend: {default_backend}") print(f"GPU Available: {has_gpu}") + +RESULTS["system"]["jax_devices"] = [str(d) for d in devices] +RESULTS["system"]["jax_backend"] = default_backend +RESULTS["system"]["has_gpu"] = has_gpu ``` ```{code-cell} ipython3 @@ -73,6 +91,11 @@ start = time.perf_counter() C = matmul(A, B).block_until_ready() compiled_time = time.perf_counter() - start print(f"Compiled execution: {compiled_time:.3f} seconds") + +RESULTS["benchmarks"]["matmul_1000x1000"] = { + "warmup": warmup_time, + "compiled": compiled_time +} ``` ```{code-cell} ipython3 @@ -96,6 +119,11 @@ start = time.perf_counter() C = matmul(A, B).block_until_ready() compiled_time = time.perf_counter() - start print(f"Compiled execution: {compiled_time:.3f} seconds") + +RESULTS["benchmarks"]["matmul_3000x3000"] = { + "warmup": warmup_time, + "compiled": compiled_time +} ``` ```{code-cell} ipython3 @@ -121,6 +149,11 @@ start = time.perf_counter() y = elementwise_ops(x).block_until_ready() compiled_time = time.perf_counter() - start print(f"Compiled execution: {compiled_time:.3f} seconds") + +RESULTS["benchmarks"]["elementwise_50M"] = { + "warmup": warmup_time, + "compiled": compiled_time +} ``` ```{code-cell} ipython3 @@ -130,6 +163,7 @@ print("BENCHMARK 4: Multiple Small Operations (lecture simulation)") print("=" * 60) total_start = time.perf_counter() +multi_op_results = {} # Simulate multiple cell executions with different operations for i, size in enumerate([100, 500, 1000, 2000, 3000]): @@ -144,13 +178,23 @@ for i, size in enumerate([100, 500, 1000, 2000, 3000]): result = compute(A, B).block_until_ready() elapsed = time.perf_counter() - start print(f" Size {size}x{size}: {elapsed:.3f} seconds") + multi_op_results[f"size_{size}x{size}"] = elapsed total_time = time.perf_counter() - total_start print(f"\nTotal time for all operations: {total_time:.3f} seconds") + +multi_op_results["total_time"] = total_time +RESULTS["benchmarks"]["multi_operations"] = multi_op_results ``` ```{code-cell} ipython3 +# Save results to JSON file +output_file = "benchmark_results_jupyterbook.json" +with open(output_file, 'w') as f: + json.dump(RESULTS, f, indent=2) + print("\n" + "=" * 60) print("JUPYTER BOOK EXECUTION BENCHMARK COMPLETE") print("=" * 60) +print(f"\nResults saved to {output_file}") ``` From 922b24c2ef512e17ba0d6842e4facc11e82fdf1c Mon Sep 17 00:00:00 2001 From: mmcky Date: Thu, 27 Nov 2025 17:00:53 +1100 Subject: [PATCH 08/19] Fix syntax errors in benchmark-hardware.py - Remove extra triple quote at start of file - Remove stray parentheses in result assignments --- scripts/benchmark-hardware.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/scripts/benchmark-hardware.py b/scripts/benchmark-hardware.py index 7eb13bb3..4e380641 100644 --- a/scripts/benchmark-hardware.py +++ b/scripts/benchmark-hardware.py @@ -1,4 +1,3 @@ -""" """Hardware benchmark script for CI runners. Compares CPU and GPU performance to diagnose slowdowns. Works on both CPU-only (GitHub Actions) and GPU (RunsOn) runners. @@ -204,7 +203,7 @@ def matmul(a, b): C = matmul(A, B).block_until_ready() elapsed = time.perf_counter() - start print(f"Matrix multiply compiled ({n}x{n}): {elapsed:.3f} seconds") - results["matmul_3000x3000_compiled"] = elapsed) + results["matmul_3000x3000_compiled"] = elapsed # Element-wise GPU benchmark x = jax.random.normal(key, (50_000_000,)) @@ -235,7 +234,7 @@ def elementwise_ops(x): RESULTS["benchmarks"]["jax"] = {"error": str(e)} except Exception as e: print(f"JAX benchmark failed: {e}") - RESULTS["benchmarks"]["jax"] = {"error": str(e)}) + RESULTS["benchmarks"]["jax"] = {"error": str(e)} def benchmark_numba(): """Numba CPU benchmark.""" @@ -268,7 +267,7 @@ def numba_sum(n): result = numba_sum(10_000_000) elapsed = time.perf_counter() - start print(f"Integer sum compiled (10M): {elapsed:.3f} seconds") - results["integer_sum_10M_compiled"] = elapsed) + results["integer_sum_10M_compiled"] = elapsed @numba.jit(nopython=True, parallel=True) def numba_parallel_sum(arr): @@ -301,7 +300,7 @@ def numba_parallel_sum(arr): RESULTS["benchmarks"]["numba"] = {"error": str(e)} except Exception as e: print(f"Numba benchmark failed: {e}") - RESULTS["benchmarks"]["numba"] = {"error": str(e)}) + RESULTS["benchmarks"]["numba"] = {"error": str(e)} if __name__ == "__main__": print("\n" + "=" * 60) From 10627efea44b46faccb07ba6037720e731ffc09c Mon Sep 17 00:00:00 2001 From: mmcky Date: Thu, 27 Nov 2025 17:04:00 +1100 Subject: [PATCH 09/19] Sync benchmark scripts with CPU branch for comparable results - Copy benchmark-hardware.py from debug/benchmark-github-actions - Copy benchmark-jupyter.ipynb from debug/benchmark-github-actions - Copy benchmark-jupyterbook.md from debug/benchmark-github-actions - Update ci.yml to use matching file names The test scripts are now identical between both branches, only the CI workflow differs (runner type and JAX installation). --- .github/workflows/ci.yml | 27 +++++----- scripts/benchmark-hardware.py | 88 ++++++++++++++++++-------------- scripts/benchmark-jupyter.ipynb | 54 +++++++++----------- scripts/benchmark-jupyterbook.md | 50 +++++++++--------- 4 files changed, 111 insertions(+), 108 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 07f58b2b..1e7f59f7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,7 +32,7 @@ jobs: echo "=== Bare Metal Python Script Execution ===" python scripts/benchmark-hardware.py mkdir -p benchmark_results - mv benchmark_results_baremetal.json benchmark_results/ + mv benchmark_results_bare_metal.json benchmark_results/ - name: Run Jupyter Notebook Benchmark (via nbconvert) shell: bash -l {0} run: | @@ -56,21 +56,14 @@ jobs: # Create minimal _toc.yml echo "format: jb-book" > benchmark_test/_toc.yml echo "root: benchmark-jupyterbook" >> benchmark_test/_toc.yml - # Build - jb build benchmark_test --path-output benchmark_build/ + # Build (run from benchmark_test so JSON is written there) + cd benchmark_test + jb build . --path-output ../benchmark_build/ + cd .. echo "Jupyter-Book build completed successfully" # Move JSON results if generated - if [ -f benchmark_test/benchmark_results_jupyterbook.json ]; then - mv benchmark_test/benchmark_results_jupyterbook.json benchmark_results/ - elif [ -f benchmark_results_jupyterbook.json ]; then - mv benchmark_results_jupyterbook.json benchmark_results/ - fi - - name: Collect and Upload Benchmark Results - uses: actions/upload-artifact@v5 - with: - name: benchmark-results - path: benchmark_results/ - - name: Display Benchmark Results + cp benchmark_test/benchmark_results_jupyterbook.json benchmark_results/ 2>/dev/null || echo "No jupyterbook results" + - name: Collect and Display Benchmark Results shell: bash -l {0} run: | echo "=== Benchmark Results Summary ===" @@ -79,6 +72,12 @@ jobs: cat "$f" echo "" done + - name: Upload Benchmark Results + uses: actions/upload-artifact@v5 + with: + name: benchmark-results + path: benchmark_results/ + if-no-files-found: warn - name: Install latex dependencies run: | sudo apt-get -qq update diff --git a/scripts/benchmark-hardware.py b/scripts/benchmark-hardware.py index 4e380641..12443855 100644 --- a/scripts/benchmark-hardware.py +++ b/scripts/benchmark-hardware.py @@ -1,4 +1,5 @@ -"""Hardware benchmark script for CI runners. +""" +Hardware benchmark script for CI runners. Compares CPU and GPU performance to diagnose slowdowns. Works on both CPU-only (GitHub Actions) and GPU (RunsOn) runners. """ @@ -8,10 +9,10 @@ import json from datetime import datetime -# Global results dictionary for JSON output +# Global results dictionary RESULTS = { + "pathway": "bare_metal", "timestamp": datetime.now().isoformat(), - "execution_method": "bare_metal", "system": {}, "benchmarks": {} } @@ -27,14 +28,18 @@ def get_cpu_info(): RESULTS["system"]["platform"] = platform.platform() RESULTS["system"]["processor"] = platform.processor() - RESULTS["system"]["python_version"] = platform.python_version() + RESULTS["system"]["python"] = platform.python_version() + RESULTS["system"]["cpu_count"] = os.cpu_count() - # Try to get CPU frequency + # Try to get CPU model + cpu_model = None + cpu_mhz = None try: with open('/proc/cpuinfo', 'r') as f: for line in f: if 'model name' in line: - print(f"CPU Model: {line.split(':')[1].strip()}") + cpu_model = line.split(':')[1].strip() + print(f"CPU Model: {cpu_model}") break except: pass @@ -44,16 +49,20 @@ def get_cpu_info(): with open('/proc/cpuinfo', 'r') as f: for line in f: if 'cpu MHz' in line: - print(f"CPU MHz: {line.split(':')[1].strip()}") + cpu_mhz = line.split(':')[1].strip() + print(f"CPU MHz: {cpu_mhz}") break except: pass + RESULTS["system"]["cpu_model"] = cpu_model + RESULTS["system"]["cpu_mhz"] = cpu_mhz + # CPU count print(f"CPU Count: {os.cpu_count()}") - RESULTS["system"]["cpu_count"] = os.cpu_count() # Check for GPU + gpu_info = None try: import subprocess result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv,noheader'], @@ -61,14 +70,12 @@ def get_cpu_info(): if result.returncode == 0: gpu_info = result.stdout.strip() print(f"GPU: {gpu_info}") - RESULTS["system"]["gpu"] = gpu_info else: print("GPU: None detected") - RESULTS["system"]["gpu"] = None except: print("GPU: None detected (nvidia-smi not available)") - RESULTS["system"]["gpu"] = None + RESULTS["system"]["gpu"] = gpu_info print() def benchmark_cpu_pure_python(): @@ -84,7 +91,7 @@ def benchmark_cpu_pure_python(): total = sum(i * i for i in range(10_000_000)) elapsed = time.perf_counter() - start print(f"Integer sum (10M iterations): {elapsed:.3f} seconds") - results["integer_sum_10M"] = elapsed + results["integer_sum_10m"] = elapsed # Float computation start = time.perf_counter() @@ -93,10 +100,10 @@ def benchmark_cpu_pure_python(): total += (i * 0.1) ** 0.5 elapsed = time.perf_counter() - start print(f"Float sqrt (1M iterations): {elapsed:.3f} seconds") - results["float_sqrt_1M"] = elapsed + results["float_sqrt_1m"] = elapsed + print() RESULTS["benchmarks"]["pure_python"] = results - print() def benchmark_cpu_numpy(): """NumPy CPU benchmark.""" @@ -126,10 +133,10 @@ def benchmark_cpu_numpy(): y = np.cos(x**2) + np.sin(x) elapsed = time.perf_counter() - start print(f"Element-wise ops (50M elements): {elapsed:.3f} seconds") - results["elementwise_50M"] = elapsed + results["elementwise_50m"] = elapsed + print() RESULTS["benchmarks"]["numpy"] = results - print() def benchmark_gpu_jax(): """JAX benchmark (GPU if available, otherwise CPU).""" @@ -137,18 +144,12 @@ def benchmark_gpu_jax(): import jax import jax.numpy as jnp - results = {} - devices = jax.devices() default_backend = jax.default_backend() # Check if GPU is available has_gpu = any('cuda' in str(d).lower() or 'gpu' in str(d).lower() for d in devices) - results["has_gpu"] = has_gpu - results["default_backend"] = default_backend - results["devices"] = [str(d) for d in devices] - print("=" * 60) if has_gpu: print("JAX BENCHMARK: GPU") @@ -161,6 +162,12 @@ def benchmark_gpu_jax(): print(f"GPU Available: {has_gpu}") print() + results = { + "backend": default_backend, + "has_gpu": has_gpu, + "devices": str(devices) + } + # Warm-up JIT compilation print("Warming up JIT compilation...") n = 1000 @@ -217,17 +224,17 @@ def elementwise_ops(x): y = elementwise_ops(x).block_until_ready() warmup_time = time.perf_counter() - start print(f"Element-wise warm-up (50M): {warmup_time:.3f} seconds") - results["elementwise_50M_warmup"] = warmup_time + results["elementwise_50m_warmup"] = warmup_time # Compiled start = time.perf_counter() y = elementwise_ops(x).block_until_ready() elapsed = time.perf_counter() - start print(f"Element-wise compiled (50M): {elapsed:.3f} seconds") - results["elementwise_50M_compiled"] = elapsed + results["elementwise_50m_compiled"] = elapsed - RESULTS["benchmarks"]["jax"] = results print() + RESULTS["benchmarks"]["jax"] = results except ImportError as e: print(f"JAX not available: {e}") @@ -242,12 +249,12 @@ def benchmark_numba(): import numba import numpy as np - results = {} - print("=" * 60) print("CPU BENCHMARK: Numba") print("=" * 60) + results = {} + @numba.jit(nopython=True) def numba_sum(n): total = 0 @@ -260,14 +267,14 @@ def numba_sum(n): result = numba_sum(10_000_000) warmup_time = time.perf_counter() - start print(f"Integer sum warm-up (includes compile): {warmup_time:.3f} seconds") - results["integer_sum_10M_warmup"] = warmup_time + results["integer_sum_10m_warmup"] = warmup_time # Compiled run start = time.perf_counter() result = numba_sum(10_000_000) elapsed = time.perf_counter() - start print(f"Integer sum compiled (10M): {elapsed:.3f} seconds") - results["integer_sum_10M_compiled"] = elapsed + results["integer_sum_10m_compiled"] = elapsed @numba.jit(nopython=True, parallel=True) def numba_parallel_sum(arr): @@ -283,17 +290,17 @@ def numba_parallel_sum(arr): result = numba_parallel_sum(arr) warmup_time = time.perf_counter() - start print(f"Parallel sum warm-up (50M): {warmup_time:.3f} seconds") - results["parallel_sum_50M_warmup"] = warmup_time + results["parallel_sum_50m_warmup"] = warmup_time # Compiled start = time.perf_counter() result = numba_parallel_sum(arr) elapsed = time.perf_counter() - start print(f"Parallel sum compiled (50M): {elapsed:.3f} seconds") - results["parallel_sum_50M_compiled"] = elapsed + results["parallel_sum_50m_compiled"] = elapsed - RESULTS["benchmarks"]["numba"] = results print() + RESULTS["benchmarks"]["numba"] = results except ImportError as e: print(f"Numba not available: {e}") @@ -302,6 +309,14 @@ def numba_parallel_sum(arr): print(f"Numba benchmark failed: {e}") RESULTS["benchmarks"]["numba"] = {"error": str(e)} + +def save_results(output_path="benchmark_results_bare_metal.json"): + """Save benchmark results to JSON file.""" + with open(output_path, 'w') as f: + json.dump(RESULTS, f, indent=2) + print(f"\nResults saved to: {output_path}") + + if __name__ == "__main__": print("\n" + "=" * 60) print("HARDWARE BENCHMARK FOR CI RUNNER") @@ -313,12 +328,9 @@ def numba_parallel_sum(arr): benchmark_numba() benchmark_gpu_jax() + # Save results to JSON + save_results("benchmark_results_bare_metal.json") + print("=" * 60) print("BENCHMARK COMPLETE") print("=" * 60) - - # Save results to JSON - output_file = "benchmark_results_baremetal.json" - with open(output_file, 'w') as f: - json.dump(RESULTS, f, indent=2) - print(f"\nResults saved to {output_file}") diff --git a/scripts/benchmark-jupyter.ipynb b/scripts/benchmark-jupyter.ipynb index 3804af46..909b8fe5 100644 --- a/scripts/benchmark-jupyter.ipynb +++ b/scripts/benchmark-jupyter.ipynb @@ -22,11 +22,15 @@ "import json\n", "from datetime import datetime\n", "\n", - "# Global results dictionary for JSON output\n", + "# Initialize results dictionary\n", "RESULTS = {\n", + " \"pathway\": \"jupyter_kernel\",\n", " \"timestamp\": datetime.now().isoformat(),\n", - " \"execution_method\": \"jupyter_kernel\",\n", - " \"system\": {},\n", + " \"system\": {\n", + " \"platform\": platform.platform(),\n", + " \"python\": platform.python_version(),\n", + " \"cpu_count\": os.cpu_count()\n", + " },\n", " \"benchmarks\": {}\n", "}\n", "\n", @@ -35,11 +39,7 @@ "print(\"=\" * 60)\n", "print(f\"Platform: {platform.platform()}\")\n", "print(f\"Python: {platform.python_version()}\")\n", - "print(f\"CPU Count: {os.cpu_count()}\")\n", - "\n", - "RESULTS[\"system\"][\"platform\"] = platform.platform()\n", - "RESULTS[\"system\"][\"python_version\"] = platform.python_version()\n", - "RESULTS[\"system\"][\"cpu_count\"] = os.cpu_count()" + "print(f\"CPU Count: {os.cpu_count()}\")" ] }, { @@ -60,9 +60,9 @@ "print(f\"Default backend: {default_backend}\")\n", "print(f\"GPU Available: {has_gpu}\")\n", "\n", - "RESULTS[\"system\"][\"jax_devices\"] = [str(d) for d in devices]\n", "RESULTS[\"system\"][\"jax_backend\"] = default_backend\n", - "RESULTS[\"system\"][\"has_gpu\"] = has_gpu" + "RESULTS[\"system\"][\"has_gpu\"] = has_gpu\n", + "RESULTS[\"system\"][\"jax_devices\"] = str(devices)" ] }, { @@ -107,10 +107,8 @@ "compiled_time = time.perf_counter() - start\n", "print(f\"Compiled execution: {compiled_time:.3f} seconds\")\n", "\n", - "RESULTS[\"benchmarks\"][\"matmul_1000x1000\"] = {\n", - " \"warmup\": warmup_time,\n", - " \"compiled\": compiled_time\n", - "}" + "RESULTS[\"benchmarks\"][\"matmul_1000x1000_warmup\"] = warmup_time\n", + "RESULTS[\"benchmarks\"][\"matmul_1000x1000_compiled\"] = compiled_time" ] }, { @@ -140,10 +138,8 @@ "compiled_time = time.perf_counter() - start\n", "print(f\"Compiled execution: {compiled_time:.3f} seconds\")\n", "\n", - "RESULTS[\"benchmarks\"][\"matmul_3000x3000\"] = {\n", - " \"warmup\": warmup_time,\n", - " \"compiled\": compiled_time\n", - "}" + "RESULTS[\"benchmarks\"][\"matmul_3000x3000_warmup\"] = warmup_time\n", + "RESULTS[\"benchmarks\"][\"matmul_3000x3000_compiled\"] = compiled_time" ] }, { @@ -175,10 +171,8 @@ "compiled_time = time.perf_counter() - start\n", "print(f\"Compiled execution: {compiled_time:.3f} seconds\")\n", "\n", - "RESULTS[\"benchmarks\"][\"elementwise_50M\"] = {\n", - " \"warmup\": warmup_time,\n", - " \"compiled\": compiled_time\n", - "}" + "RESULTS[\"benchmarks\"][\"elementwise_50m_warmup\"] = warmup_time\n", + "RESULTS[\"benchmarks\"][\"elementwise_50m_compiled\"] = compiled_time" ] }, { @@ -193,7 +187,7 @@ "print(\"=\" * 60)\n", "\n", "total_start = time.perf_counter()\n", - "multi_op_results = {}\n", + "multi_results = {}\n", "\n", "# Simulate multiple cell executions with different operations\n", "for i, size in enumerate([100, 500, 1000, 2000, 3000]):\n", @@ -208,13 +202,13 @@ " result = compute(A, B).block_until_ready()\n", " elapsed = time.perf_counter() - start\n", " print(f\" Size {size}x{size}: {elapsed:.3f} seconds\")\n", - " multi_op_results[f\"size_{size}x{size}\"] = elapsed\n", + " multi_results[f\"size_{size}x{size}\"] = elapsed\n", "\n", "total_time = time.perf_counter() - total_start\n", "print(f\"\\nTotal time for all operations: {total_time:.3f} seconds\")\n", "\n", - "multi_op_results[\"total_time\"] = total_time\n", - "RESULTS[\"benchmarks\"][\"multi_operations\"] = multi_op_results" + "RESULTS[\"benchmarks\"][\"multi_ops\"] = multi_results\n", + "RESULTS[\"benchmarks\"][\"multi_ops_total\"] = total_time" ] }, { @@ -224,14 +218,16 @@ "outputs": [], "source": [ "# Save results to JSON file\n", - "output_file = \"benchmark_results_jupyter.json\"\n", - "with open(output_file, 'w') as f:\n", + "output_path = \"benchmark_results_jupyter.json\"\n", + "with open(output_path, 'w') as f:\n", " json.dump(RESULTS, f, indent=2)\n", "\n", "print(\"\\n\" + \"=\" * 60)\n", "print(\"JUPYTER KERNEL EXECUTION BENCHMARK COMPLETE\")\n", "print(\"=\" * 60)\n", - "print(f\"\\nResults saved to {output_file}\")" + "print(f\"\\nResults saved to: {output_path}\")\n", + "print(\"\\nJSON Results:\")\n", + "print(json.dumps(RESULTS, indent=2))" ] } ], diff --git a/scripts/benchmark-jupyterbook.md b/scripts/benchmark-jupyterbook.md index 827c1beb..162613c8 100644 --- a/scripts/benchmark-jupyterbook.md +++ b/scripts/benchmark-jupyterbook.md @@ -22,11 +22,15 @@ import os import json from datetime import datetime -# Global results dictionary for JSON output +# Initialize results dictionary RESULTS = { + "pathway": "jupyter_book", "timestamp": datetime.now().isoformat(), - "execution_method": "jupyter_book", - "system": {}, + "system": { + "platform": platform.platform(), + "python": platform.python_version(), + "cpu_count": os.cpu_count() + }, "benchmarks": {} } @@ -36,10 +40,6 @@ print("=" * 60) print(f"Platform: {platform.platform()}") print(f"Python: {platform.python_version()}") print(f"CPU Count: {os.cpu_count()}") - -RESULTS["system"]["platform"] = platform.platform() -RESULTS["system"]["python_version"] = platform.python_version() -RESULTS["system"]["cpu_count"] = os.cpu_count() ``` ```{code-cell} ipython3 @@ -55,9 +55,9 @@ print(f"JAX devices: {devices}") print(f"Default backend: {default_backend}") print(f"GPU Available: {has_gpu}") -RESULTS["system"]["jax_devices"] = [str(d) for d in devices] RESULTS["system"]["jax_backend"] = default_backend RESULTS["system"]["has_gpu"] = has_gpu +RESULTS["system"]["jax_devices"] = str(devices) ``` ```{code-cell} ipython3 @@ -92,10 +92,8 @@ C = matmul(A, B).block_until_ready() compiled_time = time.perf_counter() - start print(f"Compiled execution: {compiled_time:.3f} seconds") -RESULTS["benchmarks"]["matmul_1000x1000"] = { - "warmup": warmup_time, - "compiled": compiled_time -} +RESULTS["benchmarks"]["matmul_1000x1000_warmup"] = warmup_time +RESULTS["benchmarks"]["matmul_1000x1000_compiled"] = compiled_time ``` ```{code-cell} ipython3 @@ -120,10 +118,8 @@ C = matmul(A, B).block_until_ready() compiled_time = time.perf_counter() - start print(f"Compiled execution: {compiled_time:.3f} seconds") -RESULTS["benchmarks"]["matmul_3000x3000"] = { - "warmup": warmup_time, - "compiled": compiled_time -} +RESULTS["benchmarks"]["matmul_3000x3000_warmup"] = warmup_time +RESULTS["benchmarks"]["matmul_3000x3000_compiled"] = compiled_time ``` ```{code-cell} ipython3 @@ -150,10 +146,8 @@ y = elementwise_ops(x).block_until_ready() compiled_time = time.perf_counter() - start print(f"Compiled execution: {compiled_time:.3f} seconds") -RESULTS["benchmarks"]["elementwise_50M"] = { - "warmup": warmup_time, - "compiled": compiled_time -} +RESULTS["benchmarks"]["elementwise_50m_warmup"] = warmup_time +RESULTS["benchmarks"]["elementwise_50m_compiled"] = compiled_time ``` ```{code-cell} ipython3 @@ -163,7 +157,7 @@ print("BENCHMARK 4: Multiple Small Operations (lecture simulation)") print("=" * 60) total_start = time.perf_counter() -multi_op_results = {} +multi_results = {} # Simulate multiple cell executions with different operations for i, size in enumerate([100, 500, 1000, 2000, 3000]): @@ -178,23 +172,25 @@ for i, size in enumerate([100, 500, 1000, 2000, 3000]): result = compute(A, B).block_until_ready() elapsed = time.perf_counter() - start print(f" Size {size}x{size}: {elapsed:.3f} seconds") - multi_op_results[f"size_{size}x{size}"] = elapsed + multi_results[f"size_{size}x{size}"] = elapsed total_time = time.perf_counter() - total_start print(f"\nTotal time for all operations: {total_time:.3f} seconds") -multi_op_results["total_time"] = total_time -RESULTS["benchmarks"]["multi_operations"] = multi_op_results +RESULTS["benchmarks"]["multi_ops"] = multi_results +RESULTS["benchmarks"]["multi_ops_total"] = total_time ``` ```{code-cell} ipython3 # Save results to JSON file -output_file = "benchmark_results_jupyterbook.json" -with open(output_file, 'w') as f: +output_path = "benchmark_results_jupyterbook.json" +with open(output_path, 'w') as f: json.dump(RESULTS, f, indent=2) print("\n" + "=" * 60) print("JUPYTER BOOK EXECUTION BENCHMARK COMPLETE") print("=" * 60) -print(f"\nResults saved to {output_file}") +print(f"\nResults saved to: {output_path}") +print("\nJSON Results:") +print(json.dumps(RESULTS, indent=2)) ``` From 54b2b34a69e24c6323cfa41c0eb63a15e4d960ee Mon Sep 17 00:00:00 2001 From: mmcky Date: Thu, 27 Nov 2025 19:27:07 +1100 Subject: [PATCH 10/19] ENH: Force lax.scan sequential operation to run on CPU Add device=cpu to the qm_jax function decorator to avoid the known XLA limitation where lax.scan with millions of lightweight iterations performs poorly on GPU due to CPU-GPU synchronization overhead. Added explanatory note about this pattern. Co-authored-by: HumphreyYang --- lectures/numpy_vs_numba_vs_jax.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index 99393e2e..ef9b3532 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -509,7 +509,9 @@ Now let's create a JAX version using `lax.scan`: from jax import lax from functools import partial -@partial(jax.jit, static_argnums=(1,)) +cpu = jax.devices("cpu")[0] + +@partial(jax.jit, static_argnums=(1,), device=cpu) def qm_jax(x0, n, α=4.0): def update(x, t): x_new = α * x * (1 - x) @@ -521,6 +523,11 @@ def qm_jax(x0, n, α=4.0): This code is not easy to read but, in essence, `lax.scan` repeatedly calls `update` and accumulates the returns `x_new` into an array. +```{note} +We explicitly target the CPU using `device=cpu` because `lax.scan` with many +lightweight iterations performs poorly on GPU due to synchronization overhead. +``` + Let's time it with the same parameters: ```{code-cell} ipython3 From 6bf345ae75fdb5746b8af12a013c3fe86fb5be08 Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Thu, 27 Nov 2025 20:14:24 +1100 Subject: [PATCH 11/19] update note --- lectures/numpy_vs_numba_vs_jax.md | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index ef9b3532..84c79f8f 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -524,8 +524,13 @@ def qm_jax(x0, n, α=4.0): This code is not easy to read but, in essence, `lax.scan` repeatedly calls `update` and accumulates the returns `x_new` into an array. ```{note} -We explicitly target the CPU using `device=cpu` because `lax.scan` with many -lightweight iterations performs poorly on GPU due to synchronization overhead. +Sharp readers will notice that we specify `device=cpu` in the `jax.jit` decorator. + +The computation consists of many very small `lax.scan` iterations that must run sequentially, leaving little opportunity for the GPU to exploit parallelism. + +As a result, kernel-launch overhead tends to dominate on the GPU, making the CPU a better fit for this workload. + +Curious readers can try removing this option to see how performance changes. ``` Let's time it with the same parameters: From 8fbb9a752fddf54b25646d3c6f1fe515fc85f332 Mon Sep 17 00:00:00 2001 From: mmcky Date: Thu, 27 Nov 2025 21:13:33 +1100 Subject: [PATCH 12/19] Add lax.scan profiler to CI for GPU debugging - Add scripts/profile_lax_scan.py: Profiles lax.scan performance on GPU vs CPU to investigate the synchronization overhead issue (JAX Issue #2491) - Add CI step to run profiler with 100K iterations on RunsOn GPU environment - Script supports multiple profiling modes: basic timing, Nsight, JAX profiler, XLA dumps --- .github/workflows/ci.yml | 10 +++ scripts/profile_lax_scan.py | 174 ++++++++++++++++++++++++++++++++++++ 2 files changed, 184 insertions(+) create mode 100644 scripts/profile_lax_scan.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1e7f59f7..4d7e367a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,6 +25,16 @@ jobs: pip install -U "jax[cuda13]" pip install numpyro python scripts/test-jax-install.py + # === lax.scan GPU Performance Profiling === + - name: Profile lax.scan (GPU vs CPU) + shell: bash -l {0} + run: | + echo "=== lax.scan Performance Profiling ===" + echo "This profiles the known issue with lax.scan on GPU (JAX Issue #2491)" + echo "" + python scripts/profile_lax_scan.py --iterations 100000 + echo "" + echo "Note: GPU is expected to be much slower due to CPU-GPU sync per iteration" # === Benchmark Tests (Bare Metal, Jupyter, Jupyter-Book) === - name: Run Hardware Benchmarks (Bare Metal) shell: bash -l {0} diff --git a/scripts/profile_lax_scan.py b/scripts/profile_lax_scan.py new file mode 100644 index 00000000..89ba37e5 --- /dev/null +++ b/scripts/profile_lax_scan.py @@ -0,0 +1,174 @@ +""" +Profile lax.scan performance on GPU vs CPU to investigate synchronization overhead. + +This script helps diagnose why lax.scan with many lightweight iterations +performs poorly on GPU (81s) compared to CPU (0.06s). + +Usage: + # Basic timing comparison + python profile_lax_scan.py + + # With NVIDIA Nsight Systems (requires nsys installed) + nsys profile -o lax_scan_profile --trace=cuda,nvtx python profile_lax_scan.py --nsys + + # With JAX profiler (view with TensorBoard) + python profile_lax_scan.py --jax-profile + + # With XLA debug dumps + python profile_lax_scan.py --xla-dump + +Requirements: + - JAX with CUDA support: pip install jax[cuda12] + - For Nsight: NVIDIA Nsight Systems (https://developer.nvidia.com/nsight-systems) + - For TensorBoard: pip install tensorboard tensorboard-plugin-profile +""" + +import argparse +import os +import time +from functools import partial + +def setup_xla_dump(dump_dir="/tmp/xla_dump"): + """Enable XLA debug dumps before importing JAX.""" + os.makedirs(dump_dir, exist_ok=True) + os.environ["XLA_FLAGS"] = f"--xla_dump_to={dump_dir} --xla_dump_hlo_as_text" + print(f"XLA dumps will be written to: {dump_dir}") + +def main(): + parser = argparse.ArgumentParser(description="Profile lax.scan GPU performance") + parser.add_argument("--nsys", action="store_true", + help="Run in Nsight Systems compatible mode (smaller n)") + parser.add_argument("--jax-profile", action="store_true", + help="Enable JAX profiler (view with TensorBoard)") + parser.add_argument("--xla-dump", action="store_true", + help="Dump XLA HLO for analysis") + parser.add_argument("-n", "--iterations", type=int, default=10_000_000, + dest="n", help="Number of iterations (default: 10M)") + parser.add_argument("--profile-dir", type=str, default="/tmp/jax-trace", + help="Directory for JAX profile output") + args = parser.parse_args() + + # Setup XLA dump before importing JAX + if args.xla_dump: + setup_xla_dump() + + # Now import JAX + import jax + import jax.numpy as jnp + from jax import lax + + print("=" * 60) + print("lax.scan GPU Performance Profiling") + print("=" * 60) + + # Show device info + print(f"\nJAX version: {jax.__version__}") + print(f"Available devices: {jax.devices()}") + print(f"Default device: {jax.devices()[0]}") + + # Reduce n for Nsight profiling to keep trace manageable + n = 100_000 if args.nsys else args.n + print(f"\nIterations (n): {n:,}") + + # Define the functions + @partial(jax.jit, static_argnums=(1,)) + def qm_jax_default(x0, n, α=4.0): + """lax.scan on default device (GPU if available).""" + def update(x, t): + x_new = α * x * (1 - x) + return x_new, x_new + _, x = lax.scan(update, x0, jnp.arange(n)) + return jnp.concatenate([jnp.array([x0]), x]) + + cpu = jax.devices("cpu")[0] + + @partial(jax.jit, static_argnums=(1,), device=cpu) + def qm_jax_cpu(x0, n, α=4.0): + """lax.scan forced to CPU.""" + def update(x, t): + x_new = α * x * (1 - x) + return x_new, x_new + _, x = lax.scan(update, x0, jnp.arange(n)) + return jnp.concatenate([jnp.array([x0]), x]) + + # Warm up (compilation) + print("\n--- Compilation (warm-up) ---") + print("Compiling default device version...", end=" ", flush=True) + t0 = time.perf_counter() + _ = qm_jax_default(0.1, n).block_until_ready() + print(f"done ({time.perf_counter() - t0:.2f}s)") + + print("Compiling CPU version...", end=" ", flush=True) + t0 = time.perf_counter() + _ = qm_jax_cpu(0.1, n).block_until_ready() + print(f"done ({time.perf_counter() - t0:.2f}s)") + + # Profile with JAX profiler if requested + if args.jax_profile: + print(f"\n--- JAX Profiler (output: {args.profile_dir}) ---") + os.makedirs(args.profile_dir, exist_ok=True) + + jax.profiler.start_trace(args.profile_dir) + + # Run both versions while profiling + print("Profiling default device version...") + result_default = qm_jax_default(0.1, n).block_until_ready() + + print("Profiling CPU version...") + result_cpu = qm_jax_cpu(0.1, n).block_until_ready() + + jax.profiler.stop_trace() + print(f"\nProfile saved. View with:") + print(f" tensorboard --logdir={args.profile_dir}") + + # Timing runs + print("\n--- Timing Runs (post-compilation) ---") + + # Default device (GPU if available) + print(f"\nDefault device ({jax.devices()[0]}):") + times_default = [] + for i in range(3): + t0 = time.perf_counter() + result = qm_jax_default(0.1, n).block_until_ready() + elapsed = time.perf_counter() - t0 + times_default.append(elapsed) + print(f" Run {i+1}: {elapsed:.6f}s") + + # CPU + print(f"\nCPU (forced with device=cpu):") + times_cpu = [] + for i in range(3): + t0 = time.perf_counter() + result = qm_jax_cpu(0.1, n).block_until_ready() + elapsed = time.perf_counter() - t0 + times_cpu.append(elapsed) + print(f" Run {i+1}: {elapsed:.6f}s") + + # Summary + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + avg_default = sum(times_default) / len(times_default) + avg_cpu = sum(times_cpu) / len(times_cpu) + print(f"Default device avg: {avg_default:.6f}s") + print(f"CPU avg: {avg_cpu:.6f}s") + print(f"Ratio (default/cpu): {avg_default/avg_cpu:.1f}x") + + if avg_default > avg_cpu * 10: + print("\n⚠️ GPU is significantly slower than CPU!") + print(" This confirms the lax.scan synchronization overhead issue.") + elif avg_default < avg_cpu: + print("\n✓ GPU is faster (unexpected for this workload)") + else: + print("\n~ Performance is similar") + + if args.xla_dump: + print(f"\nXLA dumps written to /tmp/xla_dump/") + print("Look for .txt files with HLO representation") + + if args.nsys: + print("\nNsight Systems trace will be saved as lax_scan_profile.nsys-rep") + print("View with: nsys-ui lax_scan_profile.nsys-rep") + +if __name__ == "__main__": + main() From 1bfbaf9a1946ac4778f41659e903e7100ab0850e Mon Sep 17 00:00:00 2001 From: mmcky Date: Thu, 27 Nov 2025 21:22:13 +1100 Subject: [PATCH 13/19] Add diagnostic mode to lax.scan profiler - Add --diagnose flag that tests time scaling across iteration counts - If time scales linearly with iterations (not compute), it proves constant per-iteration overhead (CPU-GPU synchronization) - Also add --verbose flag for CUDA/XLA logging - Update CI to run with --diagnose flag --- .github/workflows/ci.yml | 5 +-- scripts/profile_lax_scan.py | 67 +++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4d7e367a..36817cd8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,9 +32,10 @@ jobs: echo "=== lax.scan Performance Profiling ===" echo "This profiles the known issue with lax.scan on GPU (JAX Issue #2491)" echo "" - python scripts/profile_lax_scan.py --iterations 100000 + python scripts/profile_lax_scan.py --iterations 100000 --diagnose echo "" - echo "Note: GPU is expected to be much slower due to CPU-GPU sync per iteration" + echo "The diagnostic shows if time scales linearly with iterations," + echo "which indicates constant per-iteration CPU-GPU sync overhead." # === Benchmark Tests (Bare Metal, Jupyter, Jupyter-Book) === - name: Run Hardware Benchmarks (Bare Metal) shell: bash -l {0} diff --git a/scripts/profile_lax_scan.py b/scripts/profile_lax_scan.py index 89ba37e5..9930afe5 100644 --- a/scripts/profile_lax_scan.py +++ b/scripts/profile_lax_scan.py @@ -34,6 +34,13 @@ def setup_xla_dump(dump_dir="/tmp/xla_dump"): os.environ["XLA_FLAGS"] = f"--xla_dump_to={dump_dir} --xla_dump_hlo_as_text" print(f"XLA dumps will be written to: {dump_dir}") +def setup_cuda_logging(): + """Enable CUDA/XLA logging to see sync patterns.""" + # These may help reveal synchronization behavior + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" # Show all TF/XLA logs + os.environ["XLA_FLAGS"] = os.environ.get("XLA_FLAGS", "") + " --xla_gpu_cuda_data_dir=/usr/local/cuda" + print("CUDA/XLA logging enabled") + def main(): parser = argparse.ArgumentParser(description="Profile lax.scan GPU performance") parser.add_argument("--nsys", action="store_true", @@ -42,6 +49,10 @@ def main(): help="Enable JAX profiler (view with TensorBoard)") parser.add_argument("--xla-dump", action="store_true", help="Dump XLA HLO for analysis") + parser.add_argument("--verbose", action="store_true", + help="Enable verbose CUDA/XLA logging") + parser.add_argument("--diagnose", action="store_true", + help="Run diagnostic to demonstrate sync overhead") parser.add_argument("-n", "--iterations", type=int, default=10_000_000, dest="n", help="Number of iterations (default: 10M)") parser.add_argument("--profile-dir", type=str, default="/tmp/jax-trace", @@ -51,6 +62,9 @@ def main(): # Setup XLA dump before importing JAX if args.xla_dump: setup_xla_dump() + + if args.verbose: + setup_cuda_logging() # Now import JAX import jax @@ -170,5 +184,58 @@ def update(x, t): print("\nNsight Systems trace will be saved as lax_scan_profile.nsys-rep") print("View with: nsys-ui lax_scan_profile.nsys-rep") + # Diagnostic: demonstrate sync overhead by showing time scaling + if args.diagnose: + print("\n" + "=" * 60) + print("DIAGNOSTIC: Per-iteration Sync Overhead Analysis") + print("=" * 60) + print("\nIf there's a CPU-GPU sync per iteration, time should scale") + print("linearly with iteration count (not with compute work).\n") + + # Test different iteration counts + test_ns = [1000, 5000, 10000, 50000, 100000] + + print("Iteration Count | GPU Time (s) | Time/Iter (µs) | Expected if O(n)") + print("-" * 70) + + gpu_times = [] + for test_n in test_ns: + # Define fresh function for this n + @partial(jax.jit, static_argnums=(1,)) + def qm_test(x0, n, α=4.0): + def update(x, t): + return α * x * (1 - x), α * x * (1 - x) + _, x = lax.scan(update, x0, jnp.arange(n)) + return jnp.concatenate([jnp.array([x0]), x]) + + # Compile + _ = qm_test(0.1, test_n).block_until_ready() + + # Time + t0 = time.perf_counter() + _ = qm_test(0.1, test_n).block_until_ready() + elapsed = time.perf_counter() - t0 + gpu_times.append(elapsed) + + time_per_iter = (elapsed / test_n) * 1_000_000 # microseconds + expected = gpu_times[0] * (test_n / test_ns[0]) if gpu_times else elapsed + + print(f"{test_n:>15,} | {elapsed:>12.6f} | {time_per_iter:>14.2f} | {expected:.6f}") + + # Calculate if time scales linearly (indicating per-iteration overhead) + ratio_1k_to_100k = gpu_times[-1] / gpu_times[0] + expected_ratio = test_ns[-1] / test_ns[0] # 100x if linear + + print(f"\nScaling analysis:") + print(f" Time ratio (100k/1k iterations): {ratio_1k_to_100k:.1f}x") + print(f" Expected if linear O(n): {expected_ratio:.1f}x") + + if 0.5 * expected_ratio < ratio_1k_to_100k < 2.0 * expected_ratio: + print("\n✓ Time scales ~linearly with iterations!") + print(" This indicates constant per-iteration overhead (CPU-GPU sync).") + print(f" Estimated sync overhead: ~{(gpu_times[0]/test_ns[0])*1e6:.1f} µs per iteration") + else: + print("\n? Scaling is not linear - may be other factors involved") + if __name__ == "__main__": main() From 8c32d7c52300e457dd6c73dfb261c47dd65ae15c Mon Sep 17 00:00:00 2001 From: mmcky Date: Thu, 27 Nov 2025 21:26:09 +1100 Subject: [PATCH 14/19] Add Nsight Systems profiling to CI - Run nsys profile with 1000 iterations if nsys is available - Captures CUDA, NVTX, and OS runtime traces - Uploads .nsys-rep file as artifact for visual analysis - continue-on-error: true so CI doesn't fail if nsys unavailable --- .github/workflows/ci.yml | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 36817cd8..53eec239 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,6 +36,34 @@ jobs: echo "" echo "The diagnostic shows if time scales linearly with iterations," echo "which indicates constant per-iteration CPU-GPU sync overhead." + - name: Nsight Systems Profile (if available) + shell: bash -l {0} + continue-on-error: true + run: | + echo "=== NVIDIA Nsight Systems Profiling ===" + if command -v nsys &> /dev/null; then + echo "nsys found, running profile with 1000 iterations..." + mkdir -p nsight_profiles + nsys profile -o nsight_profiles/lax_scan_trace \ + --trace=cuda,nvtx,osrt \ + --cuda-memory-usage=true \ + --stats=true \ + python scripts/profile_lax_scan.py --nsys -n 1000 + echo "" + echo "Profile saved to nsight_profiles/lax_scan_trace.nsys-rep" + echo "Download artifact and open in Nsight Systems UI to see CPU-GPU sync pattern" + else + echo "nsys not found, skipping Nsight profiling" + echo "Install NVIDIA Nsight Systems to enable this profiling" + fi + - name: Upload Nsight Profile + uses: actions/upload-artifact@v5 + if: success() || failure() + continue-on-error: true + with: + name: nsight-profile + path: nsight_profiles/ + if-no-files-found: ignore # === Benchmark Tests (Bare Metal, Jupyter, Jupyter-Book) === - name: Run Hardware Benchmarks (Bare Metal) shell: bash -l {0} From a623fb7077cb004273071edfc144ef14fd9ec294 Mon Sep 17 00:00:00 2001 From: mmcky Date: Thu, 27 Nov 2025 21:50:08 +1100 Subject: [PATCH 15/19] address @jstac comment --- lectures/jax_intro.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 7d4706e4..e0d1cb0f 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -37,8 +37,6 @@ This lecture provides a short introduction to [Google JAX](https://github.com/ja JAX provides a NumPy-like interface that can leverage GPU acceleration for high-performance numerical computing. -For a more comprehensive discussion of JAX, see [our JAX lecture series](https://jax.quantecon.org/intro.html). - ## JAX as a NumPy Replacement From 1739f51aa0daea8b03214394d107d2affdcc250b Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Fri, 28 Nov 2025 05:04:54 +0900 Subject: [PATCH 16/19] Improve JAX lecture content and pedagogy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Reorganize jax_intro.md to introduce JAX features upfront with clearer structure - Expand JAX introduction with bulleted list of key capabilities (parallelization, JIT, autodiff) - Add explicit GPU performance notes in vmap sections - Enhance vmap explanation with detailed function composition breakdown - Clarify memory efficiency tradeoffs between different vmap approaches 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- lectures/jax_intro.md | 23 +++++++++++++-------- lectures/numpy_vs_numba_vs_jax.md | 33 ++++++++++++++++++++----------- 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index e0d1cb0f..b4114630 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -13,6 +13,18 @@ kernelspec: # JAX +This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax). + +JAX is a high-performance scientific computing library that provides + +* a NumPy-like interface that can automatically parallize across CPUs and GPUs, +* a just-in-time compiler for accelerating a large range of numerical + operations, and +* automatic differentiation. + +Increasingly, JAX also maintains and provides more specialized scientific +computing routines, such as those originally found in SciPy. + In addition to what's in Anaconda, this lecture will need the following libraries: ```{code-cell} ipython3 @@ -33,17 +45,12 @@ Alternatively, if you have your own GPU, you can follow the [instructions](https If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]` ``` -This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax). - -JAX provides a NumPy-like interface that can leverage GPU acceleration for high-performance numerical computing. - - ## JAX as a NumPy Replacement -One of the attractive features of JAX is that, whenever possible, it conforms to -the NumPy API for array operations. +One of the attractive features of JAX is that, whenever possible, its array +processing operations conform to the NumPy API. -This means that, to a large extent, we can use JAX is as a drop-in NumPy replacement. +This means that, in many cases, we can use JAX is as a drop-in NumPy replacement. Let's look at the similarities and differences between JAX and NumPy. diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index 84c79f8f..5bf55991 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -382,23 +382,29 @@ with qe.Timer(precision=8): z_max.block_until_ready() ``` -The execution time is similar to the mesh operation but, by avoiding the large input arrays `x_mesh` and `y_mesh`, -we are using far less memory. +By avoiding the large input arrays `x_mesh` and `y_mesh`, this `vmap` version uses far less memory. -In addition, `vmap` allows us to break vectorization up into stages, which is -often easier to comprehend than the traditional approach. +When run on a CPU, its runtime is similar to that of the meshgrid version. -This will become more obvious when we tackle larger problems. +When run on a GPU, it is usually significantly faster. + +In fact, using `vmap` has another advantage: It allows us to break vectorization up into stages. + +This leads to code that is often easier to comprehend than traditional vectorized code. + +We will investigate these ideas more when we tackle larger problems. ### vmap version 2 We can be still more memory efficient using vmap. -While we avoided large input arrays in the preceding version, +While we avoid large input arrays in the preceding version, we still create the large output array `f(x,y)` before we compute the max. -Let's use a slightly different approach that takes the max to the inside. +Let's try a slightly different approach that takes the max to the inside. + +Because of this change, we never compute the two-dimensional array `f(x,y)`. ```{code-cell} ipython3 @jax.jit @@ -411,14 +417,20 @@ def compute_max_vmap_v2(grid): return jnp.max(f_vec_max(grid)) ``` -Let's try it +Here + +* `f_vec_x_max` computes the max along any given row +* `f_vec_max` is a vectorized version that can compute the max of all rows in parallel. + +We apply this function to all rows and then take the max of the row maxes. + +Let's try it. ```{code-cell} ipython3 with qe.Timer(precision=8): z_max = compute_max_vmap_v2(grid).block_until_ready() ``` - Let's run it again to eliminate compilation time: ```{code-cell} ipython3 @@ -426,8 +438,7 @@ with qe.Timer(precision=8): z_max = compute_max_vmap_v2(grid).block_until_ready() ``` -We don't get much speed gain but we do save some memory. - +If you are running this on a GPU, as well are, you should see another nontrivial speed gain. ### Summary From 350da37fdde6688363a3f7c47502d1589ca508b5 Mon Sep 17 00:00:00 2001 From: mmcky Date: Fri, 28 Nov 2025 08:23:52 +1100 Subject: [PATCH 17/19] Remove benchmark scripts (moved to QuantEcon/benchmarks) - Remove profile_lax_scan.py, benchmark-hardware.py, benchmark-jupyter.ipynb, benchmark-jupyterbook.md - Remove profiling/benchmarking steps from ci.yml - Keep test-jax-install.py for JAX installation verification Benchmark scripts are now maintained in: https://github.com/QuantEcon/benchmarks --- .github/workflows/ci.yml | 92 --------- scripts/benchmark-hardware.py | 336 ------------------------------- scripts/benchmark-jupyter.ipynb | 247 ----------------------- scripts/benchmark-jupyterbook.md | 196 ------------------ scripts/profile_lax_scan.py | 241 ---------------------- 5 files changed, 1112 deletions(-) delete mode 100644 scripts/benchmark-hardware.py delete mode 100644 scripts/benchmark-jupyter.ipynb delete mode 100644 scripts/benchmark-jupyterbook.md delete mode 100644 scripts/profile_lax_scan.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 53eec239..58f69dcc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,98 +25,6 @@ jobs: pip install -U "jax[cuda13]" pip install numpyro python scripts/test-jax-install.py - # === lax.scan GPU Performance Profiling === - - name: Profile lax.scan (GPU vs CPU) - shell: bash -l {0} - run: | - echo "=== lax.scan Performance Profiling ===" - echo "This profiles the known issue with lax.scan on GPU (JAX Issue #2491)" - echo "" - python scripts/profile_lax_scan.py --iterations 100000 --diagnose - echo "" - echo "The diagnostic shows if time scales linearly with iterations," - echo "which indicates constant per-iteration CPU-GPU sync overhead." - - name: Nsight Systems Profile (if available) - shell: bash -l {0} - continue-on-error: true - run: | - echo "=== NVIDIA Nsight Systems Profiling ===" - if command -v nsys &> /dev/null; then - echo "nsys found, running profile with 1000 iterations..." - mkdir -p nsight_profiles - nsys profile -o nsight_profiles/lax_scan_trace \ - --trace=cuda,nvtx,osrt \ - --cuda-memory-usage=true \ - --stats=true \ - python scripts/profile_lax_scan.py --nsys -n 1000 - echo "" - echo "Profile saved to nsight_profiles/lax_scan_trace.nsys-rep" - echo "Download artifact and open in Nsight Systems UI to see CPU-GPU sync pattern" - else - echo "nsys not found, skipping Nsight profiling" - echo "Install NVIDIA Nsight Systems to enable this profiling" - fi - - name: Upload Nsight Profile - uses: actions/upload-artifact@v5 - if: success() || failure() - continue-on-error: true - with: - name: nsight-profile - path: nsight_profiles/ - if-no-files-found: ignore - # === Benchmark Tests (Bare Metal, Jupyter, Jupyter-Book) === - - name: Run Hardware Benchmarks (Bare Metal) - shell: bash -l {0} - run: | - echo "=== Bare Metal Python Script Execution ===" - python scripts/benchmark-hardware.py - mkdir -p benchmark_results - mv benchmark_results_bare_metal.json benchmark_results/ - - name: Run Jupyter Notebook Benchmark (via nbconvert) - shell: bash -l {0} - run: | - echo "=== Jupyter Kernel Execution ===" - cd scripts - jupyter nbconvert --to notebook --execute benchmark-jupyter.ipynb --output benchmark-jupyter-executed.ipynb - echo "Notebook executed successfully" - cd .. - mv scripts/benchmark_results_jupyter.json benchmark_results/ - - name: Run Jupyter-Book Benchmark - shell: bash -l {0} - run: | - echo "=== Jupyter-Book Execution ===" - # Build just the benchmark file using jupyter-book - mkdir -p benchmark_test - cp scripts/benchmark-jupyterbook.md benchmark_test/ - # Create minimal _config.yml - echo "title: Benchmark Test" > benchmark_test/_config.yml - echo "execute:" >> benchmark_test/_config.yml - echo " execute_notebooks: force" >> benchmark_test/_config.yml - # Create minimal _toc.yml - echo "format: jb-book" > benchmark_test/_toc.yml - echo "root: benchmark-jupyterbook" >> benchmark_test/_toc.yml - # Build (run from benchmark_test so JSON is written there) - cd benchmark_test - jb build . --path-output ../benchmark_build/ - cd .. - echo "Jupyter-Book build completed successfully" - # Move JSON results if generated - cp benchmark_test/benchmark_results_jupyterbook.json benchmark_results/ 2>/dev/null || echo "No jupyterbook results" - - name: Collect and Display Benchmark Results - shell: bash -l {0} - run: | - echo "=== Benchmark Results Summary ===" - for f in benchmark_results/*.json; do - echo "--- $f ---" - cat "$f" - echo "" - done - - name: Upload Benchmark Results - uses: actions/upload-artifact@v5 - with: - name: benchmark-results - path: benchmark_results/ - if-no-files-found: warn - name: Install latex dependencies run: | sudo apt-get -qq update diff --git a/scripts/benchmark-hardware.py b/scripts/benchmark-hardware.py deleted file mode 100644 index 12443855..00000000 --- a/scripts/benchmark-hardware.py +++ /dev/null @@ -1,336 +0,0 @@ -""" -Hardware benchmark script for CI runners. -Compares CPU and GPU performance to diagnose slowdowns. -Works on both CPU-only (GitHub Actions) and GPU (RunsOn) runners. -""" -import time -import platform -import os -import json -from datetime import datetime - -# Global results dictionary -RESULTS = { - "pathway": "bare_metal", - "timestamp": datetime.now().isoformat(), - "system": {}, - "benchmarks": {} -} - -def get_cpu_info(): - """Get CPU information.""" - print("=" * 60) - print("SYSTEM INFORMATION") - print("=" * 60) - print(f"Platform: {platform.platform()}") - print(f"Processor: {platform.processor()}") - print(f"Python: {platform.python_version()}") - - RESULTS["system"]["platform"] = platform.platform() - RESULTS["system"]["processor"] = platform.processor() - RESULTS["system"]["python"] = platform.python_version() - RESULTS["system"]["cpu_count"] = os.cpu_count() - - # Try to get CPU model - cpu_model = None - cpu_mhz = None - try: - with open('/proc/cpuinfo', 'r') as f: - for line in f: - if 'model name' in line: - cpu_model = line.split(':')[1].strip() - print(f"CPU Model: {cpu_model}") - break - except: - pass - - # Try to get CPU frequency - try: - with open('/proc/cpuinfo', 'r') as f: - for line in f: - if 'cpu MHz' in line: - cpu_mhz = line.split(':')[1].strip() - print(f"CPU MHz: {cpu_mhz}") - break - except: - pass - - RESULTS["system"]["cpu_model"] = cpu_model - RESULTS["system"]["cpu_mhz"] = cpu_mhz - - # CPU count - print(f"CPU Count: {os.cpu_count()}") - - # Check for GPU - gpu_info = None - try: - import subprocess - result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv,noheader'], - capture_output=True, text=True, timeout=5) - if result.returncode == 0: - gpu_info = result.stdout.strip() - print(f"GPU: {gpu_info}") - else: - print("GPU: None detected") - except: - print("GPU: None detected (nvidia-smi not available)") - - RESULTS["system"]["gpu"] = gpu_info - print() - -def benchmark_cpu_pure_python(): - """Pure Python CPU benchmark.""" - print("=" * 60) - print("CPU BENCHMARK: Pure Python") - print("=" * 60) - - results = {} - - # Integer computation - start = time.perf_counter() - total = sum(i * i for i in range(10_000_000)) - elapsed = time.perf_counter() - start - print(f"Integer sum (10M iterations): {elapsed:.3f} seconds") - results["integer_sum_10m"] = elapsed - - # Float computation - start = time.perf_counter() - total = 0.0 - for i in range(1_000_000): - total += (i * 0.1) ** 0.5 - elapsed = time.perf_counter() - start - print(f"Float sqrt (1M iterations): {elapsed:.3f} seconds") - results["float_sqrt_1m"] = elapsed - print() - - RESULTS["benchmarks"]["pure_python"] = results - -def benchmark_cpu_numpy(): - """NumPy CPU benchmark.""" - import numpy as np - - print("=" * 60) - print("CPU BENCHMARK: NumPy") - print("=" * 60) - - results = {} - - # Matrix multiplication - n = 3000 - A = np.random.randn(n, n) - B = np.random.randn(n, n) - - start = time.perf_counter() - C = A @ B - elapsed = time.perf_counter() - start - print(f"Matrix multiply ({n}x{n}): {elapsed:.3f} seconds") - results["matmul_3000x3000"] = elapsed - - # Element-wise operations - x = np.random.randn(50_000_000) - - start = time.perf_counter() - y = np.cos(x**2) + np.sin(x) - elapsed = time.perf_counter() - start - print(f"Element-wise ops (50M elements): {elapsed:.3f} seconds") - results["elementwise_50m"] = elapsed - print() - - RESULTS["benchmarks"]["numpy"] = results - -def benchmark_gpu_jax(): - """JAX benchmark (GPU if available, otherwise CPU).""" - try: - import jax - import jax.numpy as jnp - - devices = jax.devices() - default_backend = jax.default_backend() - - # Check if GPU is available - has_gpu = any('cuda' in str(d).lower() or 'gpu' in str(d).lower() for d in devices) - - print("=" * 60) - if has_gpu: - print("JAX BENCHMARK: GPU") - else: - print("JAX BENCHMARK: CPU (no GPU detected)") - print("=" * 60) - - print(f"JAX devices: {devices}") - print(f"Default backend: {default_backend}") - print(f"GPU Available: {has_gpu}") - print() - - results = { - "backend": default_backend, - "has_gpu": has_gpu, - "devices": str(devices) - } - - # Warm-up JIT compilation - print("Warming up JIT compilation...") - n = 1000 - key = jax.random.PRNGKey(0) - A = jax.random.normal(key, (n, n)) - B = jax.random.normal(key, (n, n)) - - @jax.jit - def matmul(a, b): - return jnp.dot(a, b) - - # Warm-up run (includes compilation) - start = time.perf_counter() - C = matmul(A, B).block_until_ready() - warmup_time = time.perf_counter() - start - print(f"Warm-up (includes JIT compile, {n}x{n}): {warmup_time:.3f} seconds") - results["matmul_1000x1000_warmup"] = warmup_time - - # Actual benchmark (compiled) - start = time.perf_counter() - C = matmul(A, B).block_until_ready() - elapsed = time.perf_counter() - start - print(f"Matrix multiply compiled ({n}x{n}): {elapsed:.3f} seconds") - results["matmul_1000x1000_compiled"] = elapsed - - # Larger matrix - n = 3000 - A = jax.random.normal(key, (n, n)) - B = jax.random.normal(key, (n, n)) - - # Warm-up for new size - start = time.perf_counter() - C = matmul(A, B).block_until_ready() - warmup_time = time.perf_counter() - start - print(f"Warm-up (recompile for {n}x{n}): {warmup_time:.3f} seconds") - results["matmul_3000x3000_warmup"] = warmup_time - - # Benchmark compiled - start = time.perf_counter() - C = matmul(A, B).block_until_ready() - elapsed = time.perf_counter() - start - print(f"Matrix multiply compiled ({n}x{n}): {elapsed:.3f} seconds") - results["matmul_3000x3000_compiled"] = elapsed - - # Element-wise GPU benchmark - x = jax.random.normal(key, (50_000_000,)) - - @jax.jit - def elementwise_ops(x): - return jnp.cos(x**2) + jnp.sin(x) - - # Warm-up - start = time.perf_counter() - y = elementwise_ops(x).block_until_ready() - warmup_time = time.perf_counter() - start - print(f"Element-wise warm-up (50M): {warmup_time:.3f} seconds") - results["elementwise_50m_warmup"] = warmup_time - - # Compiled - start = time.perf_counter() - y = elementwise_ops(x).block_until_ready() - elapsed = time.perf_counter() - start - print(f"Element-wise compiled (50M): {elapsed:.3f} seconds") - results["elementwise_50m_compiled"] = elapsed - - print() - RESULTS["benchmarks"]["jax"] = results - - except ImportError as e: - print(f"JAX not available: {e}") - RESULTS["benchmarks"]["jax"] = {"error": str(e)} - except Exception as e: - print(f"JAX benchmark failed: {e}") - RESULTS["benchmarks"]["jax"] = {"error": str(e)} - -def benchmark_numba(): - """Numba CPU benchmark.""" - try: - import numba - import numpy as np - - print("=" * 60) - print("CPU BENCHMARK: Numba") - print("=" * 60) - - results = {} - - @numba.jit(nopython=True) - def numba_sum(n): - total = 0 - for i in range(n): - total += i * i - return total - - # Warm-up (compilation) - start = time.perf_counter() - result = numba_sum(10_000_000) - warmup_time = time.perf_counter() - start - print(f"Integer sum warm-up (includes compile): {warmup_time:.3f} seconds") - results["integer_sum_10m_warmup"] = warmup_time - - # Compiled run - start = time.perf_counter() - result = numba_sum(10_000_000) - elapsed = time.perf_counter() - start - print(f"Integer sum compiled (10M): {elapsed:.3f} seconds") - results["integer_sum_10m_compiled"] = elapsed - - @numba.jit(nopython=True, parallel=True) - def numba_parallel_sum(arr): - total = 0.0 - for i in numba.prange(len(arr)): - total += arr[i] ** 2 - return total - - arr = np.random.randn(50_000_000) - - # Warm-up - start = time.perf_counter() - result = numba_parallel_sum(arr) - warmup_time = time.perf_counter() - start - print(f"Parallel sum warm-up (50M): {warmup_time:.3f} seconds") - results["parallel_sum_50m_warmup"] = warmup_time - - # Compiled - start = time.perf_counter() - result = numba_parallel_sum(arr) - elapsed = time.perf_counter() - start - print(f"Parallel sum compiled (50M): {elapsed:.3f} seconds") - results["parallel_sum_50m_compiled"] = elapsed - - print() - RESULTS["benchmarks"]["numba"] = results - - except ImportError as e: - print(f"Numba not available: {e}") - RESULTS["benchmarks"]["numba"] = {"error": str(e)} - except Exception as e: - print(f"Numba benchmark failed: {e}") - RESULTS["benchmarks"]["numba"] = {"error": str(e)} - - -def save_results(output_path="benchmark_results_bare_metal.json"): - """Save benchmark results to JSON file.""" - with open(output_path, 'w') as f: - json.dump(RESULTS, f, indent=2) - print(f"\nResults saved to: {output_path}") - - -if __name__ == "__main__": - print("\n" + "=" * 60) - print("HARDWARE BENCHMARK FOR CI RUNNER") - print("=" * 60 + "\n") - - get_cpu_info() - benchmark_cpu_pure_python() - benchmark_cpu_numpy() - benchmark_numba() - benchmark_gpu_jax() - - # Save results to JSON - save_results("benchmark_results_bare_metal.json") - - print("=" * 60) - print("BENCHMARK COMPLETE") - print("=" * 60) diff --git a/scripts/benchmark-jupyter.ipynb b/scripts/benchmark-jupyter.ipynb deleted file mode 100644 index 909b8fe5..00000000 --- a/scripts/benchmark-jupyter.ipynb +++ /dev/null @@ -1,247 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# JAX Performance Benchmark - Jupyter Kernel Execution\n", - "\n", - "This notebook tests JAX performance when executed through a Jupyter kernel.\n", - "Compare results with direct script and jupyter-book execution." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import time\n", - "import platform\n", - "import os\n", - "import json\n", - "from datetime import datetime\n", - "\n", - "# Initialize results dictionary\n", - "RESULTS = {\n", - " \"pathway\": \"jupyter_kernel\",\n", - " \"timestamp\": datetime.now().isoformat(),\n", - " \"system\": {\n", - " \"platform\": platform.platform(),\n", - " \"python\": platform.python_version(),\n", - " \"cpu_count\": os.cpu_count()\n", - " },\n", - " \"benchmarks\": {}\n", - "}\n", - "\n", - "print(\"=\" * 60)\n", - "print(\"JUPYTER KERNEL EXECUTION BENCHMARK\")\n", - "print(\"=\" * 60)\n", - "print(f\"Platform: {platform.platform()}\")\n", - "print(f\"Python: {platform.python_version()}\")\n", - "print(f\"CPU Count: {os.cpu_count()}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Import JAX and check devices\n", - "import jax\n", - "import jax.numpy as jnp\n", - "\n", - "devices = jax.devices()\n", - "default_backend = jax.default_backend()\n", - "has_gpu = any('cuda' in str(d).lower() or 'gpu' in str(d).lower() for d in devices)\n", - "\n", - "print(f\"JAX devices: {devices}\")\n", - "print(f\"Default backend: {default_backend}\")\n", - "print(f\"GPU Available: {has_gpu}\")\n", - "\n", - "RESULTS[\"system\"][\"jax_backend\"] = default_backend\n", - "RESULTS[\"system\"][\"has_gpu\"] = has_gpu\n", - "RESULTS[\"system\"][\"jax_devices\"] = str(devices)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Define JIT-compiled function\n", - "@jax.jit\n", - "def matmul(a, b):\n", - " return jnp.dot(a, b)\n", - "\n", - "print(\"matmul function defined with @jax.jit\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Benchmark 1: Small matrix (1000x1000) - includes JIT compilation\n", - "print(\"\\n\" + \"=\" * 60)\n", - "print(\"BENCHMARK 1: Small Matrix (1000x1000)\")\n", - "print(\"=\" * 60)\n", - "\n", - "n = 1000\n", - "key = jax.random.PRNGKey(0)\n", - "A = jax.random.normal(key, (n, n))\n", - "B = jax.random.normal(key, (n, n))\n", - "\n", - "# Warm-up run (includes compilation)\n", - "start = time.perf_counter()\n", - "C = matmul(A, B).block_until_ready()\n", - "warmup_time = time.perf_counter() - start\n", - "print(f\"Warm-up (includes JIT compile): {warmup_time:.3f} seconds\")\n", - "\n", - "# Compiled run\n", - "start = time.perf_counter()\n", - "C = matmul(A, B).block_until_ready()\n", - "compiled_time = time.perf_counter() - start\n", - "print(f\"Compiled execution: {compiled_time:.3f} seconds\")\n", - "\n", - "RESULTS[\"benchmarks\"][\"matmul_1000x1000_warmup\"] = warmup_time\n", - "RESULTS[\"benchmarks\"][\"matmul_1000x1000_compiled\"] = compiled_time" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Benchmark 2: Large matrix (3000x3000) - triggers recompilation\n", - "print(\"\\n\" + \"=\" * 60)\n", - "print(\"BENCHMARK 2: Large Matrix (3000x3000)\")\n", - "print(\"=\" * 60)\n", - "\n", - "n = 3000\n", - "A = jax.random.normal(key, (n, n))\n", - "B = jax.random.normal(key, (n, n))\n", - "\n", - "# Warm-up run (recompilation for new size)\n", - "start = time.perf_counter()\n", - "C = matmul(A, B).block_until_ready()\n", - "warmup_time = time.perf_counter() - start\n", - "print(f\"Warm-up (recompile for new size): {warmup_time:.3f} seconds\")\n", - "\n", - "# Compiled run\n", - "start = time.perf_counter()\n", - "C = matmul(A, B).block_until_ready()\n", - "compiled_time = time.perf_counter() - start\n", - "print(f\"Compiled execution: {compiled_time:.3f} seconds\")\n", - "\n", - "RESULTS[\"benchmarks\"][\"matmul_3000x3000_warmup\"] = warmup_time\n", - "RESULTS[\"benchmarks\"][\"matmul_3000x3000_compiled\"] = compiled_time" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Benchmark 3: Element-wise operations (50M elements)\n", - "print(\"\\n\" + \"=\" * 60)\n", - "print(\"BENCHMARK 3: Element-wise Operations (50M elements)\")\n", - "print(\"=\" * 60)\n", - "\n", - "@jax.jit\n", - "def elementwise_ops(x):\n", - " return jnp.cos(x**2) + jnp.sin(x)\n", - "\n", - "x = jax.random.normal(key, (50_000_000,))\n", - "\n", - "# Warm-up\n", - "start = time.perf_counter()\n", - "y = elementwise_ops(x).block_until_ready()\n", - "warmup_time = time.perf_counter() - start\n", - "print(f\"Warm-up (includes JIT compile): {warmup_time:.3f} seconds\")\n", - "\n", - "# Compiled\n", - "start = time.perf_counter()\n", - "y = elementwise_ops(x).block_until_ready()\n", - "compiled_time = time.perf_counter() - start\n", - "print(f\"Compiled execution: {compiled_time:.3f} seconds\")\n", - "\n", - "RESULTS[\"benchmarks\"][\"elementwise_50m_warmup\"] = warmup_time\n", - "RESULTS[\"benchmarks\"][\"elementwise_50m_compiled\"] = compiled_time" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Benchmark 4: Multiple small operations (simulates lecture cells)\n", - "print(\"\\n\" + \"=\" * 60)\n", - "print(\"BENCHMARK 4: Multiple Small Operations (lecture simulation)\")\n", - "print(\"=\" * 60)\n", - "\n", - "total_start = time.perf_counter()\n", - "multi_results = {}\n", - "\n", - "# Simulate multiple cell executions with different operations\n", - "for i, size in enumerate([100, 500, 1000, 2000, 3000]):\n", - " @jax.jit\n", - " def compute(a, b):\n", - " return jnp.dot(a, b) + jnp.sum(a)\n", - " \n", - " A = jax.random.normal(key, (size, size))\n", - " B = jax.random.normal(key, (size, size))\n", - " \n", - " start = time.perf_counter()\n", - " result = compute(A, B).block_until_ready()\n", - " elapsed = time.perf_counter() - start\n", - " print(f\" Size {size}x{size}: {elapsed:.3f} seconds\")\n", - " multi_results[f\"size_{size}x{size}\"] = elapsed\n", - "\n", - "total_time = time.perf_counter() - total_start\n", - "print(f\"\\nTotal time for all operations: {total_time:.3f} seconds\")\n", - "\n", - "RESULTS[\"benchmarks\"][\"multi_ops\"] = multi_results\n", - "RESULTS[\"benchmarks\"][\"multi_ops_total\"] = total_time" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Save results to JSON file\n", - "output_path = \"benchmark_results_jupyter.json\"\n", - "with open(output_path, 'w') as f:\n", - " json.dump(RESULTS, f, indent=2)\n", - "\n", - "print(\"\\n\" + \"=\" * 60)\n", - "print(\"JUPYTER KERNEL EXECUTION BENCHMARK COMPLETE\")\n", - "print(\"=\" * 60)\n", - "print(f\"\\nResults saved to: {output_path}\")\n", - "print(\"\\nJSON Results:\")\n", - "print(json.dumps(RESULTS, indent=2))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.13.0" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/scripts/benchmark-jupyterbook.md b/scripts/benchmark-jupyterbook.md deleted file mode 100644 index 162613c8..00000000 --- a/scripts/benchmark-jupyterbook.md +++ /dev/null @@ -1,196 +0,0 @@ ---- -jupytext: - text_representation: - extension: .md - format_name: myst - format_version: 0.13 -kernelspec: - display_name: Python 3 (ipykernel) - language: python - name: python3 ---- - -# JAX Performance Benchmark - Jupyter Book Execution - -This file tests JAX performance when executed through Jupyter Book's notebook execution. -Compare results with direct script and nbconvert execution. - -```{code-cell} ipython3 -import time -import platform -import os -import json -from datetime import datetime - -# Initialize results dictionary -RESULTS = { - "pathway": "jupyter_book", - "timestamp": datetime.now().isoformat(), - "system": { - "platform": platform.platform(), - "python": platform.python_version(), - "cpu_count": os.cpu_count() - }, - "benchmarks": {} -} - -print("=" * 60) -print("JUPYTER BOOK EXECUTION BENCHMARK") -print("=" * 60) -print(f"Platform: {platform.platform()}") -print(f"Python: {platform.python_version()}") -print(f"CPU Count: {os.cpu_count()}") -``` - -```{code-cell} ipython3 -# Import JAX and check devices -import jax -import jax.numpy as jnp - -devices = jax.devices() -default_backend = jax.default_backend() -has_gpu = any('cuda' in str(d).lower() or 'gpu' in str(d).lower() for d in devices) - -print(f"JAX devices: {devices}") -print(f"Default backend: {default_backend}") -print(f"GPU Available: {has_gpu}") - -RESULTS["system"]["jax_backend"] = default_backend -RESULTS["system"]["has_gpu"] = has_gpu -RESULTS["system"]["jax_devices"] = str(devices) -``` - -```{code-cell} ipython3 -# Define JIT-compiled function -@jax.jit -def matmul(a, b): - return jnp.dot(a, b) - -print("matmul function defined with @jax.jit") -``` - -```{code-cell} ipython3 -# Benchmark 1: Small matrix (1000x1000) - includes JIT compilation -print("\n" + "=" * 60) -print("BENCHMARK 1: Small Matrix (1000x1000)") -print("=" * 60) - -n = 1000 -key = jax.random.PRNGKey(0) -A = jax.random.normal(key, (n, n)) -B = jax.random.normal(key, (n, n)) - -# Warm-up run (includes compilation) -start = time.perf_counter() -C = matmul(A, B).block_until_ready() -warmup_time = time.perf_counter() - start -print(f"Warm-up (includes JIT compile): {warmup_time:.3f} seconds") - -# Compiled run -start = time.perf_counter() -C = matmul(A, B).block_until_ready() -compiled_time = time.perf_counter() - start -print(f"Compiled execution: {compiled_time:.3f} seconds") - -RESULTS["benchmarks"]["matmul_1000x1000_warmup"] = warmup_time -RESULTS["benchmarks"]["matmul_1000x1000_compiled"] = compiled_time -``` - -```{code-cell} ipython3 -# Benchmark 2: Large matrix (3000x3000) - triggers recompilation -print("\n" + "=" * 60) -print("BENCHMARK 2: Large Matrix (3000x3000)") -print("=" * 60) - -n = 3000 -A = jax.random.normal(key, (n, n)) -B = jax.random.normal(key, (n, n)) - -# Warm-up run (recompilation for new size) -start = time.perf_counter() -C = matmul(A, B).block_until_ready() -warmup_time = time.perf_counter() - start -print(f"Warm-up (recompile for new size): {warmup_time:.3f} seconds") - -# Compiled run -start = time.perf_counter() -C = matmul(A, B).block_until_ready() -compiled_time = time.perf_counter() - start -print(f"Compiled execution: {compiled_time:.3f} seconds") - -RESULTS["benchmarks"]["matmul_3000x3000_warmup"] = warmup_time -RESULTS["benchmarks"]["matmul_3000x3000_compiled"] = compiled_time -``` - -```{code-cell} ipython3 -# Benchmark 3: Element-wise operations (50M elements) -print("\n" + "=" * 60) -print("BENCHMARK 3: Element-wise Operations (50M elements)") -print("=" * 60) - -@jax.jit -def elementwise_ops(x): - return jnp.cos(x**2) + jnp.sin(x) - -x = jax.random.normal(key, (50_000_000,)) - -# Warm-up -start = time.perf_counter() -y = elementwise_ops(x).block_until_ready() -warmup_time = time.perf_counter() - start -print(f"Warm-up (includes JIT compile): {warmup_time:.3f} seconds") - -# Compiled -start = time.perf_counter() -y = elementwise_ops(x).block_until_ready() -compiled_time = time.perf_counter() - start -print(f"Compiled execution: {compiled_time:.3f} seconds") - -RESULTS["benchmarks"]["elementwise_50m_warmup"] = warmup_time -RESULTS["benchmarks"]["elementwise_50m_compiled"] = compiled_time -``` - -```{code-cell} ipython3 -# Benchmark 4: Multiple small operations (simulates lecture cells) -print("\n" + "=" * 60) -print("BENCHMARK 4: Multiple Small Operations (lecture simulation)") -print("=" * 60) - -total_start = time.perf_counter() -multi_results = {} - -# Simulate multiple cell executions with different operations -for i, size in enumerate([100, 500, 1000, 2000, 3000]): - @jax.jit - def compute(a, b): - return jnp.dot(a, b) + jnp.sum(a) - - A = jax.random.normal(key, (size, size)) - B = jax.random.normal(key, (size, size)) - - start = time.perf_counter() - result = compute(A, B).block_until_ready() - elapsed = time.perf_counter() - start - print(f" Size {size}x{size}: {elapsed:.3f} seconds") - multi_results[f"size_{size}x{size}"] = elapsed - -total_time = time.perf_counter() - total_start -print(f"\nTotal time for all operations: {total_time:.3f} seconds") - -RESULTS["benchmarks"]["multi_ops"] = multi_results -RESULTS["benchmarks"]["multi_ops_total"] = total_time -``` - -```{code-cell} ipython3 -# Save results to JSON file -output_path = "benchmark_results_jupyterbook.json" -with open(output_path, 'w') as f: - json.dump(RESULTS, f, indent=2) - -print("\n" + "=" * 60) -print("JUPYTER BOOK EXECUTION BENCHMARK COMPLETE") -print("=" * 60) -print(f"\nResults saved to: {output_path}") -print("\nJSON Results:") -print(json.dumps(RESULTS, indent=2)) -``` diff --git a/scripts/profile_lax_scan.py b/scripts/profile_lax_scan.py deleted file mode 100644 index 9930afe5..00000000 --- a/scripts/profile_lax_scan.py +++ /dev/null @@ -1,241 +0,0 @@ -""" -Profile lax.scan performance on GPU vs CPU to investigate synchronization overhead. - -This script helps diagnose why lax.scan with many lightweight iterations -performs poorly on GPU (81s) compared to CPU (0.06s). - -Usage: - # Basic timing comparison - python profile_lax_scan.py - - # With NVIDIA Nsight Systems (requires nsys installed) - nsys profile -o lax_scan_profile --trace=cuda,nvtx python profile_lax_scan.py --nsys - - # With JAX profiler (view with TensorBoard) - python profile_lax_scan.py --jax-profile - - # With XLA debug dumps - python profile_lax_scan.py --xla-dump - -Requirements: - - JAX with CUDA support: pip install jax[cuda12] - - For Nsight: NVIDIA Nsight Systems (https://developer.nvidia.com/nsight-systems) - - For TensorBoard: pip install tensorboard tensorboard-plugin-profile -""" - -import argparse -import os -import time -from functools import partial - -def setup_xla_dump(dump_dir="/tmp/xla_dump"): - """Enable XLA debug dumps before importing JAX.""" - os.makedirs(dump_dir, exist_ok=True) - os.environ["XLA_FLAGS"] = f"--xla_dump_to={dump_dir} --xla_dump_hlo_as_text" - print(f"XLA dumps will be written to: {dump_dir}") - -def setup_cuda_logging(): - """Enable CUDA/XLA logging to see sync patterns.""" - # These may help reveal synchronization behavior - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" # Show all TF/XLA logs - os.environ["XLA_FLAGS"] = os.environ.get("XLA_FLAGS", "") + " --xla_gpu_cuda_data_dir=/usr/local/cuda" - print("CUDA/XLA logging enabled") - -def main(): - parser = argparse.ArgumentParser(description="Profile lax.scan GPU performance") - parser.add_argument("--nsys", action="store_true", - help="Run in Nsight Systems compatible mode (smaller n)") - parser.add_argument("--jax-profile", action="store_true", - help="Enable JAX profiler (view with TensorBoard)") - parser.add_argument("--xla-dump", action="store_true", - help="Dump XLA HLO for analysis") - parser.add_argument("--verbose", action="store_true", - help="Enable verbose CUDA/XLA logging") - parser.add_argument("--diagnose", action="store_true", - help="Run diagnostic to demonstrate sync overhead") - parser.add_argument("-n", "--iterations", type=int, default=10_000_000, - dest="n", help="Number of iterations (default: 10M)") - parser.add_argument("--profile-dir", type=str, default="/tmp/jax-trace", - help="Directory for JAX profile output") - args = parser.parse_args() - - # Setup XLA dump before importing JAX - if args.xla_dump: - setup_xla_dump() - - if args.verbose: - setup_cuda_logging() - - # Now import JAX - import jax - import jax.numpy as jnp - from jax import lax - - print("=" * 60) - print("lax.scan GPU Performance Profiling") - print("=" * 60) - - # Show device info - print(f"\nJAX version: {jax.__version__}") - print(f"Available devices: {jax.devices()}") - print(f"Default device: {jax.devices()[0]}") - - # Reduce n for Nsight profiling to keep trace manageable - n = 100_000 if args.nsys else args.n - print(f"\nIterations (n): {n:,}") - - # Define the functions - @partial(jax.jit, static_argnums=(1,)) - def qm_jax_default(x0, n, α=4.0): - """lax.scan on default device (GPU if available).""" - def update(x, t): - x_new = α * x * (1 - x) - return x_new, x_new - _, x = lax.scan(update, x0, jnp.arange(n)) - return jnp.concatenate([jnp.array([x0]), x]) - - cpu = jax.devices("cpu")[0] - - @partial(jax.jit, static_argnums=(1,), device=cpu) - def qm_jax_cpu(x0, n, α=4.0): - """lax.scan forced to CPU.""" - def update(x, t): - x_new = α * x * (1 - x) - return x_new, x_new - _, x = lax.scan(update, x0, jnp.arange(n)) - return jnp.concatenate([jnp.array([x0]), x]) - - # Warm up (compilation) - print("\n--- Compilation (warm-up) ---") - print("Compiling default device version...", end=" ", flush=True) - t0 = time.perf_counter() - _ = qm_jax_default(0.1, n).block_until_ready() - print(f"done ({time.perf_counter() - t0:.2f}s)") - - print("Compiling CPU version...", end=" ", flush=True) - t0 = time.perf_counter() - _ = qm_jax_cpu(0.1, n).block_until_ready() - print(f"done ({time.perf_counter() - t0:.2f}s)") - - # Profile with JAX profiler if requested - if args.jax_profile: - print(f"\n--- JAX Profiler (output: {args.profile_dir}) ---") - os.makedirs(args.profile_dir, exist_ok=True) - - jax.profiler.start_trace(args.profile_dir) - - # Run both versions while profiling - print("Profiling default device version...") - result_default = qm_jax_default(0.1, n).block_until_ready() - - print("Profiling CPU version...") - result_cpu = qm_jax_cpu(0.1, n).block_until_ready() - - jax.profiler.stop_trace() - print(f"\nProfile saved. View with:") - print(f" tensorboard --logdir={args.profile_dir}") - - # Timing runs - print("\n--- Timing Runs (post-compilation) ---") - - # Default device (GPU if available) - print(f"\nDefault device ({jax.devices()[0]}):") - times_default = [] - for i in range(3): - t0 = time.perf_counter() - result = qm_jax_default(0.1, n).block_until_ready() - elapsed = time.perf_counter() - t0 - times_default.append(elapsed) - print(f" Run {i+1}: {elapsed:.6f}s") - - # CPU - print(f"\nCPU (forced with device=cpu):") - times_cpu = [] - for i in range(3): - t0 = time.perf_counter() - result = qm_jax_cpu(0.1, n).block_until_ready() - elapsed = time.perf_counter() - t0 - times_cpu.append(elapsed) - print(f" Run {i+1}: {elapsed:.6f}s") - - # Summary - print("\n" + "=" * 60) - print("SUMMARY") - print("=" * 60) - avg_default = sum(times_default) / len(times_default) - avg_cpu = sum(times_cpu) / len(times_cpu) - print(f"Default device avg: {avg_default:.6f}s") - print(f"CPU avg: {avg_cpu:.6f}s") - print(f"Ratio (default/cpu): {avg_default/avg_cpu:.1f}x") - - if avg_default > avg_cpu * 10: - print("\n⚠️ GPU is significantly slower than CPU!") - print(" This confirms the lax.scan synchronization overhead issue.") - elif avg_default < avg_cpu: - print("\n✓ GPU is faster (unexpected for this workload)") - else: - print("\n~ Performance is similar") - - if args.xla_dump: - print(f"\nXLA dumps written to /tmp/xla_dump/") - print("Look for .txt files with HLO representation") - - if args.nsys: - print("\nNsight Systems trace will be saved as lax_scan_profile.nsys-rep") - print("View with: nsys-ui lax_scan_profile.nsys-rep") - - # Diagnostic: demonstrate sync overhead by showing time scaling - if args.diagnose: - print("\n" + "=" * 60) - print("DIAGNOSTIC: Per-iteration Sync Overhead Analysis") - print("=" * 60) - print("\nIf there's a CPU-GPU sync per iteration, time should scale") - print("linearly with iteration count (not with compute work).\n") - - # Test different iteration counts - test_ns = [1000, 5000, 10000, 50000, 100000] - - print("Iteration Count | GPU Time (s) | Time/Iter (µs) | Expected if O(n)") - print("-" * 70) - - gpu_times = [] - for test_n in test_ns: - # Define fresh function for this n - @partial(jax.jit, static_argnums=(1,)) - def qm_test(x0, n, α=4.0): - def update(x, t): - return α * x * (1 - x), α * x * (1 - x) - _, x = lax.scan(update, x0, jnp.arange(n)) - return jnp.concatenate([jnp.array([x0]), x]) - - # Compile - _ = qm_test(0.1, test_n).block_until_ready() - - # Time - t0 = time.perf_counter() - _ = qm_test(0.1, test_n).block_until_ready() - elapsed = time.perf_counter() - t0 - gpu_times.append(elapsed) - - time_per_iter = (elapsed / test_n) * 1_000_000 # microseconds - expected = gpu_times[0] * (test_n / test_ns[0]) if gpu_times else elapsed - - print(f"{test_n:>15,} | {elapsed:>12.6f} | {time_per_iter:>14.2f} | {expected:.6f}") - - # Calculate if time scales linearly (indicating per-iteration overhead) - ratio_1k_to_100k = gpu_times[-1] / gpu_times[0] - expected_ratio = test_ns[-1] / test_ns[0] # 100x if linear - - print(f"\nScaling analysis:") - print(f" Time ratio (100k/1k iterations): {ratio_1k_to_100k:.1f}x") - print(f" Expected if linear O(n): {expected_ratio:.1f}x") - - if 0.5 * expected_ratio < ratio_1k_to_100k < 2.0 * expected_ratio: - print("\n✓ Time scales ~linearly with iterations!") - print(" This indicates constant per-iteration overhead (CPU-GPU sync).") - print(f" Estimated sync overhead: ~{(gpu_times[0]/test_ns[0])*1e6:.1f} µs per iteration") - else: - print("\n? Scaling is not linear - may be other factors involved") - -if __name__ == "__main__": - main() From e2939c2aaeb62e8e3d7f53af38df313c6213505b Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Fri, 28 Nov 2025 08:38:16 +1100 Subject: [PATCH 18/19] Update lectures/numpy_vs_numba_vs_jax.md --- lectures/numpy_vs_numba_vs_jax.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index 5bf55991..c7e8d4c8 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -438,7 +438,7 @@ with qe.Timer(precision=8): z_max = compute_max_vmap_v2(grid).block_until_ready() ``` -If you are running this on a GPU, as well are, you should see another nontrivial speed gain. +If you are running this on a GPU, as we are, you should see another nontrivial speed gain. ### Summary From 56047ab147a934db7e7338220adddd4bd8c92124 Mon Sep 17 00:00:00 2001 From: mmcky Date: Fri, 28 Nov 2025 08:47:08 +1100 Subject: [PATCH 19/19] Add GPU and JAX hardware details to status page - Add nvidia-smi output to show GPU availability - Add JAX backend check to confirm GPU usage - Matches format used in lecture-python.myst --- lectures/status.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/lectures/status.md b/lectures/status.md index 3ada25f0..2ec414c4 100644 --- a/lectures/status.md +++ b/lectures/status.md @@ -31,4 +31,18 @@ and the following package versions ```{code-cell} ipython :tags: [hide-output] !conda list +``` + +This lecture series has access to the following GPU + +```{code-cell} ipython +!nvidia-smi +``` + +You can check the backend used by JAX using: + +```{code-cell} ipython3 +import jax +# Check if JAX is using GPU +print(f"JAX backend: {jax.devices()[0].platform}") ``` \ No newline at end of file