File tree Expand file tree Collapse file tree 5 files changed +59
-3
lines changed
Expand file tree Collapse file tree 5 files changed +59
-3
lines changed Original file line number Diff line number Diff line change 1+ images :
2+ quantecon_ubuntu2404 :
3+ platform : " linux"
4+ arch : " x64"
5+ ami : " ami-0edec81935264b6d3"
6+ region : " us-west-2"
Original file line number Diff line number Diff line change 66 workflow_dispatch :
77jobs :
88 cache :
9- runs-on : ubuntu-latest
9+ runs-on : " runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large "
1010 steps :
1111 - uses : actions/checkout@v6
1212 - name : Setup Anaconda
1818 python-version : " 3.13"
1919 environment-file : environment.yml
2020 activate-environment : quantecon
21+ - name : Install JAX and Numpyro
22+ shell : bash -l {0}
23+ run : |
24+ pip install -U "jax[cuda13]"
25+ pip install numpyro
26+ python scripts/test-jax-install.py
27+ - name : Check nvidia drivers
28+ shell : bash -l {0}
29+ run : |
30+ nvidia-smi
2131 - name : Build HTML
2232 shell : bash -l {0}
2333 run : |
Original file line number Diff line number Diff line change @@ -2,7 +2,7 @@ name: Build Project [using jupyter-book]
22on : [pull_request]
33jobs :
44 preview :
5- runs-on : ubuntu-latest
5+ runs-on : " runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large "
66 steps :
77 - uses : actions/checkout@v6
88 with :
1616 python-version : " 3.13"
1717 environment-file : environment.yml
1818 activate-environment : quantecon
19+ - name : Check nvidia Drivers
20+ shell : bash -l {0}
21+ run : nvidia-smi
22+ - name : Install JAX and Numpyro
23+ shell : bash -l {0}
24+ run : |
25+ pip install -U "jax[cuda13]"
26+ pip install numpyro
27+ python scripts/test-jax-install.py
1928 - name : Install latex dependencies
2029 run : |
2130 sudo apt-get -qq update
Original file line number Diff line number Diff line change 66jobs :
77 publish :
88 if : github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
9- runs-on : ubuntu-latest
9+ runs-on : " runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large "
1010 steps :
1111 - name : Checkout
1212 uses : actions/checkout@v6
2121 python-version : " 3.13"
2222 environment-file : environment.yml
2323 activate-environment : quantecon
24+ - name : Install JAX and Numpyro
25+ shell : bash -l {0}
26+ run : |
27+ pip install -U "jax[cuda13]"
28+ pip install numpyro
29+ python scripts/test-jax-install.py
30+ - name : Check nvidia drivers
31+ shell : bash -l {0}
32+ run : |
33+ nvidia-smi
2434 - name : Install latex dependencies
2535 run : |
2636 sudo apt-get -qq update
Original file line number Diff line number Diff line change 1+ import jax
2+ import jax .numpy as jnp
3+
4+ devices = jax .devices ()
5+ print (f"The available devices are: { devices } " )
6+
7+ @jax .jit
8+ def matrix_multiply (a , b ):
9+ return jnp .dot (a , b )
10+
11+ # Example usage:
12+ key = jax .random .PRNGKey (0 )
13+ x = jax .random .normal (key , (1000 , 1000 ))
14+ y = jax .random .normal (key , (1000 , 1000 ))
15+ z = matrix_multiply (x , y )
16+
17+ # Now the function is JIT compiled and will likely run on GPU (if available)
18+ print (z )
19+
20+ devices = jax .devices ()
21+ print (f"The available devices are: { devices } " )
You can’t perform that action at this time.
0 commit comments