Skip to content

Commit 6bf345a

Browse files
committed
update note
1 parent 54b2b34 commit 6bf345a

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,8 +524,13 @@ def qm_jax(x0, n, α=4.0):
524524
This code is not easy to read but, in essence, `lax.scan` repeatedly calls `update` and accumulates the returns `x_new` into an array.
525525

526526
```{note}
527-
We explicitly target the CPU using `device=cpu` because `lax.scan` with many
528-
lightweight iterations performs poorly on GPU due to synchronization overhead.
527+
Sharp readers will notice that we specify `device=cpu` in the `jax.jit` decorator.
528+
529+
The computation consists of many very small `lax.scan` iterations that must run sequentially, leaving little opportunity for the GPU to exploit parallelism.
530+
531+
As a result, kernel-launch overhead tends to dominate on the GPU, making the CPU a better fit for this workload.
532+
533+
Curious readers can try removing this option to see how performance changes.
529534
```
530535

531536
Let's time it with the same parameters:

0 commit comments

Comments
 (0)