Skip to content

Commit f806f51

Browse files
author
Ian Atkinson
committed
Added support for multiprocessing executor
1 parent 2bff5ef commit f806f51

File tree

2 files changed

+82
-1
lines changed

2 files changed

+82
-1
lines changed

pyiceberg/utils/concurrent.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
"""Concurrency concepts that support efficient multi-threading."""
1818

19+
import os
1920
from concurrent.futures import Executor, ThreadPoolExecutor
2021
from typing import Optional
2122

@@ -24,6 +25,7 @@
2425

2526
class ExecutorFactory:
2627
_instance: Optional[Executor] = None
28+
_instance_pid: Optional[int] = None
2729

2830
@staticmethod
2931
def max_workers() -> Optional[int]:
@@ -33,6 +35,14 @@ def max_workers() -> Optional[int]:
3335
@staticmethod
3436
def get_or_create() -> Executor:
3537
"""Return the same executor in each call."""
38+
39+
# ThreadPoolExecutor cannot be shared across processes. If a new pid is found it means
40+
# there has been a fork so a new exector is needed. Otherwise, the executor may be in
41+
# an invalid state and tasks submitted will not be started.
42+
if ExecutorFactory._instance_pid != os.getpid():
43+
ExecutorFactory._instance_pid = os.getpid()
44+
ExecutorFactory._instance = None
45+
3646
if ExecutorFactory._instance is None:
3747
max_workers = ExecutorFactory.max_workers()
3848
ExecutorFactory._instance = ThreadPoolExecutor(max_workers=max_workers)

tests/utils/test_concurrent.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import multiprocessing
1819
import os
19-
from concurrent.futures import ThreadPoolExecutor
20+
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
2021
from typing import Dict, Optional
2122
from unittest import mock
2223

@@ -28,6 +29,39 @@
2829
VALID_ENV = {"PYICEBERG_MAX_WORKERS": "5"}
2930
INVALID_ENV = {"PYICEBERG_MAX_WORKERS": "invalid"}
3031

32+
@pytest.fixture
33+
def fork_process():
34+
original = multiprocessing.get_start_method()
35+
allowed = multiprocessing.get_all_start_methods()
36+
37+
assert "fork" in allowed
38+
39+
multiprocessing.set_start_method("fork", force=True)
40+
41+
yield
42+
43+
multiprocessing.set_start_method(original, force=True)
44+
45+
46+
@pytest.fixture
47+
def spawn_process():
48+
original = multiprocessing.get_start_method()
49+
allowed = multiprocessing.get_all_start_methods()
50+
51+
assert "spawn" in allowed
52+
53+
multiprocessing.set_start_method("spawn", force=True)
54+
55+
yield
56+
57+
multiprocessing.set_start_method(original, force=True)
58+
59+
60+
def _use_executor_to_return(value):
61+
executor = ExecutorFactory.get_or_create()
62+
future = executor.submit(lambda: value)
63+
return future.result()
64+
3165

3266
def test_create_reused() -> None:
3367
first = ExecutorFactory.get_or_create()
@@ -50,3 +84,40 @@ def test_max_workers() -> None:
5084
def test_max_workers_invalid() -> None:
5185
with pytest.raises(ValueError):
5286
ExecutorFactory.max_workers()
87+
88+
89+
@pytest.mark.parametrize(
90+
"fixture",
91+
[
92+
pytest.param(
93+
"fork_process",
94+
marks=pytest.mark.skipif(
95+
"fork" not in multiprocessing.get_all_start_methods(), reason="Fork start method is not available"
96+
),
97+
),
98+
pytest.param(
99+
"spawn_process",
100+
marks=pytest.mark.skipif(
101+
"spawn" not in multiprocessing.get_all_start_methods(), reason="Spawn start method is not available"
102+
),
103+
),
104+
],
105+
)
106+
def test_use_executor_in_different_process(fixture, request):
107+
# Use the fixture
108+
request.getfixturevalue(fixture)
109+
110+
main_value = _use_executor_to_return(10)
111+
112+
with ProcessPoolExecutor() as process_executor:
113+
future1 = process_executor.submit(_use_executor_to_return, 20)
114+
with ProcessPoolExecutor() as process_executor:
115+
future2 = process_executor.submit(_use_executor_to_return, 30)
116+
117+
assert main_value == 10
118+
assert future1.result() == 20
119+
assert future2.result() == 30
120+
121+
122+
if __name__ == "__main__":
123+
pytest.main([__file__])

0 commit comments

Comments
 (0)