Skip to content

Commit 54b2b34

Browse files
mmckyHumphreyYang
andcommitted
ENH: Force lax.scan sequential operation to run on CPU
Add device=cpu to the qm_jax function decorator to avoid the known XLA limitation where lax.scan with millions of lightweight iterations performs poorly on GPU due to CPU-GPU synchronization overhead. Added explanatory note about this pattern. Co-authored-by: HumphreyYang <Humphrey.Yang@anu.edu.au>
1 parent 10627ef commit 54b2b34

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,9 @@ Now let's create a JAX version using `lax.scan`:
509509
from jax import lax
510510
from functools import partial
511511
512-
@partial(jax.jit, static_argnums=(1,))
512+
cpu = jax.devices("cpu")[0]
513+
514+
@partial(jax.jit, static_argnums=(1,), device=cpu)
513515
def qm_jax(x0, n, α=4.0):
514516
def update(x, t):
515517
x_new = α * x * (1 - x)
@@ -521,6 +523,11 @@ def qm_jax(x0, n, α=4.0):
521523

522524
This code is not easy to read but, in essence, `lax.scan` repeatedly calls `update` and accumulates the returns `x_new` into an array.
523525

526+
```{note}
527+
We explicitly target the CPU using `device=cpu` because `lax.scan` with many
528+
lightweight iterations performs poorly on GPU due to synchronization overhead.
529+
```
530+
524531
Let's time it with the same parameters:
525532

526533
```{code-cell} ipython3

0 commit comments

Comments
 (0)