Skip to content

Commit 849cf16

Browse files
committed
DOC: Update JAX lectures with GPU admonition and narrative
- Add standard GPU admonition to jax_intro.md and numpy_vs_numba_vs_jax.md - Update introduction in jax_intro.md to reflect GPU access - Update conditional GPU language to reflect lectures now run on GPU - Following QuantEcon style guide for JAX lectures
1 parent 1671eb7 commit 849cf16

File tree

2 files changed

+28
-21
lines changed

2 files changed

+28
-21
lines changed

lectures/jax_intro.md

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,23 @@ In addition to what's in Anaconda, this lecture will need the following librarie
2121
!pip install jax quantecon
2222
```
2323

24-
This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax).
24+
```{admonition} GPU
25+
:class: warning
26+
27+
This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and target JAX for GPU programming.
2528
26-
Here we are focused on using JAX on the CPU, rather than on accelerators such as
27-
GPUs or TPUs.
29+
Free GPUs are available on Google Colab.
30+
To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU.
2831
29-
This means we will only see a small amount of the possible benefits from using JAX.
32+
Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support.
33+
If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]`
34+
```
3035

31-
However, JAX seamlessly handles transitions across different hardware platforms.
36+
This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax).
3237

33-
As a result, if you run this code on a machine with a GPU and a GPU-aware
34-
version of JAX installed, your code will be automatically accelerated and you
35-
will receive the full benefits.
38+
JAX provides a NumPy-like interface that can leverage GPU acceleration for high-performance numerical computing.
3639

37-
For a discussion of JAX on GPUs, see [our JAX lecture series](https://jax.quantecon.org/intro.html).
40+
For a more comprehensive discussion of JAX, see [our JAX lecture series](https://jax.quantecon.org/intro.html).
3841

3942

4043
## JAX as a NumPy Replacement
@@ -523,16 +526,9 @@ with qe.Timer():
523526
jax.block_until_ready(y);
524527
```
525528

526-
If you are running this on a GPU the code will run much faster than its NumPy
527-
equivalent, which ran on the CPU.
528-
529-
Even if you are running on a machine with many CPUs, the second JAX run should
530-
be substantially faster with JAX.
531-
532-
Also, typically, the second run is faster than the first.
529+
On a GPU, this code runs much faster than its NumPy equivalent.
533530

534-
(This might not be noticable on the CPU but it should definitely be noticable on
535-
the GPU.)
531+
Also, typically, the second run is faster than the first due to JIT compilation.
536532

537533
This is because even built in functions like `jnp.cos` are JIT-compiled --- and the
538534
first run includes compile time.
@@ -634,8 +630,7 @@ with qe.Timer():
634630
jax.block_until_ready(y);
635631
```
636632

637-
The outcome is similar to the `cos` example --- JAX is faster, especially if you
638-
use a GPU and especially on the second run.
633+
The outcome is similar to the `cos` example --- JAX is faster, especially on the second run after JIT compilation.
639634

640635
Moreover, with JAX, we have another trick up our sleeve:
641636

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ tags: [hide-output]
4848
!pip install quantecon jax
4949
```
5050

51+
```{admonition} GPU
52+
:class: warning
53+
54+
This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and target JAX for GPU programming.
55+
56+
Free GPUs are available on Google Colab.
57+
To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU.
58+
59+
Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support.
60+
If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]`
61+
```
62+
5163
We will use the following imports.
5264

5365
```{code-cell} ipython3
@@ -317,7 +329,7 @@ with qe.Timer(precision=8):
317329
z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready()
318330
```
319331

320-
Once compiled, JAX will be significantly faster than NumPy, especially if you are using a GPU.
332+
Once compiled, JAX is significantly faster than NumPy due to GPU acceleration.
321333

322334
The compilation overhead is a one-time cost that pays off when the function is called repeatedly.
323335

0 commit comments

Comments
 (0)