Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
2e1684b
Enable RunsOn GPU support for lecture builds
mmcky Nov 27, 2025
1671eb7
Merge branch 'main' into feature/runson-gpu-support
mmcky Nov 27, 2025
849cf16
DOC: Update JAX lectures with GPU admonition and narrative
mmcky Nov 27, 2025
0b6a567
DEBUG: Add hardware benchmark script to diagnose performance
mmcky Nov 27, 2025
6d1b9c3
Add multi-pathway benchmark tests (bare metal, Jupyter, jupyter-book)
mmcky Nov 27, 2025
d129f79
Fix: Add content to benchmark-jupyter.ipynb (was empty)
mmcky Nov 27, 2025
2da4e0c
Fix: Add benchmark content to benchmark-jupyter.ipynb
mmcky Nov 27, 2025
2bda114
Add JSON output to benchmarks and upload as artifacts
mmcky Nov 27, 2025
922b24c
Fix syntax errors in benchmark-hardware.py
mmcky Nov 27, 2025
10627ef
Sync benchmark scripts with CPU branch for comparable results
mmcky Nov 27, 2025
54b2b34
ENH: Force lax.scan sequential operation to run on CPU
mmcky Nov 27, 2025
6bf345a
update note
HumphreyYang Nov 27, 2025
8fbb9a7
Add lax.scan profiler to CI for GPU debugging
mmcky Nov 27, 2025
1bfbaf9
Add diagnostic mode to lax.scan profiler
mmcky Nov 27, 2025
8c32d7c
Add Nsight Systems profiling to CI
mmcky Nov 27, 2025
a623fb7
address @jstac comment
mmcky Nov 27, 2025
1739f51
Improve JAX lecture content and pedagogy
jstac Nov 27, 2025
350da37
Remove benchmark scripts (moved to QuantEcon/benchmarks)
mmcky Nov 27, 2025
e2939c2
Update lectures/numpy_vs_numba_vs_jax.md
mmcky Nov 27, 2025
56047ab
Add GPU and JAX hardware details to status page
mmcky Nov 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion .github/workflows/cache.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
workflow_dispatch:
jobs:
cache:
runs-on: ubuntu-latest
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
steps:
- uses: actions/checkout@v6
- name: Setup Anaconda
Expand All @@ -18,6 +18,16 @@ jobs:
python-version: "3.13"
environment-file: environment.yml
activate-environment: quantecon
- name: Install JAX and Numpyro
shell: bash -l {0}
run: |
pip install -U "jax[cuda13]"
pip install numpyro
python scripts/test-jax-install.py
- name: Check nvidia drivers
shell: bash -l {0}
run: |
nvidia-smi
- name: Build HTML
shell: bash -l {0}
run: |
Expand Down
103 changes: 102 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Build Project [using jupyter-book]
on: [pull_request]
jobs:
preview:
runs-on: ubuntu-latest
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
steps:
- uses: actions/checkout@v6
with:
Expand All @@ -16,6 +16,107 @@ jobs:
python-version: "3.13"
environment-file: environment.yml
activate-environment: quantecon
- name: Check nvidia Drivers
shell: bash -l {0}
run: nvidia-smi
- name: Install JAX and Numpyro
shell: bash -l {0}
run: |
pip install -U "jax[cuda13]"
pip install numpyro
python scripts/test-jax-install.py
# === lax.scan GPU Performance Profiling ===
- name: Profile lax.scan (GPU vs CPU)
shell: bash -l {0}
run: |
echo "=== lax.scan Performance Profiling ==="
echo "This profiles the known issue with lax.scan on GPU (JAX Issue #2491)"
echo ""
python scripts/profile_lax_scan.py --iterations 100000 --diagnose
echo ""
echo "The diagnostic shows if time scales linearly with iterations,"
echo "which indicates constant per-iteration CPU-GPU sync overhead."
- name: Nsight Systems Profile (if available)
shell: bash -l {0}
continue-on-error: true
run: |
echo "=== NVIDIA Nsight Systems Profiling ==="
if command -v nsys &> /dev/null; then
echo "nsys found, running profile with 1000 iterations..."
mkdir -p nsight_profiles
nsys profile -o nsight_profiles/lax_scan_trace \
--trace=cuda,nvtx,osrt \
--cuda-memory-usage=true \
--stats=true \
python scripts/profile_lax_scan.py --nsys -n 1000
echo ""
echo "Profile saved to nsight_profiles/lax_scan_trace.nsys-rep"
echo "Download artifact and open in Nsight Systems UI to see CPU-GPU sync pattern"
else
echo "nsys not found, skipping Nsight profiling"
echo "Install NVIDIA Nsight Systems to enable this profiling"
fi
- name: Upload Nsight Profile
uses: actions/upload-artifact@v5
if: success() || failure()
continue-on-error: true
with:
name: nsight-profile
path: nsight_profiles/
if-no-files-found: ignore
# === Benchmark Tests (Bare Metal, Jupyter, Jupyter-Book) ===
- name: Run Hardware Benchmarks (Bare Metal)
shell: bash -l {0}
run: |
echo "=== Bare Metal Python Script Execution ==="
python scripts/benchmark-hardware.py
mkdir -p benchmark_results
mv benchmark_results_bare_metal.json benchmark_results/
- name: Run Jupyter Notebook Benchmark (via nbconvert)
shell: bash -l {0}
run: |
echo "=== Jupyter Kernel Execution ==="
cd scripts
jupyter nbconvert --to notebook --execute benchmark-jupyter.ipynb --output benchmark-jupyter-executed.ipynb
echo "Notebook executed successfully"
cd ..
mv scripts/benchmark_results_jupyter.json benchmark_results/
- name: Run Jupyter-Book Benchmark
shell: bash -l {0}
run: |
echo "=== Jupyter-Book Execution ==="
# Build just the benchmark file using jupyter-book
mkdir -p benchmark_test
cp scripts/benchmark-jupyterbook.md benchmark_test/
# Create minimal _config.yml
echo "title: Benchmark Test" > benchmark_test/_config.yml
echo "execute:" >> benchmark_test/_config.yml
echo " execute_notebooks: force" >> benchmark_test/_config.yml
# Create minimal _toc.yml
echo "format: jb-book" > benchmark_test/_toc.yml
echo "root: benchmark-jupyterbook" >> benchmark_test/_toc.yml
# Build (run from benchmark_test so JSON is written there)
cd benchmark_test
jb build . --path-output ../benchmark_build/
cd ..
echo "Jupyter-Book build completed successfully"
# Move JSON results if generated
cp benchmark_test/benchmark_results_jupyterbook.json benchmark_results/ 2>/dev/null || echo "No jupyterbook results"
- name: Collect and Display Benchmark Results
shell: bash -l {0}
run: |
echo "=== Benchmark Results Summary ==="
for f in benchmark_results/*.json; do
echo "--- $f ---"
cat "$f"
echo ""
done
- name: Upload Benchmark Results
uses: actions/upload-artifact@v5
with:
name: benchmark-results
path: benchmark_results/
if-no-files-found: warn
- name: Install latex dependencies
run: |
sudo apt-get -qq update
Expand Down
12 changes: 11 additions & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
jobs:
publish:
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
runs-on: ubuntu-latest
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
steps:
- name: Checkout
uses: actions/checkout@v6
Expand All @@ -21,6 +21,16 @@ jobs:
python-version: "3.13"
environment-file: environment.yml
activate-environment: quantecon
- name: Install JAX and Numpyro
shell: bash -l {0}
run: |
pip install -U "jax[cuda13]"
pip install numpyro
python scripts/test-jax-install.py
- name: Check nvidia drivers
shell: bash -l {0}
run: |
nvidia-smi
- name: Install latex dependencies
run: |
sudo apt-get -qq update
Expand Down
35 changes: 15 additions & 20 deletions lectures/jax_intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,23 @@ In addition to what's in Anaconda, this lecture will need the following librarie
!pip install jax quantecon
```

This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax).
```{admonition} GPU
:class: warning

This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and target JAX for GPU programming.

Here we are focused on using JAX on the CPU, rather than on accelerators such as
GPUs or TPUs.
Free GPUs are available on Google Colab.
To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU.

This means we will only see a small amount of the possible benefits from using JAX.
Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support.
If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]`
```

However, JAX seamlessly handles transitions across different hardware platforms.
This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax).

As a result, if you run this code on a machine with a GPU and a GPU-aware
version of JAX installed, your code will be automatically accelerated and you
will receive the full benefits.
JAX provides a NumPy-like interface that can leverage GPU acceleration for high-performance numerical computing.

For a discussion of JAX on GPUs, see [our JAX lecture series](https://jax.quantecon.org/intro.html).
For a more comprehensive discussion of JAX, see [our JAX lecture series](https://jax.quantecon.org/intro.html).


## JAX as a NumPy Replacement
Expand Down Expand Up @@ -523,16 +526,9 @@ with qe.Timer():
jax.block_until_ready(y);
```

If you are running this on a GPU the code will run much faster than its NumPy
equivalent, which ran on the CPU.

Even if you are running on a machine with many CPUs, the second JAX run should
be substantially faster with JAX.

Also, typically, the second run is faster than the first.
On a GPU, this code runs much faster than its NumPy equivalent.

(This might not be noticable on the CPU but it should definitely be noticable on
the GPU.)
Also, typically, the second run is faster than the first due to JIT compilation.

This is because even built in functions like `jnp.cos` are JIT-compiled --- and the
first run includes compile time.
Expand Down Expand Up @@ -634,8 +630,7 @@ with qe.Timer():
jax.block_until_ready(y);
```

The outcome is similar to the `cos` example --- JAX is faster, especially if you
use a GPU and especially on the second run.
The outcome is similar to the `cos` example --- JAX is faster, especially on the second run after JIT compilation.

Moreover, with JAX, we have another trick up our sleeve:

Expand Down
28 changes: 26 additions & 2 deletions lectures/numpy_vs_numba_vs_jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ tags: [hide-output]
!pip install quantecon jax
```

```{admonition} GPU
:class: warning

This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and target JAX for GPU programming.

Free GPUs are available on Google Colab.
To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU.

Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support.
If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]`
```

We will use the following imports.

```{code-cell} ipython3
Expand Down Expand Up @@ -317,7 +329,7 @@ with qe.Timer(precision=8):
z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready()
```

Once compiled, JAX will be significantly faster than NumPy, especially if you are using a GPU.
Once compiled, JAX is significantly faster than NumPy due to GPU acceleration.

The compilation overhead is a one-time cost that pays off when the function is called repeatedly.

Expand Down Expand Up @@ -497,7 +509,9 @@ Now let's create a JAX version using `lax.scan`:
from jax import lax
from functools import partial

@partial(jax.jit, static_argnums=(1,))
cpu = jax.devices("cpu")[0]

@partial(jax.jit, static_argnums=(1,), device=cpu)
def qm_jax(x0, n, α=4.0):
def update(x, t):
x_new = α * x * (1 - x)
Expand All @@ -509,6 +523,16 @@ def qm_jax(x0, n, α=4.0):

This code is not easy to read but, in essence, `lax.scan` repeatedly calls `update` and accumulates the returns `x_new` into an array.

```{note}
Sharp readers will notice that we specify `device=cpu` in the `jax.jit` decorator.

The computation consists of many very small `lax.scan` iterations that must run sequentially, leaving little opportunity for the GPU to exploit parallelism.

As a result, kernel-launch overhead tends to dominate on the GPU, making the CPU a better fit for this workload.

Curious readers can try removing this option to see how performance changes.
```

Let's time it with the same parameters:

```{code-cell} ipython3
Expand Down
Loading
Loading