|
18 | 18 | import base64 |
19 | 19 | import importlib |
20 | 20 | import logging |
| 21 | +import threading |
| 22 | +import time |
21 | 23 | from abc import ABC, abstractmethod |
| 24 | +from functools import cached_property |
22 | 25 | from typing import Any, Dict, List, Optional, Type |
23 | 26 |
|
| 27 | +import requests |
24 | 28 | from requests import HTTPError, PreparedRequest, Session |
25 | 29 | from requests.auth import AuthBase |
26 | 30 |
|
@@ -121,6 +125,98 @@ def auth_header(self) -> str: |
121 | 125 | return f"Bearer {self._token}" |
122 | 126 |
|
123 | 127 |
|
| 128 | +class OAuth2TokenProvider: |
| 129 | + """Thread-safe OAuth2 token provider with token refresh support.""" |
| 130 | + |
| 131 | + client_id: str |
| 132 | + client_secret: str |
| 133 | + token_url: str |
| 134 | + scope: Optional[str] |
| 135 | + refresh_margin: int |
| 136 | + expires_in: Optional[int] |
| 137 | + |
| 138 | + _token: Optional[str] |
| 139 | + _expires_at: int |
| 140 | + _lock: threading.Lock |
| 141 | + |
| 142 | + def __init__( |
| 143 | + self, |
| 144 | + client_id: str, |
| 145 | + client_secret: str, |
| 146 | + token_url: str, |
| 147 | + scope: Optional[str] = None, |
| 148 | + refresh_margin: int = 60, |
| 149 | + expires_in: Optional[int] = None, |
| 150 | + ): |
| 151 | + self.client_id = client_id |
| 152 | + self.client_secret = client_secret |
| 153 | + self.token_url = token_url |
| 154 | + self.scope = scope |
| 155 | + self.refresh_margin = refresh_margin |
| 156 | + self.expires_in = expires_in |
| 157 | + |
| 158 | + self._token = None |
| 159 | + self._expires_at = 0 |
| 160 | + self._lock = threading.Lock() |
| 161 | + |
| 162 | + @cached_property |
| 163 | + def _client_secret_header(self) -> str: |
| 164 | + creds = f"{self.client_id}:{self.client_secret}" |
| 165 | + creds_bytes = creds.encode("utf-8") |
| 166 | + b64_creds = base64.b64encode(creds_bytes).decode("utf-8") |
| 167 | + return f"Basic {b64_creds}" |
| 168 | + |
| 169 | + def _refresh_token(self) -> None: |
| 170 | + data = {"grant_type": "client_credentials"} |
| 171 | + if self.scope: |
| 172 | + data["scope"] = self.scope |
| 173 | + |
| 174 | + response = requests.post(self.token_url, data=data, headers={"Authorization": self._client_secret_header}) |
| 175 | + response.raise_for_status() |
| 176 | + result = response.json() |
| 177 | + |
| 178 | + self._token = result["access_token"] |
| 179 | + expires_in = result.get("expires_in", self.expires_in) |
| 180 | + if expires_in is None: |
| 181 | + raise ValueError( |
| 182 | + "The expiration time of the Token must be provided by the Server in the Access Token Response in `expires_in` field, or by the PyIceberg Client." |
| 183 | + ) |
| 184 | + self._expires_at = time.monotonic() + expires_in - self.refresh_margin |
| 185 | + |
| 186 | + def get_token(self) -> str: |
| 187 | + with self._lock: |
| 188 | + if not self._token or time.monotonic() >= self._expires_at: |
| 189 | + self._refresh_token() |
| 190 | + if self._token is None: |
| 191 | + raise ValueError("Authorization token is None after refresh") |
| 192 | + return self._token |
| 193 | + |
| 194 | + |
| 195 | +class OAuth2AuthManager(AuthManager): |
| 196 | + """Auth Manager implementation that supports OAuth2 as defined in IETF RFC6749.""" |
| 197 | + |
| 198 | + def __init__( |
| 199 | + self, |
| 200 | + client_id: str, |
| 201 | + client_secret: str, |
| 202 | + token_url: str, |
| 203 | + scope: Optional[str] = None, |
| 204 | + refresh_margin: int = 60, |
| 205 | + expires_in: Optional[int] = None, |
| 206 | + ): |
| 207 | + self.token_provider = OAuth2TokenProvider( |
| 208 | + client_id, |
| 209 | + client_secret, |
| 210 | + token_url, |
| 211 | + scope, |
| 212 | + refresh_margin, |
| 213 | + expires_in, |
| 214 | + ) |
| 215 | + |
| 216 | + def auth_header(self) -> str: |
| 217 | + return f"Bearer {self.token_provider.get_token()}" |
| 218 | + |
| 219 | + |
124 | 220 | class GoogleAuthManager(AuthManager): |
125 | 221 | """An auth manager that is responsible for handling Google credentials.""" |
126 | 222 |
|
@@ -228,4 +324,5 @@ def create(cls, class_or_name: str, config: Dict[str, Any]) -> AuthManager: |
228 | 324 | AuthManagerFactory.register("noop", NoopAuthManager) |
229 | 325 | AuthManagerFactory.register("basic", BasicAuthManager) |
230 | 326 | AuthManagerFactory.register("legacyoauth2", LegacyOAuth2AuthManager) |
| 327 | +AuthManagerFactory.register("oauth2", OAuth2AuthManager) |
231 | 328 | AuthManagerFactory.register("google", GoogleAuthManager) |
0 commit comments