From 7c416e2d016bc438a77a6f406529cb0198b48388 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+sungwy@users.noreply.github.com> Date: Fri, 2 May 2025 03:52:52 +0000 Subject: [PATCH 1/4] legacyoauth2 --- pyiceberg/catalog/rest/__init__.py | 174 ++++++----------------------- pyiceberg/catalog/rest/auth.py | 98 +++++++++++++++- pyiceberg/catalog/rest/util.py | 119 ++++++++++++++++++++ 3 files changed, 250 insertions(+), 141 deletions(-) create mode 100644 pyiceberg/catalog/rest/util.py diff --git a/pyiceberg/catalog/rest/__init__.py b/pyiceberg/catalog/rest/__init__.py index 8ee9e5fdc9..85380c7d34 100644 --- a/pyiceberg/catalog/rest/__init__.py +++ b/pyiceberg/catalog/rest/__init__.py @@ -15,21 +15,18 @@ # specific language governing permissions and limitations # under the License. from enum import Enum -from json import JSONDecodeError from typing import ( TYPE_CHECKING, Any, Dict, List, - Literal, Optional, Set, Tuple, - Type, Union, ) -from pydantic import Field, ValidationError, field_validator +from pydantic import Field, field_validator from requests import HTTPError, Session from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt @@ -41,22 +38,18 @@ Catalog, PropertiesUpdateSummary, ) +from pyiceberg.catalog.rest.auth import AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager +from pyiceberg.catalog.rest.util import _handle_non_200_response from pyiceberg.exceptions import ( AuthorizationExpiredError, - BadRequestError, CommitFailedException, CommitStateUnknownException, - ForbiddenError, NamespaceAlreadyExistsError, NamespaceNotEmptyError, NoSuchIdentifierError, NoSuchNamespaceError, NoSuchTableError, NoSuchViewError, - OAuthError, - RESTError, - ServerError, - ServiceUnavailableError, TableAlreadyExistsError, UnauthorizedError, ) @@ -181,15 +174,6 @@ class RegisterTableRequest(IcebergBaseModel): metadata_location: str = Field(..., alias="metadata-location") -class TokenResponse(IcebergBaseModel): - access_token: str = Field() - token_type: str = Field() - expires_in: Optional[int] = Field(default=None) - issued_token_type: Optional[str] = Field(default=None) - refresh_token: Optional[str] = Field(default=None) - scope: Optional[str] = Field(default=None) - - class ConfigResponse(IcebergBaseModel): defaults: Properties = Field() overrides: Properties = Field() @@ -228,24 +212,6 @@ class ListViewsResponse(IcebergBaseModel): identifiers: List[ListViewResponseEntry] = Field() -class ErrorResponseMessage(IcebergBaseModel): - message: str = Field() - type: str = Field() - code: int = Field() - - -class ErrorResponse(IcebergBaseModel): - error: ErrorResponseMessage = Field() - - -class OAuthErrorResponse(IcebergBaseModel): - error: Literal[ - "invalid_request", "invalid_client", "invalid_grant", "unauthorized_client", "unsupported_grant_type", "invalid_scope" - ] - error_description: Optional[str] = None - error_uri: Optional[str] = None - - class RestCatalog(Catalog): uri: str _session: Session @@ -278,8 +244,15 @@ def _create_session(self) -> Session: elif ssl_client_cert := ssl_client.get(CERT): session.cert = ssl_client_cert - self._refresh_token(session, self.properties.get(TOKEN)) + auth_config = { + "session": self._session, + "credential": self.properties.get(CREDENTIAL), + "initial_token": self.properties.get(TOKEN), + "optional_oauth_params": self._extract_optional_oauth_params(), + } + auth_manager = AuthManagerFactory.create("legacyoauth2", auth_config) + session.auth = AuthManagerAdapter(auth_manager) # Set HTTP headers self._config_headers(session) @@ -351,27 +324,6 @@ def _extract_optional_oauth_params(self) -> Dict[str, str]: return optional_oauth_param - def _fetch_access_token(self, session: Session, credential: str) -> str: - if SEMICOLON in credential: - client_id, client_secret = credential.split(SEMICOLON) - else: - client_id, client_secret = None, credential - - data = {GRANT_TYPE: CLIENT_CREDENTIALS, CLIENT_ID: client_id, CLIENT_SECRET: client_secret} - - optional_oauth_params = self._extract_optional_oauth_params() - data.update(optional_oauth_params) - - response = session.post( - url=self.auth_url, data=data, headers={**session.headers, "Content-type": "application/x-www-form-urlencoded"} - ) - try: - response.raise_for_status() - except HTTPError as exc: - self._handle_non_200_response(exc, {400: OAuthError, 401: OAuthError}) - - return TokenResponse.model_validate_json(response.text).access_token - def _fetch_config(self) -> None: params = {} if warehouse_location := self.properties.get(WAREHOUSE_LOCATION): @@ -382,7 +334,7 @@ def _fetch_config(self) -> None: try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {}) + _handle_non_200_response(exc, {}) config_response = ConfigResponse.model_validate_json(response.text) config = config_response.defaults @@ -412,58 +364,6 @@ def _split_identifier_for_json(self, identifier: Union[str, Identifier]) -> Dict identifier_tuple = self._identifier_to_validated_tuple(identifier) return {"namespace": identifier_tuple[:-1], "name": identifier_tuple[-1]} - def _handle_non_200_response(self, exc: HTTPError, error_handler: Dict[int, Type[Exception]]) -> None: - exception: Type[Exception] - - if exc.response is None: - raise ValueError("Did not receive a response") - - code = exc.response.status_code - if code in error_handler: - exception = error_handler[code] - elif code == 400: - exception = BadRequestError - elif code == 401: - exception = UnauthorizedError - elif code == 403: - exception = ForbiddenError - elif code == 422: - exception = RESTError - elif code == 419: - exception = AuthorizationExpiredError - elif code == 501: - exception = NotImplementedError - elif code == 503: - exception = ServiceUnavailableError - elif 500 <= code < 600: - exception = ServerError - else: - exception = RESTError - - try: - if exception == OAuthError: - # The OAuthErrorResponse has a different format - error = OAuthErrorResponse.model_validate_json(exc.response.text) - response = str(error.error) - if description := error.error_description: - response += f": {description}" - if uri := error.error_uri: - response += f" ({uri})" - else: - error = ErrorResponse.model_validate_json(exc.response.text).error - response = f"{error.type}: {error.message}" - except JSONDecodeError: - # In the case we don't have a proper response - response = f"RESTError {exc.response.status_code}: Could not decode json payload: {exc.response.text}" - except ValidationError as e: - # In the case we don't have a proper response - errs = ", ".join(err["msg"] for err in e.errors()) - response = ( - f"RESTError {exc.response.status_code}: Received unexpected JSON Payload: {exc.response.text}, errors: {errs}" - ) - - raise exception(response) from exc - def _init_sigv4(self, session: Session) -> None: from urllib import parse @@ -533,16 +433,12 @@ def _response_to_staged_table(self, identifier_tuple: Tuple[str, ...], table_res catalog=self, ) - def _refresh_token(self, session: Optional[Session] = None, initial_token: Optional[str] = None) -> None: - session = session or self._session - if initial_token is not None: - self.properties[TOKEN] = initial_token - elif CREDENTIAL in self.properties: - self.properties[TOKEN] = self._fetch_access_token(session, self.properties[CREDENTIAL]) - - # Set Auth token for subsequent calls in the session - if token := self.properties.get(TOKEN): - session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}" + def _refresh_token(self) -> None: + # Reactive token refresh is atypical - we should proactively refresh tokens in a separate thread + # instead of retrying on Auth Exceptions. Keeping refresh behavior for the LegacyOAuth2AuthManager + # for backward compatibility + if isinstance(self._session.auth.auth_manager, LegacyOAuth2AuthManager): + self._session.auth.auth_manager._refresh_token() def _config_headers(self, session: Session) -> None: header_properties = get_header_properties(self.properties) @@ -587,7 +483,7 @@ def _create_table( try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {409: TableAlreadyExistsError}) + _handle_non_200_response(exc, {409: TableAlreadyExistsError}) return TableResponse.model_validate_json(response.text) @retry(**_RETRY_ARGS) @@ -660,7 +556,7 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location: try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {409: TableAlreadyExistsError}) + _handle_non_200_response(exc, {409: TableAlreadyExistsError}) table_response = TableResponse.model_validate_json(response.text) return self._response_to_table(self.identifier_to_tuple(identifier), table_response) @@ -673,7 +569,7 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {404: NoSuchNamespaceError}) + _handle_non_200_response(exc, {404: NoSuchNamespaceError}) return [(*table.namespace, table.name) for table in ListTablesResponse.model_validate_json(response.text).identifiers] @retry(**_RETRY_ARGS) @@ -682,7 +578,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table: try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {404: NoSuchTableError}) + _handle_non_200_response(exc, {404: NoSuchTableError}) table_response = TableResponse.model_validate_json(response.text) return self._response_to_table(self.identifier_to_tuple(identifier), table_response) @@ -695,7 +591,7 @@ def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool = try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {404: NoSuchTableError}) + _handle_non_200_response(exc, {404: NoSuchTableError}) @retry(**_RETRY_ARGS) def purge_table(self, identifier: Union[str, Identifier]) -> None: @@ -711,7 +607,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {404: NoSuchTableError, 409: TableAlreadyExistsError}) + _handle_non_200_response(exc, {404: NoSuchTableError, 409: TableAlreadyExistsError}) return self.load_table(to_identifier) @@ -734,7 +630,7 @@ def list_views(self, namespace: Union[str, Identifier]) -> List[Identifier]: try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {404: NoSuchNamespaceError}) + _handle_non_200_response(exc, {404: NoSuchNamespaceError}) return [(*view.namespace, view.name) for view in ListViewsResponse.model_validate_json(response.text).identifiers] @retry(**_RETRY_ARGS) @@ -772,7 +668,7 @@ def commit_table( try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response( + _handle_non_200_response( exc, { 409: CommitFailedException, @@ -791,7 +687,7 @@ def create_namespace(self, namespace: Union[str, Identifier], properties: Proper try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {409: NamespaceAlreadyExistsError}) + _handle_non_200_response(exc, {409: NamespaceAlreadyExistsError}) @retry(**_RETRY_ARGS) def drop_namespace(self, namespace: Union[str, Identifier]) -> None: @@ -801,7 +697,7 @@ def drop_namespace(self, namespace: Union[str, Identifier]) -> None: try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {404: NoSuchNamespaceError, 409: NamespaceNotEmptyError}) + _handle_non_200_response(exc, {404: NoSuchNamespaceError, 409: NamespaceNotEmptyError}) @retry(**_RETRY_ARGS) def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]: @@ -816,7 +712,7 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {}) + _handle_non_200_response(exc, {}) return ListNamespaceResponse.model_validate_json(response.text).namespaces @@ -828,7 +724,7 @@ def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Proper try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {404: NoSuchNamespaceError}) + _handle_non_200_response(exc, {404: NoSuchNamespaceError}) return NamespaceResponse.model_validate_json(response.text).properties @@ -843,7 +739,7 @@ def update_namespace_properties( try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {404: NoSuchNamespaceError}) + _handle_non_200_response(exc, {404: NoSuchNamespaceError}) parsed_response = UpdateNamespacePropertiesResponse.model_validate_json(response.text) return PropertiesUpdateSummary( removed=parsed_response.removed, @@ -865,7 +761,7 @@ def namespace_exists(self, namespace: Union[str, Identifier]) -> bool: try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {}) + _handle_non_200_response(exc, {}) return False @@ -891,7 +787,7 @@ def table_exists(self, identifier: Union[str, Identifier]) -> bool: try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {}) + _handle_non_200_response(exc, {}) return False @@ -916,7 +812,7 @@ def view_exists(self, identifier: Union[str, Identifier]) -> bool: try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {}) + _handle_non_200_response(exc, {}) return False @@ -928,4 +824,4 @@ def drop_view(self, identifier: Union[str]) -> None: try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {404: NoSuchViewError}) + _handle_non_200_response(exc, {404: NoSuchViewError}) diff --git a/pyiceberg/catalog/rest/auth.py b/pyiceberg/catalog/rest/auth.py index 041a8a4cd1..3833ddc571 100644 --- a/pyiceberg/catalog/rest/auth.py +++ b/pyiceberg/catalog/rest/auth.py @@ -16,12 +16,18 @@ # under the License. import base64 +import importlib from abc import ABC, abstractmethod -from typing import Optional +from typing import Any, Dict, Optional, Type -from requests import PreparedRequest +from requests import HTTPError, PreparedRequest, Session from requests.auth import AuthBase +from pyiceberg.catalog.rest.util import TokenResponse, _handle_non_200_response +from pyiceberg.exceptions import OAuthError + +COLON = ":" + class AuthManager(ABC): """ @@ -49,6 +55,54 @@ def auth_header(self) -> str: return f"Basic {self._token}" +class LegacyOAuth2AuthManager(AuthManager): + def __init__( + self, + session: Session, + auth_url: str, + credential: Optional[str] = None, + initial_token: Optional[str] = None, + optional_oauth_params: Optional[Dict[str, str]] = None, + ): + self._token: Optional[str] = None + self._initial_token = initial_token + self._credential = credential + self._auth_url = auth_url + self._optional_oauth_params = optional_oauth_params + self._session = session + self._refresh_token() + + def _fetch_access_token(self, credential: str) -> str: + if COLON in credential: + client_id, client_secret = credential.split(COLON) + else: + client_id, client_secret = None, credential + + data = {"grant_type": "client_credentials", "client_id": client_id, "client_secret": client_secret} + + if self._optional_oauth_params: + data.update(self._optional_oauth_params) + + response = self._session.post( + url=self._auth_url, data=data, headers={**self._session.headers, "Content-type": "application/x-www-form-urlencoded"} + ) + try: + response.raise_for_status() + except HTTPError as exc: + _handle_non_200_response(exc, {400: OAuthError, 401: OAuthError}) + + return TokenResponse.model_validate_json(response.text).access_token + + def _refresh_token(self) -> None: + if self._initial_token is not None: + self._token = self._initial_token + elif self._credential: + self._token = self._fetch_access_token(self._credential) + + def auth_header(self) -> str: + return f"Bearer {self._token}" + + class AuthManagerAdapter(AuthBase): """A `requests.auth.AuthBase` adapter that integrates an `AuthManager` into a `requests.Session` to automatically attach the appropriate Authorization header to every request. @@ -80,3 +134,43 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest: if auth_header := self.auth_manager.auth_header(): request.headers["Authorization"] = auth_header return request + + +class AuthManagerFactory: + _registry: Dict[str, Type["AuthManager"]] = {} + + @classmethod + def register(cls, name: str, manager_cls: Type["AuthManager"]) -> None: + """ + Register a string name to a known AuthManager class. + """ + cls._registry[name] = manager_cls + + @classmethod + def create(cls, class_or_name: str, config: Dict[str, Any]) -> "AuthManager": + """ + Create an AuthManager by name or fully-qualified class path. + + Args: + class_or_name (str): Either a name like 'oauth2' or a full class path like 'my.module.CustomAuthManager' + config (Dict[str, Any]): Configuration passed to the AuthManager constructor + + Returns: + AuthManager: An instantiated AuthManager subclass + """ + if class_or_name in cls._registry: + manager_cls = cls._registry[class_or_name] + else: + try: + module_path, class_name = class_or_name.rsplit(".", 1) + module = importlib.import_module(module_path) + manager_cls = getattr(module, class_name) + except Exception as err: + raise ValueError(f"Could not load AuthManager class for '{class_or_name}'") from err + + return manager_cls(**config) + + +AuthManagerFactory.register("noop", NoopAuthManager) +AuthManagerFactory.register("basic", BasicAuthManager) +AuthManagerFactory.register("legacyoauth2", LegacyOAuth2AuthManager) diff --git a/pyiceberg/catalog/rest/util.py b/pyiceberg/catalog/rest/util.py new file mode 100644 index 0000000000..6862741383 --- /dev/null +++ b/pyiceberg/catalog/rest/util.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from requests import HTTPError +from json import JSONDecodeError +from pyiceberg.typedef import IcebergBaseModel +from pydantic import Field +from typing import Dict, Type, Optional, Literal +from pyiceberg.exceptions import ( + AuthorizationExpiredError, + BadRequestError, + CommitFailedException, + CommitStateUnknownException, + ForbiddenError, + NamespaceAlreadyExistsError, + NamespaceNotEmptyError, + NoSuchIdentifierError, + NoSuchNamespaceError, + NoSuchTableError, + NoSuchViewError, + OAuthError, + RESTError, + ServerError, + ServiceUnavailableError, + TableAlreadyExistsError, + UnauthorizedError, + ValidationError +) + +class TokenResponse(IcebergBaseModel): + access_token: str = Field() + token_type: str = Field() + expires_in: Optional[int] = Field(default=None) + issued_token_type: Optional[str] = Field(default=None) + refresh_token: Optional[str] = Field(default=None) + scope: Optional[str] = Field(default=None) + +class ErrorResponseMessage(IcebergBaseModel): + message: str = Field() + type: str = Field() + code: int = Field() + + +class ErrorResponse(IcebergBaseModel): + error: ErrorResponseMessage = Field() + + +class OAuthErrorResponse(IcebergBaseModel): + error: Literal[ + "invalid_request", "invalid_client", "invalid_grant", "unauthorized_client", "unsupported_grant_type", "invalid_scope" + ] + error_description: Optional[str] = None + error_uri: Optional[str] = None + + +def _handle_non_200_response(exc: HTTPError, error_handler: Dict[int, Type[Exception]]) -> None: + exception: Type[Exception] + + if exc.response is None: + raise ValueError("Did not receive a response") + + code = exc.response.status_code + if code in error_handler: + exception = error_handler[code] + elif code == 400: + exception = BadRequestError + elif code == 401: + exception = UnauthorizedError + elif code == 403: + exception = ForbiddenError + elif code == 422: + exception = RESTError + elif code == 419: + exception = AuthorizationExpiredError + elif code == 501: + exception = NotImplementedError + elif code == 503: + exception = ServiceUnavailableError + elif 500 <= code < 600: + exception = ServerError + else: + exception = RESTError + + try: + if exception == OAuthError: + # The OAuthErrorResponse has a different format + error = OAuthErrorResponse.model_validate_json(exc.response.text) + response = str(error.error) + if description := error.error_description: + response += f": {description}" + if uri := error.error_uri: + response += f" ({uri})" + else: + error = ErrorResponse.model_validate_json(exc.response.text).error + response = f"{error.type}: {error.message}" + except JSONDecodeError: + # In the case we don't have a proper response + response = f"RESTError {exc.response.status_code}: Could not decode json payload: {exc.response.text}" + except ValidationError as e: + # In the case we don't have a proper response + errs = ", ".join(err["msg"] for err in e.errors()) + response = ( + f"RESTError {exc.response.status_code}: Received unexpected JSON Payload: {exc.response.text}, errors: {errs}" + ) + + raise exception(response) from exc From f9d5678831bee62a7f9dd20cddf9365501cd965b Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+sungwy@users.noreply.github.com> Date: Sat, 3 May 2025 22:07:45 +0000 Subject: [PATCH 2/4] substitute existing oauth with LegacyOAuth2AuthManager --- pyiceberg/catalog/rest/__init__.py | 37 ++++++++++++++++++++---------- pyiceberg/catalog/rest/auth.py | 35 +++++++++++++++++++--------- pyiceberg/catalog/rest/util.py | 26 ++++++++------------- 3 files changed, 58 insertions(+), 40 deletions(-) diff --git a/pyiceberg/catalog/rest/__init__.py b/pyiceberg/catalog/rest/__init__.py index 85380c7d34..6a04097a0d 100644 --- a/pyiceberg/catalog/rest/__init__.py +++ b/pyiceberg/catalog/rest/__init__.py @@ -38,7 +38,7 @@ Catalog, PropertiesUpdateSummary, ) -from pyiceberg.catalog.rest.auth import AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager +from pyiceberg.catalog.rest.auth import AuthManager, AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager from pyiceberg.catalog.rest.util import _handle_non_200_response from pyiceberg.exceptions import ( AuthorizationExpiredError, @@ -244,15 +244,7 @@ def _create_session(self) -> Session: elif ssl_client_cert := ssl_client.get(CERT): session.cert = ssl_client_cert - auth_config = { - "session": self._session, - "credential": self.properties.get(CREDENTIAL), - "initial_token": self.properties.get(TOKEN), - "optional_oauth_params": self._extract_optional_oauth_params(), - } - - auth_manager = AuthManagerFactory.create("legacyoauth2", auth_config) - session.auth = AuthManagerAdapter(auth_manager) + session.auth = AuthManagerAdapter(self._create_legacy_oauth2_auth_manager(session)) # Set HTTP headers self._config_headers(session) @@ -262,6 +254,26 @@ def _create_session(self) -> Session: return session + def _create_legacy_oauth2_auth_manager(self, session: Session) -> AuthManager: + """Create the LegacyOAuth2AuthManager by fetching required properties. + + This will be deprecated in PyIceberg 1.0 + """ + client_credentials = self.properties.get(CREDENTIAL) + # We want to call `self.auth_url` only when we are using CREDENTIAL + # with the legacy OAUTH2 flow as it will raise a DeprecationWarning + auth_url = self.auth_url if client_credentials is not None else None + + auth_config = { + "session": session, + "auth_url": auth_url, + "credential": client_credentials, + "initial_token": self.properties.get(TOKEN), + "optional_oauth_params": self._extract_optional_oauth_params(), + } + + return AuthManagerFactory.create("legacyoauth2", auth_config) + def _check_valid_namespace_identifier(self, identifier: Union[str, Identifier]) -> Identifier: """Check if the identifier has at least one element.""" identifier_tuple = Catalog.identifier_to_tuple(identifier) @@ -437,8 +449,9 @@ def _refresh_token(self) -> None: # Reactive token refresh is atypical - we should proactively refresh tokens in a separate thread # instead of retrying on Auth Exceptions. Keeping refresh behavior for the LegacyOAuth2AuthManager # for backward compatibility - if isinstance(self._session.auth.auth_manager, LegacyOAuth2AuthManager): - self._session.auth.auth_manager._refresh_token() + auth_manager = self._session.auth.auth_manager # type: ignore[union-attr] + if isinstance(auth_manager, LegacyOAuth2AuthManager): + auth_manager._refresh_token() def _config_headers(self, session: Session) -> None: header_properties = get_header_properties(self.properties) diff --git a/pyiceberg/catalog/rest/auth.py b/pyiceberg/catalog/rest/auth.py index 3833ddc571..d45852f3b4 100644 --- a/pyiceberg/catalog/rest/auth.py +++ b/pyiceberg/catalog/rest/auth.py @@ -56,20 +56,25 @@ def auth_header(self) -> str: class LegacyOAuth2AuthManager(AuthManager): + _session: Session + _auth_url: Optional[str] + _token: Optional[str] + _credential: Optional[str] + _optional_oauth_params: Optional[Dict[str, str]] + def __init__( self, session: Session, - auth_url: str, + auth_url: Optional[str] = None, credential: Optional[str] = None, initial_token: Optional[str] = None, optional_oauth_params: Optional[Dict[str, str]] = None, ): - self._token: Optional[str] = None - self._initial_token = initial_token - self._credential = credential + self._session = session self._auth_url = auth_url + self._token = initial_token + self._credential = credential self._optional_oauth_params = optional_oauth_params - self._session = session self._refresh_token() def _fetch_access_token(self, credential: str) -> str: @@ -83,6 +88,9 @@ def _fetch_access_token(self, credential: str) -> str: if self._optional_oauth_params: data.update(self._optional_oauth_params) + if self._auth_url is None: + raise ValueError("Cannot fetch access token from undefined auth_url") + response = self._session.post( url=self._auth_url, data=data, headers={**self._session.headers, "Content-type": "application/x-www-form-urlencoded"} ) @@ -94,9 +102,7 @@ def _fetch_access_token(self, credential: str) -> str: return TokenResponse.model_validate_json(response.text).access_token def _refresh_token(self) -> None: - if self._initial_token is not None: - self._token = self._initial_token - elif self._credential: + if self._credential is not None: self._token = self._fetch_access_token(self._credential) def auth_header(self) -> str: @@ -140,14 +146,21 @@ class AuthManagerFactory: _registry: Dict[str, Type["AuthManager"]] = {} @classmethod - def register(cls, name: str, manager_cls: Type["AuthManager"]) -> None: + def register(cls, name: str, auth_manager_class: Type["AuthManager"]) -> None: """ Register a string name to a known AuthManager class. + + Args: + name (str): unique name like 'oauth2' to register the AuthManager with + auth_manager_class (Type["AuthManager"]): Implementation of AuthManager + + Returns: + None """ - cls._registry[name] = manager_cls + cls._registry[name] = auth_manager_class @classmethod - def create(cls, class_or_name: str, config: Dict[str, Any]) -> "AuthManager": + def create(cls, class_or_name: str, config: Dict[str, Any]) -> AuthManager: """ Create an AuthManager by name or fully-qualified class path. diff --git a/pyiceberg/catalog/rest/util.py b/pyiceberg/catalog/rest/util.py index 6862741383..8f23af8c35 100644 --- a/pyiceberg/catalog/rest/util.py +++ b/pyiceberg/catalog/rest/util.py @@ -14,31 +14,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from requests import HTTPError from json import JSONDecodeError -from pyiceberg.typedef import IcebergBaseModel -from pydantic import Field -from typing import Dict, Type, Optional, Literal +from typing import Dict, Literal, Optional, Type + +from pydantic import Field, ValidationError +from requests import HTTPError + from pyiceberg.exceptions import ( AuthorizationExpiredError, BadRequestError, - CommitFailedException, - CommitStateUnknownException, ForbiddenError, - NamespaceAlreadyExistsError, - NamespaceNotEmptyError, - NoSuchIdentifierError, - NoSuchNamespaceError, - NoSuchTableError, - NoSuchViewError, OAuthError, RESTError, ServerError, ServiceUnavailableError, - TableAlreadyExistsError, UnauthorizedError, - ValidationError ) +from pyiceberg.typedef import IcebergBaseModel + class TokenResponse(IcebergBaseModel): access_token: str = Field() @@ -48,6 +41,7 @@ class TokenResponse(IcebergBaseModel): refresh_token: Optional[str] = Field(default=None) scope: Optional[str] = Field(default=None) + class ErrorResponseMessage(IcebergBaseModel): message: str = Field() type: str = Field() @@ -112,8 +106,6 @@ def _handle_non_200_response(exc: HTTPError, error_handler: Dict[int, Type[Excep except ValidationError as e: # In the case we don't have a proper response errs = ", ".join(err["msg"] for err in e.errors()) - response = ( - f"RESTError {exc.response.status_code}: Received unexpected JSON Payload: {exc.response.text}, errors: {errs}" - ) + response = f"RESTError {exc.response.status_code}: Received unexpected JSON Payload: {exc.response.text}, errors: {errs}" raise exception(response) from exc From 613ca537c0d7f69004633f1030d7e36fd0f80b50 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+sungwy@users.noreply.github.com> Date: Thu, 8 May 2025 02:49:15 +0000 Subject: [PATCH 3/4] fix test --- tests/catalog/test_rest.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/catalog/test_rest.py b/tests/catalog/test_rest.py index f2fc6ceb6b..3b7f530754 100644 --- a/tests/catalog/test_rest.py +++ b/tests/catalog/test_rest.py @@ -602,6 +602,10 @@ def test_list_namespaces_token_expired_success_on_retries(rest_mock: Mocker, sta status_code=200, ) catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN, credential=TEST_CREDENTIALS) + # LegacyOAuth2AuthManager is created twice through `_create_session()` + # which results in the token being refreshed twice when the RestCatalog is initialized. + assert tokens.call_count == 2 + assert catalog.list_namespaces() == [ ("default",), ("examples",), @@ -609,7 +613,7 @@ def test_list_namespaces_token_expired_success_on_retries(rest_mock: Mocker, sta ("system",), ] assert namespaces.call_count == 2 - assert tokens.call_count == 1 + assert tokens.call_count == 3 assert catalog.list_namespaces() == [ ("default",), @@ -618,7 +622,7 @@ def test_list_namespaces_token_expired_success_on_retries(rest_mock: Mocker, sta ("system",), ] assert namespaces.call_count == 3 - assert tokens.call_count == 1 + assert tokens.call_count == 3 def test_create_namespace_200(rest_mock: Mocker) -> None: From fc4a1922e94590f8e3997ef8a2507c03f35013fe Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+sungwy@users.noreply.github.com> Date: Thu, 15 May 2025 01:49:56 +0000 Subject: [PATCH 4/4] adopt feedback - thanks Fokko! --- pyiceberg/catalog/rest/__init__.py | 4 ++-- pyiceberg/catalog/rest/auth.py | 2 +- pyiceberg/catalog/rest/{util.py => response.py} | 0 3 files changed, 3 insertions(+), 3 deletions(-) rename pyiceberg/catalog/rest/{util.py => response.py} (100%) diff --git a/pyiceberg/catalog/rest/__init__.py b/pyiceberg/catalog/rest/__init__.py index 74651b475c..f2e1989613 100644 --- a/pyiceberg/catalog/rest/__init__.py +++ b/pyiceberg/catalog/rest/__init__.py @@ -39,7 +39,7 @@ PropertiesUpdateSummary, ) from pyiceberg.catalog.rest.auth import AuthManager, AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager -from pyiceberg.catalog.rest.util import _handle_non_200_response +from pyiceberg.catalog.rest.response import _handle_non_200_response from pyiceberg.exceptions import ( AuthorizationExpiredError, CommitFailedException, @@ -257,7 +257,7 @@ def _create_session(self) -> Session: def _create_legacy_oauth2_auth_manager(self, session: Session) -> AuthManager: """Create the LegacyOAuth2AuthManager by fetching required properties. - This will be deprecated in PyIceberg 1.0 + This will be removed in PyIceberg 1.0 """ client_credentials = self.properties.get(CREDENTIAL) # We want to call `self.auth_url` only when we are using CREDENTIAL diff --git a/pyiceberg/catalog/rest/auth.py b/pyiceberg/catalog/rest/auth.py index d45852f3b4..89395f1158 100644 --- a/pyiceberg/catalog/rest/auth.py +++ b/pyiceberg/catalog/rest/auth.py @@ -23,7 +23,7 @@ from requests import HTTPError, PreparedRequest, Session from requests.auth import AuthBase -from pyiceberg.catalog.rest.util import TokenResponse, _handle_non_200_response +from pyiceberg.catalog.rest.response import TokenResponse, _handle_non_200_response from pyiceberg.exceptions import OAuthError COLON = ":" diff --git a/pyiceberg/catalog/rest/util.py b/pyiceberg/catalog/rest/response.py similarity index 100% rename from pyiceberg/catalog/rest/util.py rename to pyiceberg/catalog/rest/response.py