Skip to content

Commit f9d5678

Browse files
committed
substitute existing oauth with LegacyOAuth2AuthManager
1 parent 7c416e2 commit f9d5678

File tree

3 files changed

+58
-40
lines changed

3 files changed

+58
-40
lines changed

pyiceberg/catalog/rest/__init__.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
Catalog,
3939
PropertiesUpdateSummary,
4040
)
41-
from pyiceberg.catalog.rest.auth import AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager
41+
from pyiceberg.catalog.rest.auth import AuthManager, AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager
4242
from pyiceberg.catalog.rest.util import _handle_non_200_response
4343
from pyiceberg.exceptions import (
4444
AuthorizationExpiredError,
@@ -244,15 +244,7 @@ def _create_session(self) -> Session:
244244
elif ssl_client_cert := ssl_client.get(CERT):
245245
session.cert = ssl_client_cert
246246

247-
auth_config = {
248-
"session": self._session,
249-
"credential": self.properties.get(CREDENTIAL),
250-
"initial_token": self.properties.get(TOKEN),
251-
"optional_oauth_params": self._extract_optional_oauth_params(),
252-
}
253-
254-
auth_manager = AuthManagerFactory.create("legacyoauth2", auth_config)
255-
session.auth = AuthManagerAdapter(auth_manager)
247+
session.auth = AuthManagerAdapter(self._create_legacy_oauth2_auth_manager(session))
256248
# Set HTTP headers
257249
self._config_headers(session)
258250

@@ -262,6 +254,26 @@ def _create_session(self) -> Session:
262254

263255
return session
264256

257+
def _create_legacy_oauth2_auth_manager(self, session: Session) -> AuthManager:
258+
"""Create the LegacyOAuth2AuthManager by fetching required properties.
259+
260+
This will be deprecated in PyIceberg 1.0
261+
"""
262+
client_credentials = self.properties.get(CREDENTIAL)
263+
# We want to call `self.auth_url` only when we are using CREDENTIAL
264+
# with the legacy OAUTH2 flow as it will raise a DeprecationWarning
265+
auth_url = self.auth_url if client_credentials is not None else None
266+
267+
auth_config = {
268+
"session": session,
269+
"auth_url": auth_url,
270+
"credential": client_credentials,
271+
"initial_token": self.properties.get(TOKEN),
272+
"optional_oauth_params": self._extract_optional_oauth_params(),
273+
}
274+
275+
return AuthManagerFactory.create("legacyoauth2", auth_config)
276+
265277
def _check_valid_namespace_identifier(self, identifier: Union[str, Identifier]) -> Identifier:
266278
"""Check if the identifier has at least one element."""
267279
identifier_tuple = Catalog.identifier_to_tuple(identifier)
@@ -437,8 +449,9 @@ def _refresh_token(self) -> None:
437449
# Reactive token refresh is atypical - we should proactively refresh tokens in a separate thread
438450
# instead of retrying on Auth Exceptions. Keeping refresh behavior for the LegacyOAuth2AuthManager
439451
# for backward compatibility
440-
if isinstance(self._session.auth.auth_manager, LegacyOAuth2AuthManager):
441-
self._session.auth.auth_manager._refresh_token()
452+
auth_manager = self._session.auth.auth_manager # type: ignore[union-attr]
453+
if isinstance(auth_manager, LegacyOAuth2AuthManager):
454+
auth_manager._refresh_token()
442455

443456
def _config_headers(self, session: Session) -> None:
444457
header_properties = get_header_properties(self.properties)

pyiceberg/catalog/rest/auth.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,25 @@ def auth_header(self) -> str:
5656

5757

5858
class LegacyOAuth2AuthManager(AuthManager):
59+
_session: Session
60+
_auth_url: Optional[str]
61+
_token: Optional[str]
62+
_credential: Optional[str]
63+
_optional_oauth_params: Optional[Dict[str, str]]
64+
5965
def __init__(
6066
self,
6167
session: Session,
62-
auth_url: str,
68+
auth_url: Optional[str] = None,
6369
credential: Optional[str] = None,
6470
initial_token: Optional[str] = None,
6571
optional_oauth_params: Optional[Dict[str, str]] = None,
6672
):
67-
self._token: Optional[str] = None
68-
self._initial_token = initial_token
69-
self._credential = credential
73+
self._session = session
7074
self._auth_url = auth_url
75+
self._token = initial_token
76+
self._credential = credential
7177
self._optional_oauth_params = optional_oauth_params
72-
self._session = session
7378
self._refresh_token()
7479

7580
def _fetch_access_token(self, credential: str) -> str:
@@ -83,6 +88,9 @@ def _fetch_access_token(self, credential: str) -> str:
8388
if self._optional_oauth_params:
8489
data.update(self._optional_oauth_params)
8590

91+
if self._auth_url is None:
92+
raise ValueError("Cannot fetch access token from undefined auth_url")
93+
8694
response = self._session.post(
8795
url=self._auth_url, data=data, headers={**self._session.headers, "Content-type": "application/x-www-form-urlencoded"}
8896
)
@@ -94,9 +102,7 @@ def _fetch_access_token(self, credential: str) -> str:
94102
return TokenResponse.model_validate_json(response.text).access_token
95103

96104
def _refresh_token(self) -> None:
97-
if self._initial_token is not None:
98-
self._token = self._initial_token
99-
elif self._credential:
105+
if self._credential is not None:
100106
self._token = self._fetch_access_token(self._credential)
101107

102108
def auth_header(self) -> str:
@@ -140,14 +146,21 @@ class AuthManagerFactory:
140146
_registry: Dict[str, Type["AuthManager"]] = {}
141147

142148
@classmethod
143-
def register(cls, name: str, manager_cls: Type["AuthManager"]) -> None:
149+
def register(cls, name: str, auth_manager_class: Type["AuthManager"]) -> None:
144150
"""
145151
Register a string name to a known AuthManager class.
152+
153+
Args:
154+
name (str): unique name like 'oauth2' to register the AuthManager with
155+
auth_manager_class (Type["AuthManager"]): Implementation of AuthManager
156+
157+
Returns:
158+
None
146159
"""
147-
cls._registry[name] = manager_cls
160+
cls._registry[name] = auth_manager_class
148161

149162
@classmethod
150-
def create(cls, class_or_name: str, config: Dict[str, Any]) -> "AuthManager":
163+
def create(cls, class_or_name: str, config: Dict[str, Any]) -> AuthManager:
151164
"""
152165
Create an AuthManager by name or fully-qualified class path.
153166

pyiceberg/catalog/rest/util.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,24 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
from requests import HTTPError
1817
from json import JSONDecodeError
19-
from pyiceberg.typedef import IcebergBaseModel
20-
from pydantic import Field
21-
from typing import Dict, Type, Optional, Literal
18+
from typing import Dict, Literal, Optional, Type
19+
20+
from pydantic import Field, ValidationError
21+
from requests import HTTPError
22+
2223
from pyiceberg.exceptions import (
2324
AuthorizationExpiredError,
2425
BadRequestError,
25-
CommitFailedException,
26-
CommitStateUnknownException,
2726
ForbiddenError,
28-
NamespaceAlreadyExistsError,
29-
NamespaceNotEmptyError,
30-
NoSuchIdentifierError,
31-
NoSuchNamespaceError,
32-
NoSuchTableError,
33-
NoSuchViewError,
3427
OAuthError,
3528
RESTError,
3629
ServerError,
3730
ServiceUnavailableError,
38-
TableAlreadyExistsError,
3931
UnauthorizedError,
40-
ValidationError
4132
)
33+
from pyiceberg.typedef import IcebergBaseModel
34+
4235

4336
class TokenResponse(IcebergBaseModel):
4437
access_token: str = Field()
@@ -48,6 +41,7 @@ class TokenResponse(IcebergBaseModel):
4841
refresh_token: Optional[str] = Field(default=None)
4942
scope: Optional[str] = Field(default=None)
5043

44+
5145
class ErrorResponseMessage(IcebergBaseModel):
5246
message: str = Field()
5347
type: str = Field()
@@ -112,8 +106,6 @@ def _handle_non_200_response(exc: HTTPError, error_handler: Dict[int, Type[Excep
112106
except ValidationError as e:
113107
# In the case we don't have a proper response
114108
errs = ", ".join(err["msg"] for err in e.errors())
115-
response = (
116-
f"RESTError {exc.response.status_code}: Received unexpected JSON Payload: {exc.response.text}, errors: {errs}"
117-
)
109+
response = f"RESTError {exc.response.status_code}: Received unexpected JSON Payload: {exc.response.text}, errors: {errs}"
118110

119111
raise exception(response) from exc

0 commit comments

Comments
 (0)