Skip to content

Commit 70c05d8

Browse files
committed
Refactor set_forkserver_preload to use _handle_preload helper
Extract preload logic into a separate _handle_preload() function to enable targeted unit testing. Add comprehensive unit tests for both module and __main__ preload with all three on_error modes. New tests: - test_handle_preload_main_on_error_{fail,warn,ignore} - test_handle_preload_module_on_error_{fail,warn,ignore} - test_handle_preload_main_valid - test_handle_preload_combined Total test count increased from 6 to 14 tests.
1 parent 9c3ba84 commit 70c05d8

File tree

2 files changed

+148
-46
lines changed

2 files changed

+148
-46
lines changed

Lib/multiprocessing/forkserver.py

Lines changed: 62 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,67 @@ def ensure_running(self):
219219
#
220220
#
221221

222+
def _handle_preload(preload, main_path=None, sys_path=None, on_error='ignore'):
223+
"""Handle module preloading with configurable error handling.
224+
225+
Args:
226+
preload: List of module names to preload.
227+
main_path: Path to __main__ module if '__main__' is in preload.
228+
sys_path: sys.path to use for imports (None means use current).
229+
on_error: How to handle import errors ('ignore', 'warn', or 'fail').
230+
"""
231+
if not preload:
232+
return
233+
234+
if sys_path is not None:
235+
sys.path[:] = sys_path
236+
237+
if '__main__' in preload and main_path is not None:
238+
process.current_process()._inheriting = True
239+
try:
240+
spawn.import_main_path(main_path)
241+
except Exception as e:
242+
match on_error:
243+
case 'fail':
244+
raise
245+
case 'warn':
246+
import warnings
247+
with warnings.catch_warnings():
248+
warnings.simplefilter('always', ImportWarning)
249+
warnings.warn(
250+
f"Failed to preload __main__ from {main_path!r}: {e}",
251+
ImportWarning,
252+
stacklevel=2
253+
)
254+
case 'ignore':
255+
pass
256+
finally:
257+
del process.current_process()._inheriting
258+
259+
for modname in preload:
260+
try:
261+
__import__(modname)
262+
except ImportError as e:
263+
match on_error:
264+
case 'fail':
265+
raise
266+
case 'warn':
267+
import warnings
268+
with warnings.catch_warnings():
269+
warnings.simplefilter('always', ImportWarning)
270+
warnings.warn(
271+
f"Failed to preload module {modname!r}: {e}",
272+
ImportWarning,
273+
stacklevel=2
274+
)
275+
case 'ignore':
276+
pass
277+
278+
# gh-135335: flush stdout/stderr in case any of the preloaded modules
279+
# wrote to them, otherwise children might inherit buffered data
280+
util._flush_std_streams()
281+
282+
222283
def main(listener_fd, alive_r, preload, main_path=None, sys_path=None,
223284
*, authkey_r=None, on_error='ignore'):
224285
"""Run forkserver."""
@@ -231,52 +292,7 @@ def main(listener_fd, alive_r, preload, main_path=None, sys_path=None,
231292
else:
232293
authkey = b''
233294

234-
if preload:
235-
if sys_path is not None:
236-
sys.path[:] = sys_path
237-
if '__main__' in preload and main_path is not None:
238-
process.current_process()._inheriting = True
239-
try:
240-
spawn.import_main_path(main_path)
241-
except Exception as e:
242-
match on_error:
243-
case 'fail':
244-
raise
245-
case 'warn':
246-
import warnings
247-
with warnings.catch_warnings():
248-
warnings.simplefilter('always', ImportWarning)
249-
warnings.warn(
250-
f"Failed to preload __main__ from {main_path!r}: {e}",
251-
ImportWarning,
252-
stacklevel=2
253-
)
254-
case 'ignore':
255-
pass
256-
finally:
257-
del process.current_process()._inheriting
258-
for modname in preload:
259-
try:
260-
__import__(modname)
261-
except ImportError as e:
262-
match on_error:
263-
case 'fail':
264-
raise
265-
case 'warn':
266-
import warnings
267-
with warnings.catch_warnings():
268-
warnings.simplefilter('always', ImportWarning)
269-
warnings.warn(
270-
f"Failed to preload module {modname!r}: {e}",
271-
ImportWarning,
272-
stacklevel=2
273-
)
274-
case 'ignore':
275-
pass
276-
277-
# gh-135335: flush stdout/stderr in case any of the preloaded modules
278-
# wrote to them, otherwise children might inherit buffered data
279-
util._flush_std_streams()
295+
_handle_preload(preload, main_path, sys_path, on_error)
280296

281297
util._close_stdin()
282298

Lib/test/test_multiprocessing_forkserver/test_preload.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
import multiprocessing
44
import multiprocessing.forkserver
5+
import os
6+
import tempfile
57
import unittest
8+
from multiprocessing.forkserver import _handle_preload
69

710

811
class TestForkserverPreload(unittest.TestCase):
@@ -116,5 +119,88 @@ def test_preload_invalid_on_error_value(self):
116119
self.assertIn("on_error must be 'ignore', 'warn', or 'fail'", str(cm.exception))
117120

118121

122+
class TestHandlePreload(unittest.TestCase):
123+
"""Unit tests for _handle_preload() function."""
124+
125+
def test_handle_preload_main_on_error_fail(self):
126+
"""Test that __main__ import failures raise with on_error='fail'."""
127+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
128+
f.write('raise RuntimeError("test error in __main__")\n')
129+
bad_main_path = f.name
130+
131+
try:
132+
with self.assertRaises(RuntimeError) as cm:
133+
_handle_preload(['__main__'], main_path=bad_main_path, on_error='fail')
134+
self.assertIn("test error in __main__", str(cm.exception))
135+
finally:
136+
os.unlink(bad_main_path)
137+
138+
def test_handle_preload_main_on_error_warn(self):
139+
"""Test that __main__ import failures warn with on_error='warn'."""
140+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
141+
f.write('raise ImportError("test import error")\n')
142+
bad_main_path = f.name
143+
144+
try:
145+
with self.assertWarns(ImportWarning) as cm:
146+
_handle_preload(['__main__'], main_path=bad_main_path, on_error='warn')
147+
self.assertIn("Failed to preload __main__", str(cm.warning))
148+
self.assertIn("test import error", str(cm.warning))
149+
finally:
150+
os.unlink(bad_main_path)
151+
152+
def test_handle_preload_main_on_error_ignore(self):
153+
"""Test that __main__ import failures are ignored with on_error='ignore'."""
154+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
155+
f.write('raise ImportError("test import error")\n')
156+
bad_main_path = f.name
157+
158+
try:
159+
# Should not raise
160+
_handle_preload(['__main__'], main_path=bad_main_path, on_error='ignore')
161+
finally:
162+
os.unlink(bad_main_path)
163+
164+
def test_handle_preload_main_valid(self):
165+
"""Test that valid __main__ preload works."""
166+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
167+
f.write('test_var = 42\n')
168+
valid_main_path = f.name
169+
170+
try:
171+
_handle_preload(['__main__'], main_path=valid_main_path, on_error='fail')
172+
# Should complete without raising
173+
finally:
174+
os.unlink(valid_main_path)
175+
176+
def test_handle_preload_module_on_error_fail(self):
177+
"""Test that module import failures raise with on_error='fail'."""
178+
with self.assertRaises(ModuleNotFoundError):
179+
_handle_preload(['nonexistent_test_module_xyz'], on_error='fail')
180+
181+
def test_handle_preload_module_on_error_warn(self):
182+
"""Test that module import failures warn with on_error='warn'."""
183+
with self.assertWarns(ImportWarning) as cm:
184+
_handle_preload(['nonexistent_test_module_xyz'], on_error='warn')
185+
self.assertIn("Failed to preload module", str(cm.warning))
186+
187+
def test_handle_preload_module_on_error_ignore(self):
188+
"""Test that module import failures are ignored with on_error='ignore'."""
189+
# Should not raise
190+
_handle_preload(['nonexistent_test_module_xyz'], on_error='ignore')
191+
192+
def test_handle_preload_combined(self):
193+
"""Test preloading both __main__ and modules."""
194+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
195+
f.write('import sys\n')
196+
valid_main_path = f.name
197+
198+
try:
199+
_handle_preload(['__main__', 'os', 'sys'], main_path=valid_main_path, on_error='fail')
200+
# Should complete without raising
201+
finally:
202+
os.unlink(valid_main_path)
203+
204+
119205
if __name__ == '__main__':
120206
unittest.main()

0 commit comments

Comments
 (0)