diff --git a/dev/Dockerfile b/dev/Dockerfile index 97f6ac642f..3b7e593a1a 100644 --- a/dev/Dockerfile +++ b/dev/Dockerfile @@ -52,7 +52,7 @@ RUN curl --retry 5 -s https://repository.apache.org/content/groups/snapshots/org # Download AWS bundle -RUN curl --retry 5 -s https://repository.apache.org/content/groups/snapshots/org/apache/iceberg/iceberg-aws-bundle/1.9.0-SNAPSHOT/iceberg-aws-bundle-1.9.0-20250408.002722-86.jar \ +RUN curl --retry 5 -s https://repository.apache.org/content/groups/snapshots/org/apache/iceberg/iceberg-aws-bundle/1.9.0-SNAPSHOT/iceberg-aws-bundle-1.9.0-20250409.002731-88.jar \ -Lo /opt/spark/jars/iceberg-aws-bundle-${ICEBERG_VERSION}.jar COPY spark-defaults.conf /opt/spark/conf diff --git a/pyiceberg/catalog/rest.py b/pyiceberg/catalog/rest/__init__.py similarity index 100% rename from pyiceberg/catalog/rest.py rename to pyiceberg/catalog/rest/__init__.py diff --git a/pyiceberg/catalog/rest/auth.py b/pyiceberg/catalog/rest/auth.py new file mode 100644 index 0000000000..041a8a4cd1 --- /dev/null +++ b/pyiceberg/catalog/rest/auth.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import base64 +from abc import ABC, abstractmethod +from typing import Optional + +from requests import PreparedRequest +from requests.auth import AuthBase + + +class AuthManager(ABC): + """ + Abstract base class for Authentication Managers used to supply authorization headers to HTTP clients (e.g. requests.Session). + + Subclasses must implement the `auth_header` method to return an Authorization header value. + """ + + @abstractmethod + def auth_header(self) -> Optional[str]: + """Return the Authorization header value, or None if not applicable.""" + + +class NoopAuthManager(AuthManager): + def auth_header(self) -> Optional[str]: + return None + + +class BasicAuthManager(AuthManager): + def __init__(self, username: str, password: str): + credentials = f"{username}:{password}" + self._token = base64.b64encode(credentials.encode()).decode() + + def auth_header(self) -> str: + return f"Basic {self._token}" + + +class AuthManagerAdapter(AuthBase): + """A `requests.auth.AuthBase` adapter that integrates an `AuthManager` into a `requests.Session` to automatically attach the appropriate Authorization header to every request. + + This adapter is useful when working with `requests.Session.auth` + and allows reuse of authentication strategies defined by `AuthManager`. + This AuthManagerAdapter is only intended to be used against the REST Catalog + Server that expects the Authorization Header. + """ + + def __init__(self, auth_manager: AuthManager): + """ + Initialize AuthManagerAdapter. + + Args: + auth_manager (AuthManager): An instance of an AuthManager subclass. + """ + self.auth_manager = auth_manager + + def __call__(self, request: PreparedRequest) -> PreparedRequest: + """ + Modify the outgoing request to include the Authorization header. + + Args: + request (requests.PreparedRequest): The HTTP request being prepared. + + Returns: + requests.PreparedRequest: The modified request with Authorization header. + """ + if auth_header := self.auth_manager.auth_header(): + request.headers["Authorization"] = auth_header + return request diff --git a/tests/catalog/test_rest_auth.py b/tests/catalog/test_rest_auth.py new file mode 100644 index 0000000000..3d3d4a807d --- /dev/null +++ b/tests/catalog/test_rest_auth.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import base64 + +import pytest +import requests +from requests_mock import Mocker + +from pyiceberg.catalog.rest.auth import AuthManagerAdapter, BasicAuthManager, NoopAuthManager + +TEST_URI = "https://iceberg-test-catalog/" + + +@pytest.fixture +def rest_mock(requests_mock: Mocker) -> Mocker: + requests_mock.get( + TEST_URI, + json={}, + status_code=200, + ) + return requests_mock + + +def test_noop_auth_header(rest_mock: Mocker) -> None: + auth_manager = NoopAuthManager() + session = requests.Session() + session.auth = AuthManagerAdapter(auth_manager) + + session.get(TEST_URI) + history = rest_mock.request_history + assert len(history) == 1 + actual_headers = history[0].headers + assert "Authorization" not in actual_headers + + +def test_basic_auth_header(rest_mock: Mocker) -> None: + username = "testuser" + password = "testpassword" + expected_token = base64.b64encode(f"{username}:{password}".encode()).decode() + expected_header = f"Basic {expected_token}" + + auth_manager = BasicAuthManager(username=username, password=password) + session = requests.Session() + session.auth = AuthManagerAdapter(auth_manager) + + session.get(TEST_URI) + history = rest_mock.request_history + assert len(history) == 1 + actual_headers = history[0].headers + assert actual_headers["Authorization"] == expected_header