Skip to content

Commit 458b6fe

Browse files
jstacclaude
andcommitted
Fix code errors and grammar in numpy_vs_numba_vs_jax lecture
- Fix incorrect function name: compute_max_numba_parallel_nested → compute_max_numba_parallel - Fix incorrect variable name: z_vmap → z_max - Fix grammar: "similar to as" → "similar to" - Fix technical description: lax.scan calls update function, not qm_jax All fixes verified by converting to Python and running successfully with ipython. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent dff29ba commit 458b6fe

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,12 +270,12 @@ Here's the timings.
270270

271271
```{code-cell} ipython3
272272
with qe.Timer(precision=8):
273-
compute_max_numba_parallel_nested(grid)
273+
compute_max_numba_parallel(grid)
274274
```
275275

276276
```{code-cell} ipython3
277277
with qe.Timer(precision=8):
278-
compute_max_numba_parallel_nested(grid)
278+
compute_max_numba_parallel(grid)
279279
```
280280

281281
If you have multiple cores, you should see at least some benefits from parallelization here.
@@ -361,16 +361,16 @@ Let's see the timing:
361361
```{code-cell} ipython3
362362
with qe.Timer(precision=8):
363363
z_max = jnp.max(f_vec(grid))
364-
z_vmap.block_until_ready()
364+
z_max.block_until_ready()
365365
```
366366

367367
```{code-cell} ipython3
368368
with qe.Timer(precision=8):
369369
z_max = jnp.max(f_vec(grid))
370-
z_vmap.block_until_ready()
370+
z_max.block_until_ready()
371371
```
372372

373-
The execution time is similar to as the mesh operation but, by avoiding the large input arrays `x_mesh` and `y_mesh`,
373+
The execution time is similar to the mesh operation but, by avoiding the large input arrays `x_mesh` and `y_mesh`,
374374
we are using far less memory.
375375

376376
In addition, `vmap` allows us to break vectorization up into stages, which is
@@ -507,7 +507,7 @@ def qm_jax(x0, n, α=4.0):
507507
return jnp.concatenate([jnp.array([x0]), x])
508508
```
509509

510-
This code is not easy to read but, in essence, `lax.scan` repeatedly calls `qm_jax` and accumulates the returns `x_new` into an array.
510+
This code is not easy to read but, in essence, `lax.scan` repeatedly calls `update` and accumulates the returns `x_new` into an array.
511511

512512
Let's time it with the same parameters:
513513

0 commit comments

Comments
 (0)