Skip to content

Commit 1739f51

Browse files
jstacclaude
andcommitted
Improve JAX lecture content and pedagogy
- Reorganize jax_intro.md to introduce JAX features upfront with clearer structure - Expand JAX introduction with bulleted list of key capabilities (parallelization, JIT, autodiff) - Add explicit GPU performance notes in vmap sections - Enhance vmap explanation with detailed function composition breakdown - Clarify memory efficiency tradeoffs between different vmap approaches 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent a623fb7 commit 1739f51

File tree

2 files changed

+37
-19
lines changed

2 files changed

+37
-19
lines changed

lectures/jax_intro.md

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,18 @@ kernelspec:
1313

1414
# JAX
1515

16+
This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax).
17+
18+
JAX is a high-performance scientific computing library that provides
19+
20+
* a NumPy-like interface that can automatically parallize across CPUs and GPUs,
21+
* a just-in-time compiler for accelerating a large range of numerical
22+
operations, and
23+
* automatic differentiation.
24+
25+
Increasingly, JAX also maintains and provides more specialized scientific
26+
computing routines, such as those originally found in SciPy.
27+
1628
In addition to what's in Anaconda, this lecture will need the following libraries:
1729

1830
```{code-cell} ipython3
@@ -33,17 +45,12 @@ Alternatively, if you have your own GPU, you can follow the [instructions](https
3345
If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]`
3446
```
3547

36-
This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax).
37-
38-
JAX provides a NumPy-like interface that can leverage GPU acceleration for high-performance numerical computing.
39-
40-
4148
## JAX as a NumPy Replacement
4249

43-
One of the attractive features of JAX is that, whenever possible, it conforms to
44-
the NumPy API for array operations.
50+
One of the attractive features of JAX is that, whenever possible, its array
51+
processing operations conform to the NumPy API.
4552

46-
This means that, to a large extent, we can use JAX is as a drop-in NumPy replacement.
53+
This means that, in many cases, we can use JAX is as a drop-in NumPy replacement.
4754

4855
Let's look at the similarities and differences between JAX and NumPy.
4956

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -382,23 +382,29 @@ with qe.Timer(precision=8):
382382
z_max.block_until_ready()
383383
```
384384

385-
The execution time is similar to the mesh operation but, by avoiding the large input arrays `x_mesh` and `y_mesh`,
386-
we are using far less memory.
385+
By avoiding the large input arrays `x_mesh` and `y_mesh`, this `vmap` version uses far less memory.
387386

388-
In addition, `vmap` allows us to break vectorization up into stages, which is
389-
often easier to comprehend than the traditional approach.
387+
When run on a CPU, its runtime is similar to that of the meshgrid version.
390388

391-
This will become more obvious when we tackle larger problems.
389+
When run on a GPU, it is usually significantly faster.
390+
391+
In fact, using `vmap` has another advantage: It allows us to break vectorization up into stages.
392+
393+
This leads to code that is often easier to comprehend than traditional vectorized code.
394+
395+
We will investigate these ideas more when we tackle larger problems.
392396

393397

394398
### vmap version 2
395399

396400
We can be still more memory efficient using vmap.
397401

398-
While we avoided large input arrays in the preceding version,
402+
While we avoid large input arrays in the preceding version,
399403
we still create the large output array `f(x,y)` before we compute the max.
400404

401-
Let's use a slightly different approach that takes the max to the inside.
405+
Let's try a slightly different approach that takes the max to the inside.
406+
407+
Because of this change, we never compute the two-dimensional array `f(x,y)`.
402408

403409
```{code-cell} ipython3
404410
@jax.jit
@@ -411,23 +417,28 @@ def compute_max_vmap_v2(grid):
411417
return jnp.max(f_vec_max(grid))
412418
```
413419

414-
Let's try it
420+
Here
421+
422+
* `f_vec_x_max` computes the max along any given row
423+
* `f_vec_max` is a vectorized version that can compute the max of all rows in parallel.
424+
425+
We apply this function to all rows and then take the max of the row maxes.
426+
427+
Let's try it.
415428

416429
```{code-cell} ipython3
417430
with qe.Timer(precision=8):
418431
z_max = compute_max_vmap_v2(grid).block_until_ready()
419432
```
420433

421-
422434
Let's run it again to eliminate compilation time:
423435

424436
```{code-cell} ipython3
425437
with qe.Timer(precision=8):
426438
z_max = compute_max_vmap_v2(grid).block_until_ready()
427439
```
428440

429-
We don't get much speed gain but we do save some memory.
430-
441+
If you are running this on a GPU, as well are, you should see another nontrivial speed gain.
431442

432443

433444
### Summary

0 commit comments

Comments
 (0)