diff --git a/Lib/multiprocessing/pool.py b/Lib/multiprocessing/pool.py index f979890170b1a1..a643e7757022fd 100644 --- a/Lib/multiprocessing/pool.py +++ b/Lib/multiprocessing/pool.py @@ -776,13 +776,23 @@ def get(self, timeout=None): def _set(self, i, obj): self._success, self._value = obj if self._callback and self._success: - self._callback(self._value) + self._handle_exceptions(self._callback, self._value) if self._error_callback and not self._success: - self._error_callback(self._value) + self._handle_exceptions(self._error_callback, self._value) self._event.set() del self._cache[self._job] self._pool = None + @staticmethod + def _handle_exceptions(callback, args): + try: + return callback(args) + except Exception as e: + args = threading.ExceptHookArgs([type(e), e, e.__traceback__, + threading.current_thread()]) + threading.excepthook(args) + del args + __class_getitem__ = classmethod(types.GenericAlias) AsyncResult = ApplyResult # create alias -- see #17805 @@ -813,7 +823,7 @@ def _set(self, i, success_result): self._value[i*self._chunksize:(i+1)*self._chunksize] = result if self._number_left == 0: if self._callback: - self._callback(self._value) + self._handle_exceptions(self._callback, self._value) del self._cache[self._job] self._event.set() self._pool = None @@ -825,7 +835,7 @@ def _set(self, i, success_result): if self._number_left == 0: # only consider the result ready once all jobs are done if self._error_callback: - self._error_callback(self._value) + self._handle_exceptions(self._error_callback, self._value) del self._cache[self._job] self._event.set() self._pool = None diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index 75f31d858d3306..f8744550a87191 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -3172,7 +3172,39 @@ def test_resource_warning(self): pool = None support.gc_collect() -def raising(): + def test_callback_errors(self): + if self.TYPE == 'manager': + self.skipTest("cannot intercept excepthook in manager") + + def _apply(pool, target, **kwargs): + return pool.apply_async(target, **kwargs) + + def _map(pool, target, **kwargs): + return pool.map_async(target, range(1), **kwargs) + + def record_exceptions(errs): + def record(args): + errs.append(args.exc_type) + return record + + errs = [] + for func in [_apply, _map]: + with self.subTest(func=func): + saved_hook = threading.excepthook + threading.excepthook = record_exceptions(errs) + try: + with self.Pool(1) as pool: + res = func(pool, noop, callback=raising) + res.get() + finally: + threading.excepthook = saved_hook + + self.assertEqual(errs, [KeyError, KeyError]) + +def noop(*args): + pass + +def raising(*args): raise KeyError("key") def unpickleable_result(): diff --git a/Misc/NEWS.d/next/Library/2025-05-27-01-06-25.gh-issue-83371.-oeZI3.rst b/Misc/NEWS.d/next/Library/2025-05-27-01-06-25.gh-issue-83371.-oeZI3.rst new file mode 100644 index 00000000000000..ecb61103d15c0c --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-05-27-01-06-25.gh-issue-83371.-oeZI3.rst @@ -0,0 +1,3 @@ +Handle exceptions thrown by callbacks passed to +:class:`multiprocessing.pool.Pool` ``*_async`` methods, preventing them from +breaking the pool.