1111# Private API: this is an evolving interface and it will change in the future.
1212# Please must not depend on it in your applications.
1313from databricks .sql .experimental .oauth_persistence import OAuthToken , OAuthPersistence
14- from databricks .sql .auth .endpoint import AzureOAuthEndpointCollection , InHouseOAuthEndpointCollection
14+ from databricks .sql .auth .endpoint import (
15+ AzureOAuthEndpointCollection ,
16+ InHouseOAuthEndpointCollection ,
17+ )
18+
1519
1620class AuthProvider :
1721 def add_headers (self , request_headers : Dict [str , str ]):
@@ -56,8 +60,10 @@ def auth_type(self) -> str:
5660 def __call__ (self , * args , ** kwargs ) -> HeaderFactory :
5761 def get_headers ():
5862 return {"Authorization" : self .__authorization_header_value }
63+
5964 return get_headers
6065
66+
6167# Private API: this is an evolving interface and it will change in the future.
6268# Please must not depend on it in your applications.
6369class DatabricksOAuthProvider (AuthProvider , CredentialsProvider ):
@@ -81,11 +87,8 @@ def __init__(
8187
8288 idp_endpoint = get_oauth_endpoints (hostname , auth_type == "azure-oauth" )
8389 if not idp_endpoint :
84- raise NotImplementedError (
85- f"OAuth is not supported for host ${ hostname } "
86- )
90+ raise NotImplementedError (f"OAuth is not supported for host ${ hostname } " )
8791
88-
8992 cloud_scopes = idp_endpoint .get_scopes_mapping (scopes )
9093 self ._scopes_as_str = self .SCOPE_DELIM .join (cloud_scopes )
9194
@@ -107,6 +110,7 @@ def __call__(self, *args, **kwargs) -> HeaderFactory:
107110 def get_headers ():
108111 self ._update_token_if_expired ()
109112 return {"Authorization" : "Bearer {}" .format (self ._access_token )}
113+
110114 return get_headers
111115
112116 def _initial_get_token (self ):
@@ -170,25 +174,24 @@ def __init__(
170174 client_id : str ,
171175 client_secret : str ,
172176 token_endpoint : str ,
173- auth_type_value : str = "client-credentials"
177+ auth_type_value : str = "client-credentials" ,
174178 ):
175179 """
176180 Initialize a ClientCredentialsProvider.
177-
181+
178182 Args:
179183 client_id: OAuth client ID
180- client_secret: OAuth client secret
184+ client_secret: OAuth client secret
181185 token_endpoint: OAuth token endpoint URL
182186 auth_type_value: Auth type identifier
183187 """
184188 self .client_id = client_id
185189 self .client_secret = client_secret
186190 self .token_endpoint = token_endpoint
187191 self .auth_type_value = auth_type_value
188-
192+
189193 self ._cached_token = None
190194 self ._token_expires_at = None
191-
192195
193196 def auth_type (self ) -> str :
194197 return self .auth_type_value
@@ -197,50 +200,54 @@ def __call__(self, *args, **kwargs) -> HeaderFactory:
197200 def get_headers () -> Dict [str , str ]:
198201 token = self ._get_access_token ()
199202 return {"Authorization" : "Bearer {}" .format (token )}
203+
200204 return get_headers
201-
205+
202206 def add_headers (self , request_headers : Dict [str , str ]):
203207 token = self ._get_access_token ()
204208 request_headers ["Authorization" ] = "Bearer {}" .format (token )
205209
206210 def _get_access_token (self ) -> str :
207211 """Get a valid access token using client credentials flow, with caching."""
208212 # Check if we have a valid cached token (with 40 second buffer since azure doesn't respect a token with less than 30s expiry)
209- if (self ._cached_token and self ._token_expires_at and
210- time .time () < self ._token_expires_at - 40 ):
213+ if (
214+ self ._cached_token
215+ and self ._token_expires_at
216+ and time .time () < self ._token_expires_at - 40
217+ ):
211218 return self ._cached_token
212-
219+
213220 # Get new token using client credentials flow
214221 token_data = self ._request_token ()
215-
216- self ._cached_token = token_data [' access_token' ]
222+
223+ self ._cached_token = token_data [" access_token" ]
217224 # expires_in is in seconds, convert to absolute time
218- self ._token_expires_at = time .time () + token_data .get (' expires_in' , 3600 )
219-
225+ self ._token_expires_at = time .time () + token_data .get (" expires_in" , 3600 )
226+
220227 return self ._cached_token
221228
222229 def _request_token (self ) -> dict :
223230 """Request a new token using OAuth client credentials flow."""
224231 data = {
225- ' grant_type' : ' client_credentials' ,
226- ' client_id' : self .client_id ,
227- ' client_secret' : self .client_secret ,
228- ' scope' : self .AZURE_DATABRICKS_SCOPE ,
232+ " grant_type" : " client_credentials" ,
233+ " client_id" : self .client_id ,
234+ " client_secret" : self .client_secret ,
235+ " scope" : self .AZURE_DATABRICKS_SCOPE ,
229236 }
230-
231- headers = {' Content-Type' : ' application/x-www-form-urlencoded' }
232-
237+
238+ headers = {" Content-Type" : " application/x-www-form-urlencoded" }
239+
233240 try :
234241 response = requests .post (self .token_endpoint , data = data , headers = headers )
235242 response .raise_for_status ()
236-
243+
237244 token_data = response .json ()
238-
239- if ' access_token' not in token_data :
245+
246+ if " access_token" not in token_data :
240247 raise ValueError ("No access_token in response: {}" .format (token_data ))
241-
248+
242249 return token_data
243-
250+
244251 except requests .exceptions .RequestException as e :
245252 raise RuntimeError ("Token request failed: {}" .format (e )) from e
246253 except ValueError as e :
0 commit comments