1717
1818import base64
1919import importlib
20+ import threading
21+ import time
2022from abc import ABC , abstractmethod
2123from typing import Any , Dict , Optional , Type
2224
25+ import requests
2326from requests import HTTPError , PreparedRequest , Session
2427from requests .auth import AuthBase
2528
@@ -42,11 +45,15 @@ def auth_header(self) -> Optional[str]:
4245
4346
4447class NoopAuthManager (AuthManager ):
48+ """Auth Manager implementation with no auth."""
49+
4550 def auth_header (self ) -> Optional [str ]:
4651 return None
4752
4853
4954class BasicAuthManager (AuthManager ):
55+ """AuthManager implementation that supports basic password auth."""
56+
5057 def __init__ (self , username : str , password : str ):
5158 credentials = f"{ username } :{ password } "
5259 self ._token = base64 .b64encode (credentials .encode ()).decode ()
@@ -56,6 +63,12 @@ def auth_header(self) -> str:
5663
5764
5865class LegacyOAuth2AuthManager (AuthManager ):
66+ """Legacy OAuth2 AuthManager implementation.
67+
68+ This class exists for backward compatibility, and will be removed in
69+ PyIceberg 1.0.0 in favor of OAuth2AuthManager.
70+ """
71+
5972 _session : Session
6073 _auth_url : Optional [str ]
6174 _token : Optional [str ]
@@ -109,6 +122,80 @@ def auth_header(self) -> str:
109122 return f"Bearer { self ._token } "
110123
111124
125+ class OAuth2TokenProvider :
126+ """Thread-safe OAuth2 token provider with token refresh support."""
127+
128+ client_id : str
129+ client_secret : str
130+ token_url : str
131+ scope : Optional [str ]
132+ refresh_margin : int
133+ expires_in : Optional [int ]
134+
135+ _token : Optional [str ]
136+ _expires_at : int
137+ _lock : threading .Lock
138+
139+ def __init__ (
140+ self ,
141+ client_id : str ,
142+ client_secret : str ,
143+ token_url : str ,
144+ scope : Optional [str ] = None ,
145+ refresh_margin : int = 60 ,
146+ expires_in : Optional [int ] = None ,
147+ ):
148+ self .client_id = client_id
149+ self .client_secret = client_secret
150+ self .token_url = token_url
151+ self .scope = scope
152+ self .refresh_margin = refresh_margin
153+ self .expires_in = expires_in
154+
155+ self ._token = None
156+ self ._expires_at = 0
157+ self ._lock = threading .Lock ()
158+
159+ def _refresh_token (self ) -> None :
160+ data = {
161+ "grant_type" : "client_credentials" ,
162+ "client_id" : self .client_id ,
163+ "client_secret" : self .client_secret ,
164+ }
165+ if self .scope :
166+ data ["scope" ] = self .scope
167+
168+ response = requests .post (self .token_url , data = data )
169+ response .raise_for_status ()
170+ result = response .json ()
171+
172+ self ._token = result ["access_token" ]
173+ expires_in = result .get ("expires_in" , self .expires_in )
174+ if expires_in is None :
175+ raise ValueError (
176+ "The expiration time of the Token must be provided by the Server in the Access Token Response in `expired_in` field, or by the PyIceberg Client."
177+ )
178+ self ._expires_at = time .time () + expires_in - self .refresh_margin
179+
180+ def get_token (self ) -> str :
181+ with self ._lock :
182+ if not self ._token or time .time () >= self ._expires_at :
183+ self ._refresh_token ()
184+ if self ._token is None :
185+ raise ValueError ("Authorization token is None after refresh" )
186+ return self ._token
187+
188+
189+ class OAuth2AuthManager (AuthManager ):
190+ """Auth Manager implementation that supports OAuth2 as defined in IETF RFC6749."""
191+
192+ def __init__ (self , token_provider : OAuth2TokenProvider ):
193+ self .token_provider = token_provider
194+
195+ def auth_header (self ) -> str :
196+ return f"Bearer { self .token_provider .get_token ()} "
197+
198+
112199class AuthManagerAdapter (AuthBase ):
113200 """A `requests.auth.AuthBase` adapter that integrates an `AuthManager` into a `requests.Session` to automatically attach the appropriate Authorization header to every request.
114201
0 commit comments