diff --git a/pyiceberg/io/fsspec.py b/pyiceberg/io/fsspec.py index 55368e1948..378370311b 100644 --- a/pyiceberg/io/fsspec.py +++ b/pyiceberg/io/fsspec.py @@ -16,18 +16,20 @@ # under the License. """FileIO implementation for reading and writing table files that uses fsspec compatible filesystems.""" +import abc import errno import json import logging import os import threading from copy import copy -from functools import lru_cache, partial +from functools import lru_cache from typing import ( TYPE_CHECKING, Any, Callable, Dict, + Type, Union, ) from urllib.parse import urlparse @@ -95,38 +97,58 @@ from botocore.awsrequest import AWSRequest -def s3v4_rest_signer(properties: Properties, request: "AWSRequest", **_: Any) -> "AWSRequest": - signer_url = properties.get(S3_SIGNER_URI, properties[URI]).rstrip("/") # type: ignore - signer_endpoint = properties.get(S3_SIGNER_ENDPOINT, S3_SIGNER_ENDPOINT_DEFAULT) +class S3RequestSigner(abc.ABC): + """Abstract base class for S3 request signers.""" - signer_headers = {} - if token := properties.get(TOKEN): - signer_headers = {"Authorization": f"Bearer {token}"} - signer_headers.update(get_header_properties(properties)) + properties: Properties - signer_body = { - "method": request.method, - "region": request.context["client_region"], - "uri": request.url, - "headers": {key: [val] for key, val in request.headers.items()}, - } + def __init__(self, properties: Properties) -> None: + self.properties = properties + + @abc.abstractmethod + def __call__(self, request: "AWSRequest", **_: Any) -> None: + pass + + +class S3V4RestSigner(S3RequestSigner): + """An S3 request signer that uses an external REST signing service to sign requests.""" + + _session: requests.Session - response = requests.post(f"{signer_url}/{signer_endpoint.strip()}", headers=signer_headers, json=signer_body) - try: - response.raise_for_status() - response_json = response.json() - except HTTPError as e: - raise SignError(f"Failed to sign request {response.status_code}: {signer_body}") from e + def __init__(self, properties: Properties) -> None: + super().__init__(properties) + self._session = requests.Session() - for key, value in response_json["headers"].items(): - request.headers.add_header(key, ", ".join(value)) + def __call__(self, request: "AWSRequest", **_: Any) -> None: + signer_url = self.properties.get(S3_SIGNER_URI, self.properties[URI]).rstrip("/") # type: ignore + signer_endpoint = self.properties.get(S3_SIGNER_ENDPOINT, S3_SIGNER_ENDPOINT_DEFAULT) + + signer_headers = {} + if token := self.properties.get(TOKEN): + signer_headers = {"Authorization": f"Bearer {token}"} + signer_headers.update(get_header_properties(self.properties)) + + signer_body = { + "method": request.method, + "region": request.context["client_region"], + "uri": request.url, + "headers": {key: [val] for key, val in request.headers.items()}, + } + + response = self._session.post(f"{signer_url}/{signer_endpoint.strip()}", headers=signer_headers, json=signer_body) + try: + response.raise_for_status() + response_json = response.json() + except HTTPError as e: + raise SignError(f"Failed to sign request {response.status_code}: {signer_body}") from e - request.url = response_json["uri"] + for key, value in response_json["headers"].items(): + request.headers.add_header(key, ", ".join(value)) - return request + request.url = response_json["uri"] -SIGNERS: Dict[str, Callable[[Properties, "AWSRequest"], "AWSRequest"]] = {"S3V4RestSigner": s3v4_rest_signer} +SIGNERS: Dict[str, Type[S3RequestSigner]] = {"S3V4RestSigner": S3V4RestSigner} def _file(_: Properties) -> LocalFileSystem: @@ -144,13 +166,13 @@ def _s3(properties: Properties) -> AbstractFileSystem: "region_name": get_first_property_value(properties, S3_REGION, AWS_REGION), } config_kwargs = {} - register_events: Dict[str, Callable[[Properties], None]] = {} + register_events: Dict[str, Callable[[AWSRequest], None]] = {} if signer := properties.get(S3_SIGNER): logger.info("Loading signer %s", signer) - if signer_func := SIGNERS.get(signer): - signer_func_with_properties = partial(signer_func, properties) - register_events["before-sign.s3"] = signer_func_with_properties + if signer_cls := SIGNERS.get(signer): + signer = signer_cls(properties) + register_events["before-sign.s3"] = signer # Disable the AWS Signer from botocore import UNSIGNED diff --git a/tests/io/test_fsspec.py b/tests/io/test_fsspec.py index 6924d6b1c3..d7d0c3c1e0 100644 --- a/tests/io/test_fsspec.py +++ b/tests/io/test_fsspec.py @@ -31,7 +31,7 @@ from pyiceberg.exceptions import SignError from pyiceberg.io import fsspec -from pyiceberg.io.fsspec import FsspecFileIO, s3v4_rest_signer +from pyiceberg.io.fsspec import FsspecFileIO, S3V4RestSigner from pyiceberg.io.pyarrow import PyArrowFileIO from pyiceberg.typedef import Properties from tests.conftest import UNIFIED_AWS_SESSION_PROPERTIES @@ -814,10 +814,11 @@ def test_s3v4_rest_signer(requests_mock: Mocker) -> None: "retries": {"attempt": 1, "invocation-id": "75d143fb-0219-439b-872c-18213d1c8d54"}, } - signed_request = s3v4_rest_signer({"token": "abc", "uri": TEST_URI, "header.X-Custom-Header": "value"}, request) + signer = S3V4RestSigner(properties={"token": "abc", "uri": TEST_URI, "header.X-Custom-Header": "value"}) + signer(request) - assert signed_request.url == new_uri - assert dict(signed_request.headers) == { + assert request.url == new_uri + assert dict(request.headers) == { "Authorization": "AWS4-HMAC-SHA256 Credential=ASIAQPRZZYGHUT57DL3I/20221017/us-west-2/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date;x-amz-security-token, Signature=430582a17d61ab02c272896fa59195f277af4bdf2121c441685e589f044bbe02", "Host": "bucket.s3.us-west-2.amazonaws.com", "User-Agent": "Botocore/1.27.59 Python/3.10.7 Darwin/21.5.0", @@ -868,10 +869,11 @@ def test_s3v4_rest_signer_endpoint(requests_mock: Mocker) -> None: "retries": {"attempt": 1, "invocation-id": "75d143fb-0219-439b-872c-18213d1c8d54"}, } - signed_request = s3v4_rest_signer({"token": "abc", "uri": TEST_URI, "s3.signer.endpoint": endpoint}, request) + signer = S3V4RestSigner(properties={"token": "abc", "uri": TEST_URI, "s3.signer.endpoint": endpoint}) + signer(request) - assert signed_request.url == new_uri - assert dict(signed_request.headers) == { + assert request.url == new_uri + assert dict(request.headers) == { "Authorization": "AWS4-HMAC-SHA256 Credential=ASIAQPRZZYGHUT57DL3I/20221017/us-west-2/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date;x-amz-security-token, Signature=430582a17d61ab02c272896fa59195f277af4bdf2121c441685e589f044bbe02", "Host": "bucket.s3.us-west-2.amazonaws.com", "User-Agent": "Botocore/1.27.59 Python/3.10.7 Darwin/21.5.0", @@ -909,8 +911,9 @@ def test_s3v4_rest_signer_forbidden(requests_mock: Mocker) -> None: "retries": {"attempt": 1, "invocation-id": "75d143fb-0219-439b-872c-18213d1c8d54"}, } + signer = S3V4RestSigner(properties={"token": "abc", "uri": TEST_URI}) with pytest.raises(SignError) as exc_info: - _ = s3v4_rest_signer({"token": "abc", "uri": TEST_URI}, request) + signer(request) assert ( """Failed to sign request 401: {'method': 'HEAD', 'region': 'us-west-2', 'uri': 'https://bucket/metadata/snap-8048355899640248710-1-a5c8ea2d-aa1f-48e8-89f4-1fa69db8c742.avro', 'headers': {'User-Agent': ['Botocore/1.27.59 Python/3.10.7 Darwin/21.5.0']}}"""