Skip to content
37 changes: 23 additions & 14 deletions pyiceberg/io/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Any,
Callable,
Dict,
Optional,
Union,
)
from urllib.parse import urlparse
Expand Down Expand Up @@ -194,13 +195,18 @@ def _gs(properties: Properties) -> AbstractFileSystem:
def _adls(properties: Properties) -> AbstractFileSystem:
from adlfs import AzureBlobFileSystem

for key, sas_token in {
key.replace(f"{ADLS_SAS_TOKEN}.", ""): value for key, value in properties.items() if key.startswith(ADLS_SAS_TOKEN)
}.items():
if ADLS_ACCOUNT_NAME not in properties:
properties[ADLS_ACCOUNT_NAME] = key.split(".")[0]
if ADLS_SAS_TOKEN not in properties:
properties[ADLS_SAS_TOKEN] = sas_token
# https://learn.microsoft.com/en-us/azure/storage/blobs/data-lake-storage-introduction-abfs-uri#uri-syntax
if netloc := properties.get("netloc"):
account_uri = netloc.split("@")[-1]
else:
account_uri = None

if not properties.get(ADLS_ACCOUNT_NAME) and account_uri:
properties[ADLS_ACCOUNT_NAME] = account_uri.split(".")[0]

# Fixes https://github.com/apache/iceberg-python/issues/1146
if not properties.get(ADLS_SAS_TOKEN) and account_uri:
properties[ADLS_SAS_TOKEN] = properties.get(f"{ADLS_SAS_TOKEN}.{account_uri}")

return AzureBlobFileSystem(
connection_string=properties.get(ADLS_CONNECTION_STRING),
Expand Down Expand Up @@ -340,7 +346,7 @@ class FsspecFileIO(FileIO):
def __init__(self, properties: Properties):
self._scheme_to_fs = {}
self._scheme_to_fs.update(SCHEME_TO_FS)
self.get_fs: Callable[[str], AbstractFileSystem] = lru_cache(self._get_fs)
self.get_fs: Callable[[str, Optional[str]], AbstractFileSystem] = lru_cache(self._get_fs)
super().__init__(properties=properties)

def new_input(self, location: str) -> FsspecInputFile:
Expand All @@ -353,7 +359,7 @@ def new_input(self, location: str) -> FsspecInputFile:
FsspecInputFile: An FsspecInputFile instance for the given location.
"""
uri = urlparse(location)
fs = self.get_fs(uri.scheme)
fs = self.get_fs(uri.scheme, uri.netloc)
return FsspecInputFile(location=location, fs=fs)

def new_output(self, location: str) -> FsspecOutputFile:
Expand All @@ -366,7 +372,7 @@ def new_output(self, location: str) -> FsspecOutputFile:
FsspecOutputFile: An FsspecOutputFile instance for the given location.
"""
uri = urlparse(location)
fs = self.get_fs(uri.scheme)
fs = self.get_fs(uri.scheme, uri.netloc)
return FsspecOutputFile(location=location, fs=fs)

def delete(self, location: Union[str, InputFile, OutputFile]) -> None:
Expand All @@ -383,14 +389,17 @@ def delete(self, location: Union[str, InputFile, OutputFile]) -> None:
str_location = location

uri = urlparse(str_location)
fs = self.get_fs(uri.scheme)
fs = self.get_fs(uri.scheme, uri.netloc)
fs.rm(str_location)

def _get_fs(self, scheme: str) -> AbstractFileSystem:
"""Get a filesystem for a specific scheme."""
def _get_fs(self, scheme: str, netloc: Optional[str] = None) -> AbstractFileSystem:
"""Get a filesystem for a specific scheme and netloc."""
if scheme not in self._scheme_to_fs:
raise ValueError(f"No registered filesystem for scheme: {scheme}")
return self._scheme_to_fs[scheme](self.properties)
properties = self.properties.copy()
if netloc:
properties["netloc"] = netloc
return self._scheme_to_fs[scheme](properties)

def __getstate__(self) -> Dict[str, Any]:
"""Create a dictionary of the FsSpecFileIO fields used when pickling."""
Expand Down
Loading