Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions pyiceberg/io/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import json
import logging
import os
import threading
from copy import copy
from functools import lru_cache, partial
from typing import (
Expand Down Expand Up @@ -370,7 +371,7 @@ class FsspecFileIO(FileIO):
def __init__(self, properties: Properties):
self._scheme_to_fs = {}
self._scheme_to_fs.update(SCHEME_TO_FS)
self.get_fs: Callable[[str], AbstractFileSystem] = lru_cache(self._get_fs)
self._thread_locals = threading.local()
super().__init__(properties=properties)

def new_input(self, location: str) -> FsspecInputFile:
Expand Down Expand Up @@ -416,6 +417,13 @@ def delete(self, location: Union[str, InputFile, OutputFile]) -> None:
fs = self.get_fs(uri.scheme)
fs.rm(str_location)

def get_fs(self, scheme: str) -> AbstractFileSystem:
"""Get a filesystem for a specific scheme, cached per thread."""
if not hasattr(self._thread_locals, "get_fs_cached"):
self._thread_locals.get_fs_cached = lru_cache(self._get_fs)

return self._thread_locals.get_fs_cached(scheme)

def _get_fs(self, scheme: str) -> AbstractFileSystem:
"""Get a filesystem for a specific scheme."""
if scheme not in self._scheme_to_fs:
Expand All @@ -425,10 +433,10 @@ def _get_fs(self, scheme: str) -> AbstractFileSystem:
def __getstate__(self) -> Dict[str, Any]:
"""Create a dictionary of the FsSpecFileIO fields used when pickling."""
fileio_copy = copy(self.__dict__)
fileio_copy["get_fs"] = None
del fileio_copy["_thread_locals"]
return fileio_copy

def __setstate__(self, state: Dict[str, Any]) -> None:
"""Deserialize the state into a FsSpecFileIO instance."""
self.__dict__ = state
self.get_fs = lru_cache(self._get_fs)
self._thread_locals = threading.local()
39 changes: 39 additions & 0 deletions tests/io/test_fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
import os
import pickle
import tempfile
import threading
import uuid
from typing import List
from unittest import mock

import pytest
from botocore.awsrequest import AWSRequest
from fsspec.implementations.local import LocalFileSystem
from fsspec.spec import AbstractFileSystem
from requests_mock import Mocker

from pyiceberg.exceptions import SignError
Expand Down Expand Up @@ -54,6 +57,42 @@ def test_fsspec_local_fs_can_create_path_without_parent_dir(fsspec_fileio: Fsspe
pytest.fail("Failed to write to file without parent directory")


def test_fsspec_get_fs_instance_per_thread_caching(fsspec_fileio: FsspecFileIO) -> None:
"""Test that filesystem instances are cached per-thread by `FsspecFileIO.get_fs`"""
fs_instances: List[AbstractFileSystem] = []
start_work_events: List[threading.Event] = [threading.Event() for _ in range(2)]

def get_fs(start_work_event: threading.Event) -> None:
# Wait to be told to actually start getting the filesystem instances
start_work_event.wait()

# Call twice to ensure caching within the same thread
for _ in range(2):
fs_instances.append(fsspec_fileio.get_fs("file"))

threads = [threading.Thread(target=get_fs, args=[start_work_event]) for start_work_event in start_work_events]

# Start both threads (which will immediately block on their `Event`s) as we want to ensure distinct
# `threading.get_ident()` values that are used in the `fsspec.spec.AbstractFileSystem`s cache keys..
for thread in threads:
thread.start()

# Get the filesystem instances in the first thread and wait for completion
start_work_events[0].set()
threads[0].join()

# Get the filesystem instances in the second thread and wait for completion
start_work_events[1].set()
threads[1].join()

# Same thread, same instance
assert fs_instances[0] is fs_instances[1]
assert fs_instances[2] is fs_instances[3]

# Different threads, different instances
assert fs_instances[0] is not fs_instances[2]


@pytest.mark.s3
def test_fsspec_new_input_file(fsspec_fileio: FsspecFileIO) -> None:
"""Test creating a new input file from a fsspec file-io"""
Expand Down