diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index 55de222b..47a2fe80 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -200,23 +200,28 @@ working with a single one-dimensional grid. ### Parallelized Numba -Now let's try parallelization with Numba using `prange`: +Now let's try parallelization with Numba using `prange`. -First we parallelize just the outer loop. +When parallelizing a reduction (like finding a maximum), we need to avoid race conditions +where multiple threads try to update the same variable simultaneously. + +The solution is to compute partial results (row maximums) in parallel, then combine them. ```{code-cell} ipython3 @numba.jit(parallel=True) def compute_max_numba_parallel(grid): n = len(grid) - m = -np.inf + row_maxes = np.empty(n) for i in numba.prange(n): + row_max = -np.inf for j in range(n): x = grid[i] y = grid[j] z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) - if z > m: - m = z - return m + if z > row_max: + row_max = z + row_maxes[i] = row_max + return np.max(row_maxes) with qe.Timer(precision=8): compute_max_numba_parallel(grid) @@ -228,32 +233,6 @@ with qe.Timer(precision=8): compute_max_numba_parallel(grid) ``` -Next we parallelize both loops. - -```{code-cell} ipython3 -@numba.jit(parallel=True) -def compute_max_numba_parallel_nested(grid): - n = len(grid) - m = -np.inf - for i in numba.prange(n): - for j in numba.prange(n): - x = grid[i] - y = grid[j] - z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) - if z > m: - m = z - return m - -with qe.Timer(precision=8): - compute_max_numba_parallel_nested(grid) -``` - -```{code-cell} ipython3 -with qe.Timer(precision=8): - compute_max_numba_parallel_nested(grid) -``` - - Depending on your machine, you might or might not see large benefits from parallelization here. If you have a small number of cores, the overhead of thread management and synchronization can