Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pyiceberg/utils/concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
"""Concurrency concepts that support efficient multi-threading."""

import os
from concurrent.futures import Executor, ThreadPoolExecutor
from typing import Optional

Expand All @@ -24,6 +25,7 @@

class ExecutorFactory:
_instance: Optional[Executor] = None
_instance_pid: Optional[int] = None

@staticmethod
def max_workers() -> Optional[int]:
Expand All @@ -33,6 +35,13 @@ def max_workers() -> Optional[int]:
@staticmethod
def get_or_create() -> Executor:
"""Return the same executor in each call."""
# ThreadPoolExecutor cannot be shared across processes. If a new pid is found it means
# there is a new process so a new executor is needed. Otherwise, the executor may be in
# an invalid state and tasks submitted will not be started.
if ExecutorFactory._instance_pid != os.getpid():
ExecutorFactory._instance_pid = os.getpid()
ExecutorFactory._instance = None

if ExecutorFactory._instance is None:
max_workers = ExecutorFactory.max_workers()
ExecutorFactory._instance = ThreadPoolExecutor(max_workers=max_workers)
Expand Down
75 changes: 73 additions & 2 deletions tests/utils/test_concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
# specific language governing permissions and limitations
# under the License.

import multiprocessing
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Optional
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from typing import Dict, Generator, Optional
from unittest import mock

import pytest
Expand All @@ -29,6 +30,41 @@
INVALID_ENV = {"PYICEBERG_MAX_WORKERS": "invalid"}


@pytest.fixture
def fork_process() -> Generator[None, None, None]:
original = multiprocessing.get_start_method()
allowed = multiprocessing.get_all_start_methods()

assert "fork" in allowed

multiprocessing.set_start_method("fork", force=True)

yield

multiprocessing.set_start_method(original, force=True)


@pytest.fixture
def spawn_process() -> Generator[None, None, None]:
original = multiprocessing.get_start_method()
allowed = multiprocessing.get_all_start_methods()

assert "spawn" in allowed

multiprocessing.set_start_method("spawn", force=True)

yield

multiprocessing.set_start_method(original, force=True)


def _use_executor_to_return(value: int) -> int:
# Module level function to enabling pickling for use with ProcessPoolExecutor.
executor = ExecutorFactory.get_or_create()
future = executor.submit(lambda: value)
return future.result()


def test_create_reused() -> None:
first = ExecutorFactory.get_or_create()
second = ExecutorFactory.get_or_create()
Expand All @@ -50,3 +86,38 @@ def test_max_workers() -> None:
def test_max_workers_invalid() -> None:
with pytest.raises(ValueError):
ExecutorFactory.max_workers()


@pytest.mark.parametrize(
"fixture_name",
[
pytest.param(
"fork_process",
marks=pytest.mark.skipif(
"fork" not in multiprocessing.get_all_start_methods(), reason="Fork start method is not available"
),
),
pytest.param(
"spawn_process",
marks=pytest.mark.skipif(
"spawn" not in multiprocessing.get_all_start_methods(), reason="Spawn start method is not available"
),
),
],
)
def test_use_executor_in_different_process(fixture_name: str, request: pytest.FixtureRequest) -> None:
# Use the fixture, which sets up fork or spawn process start method.
request.getfixturevalue(fixture_name)

# Use executor in main process to ensure the singleton is initialized.
main_value = _use_executor_to_return(10)

# Use two separate ProcessPoolExecutors to ensure different processes are used.
with ProcessPoolExecutor() as process_executor:
future1 = process_executor.submit(_use_executor_to_return, 20)
with ProcessPoolExecutor() as process_executor:
future2 = process_executor.submit(_use_executor_to_return, 30)

assert main_value == 10
assert future1.result() == 20
assert future2.result() == 30