Skip to content

Conversation

@mmcky
Copy link
Contributor

@mmcky mmcky commented Nov 26, 2025

Summary

The parallel Numba implementation in the NumPy vs Numba vs JAX lecture had a race condition bug that caused it to return incorrect results.

The Bug

The original code had multiple threads simultaneously updating a shared variable m:

@numba.jit(parallel=True)
def compute_max_numba_parallel(grid):
    m = -np.inf
    for i in numba.prange(n):
        for j in range(n):
            ...
            if z > m:
                m = z  # Race condition - multiple threads writing!
    return m

This returned -inf instead of the correct maximum (~0.9999).

The Fix

Each thread now computes its own row maximum, stored in a thread-safe array:

@numba.jit(parallel=True)
def compute_max_numba_parallel(grid):
    n = len(grid)
    row_maxes = np.empty(n)  # Each thread writes to different index
    for i in numba.prange(n):
        row_max = -np.inf     # Thread-local variable
        for j in range(n):
            ...
            if z > row_max:
                row_max = z
        row_maxes[i] = row_max  # No race - each i is unique
    return np.max(row_maxes)    # Final reduction

Verification

Version Result Correct?
Original (broken) -inf
Fixed 0.9999979986680024
Serial (reference) 0.9999979986680024

Changes

  • Replace shared variable m with thread-safe row_maxes array
  • Each parallel iteration computes a thread-local row_max
  • Final np.max(row_maxes) combines partial results
  • Remove broken nested prange version (same race condition)
  • Add explanatory text about avoiding race conditions in parallel reductions

The previous implementation had a race condition where multiple threads
simultaneously updated the shared variable 'm', causing incorrect results
(returning -inf instead of the actual maximum).

The fix computes per-row maximums in parallel (each thread writes to a
unique index), then reduces to find the global maximum.

Changes:
- Replace shared variable 'm' with thread-safe 'row_maxes' array
- Each parallel iteration computes a thread-local 'row_max'
- Final np.max(row_maxes) combines partial results
- Remove broken nested prange version (same race condition)
- Add explanatory text about avoiding race conditions in reductions
@mmcky
Copy link
Contributor Author

mmcky commented Nov 26, 2025

@jstac this is available should you like the solution.

@github-actions
Copy link

@github-actions github-actions bot temporarily deployed to pull request November 26, 2025 03:14 Inactive
@jstac
Copy link
Contributor

jstac commented Nov 26, 2025

Thanks @mmcky . I'm going to close this because I want to keep the old parallel number version too, for illustration -- along with discussion and explanation.

But I'll also include your new version -- and well as the corresponding JAX version.

@jstac jstac closed this Nov 26, 2025
@mmcky
Copy link
Contributor Author

mmcky commented Nov 26, 2025

thanks @jstac that makes a lot of sense, it's a GREAT teaching example.

@mmcky mmcky deleted the fix/numba-parallel-race-condition branch November 26, 2025 05:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants