Skip to content

Commit b2134e5

Browse files
committed
add google auth and test
1 parent a67c559 commit b2134e5

File tree

2 files changed

+62
-2
lines changed

2 files changed

+62
-2
lines changed

pyiceberg/catalog/rest/auth.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
import base64
1919
import importlib
2020
from abc import ABC, abstractmethod
21-
from typing import Any, Dict, Optional, Type
21+
import logging
22+
from typing import Any, Dict, List, Optional, Type
2223

2324
from requests import HTTPError, PreparedRequest, Session
2425
from requests.auth import AuthBase
@@ -109,6 +110,38 @@ def auth_header(self) -> str:
109110
return f"Bearer {self._token}"
110111

111112

113+
class GoogleAuthManager(AuthManager):
114+
"""
115+
An auth manager that is responsible for handling Google credentials.
116+
"""
117+
118+
def __init__(self, credentials_path: Optional[str] = None, scopes: Optional[List[str]] = None):
119+
"""
120+
Initialize GoogleAuthManager.
121+
122+
Args:
123+
credentials_path: Optional path to Google credentials JSON file.
124+
scopes: Optional list of OAuth2 scopes.
125+
"""
126+
try:
127+
import google.auth
128+
import google.auth.transport.requests
129+
except ImportError as e:
130+
raise ImportError(
131+
"Google Auth libraries not found. Please install 'google-auth'."
132+
) from e
133+
134+
if credentials_path:
135+
self.credentials, _ = google.auth.load_credentials_from_file(credentials_path, scopes=scopes)
136+
else:
137+
logging.info("Using Google Default Application Credentials")
138+
self.credentials, _ = google.auth.default(scopes=scopes)
139+
self._auth_request = google.auth.transport.requests.Request()
140+
141+
def auth_header(self) -> Optional[str]:
142+
self.credentials.refresh(self._auth_request)
143+
return f"Bearer {self.credentials.token}"
144+
112145
class AuthManagerAdapter(AuthBase):
113146
"""A `requests.auth.AuthBase` adapter that integrates an `AuthManager` into a `requests.Session` to automatically attach the appropriate Authorization header to every request.
114147
@@ -187,3 +220,4 @@ def create(cls, class_or_name: str, config: Dict[str, Any]) -> AuthManager:
187220
AuthManagerFactory.register("noop", NoopAuthManager)
188221
AuthManagerFactory.register("basic", BasicAuthManager)
189222
AuthManagerFactory.register("legacyoauth2", LegacyOAuth2AuthManager)
223+
AuthManagerFactory.register("google", GoogleAuthManager)

tests/catalog/test_rest_auth.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
import requests
2222
from requests_mock import Mocker
2323

24-
from pyiceberg.catalog.rest.auth import AuthManagerAdapter, BasicAuthManager, NoopAuthManager
24+
from pyiceberg.catalog.rest.auth import AuthManagerAdapter, BasicAuthManager, GoogleAuthManager, NoopAuthManager
2525

2626
TEST_URI = "https://iceberg-test-catalog/"
27+
GOOGLE_CREDS_URI = "https://oauth2.googleapis.com/token"
2728

2829

2930
@pytest.fixture
@@ -35,6 +36,17 @@ def rest_mock(requests_mock: Mocker) -> Mocker:
3536
)
3637
return requests_mock
3738

39+
@pytest.fixture
40+
def google_mock(requests_mock: Mocker) -> Mocker:
41+
requests_mock.post(GOOGLE_CREDS_URI,
42+
json={"access_token": "aaaabbb"},
43+
status_code=200)
44+
requests_mock.get(
45+
TEST_URI,
46+
json={},
47+
status_code=200,
48+
)
49+
return requests_mock
3850

3951
def test_noop_auth_header(rest_mock: Mocker) -> None:
4052
auth_manager = NoopAuthManager()
@@ -63,3 +75,17 @@ def test_basic_auth_header(rest_mock: Mocker) -> None:
6375
assert len(history) == 1
6476
actual_headers = history[0].headers
6577
assert actual_headers["Authorization"] == expected_header
78+
79+
def test_google_auth_header(google_mock: Mocker) -> None:
80+
expected_token = "aaaabbb"
81+
expected_header = f"Bearer {expected_token}"
82+
83+
auth_manager = GoogleAuthManager()
84+
session = requests.Session()
85+
session.auth = AuthManagerAdapter(auth_manager)
86+
87+
session.get(TEST_URI)
88+
history = google_mock.request_history
89+
assert len(history) == 2
90+
actual_headers = history[1].headers
91+
assert actual_headers["Authorization"] == expected_header

0 commit comments

Comments
 (0)