From ef95c1a562c05ef008e7fc54fb043eac73f8e1db Mon Sep 17 00:00:00 2001 From: mmcky Date: Wed, 26 Nov 2025 14:05:52 +1100 Subject: [PATCH] [numpy_vs_numba_vs_jax] Fix issue in parallel Numba implementation The previous implementation had a race condition where multiple threads simultaneously updated the shared variable 'm', causing incorrect results (returning -inf instead of the actual maximum). The fix computes per-row maximums in parallel (each thread writes to a unique index), then reduces to find the global maximum. Changes: - Replace shared variable 'm' with thread-safe 'row_maxes' array - Each parallel iteration computes a thread-local 'row_max' - Final np.max(row_maxes) combines partial results - Remove broken nested prange version (same race condition) - Add explanatory text about avoiding race conditions in reductions --- lectures/numpy_vs_numba_vs_jax.md | 43 ++++++++----------------------- 1 file changed, 11 insertions(+), 32 deletions(-) diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index 55de222b..47a2fe80 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -200,23 +200,28 @@ working with a single one-dimensional grid. ### Parallelized Numba -Now let's try parallelization with Numba using `prange`: +Now let's try parallelization with Numba using `prange`. -First we parallelize just the outer loop. +When parallelizing a reduction (like finding a maximum), we need to avoid race conditions +where multiple threads try to update the same variable simultaneously. + +The solution is to compute partial results (row maximums) in parallel, then combine them. ```{code-cell} ipython3 @numba.jit(parallel=True) def compute_max_numba_parallel(grid): n = len(grid) - m = -np.inf + row_maxes = np.empty(n) for i in numba.prange(n): + row_max = -np.inf for j in range(n): x = grid[i] y = grid[j] z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) - if z > m: - m = z - return m + if z > row_max: + row_max = z + row_maxes[i] = row_max + return np.max(row_maxes) with qe.Timer(precision=8): compute_max_numba_parallel(grid) @@ -228,32 +233,6 @@ with qe.Timer(precision=8): compute_max_numba_parallel(grid) ``` -Next we parallelize both loops. - -```{code-cell} ipython3 -@numba.jit(parallel=True) -def compute_max_numba_parallel_nested(grid): - n = len(grid) - m = -np.inf - for i in numba.prange(n): - for j in numba.prange(n): - x = grid[i] - y = grid[j] - z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) - if z > m: - m = z - return m - -with qe.Timer(precision=8): - compute_max_numba_parallel_nested(grid) -``` - -```{code-cell} ipython3 -with qe.Timer(precision=8): - compute_max_numba_parallel_nested(grid) -``` - - Depending on your machine, you might or might not see large benefits from parallelization here. If you have a small number of cores, the overhead of thread management and synchronization can