Skip to content

Commit ef95c1a

Browse files
committed
[numpy_vs_numba_vs_jax] Fix issue in parallel Numba implementation
The previous implementation had a race condition where multiple threads simultaneously updated the shared variable 'm', causing incorrect results (returning -inf instead of the actual maximum). The fix computes per-row maximums in parallel (each thread writes to a unique index), then reduces to find the global maximum. Changes: - Replace shared variable 'm' with thread-safe 'row_maxes' array - Each parallel iteration computes a thread-local 'row_max' - Final np.max(row_maxes) combines partial results - Remove broken nested prange version (same race condition) - Add explanatory text about avoiding race conditions in reductions
1 parent 4f25b05 commit ef95c1a

File tree

1 file changed

+11
-32
lines changed

1 file changed

+11
-32
lines changed

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -200,23 +200,28 @@ working with a single one-dimensional grid.
200200

201201
### Parallelized Numba
202202

203-
Now let's try parallelization with Numba using `prange`:
203+
Now let's try parallelization with Numba using `prange`.
204204

205-
First we parallelize just the outer loop.
205+
When parallelizing a reduction (like finding a maximum), we need to avoid race conditions
206+
where multiple threads try to update the same variable simultaneously.
207+
208+
The solution is to compute partial results (row maximums) in parallel, then combine them.
206209

207210
```{code-cell} ipython3
208211
@numba.jit(parallel=True)
209212
def compute_max_numba_parallel(grid):
210213
n = len(grid)
211-
m = -np.inf
214+
row_maxes = np.empty(n)
212215
for i in numba.prange(n):
216+
row_max = -np.inf
213217
for j in range(n):
214218
x = grid[i]
215219
y = grid[j]
216220
z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
217-
if z > m:
218-
m = z
219-
return m
221+
if z > row_max:
222+
row_max = z
223+
row_maxes[i] = row_max
224+
return np.max(row_maxes)
220225
221226
with qe.Timer(precision=8):
222227
compute_max_numba_parallel(grid)
@@ -228,32 +233,6 @@ with qe.Timer(precision=8):
228233
compute_max_numba_parallel(grid)
229234
```
230235

231-
Next we parallelize both loops.
232-
233-
```{code-cell} ipython3
234-
@numba.jit(parallel=True)
235-
def compute_max_numba_parallel_nested(grid):
236-
n = len(grid)
237-
m = -np.inf
238-
for i in numba.prange(n):
239-
for j in numba.prange(n):
240-
x = grid[i]
241-
y = grid[j]
242-
z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
243-
if z > m:
244-
m = z
245-
return m
246-
247-
with qe.Timer(precision=8):
248-
compute_max_numba_parallel_nested(grid)
249-
```
250-
251-
```{code-cell} ipython3
252-
with qe.Timer(precision=8):
253-
compute_max_numba_parallel_nested(grid)
254-
```
255-
256-
257236
Depending on your machine, you might or might not see large benefits from parallelization here.
258237

259238
If you have a small number of cores, the overhead of thread management and synchronization can

0 commit comments

Comments
 (0)