Skip to content

Commit cded2f8

Browse files
jstacclaude
andcommitted
Improve formatting and clarity across parallel computing lectures
- Standardize header capitalization in need_for_speed.md - Update code cell types to ipython3 in numba.md for consistency - Remove redundant parallelization warning section in numba.md - Enhance explanatory text and code clarity in numpy_vs_numba_vs_jax.md - Fix formatting and add missing validation checks 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 0678fc3 commit cded2f8

File tree

3 files changed

+40
-40
lines changed

3 files changed

+40
-40
lines changed

lectures/need_for_speed.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,12 @@ On the other hand, the standard implementation of Python (called CPython) cannot
153153
match the speed of compiled languages such as C or Fortran.
154154

155155

156-
### Where are the Bottlenecks?
156+
### Where are the bottlenecks?
157157

158158
Why is this the case?
159159

160160

161-
#### Dynamic Typing
161+
#### Dynamic typing
162162

163163
```{index} single: Dynamic Typing
164164
```
@@ -200,7 +200,7 @@ If we repeatedly execute this expression in a tight loop, the nontrivial
200200
overhead becomes a large overhead.
201201

202202

203-
#### Static Types
203+
#### Static types
204204

205205
```{index} single: Static Types
206206
```
@@ -250,7 +250,7 @@ Such an array is stored in a single contiguous block of memory
250250

251251
* In modern computers, memory addresses are allocated to each byte (one byte = 8 bits).
252252
* For example, a 64 bit integer is stored in 8 bytes of memory.
253-
* An array of $n$ such integers occupies $8n$ **consecutive** memory slots.
253+
* An array of $n$ such integers occupies $8n$ *consecutive* memory slots.
254254

255255
Moreover, the compiler is made aware of the data type by the programmer.
256256

@@ -336,7 +336,7 @@ for this purpose and supplied to users as part of a package.
336336

337337
The core benefits are
338338

339-
1. type-checking is paid per array, rather than per element, and
339+
1. type-checking is paid *per array*, rather than per element, and
340340
1. arrays containing elements with the same data type are efficient in terms of
341341
memory access.
342342

lectures/numba.md

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ distribution.
475475

476476
Here's the code:
477477

478-
```{code-cell} ipython
478+
```{code-cell} ipython3
479479
from numpy.random import randn
480480
from numba import njit
481481
@@ -496,7 +496,7 @@ def h(w, r=0.1, s=0.3, v1=0.1, v2=1.0):
496496

497497
Let's have a look at how wealth evolves under this rule.
498498

499-
```{code-cell} ipython
499+
```{code-cell} ipython3
500500
fig, ax = plt.subplots()
501501
502502
T = 100
@@ -540,7 +540,7 @@ Then we'll calculate median wealth at the end period.
540540

541541
Here's the code:
542542

543-
```{code-cell} ipython
543+
```{code-cell} ipython3
544544
@njit
545545
def compute_long_run_median(w0=1, T=1000, num_reps=50_000):
546546
@@ -556,7 +556,7 @@ def compute_long_run_median(w0=1, T=1000, num_reps=50_000):
556556

557557
Let's see how fast this runs:
558558

559-
```{code-cell} ipython
559+
```{code-cell} ipython3
560560
with qe.Timer():
561561
compute_long_run_median()
562562
```
@@ -565,7 +565,7 @@ To speed this up, we're going to parallelize it via multithreading.
565565

566566
To do so, we add the `parallel=True` flag and change `range` to `prange`:
567567

568-
```{code-cell} ipython
568+
```{code-cell} ipython3
569569
from numba import prange
570570
571571
@njit(parallel=True)
@@ -583,26 +583,13 @@ def compute_long_run_median_parallel(w0=1, T=1000, num_reps=50_000):
583583

584584
Let's look at the timing:
585585

586-
```{code-cell} ipython
586+
```{code-cell} ipython3
587587
with qe.Timer():
588588
compute_long_run_median_parallel()
589589
```
590590

591591
The speed-up is significant.
592592

593-
### A Warning
594-
595-
Parallelization works well in the outer loop of the last example because the individual tasks inside the loop are independent of each other.
596-
597-
If this independence fails then parallelization is often problematic.
598-
599-
For example, each step inside the inner loop depends on the last step, so
600-
independence fails, and this is why we use ordinary `range` instead of `prange`.
601-
602-
When you see us using `prange` in later lectures, it is because the
603-
independence of tasks holds true.
604-
605-
Conversely, when you see us using ordinary `range` in a jitted function, it is either because the speed gain from parallelization is small or because independence fails.
606593

