diff --git a/pyiceberg/utils/concurrent.py b/pyiceberg/utils/concurrent.py index 751cbd9bbb..54e99dc0ba 100644 --- a/pyiceberg/utils/concurrent.py +++ b/pyiceberg/utils/concurrent.py @@ -16,6 +16,7 @@ # under the License. """Concurrency concepts that support efficient multi-threading.""" +import os from concurrent.futures import Executor, ThreadPoolExecutor from typing import Optional @@ -24,6 +25,7 @@ class ExecutorFactory: _instance: Optional[Executor] = None + _instance_pid: Optional[int] = None @staticmethod def max_workers() -> Optional[int]: @@ -33,6 +35,13 @@ def max_workers() -> Optional[int]: @staticmethod def get_or_create() -> Executor: """Return the same executor in each call.""" + # ThreadPoolExecutor cannot be shared across processes. If a new pid is found it means + # there is a new process so a new executor is needed. Otherwise, the executor may be in + # an invalid state and tasks submitted will not be started. + if ExecutorFactory._instance_pid != os.getpid(): + ExecutorFactory._instance_pid = os.getpid() + ExecutorFactory._instance = None + if ExecutorFactory._instance is None: max_workers = ExecutorFactory.max_workers() ExecutorFactory._instance = ThreadPoolExecutor(max_workers=max_workers) diff --git a/tests/utils/test_concurrent.py b/tests/utils/test_concurrent.py index 6d730cbe75..c703f764af 100644 --- a/tests/utils/test_concurrent.py +++ b/tests/utils/test_concurrent.py @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. +import multiprocessing import os -from concurrent.futures import ThreadPoolExecutor -from typing import Dict, Optional +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +from typing import Dict, Generator, Optional from unittest import mock import pytest @@ -29,6 +30,41 @@ INVALID_ENV = {"PYICEBERG_MAX_WORKERS": "invalid"} +@pytest.fixture +def fork_process() -> Generator[None, None, None]: + original = multiprocessing.get_start_method() + allowed = multiprocessing.get_all_start_methods() + + assert "fork" in allowed + + multiprocessing.set_start_method("fork", force=True) + + yield + + multiprocessing.set_start_method(original, force=True) + + +@pytest.fixture +def spawn_process() -> Generator[None, None, None]: + original = multiprocessing.get_start_method() + allowed = multiprocessing.get_all_start_methods() + + assert "spawn" in allowed + + multiprocessing.set_start_method("spawn", force=True) + + yield + + multiprocessing.set_start_method(original, force=True) + + +def _use_executor_to_return(value: int) -> int: + # Module level function to enabling pickling for use with ProcessPoolExecutor. + executor = ExecutorFactory.get_or_create() + future = executor.submit(lambda: value) + return future.result() + + def test_create_reused() -> None: first = ExecutorFactory.get_or_create() second = ExecutorFactory.get_or_create() @@ -50,3 +86,38 @@ def test_max_workers() -> None: def test_max_workers_invalid() -> None: with pytest.raises(ValueError): ExecutorFactory.max_workers() + + +@pytest.mark.parametrize( + "fixture_name", + [ + pytest.param( + "fork_process", + marks=pytest.mark.skipif( + "fork" not in multiprocessing.get_all_start_methods(), reason="Fork start method is not available" + ), + ), + pytest.param( + "spawn_process", + marks=pytest.mark.skipif( + "spawn" not in multiprocessing.get_all_start_methods(), reason="Spawn start method is not available" + ), + ), + ], +) +def test_use_executor_in_different_process(fixture_name: str, request: pytest.FixtureRequest) -> None: + # Use the fixture, which sets up fork or spawn process start method. + request.getfixturevalue(fixture_name) + + # Use executor in main process to ensure the singleton is initialized. + main_value = _use_executor_to_return(10) + + # Use two separate ProcessPoolExecutors to ensure different processes are used. + with ProcessPoolExecutor() as process_executor: + future1 = process_executor.submit(_use_executor_to_return, 20) + with ProcessPoolExecutor() as process_executor: + future2 = process_executor.submit(_use_executor_to_return, 30) + + assert main_value == 10 + assert future1.result() == 20 + assert future2.result() == 30