@@ -21,20 +21,23 @@ In addition to what's in Anaconda, this lecture will need the following librarie
2121!pip install jax quantecon
2222```
2323
24- This lecture provides a short introduction to [ Google JAX] ( https://github.com/jax-ml/jax ) .
24+ ``` {admonition} GPU
25+ :class: warning
26+
27+ This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and target JAX for GPU programming.
2528
26- Here we are focused on using JAX on the CPU, rather than on accelerators such as
27- GPUs or TPUs .
29+ Free GPUs are available on Google Colab.
30+ To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU .
2831
29- This means we will only see a small amount of the possible benefits from using JAX.
32+ Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support.
33+ If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]`
34+ ```
3035
31- However, JAX seamlessly handles transitions across different hardware platforms .
36+ This lecture provides a short introduction to [ Google JAX ] ( https://github.com/jax-ml/jax ) .
3237
33- As a result, if you run this code on a machine with a GPU and a GPU-aware
34- version of JAX installed, your code will be automatically accelerated and you
35- will receive the full benefits.
38+ JAX provides a NumPy-like interface that can leverage GPU acceleration for high-performance numerical computing.
3639
37- For a discussion of JAX on GPUs , see [ our JAX lecture series] ( https://jax.quantecon.org/intro.html ) .
40+ For a more comprehensive discussion of JAX, see [ our JAX lecture series] ( https://jax.quantecon.org/intro.html ) .
3841
3942
4043## JAX as a NumPy Replacement
@@ -523,16 +526,9 @@ with qe.Timer():
523526 jax.block_until_ready(y);
524527```
525528
526- If you are running this on a GPU the code will run much faster than its NumPy
527- equivalent, which ran on the CPU.
528-
529- Even if you are running on a machine with many CPUs, the second JAX run should
530- be substantially faster with JAX.
531-
532- Also, typically, the second run is faster than the first.
529+ On a GPU, this code runs much faster than its NumPy equivalent.
533530
534- (This might not be noticable on the CPU but it should definitely be noticable on
535- the GPU.)
531+ Also, typically, the second run is faster than the first due to JIT compilation.
536532
537533This is because even built in functions like ` jnp.cos ` are JIT-compiled --- and the
538534first run includes compile time.
@@ -634,8 +630,7 @@ with qe.Timer():
634630 jax.block_until_ready(y);
635631```
636632
637- The outcome is similar to the ` cos ` example --- JAX is faster, especially if you
638- use a GPU and especially on the second run.
633+ The outcome is similar to the ` cos ` example --- JAX is faster, especially on the second run after JIT compilation.
639634
640635Moreover, with JAX, we have another trick up our sleeve:
641636
0 commit comments