607594
## Exercises
608595

@@ -807,7 +794,7 @@ For the size of the Monte Carlo simulation, use something substantial, such as
807794

808795
Here is one solution:
809796

810-
```{code-cell} python3
797+
```{code-cell} ipython3
811798
from random import uniform
812799
813800
@njit(parallel=True)

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -321,11 +321,12 @@ x_mesh.nbytes + y_mesh.nbytes
321321

322322
This extra memory usage can be a big problem in actual research calculations.
323323

324-
Fortunately, JAX admits a different approach using [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html)
324+
Fortunately, JAX admits a different approach
325+
using [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html).
325326

326327
#### Version 1
327328

328-
Here's one way we can do this
329+
Here's one way we can apply `vmap`.
329330

330331
```{code-cell} ipython3
331332
# Set up f to compute f(x, y) at every x for any given y
@@ -340,8 +341,8 @@ Let's see the timing:
340341

341342
```{code-cell} ipython3
342343
with qe.Timer(precision=8):
343-
z_vmap_1 = f_vec(grid)
344-
z_vmap_1.block_until_ready()
344+
z_vmap = f_vec(grid)
345+
z_vmap.block_until_ready()
345346
```
346347

347348
Let's check we got the right result:
@@ -393,6 +394,13 @@ with qe.Timer(precision=8):
393394
z_vmap = f_vec(x, y).block_until_ready()
394395
```
395396

397+
Let's check we got the right result:
398+
399+
400+
```{code-cell} ipython3
401+
jnp.allclose(z_mesh, z_vmap)
402+
```
403+
396404

397405

398406
### Summary
@@ -461,6 +469,8 @@ Numba's compilation is typically quite fast, and the resulting code performance
461469

462470
Now let's create a JAX version using `lax.scan`:
463471

472+
(We'll hold `n` static because it affects array size and hence JAX wants to specialize on its value in the compiled code.)
473+
464474
```{code-cell} ipython3
465475
from jax import lax
466476
from functools import partial
@@ -475,6 +485,8 @@ def qm_jax(x0, n, α=4.0):
475485
return jnp.concatenate([jnp.array([x0]), x])
476486
```
477487

488+
This code is not easy to read but, in essence, `lax.scan` repeatedly calls `qm_jax` and accumulates the returns `x_new` into an array.
489+
478490
Let's time it with the same parameters:
479491

480492
```{code-cell} ipython3
@@ -489,24 +501,25 @@ with qe.Timer(precision=8):
489501
x_jax = qm_jax(0.1, n).block_until_ready()
490502
```
491503

492-
JAX is also very efficient for this sequential operation.
504+
JAX is also efficient for this sequential operation.
493505

494-
Both JAX and Numba deliver strong performance after compilation.
495-
496-
While the raw speed is similar for this type of operation, there are notable differences in code complexity and ease of understanding, which we discuss in the next section.
506+
Both JAX and Numba deliver strong performance after compilation, with Numba
507+
typically (but not always) offering slightly better speeds on purely sequential
508+
operations.
497509

498510
### Summary
499511

500-
While both Numba and JAX deliver excellent performance for sequential operations, there are significant differences in code readability and ease of use.
512+
While both Numba and JAX deliver strong performance for sequential operations,
513+
there are significant differences in code readability and ease of use.
501514

502-
The Numba version is straightforward and natural to read: we simply allocate an array and fill it element by element using a standard Python loop.
515+
The Numba version is straightforward and natural to read: we simply allocate an
516+
array and fill it element by element using a standard Python loop.
503517

504518
This is exactly how most programmers think about the algorithm.
505519

506-
The JAX version, on the other hand, requires using `lax.scan`, which is less intuitive and has a steeper learning curve.
507-
508-
Additionally, JAX's immutable arrays mean we cannot simply update array elements in place.
520+
The JAX version, on the other hand, requires using `lax.scan`, which is significantly less intuitive.
509521

510-
Instead, we must use functional programming patterns with `lax.scan`, where we define an `update` function that returns both the new state and the value to accumulate.
522+
Additionally, JAX's immutable arrays mean we cannot simply update array elements in place, making it hard to directly replicate the algorithm used by Numba.
511523

512-
For this type of sequential operation, Numba is the clear winner in terms of code clarity and ease of implementation, while maintaining competitive performance.
524+
For this type of sequential operation, Numba is the clear winner in terms of
525+
code clarity and ease of implementation, as well as high performance.

0 commit comments

Comments
 (0)