Skip to content

Commit 2e1684b

Browse files
committed
Enable RunsOn GPU support for lecture builds
- Add scripts/test-jax-install.py to verify JAX/GPU installation - Add .github/runs-on.yml with QuantEcon Ubuntu 24.04 AMI configuration - Update cache.yml to use RunsOn g4dn.2xlarge GPU runner - Update ci.yml to use RunsOn g4dn.2xlarge GPU runner - Update publish.yml to use RunsOn g4dn.2xlarge GPU runner - Install JAX with CUDA 13 support and Numpyro on all workflows - Add nvidia-smi check to verify GPU availability This mirrors the setup used in lecture-python.myst repository.
1 parent ce3de00 commit 2e1684b

File tree

5 files changed

+59
-3
lines changed

5 files changed

+59
-3
lines changed

.github/runs-on.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
images:
2+
quantecon_ubuntu2404:
3+
platform: "linux"
4+
arch: "x64"
5+
ami: "ami-0edec81935264b6d3"
6+
region: "us-west-2"

.github/workflows/cache.yml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ on:
66
workflow_dispatch:
77
jobs:
88
cache:
9-
runs-on: ubuntu-latest
9+
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
1010
steps:
1111
- uses: actions/checkout@v6
1212
- name: Setup Anaconda
@@ -18,6 +18,16 @@ jobs:
1818
python-version: "3.13"
1919
environment-file: environment.yml
2020
activate-environment: quantecon
21+
- name: Install JAX and Numpyro
22+
shell: bash -l {0}
23+
run: |
24+
pip install -U "jax[cuda13]"
25+
pip install numpyro
26+
python scripts/test-jax-install.py
27+
- name: Check nvidia drivers
28+
shell: bash -l {0}
29+
run: |
30+
nvidia-smi
2131
- name: Build HTML
2232
shell: bash -l {0}
2333
run: |

.github/workflows/ci.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name: Build Project [using jupyter-book]
22
on: [pull_request]
33
jobs:
44
preview:
5-
runs-on: ubuntu-latest
5+
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
66
steps:
77
- uses: actions/checkout@v6
88
with:
@@ -16,6 +16,15 @@ jobs:
1616
python-version: "3.13"
1717
environment-file: environment.yml
1818
activate-environment: quantecon
19+
- name: Check nvidia Drivers
20+
shell: bash -l {0}
21+
run: nvidia-smi
22+
- name: Install JAX and Numpyro
23+
shell: bash -l {0}
24+
run: |
25+
pip install -U "jax[cuda13]"
26+
pip install numpyro
27+
python scripts/test-jax-install.py
1928
- name: Install latex dependencies
2029
run: |
2130
sudo apt-get -qq update

.github/workflows/publish.yml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ on:
66
jobs:
77
publish:
88
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
9-
runs-on: ubuntu-latest
9+
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
1010
steps:
1111
- name: Checkout
1212
uses: actions/checkout@v6
@@ -21,6 +21,16 @@ jobs:
2121
python-version: "3.13"
2222
environment-file: environment.yml
2323
activate-environment: quantecon
24+
- name: Install JAX and Numpyro
25+
shell: bash -l {0}
26+
run: |
27+
pip install -U "jax[cuda13]"
28+
pip install numpyro
29+
python scripts/test-jax-install.py
30+
- name: Check nvidia drivers
31+
shell: bash -l {0}
32+
run: |
33+
nvidia-smi
2434
- name: Install latex dependencies
2535
run: |
2636
sudo apt-get -qq update

scripts/test-jax-install.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
devices = jax.devices()
5+
print(f"The available devices are: {devices}")
6+
7+
@jax.jit
8+
def matrix_multiply(a, b):
9+
return jnp.dot(a, b)
10+
11+
# Example usage:
12+
key = jax.random.PRNGKey(0)
13+
x = jax.random.normal(key, (1000, 1000))
14+
y = jax.random.normal(key, (1000, 1000))
15+
z = matrix_multiply(x, y)
16+
17+
# Now the function is JIT compiled and will likely run on GPU (if available)
18+
print(z)
19+
20+
devices = jax.devices()
21+
print(f"The available devices are: {devices}")

0 commit comments

Comments
 (0)