11import base64
22import json
33import logging
4- import urllib .parse
54from datetime import datetime , timezone , timedelta
6- from typing import Dict , Optional , Any , Tuple , List , Union
5+ from typing import Dict , Optional , Any , Tuple
76from urllib .parse import urlparse
87
98import requests
109from requests .exceptions import RequestException
1110
1211from databricks .sql .auth .authenticators import CredentialsProvider , HeaderFactory
1312from databricks .sql .auth .endpoint import (
14- get_databricks_oidc_url ,
1513 get_oauth_endpoints ,
1614 infer_cloud_from_host ,
1715)
2523 "return_original_token_if_authenticated" : "true" ,
2624}
2725
28- # Special client IDs for different IdPs
29- AZURE_AD_MULTI_TENANT_APP_ID = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"
30-
31- # Buffer time in seconds before token expiry to trigger a refresh (5 minutes)
32- TOKEN_REFRESH_BUFFER_SECONDS = 300
26+ TOKEN_REFRESH_BUFFER_SECONDS = 10
3327
3428
3529class Token :
@@ -85,7 +79,6 @@ def __init__(
8579 self .hostname = hostname
8680 self .identity_federation_client_id = identity_federation_client_id
8781 self .external_provider_headers : Dict [str , str ] = {}
88- self .token = None
8982 self .token_endpoint : Optional [str ] = None
9083 self .idp_endpoints = None
9184 self .openid_config = None
@@ -123,9 +116,7 @@ def get_headers() -> Dict[str, str]:
123116 self .external_provider_headers = header_factory ()
124117
125118 # Extract the token from the headers
126- token_info = self ._extract_token_info_from_header (
127- self .external_provider_headers
128- )
119+ token_info = self ._extract_token_info_from_header (self .external_provider_headers )
129120 token_type , access_token = token_info
130121
131122 try :
@@ -148,10 +139,7 @@ def get_headers() -> Dict[str, str]:
148139 return self .external_provider_headers
149140 else :
150141 # Token is from a different host, need to exchange
151- return self ._try_token_exchange_or_fallback (
152- access_token , token_type
153- )
154-
142+ return self ._try_token_exchange_or_fallback (access_token , token_type )
155143 except Exception as e :
156144 logger .error (f"Failed to process token: { str (e )} " )
157145 # Fall back to original headers in case of error
@@ -171,10 +159,8 @@ def _init_oidc_discovery(self):
171159
172160 if self .idp_endpoints :
173161 # Get the OpenID configuration URL
174- openid_config_url = self .idp_endpoints .get_openid_config_url (
175- self .hostname
176- )
177-
162+ openid_config_url = self .idp_endpoints .get_openid_config_url (self .hostname )
163+
178164 # Fetch the OpenID configuration
179165 response = requests .get (openid_config_url )
180166 if response .status_code == 200 :
@@ -189,33 +175,26 @@ def _init_oidc_discovery(self):
189175
190176 # Fallback to default token endpoint if discovery fails
191177 if not self .token_endpoint :
192- # Make sure hostname has proper format with https:// prefix and trailing slash
193- hostname = self .hostname
194- if not hostname .startswith ("https://" ):
195- hostname = f"https://{ hostname } "
196- if not hostname .endswith ("/" ):
197- hostname = f"{ hostname } /"
178+ hostname = self ._format_hostname (self .hostname )
198179 self .token_endpoint = f"{ hostname } oidc/v1/token"
199180 logger .info (f"Using default token endpoint: { self .token_endpoint } " )
200-
201181 except Exception as e :
202182 logger .warning (
203183 f"OIDC discovery failed: { str (e )} . Using default token endpoint."
204184 )
205- # Make sure hostname has proper format with https:// prefix and trailing slash
206- hostname = self .hostname
207- if not hostname .startswith ("https://" ):
208- hostname = f"https://{ hostname } "
209- if not hostname .endswith ("/" ):
210- hostname = f"{ hostname } /"
185+ hostname = self ._format_hostname (self .hostname )
211186 self .token_endpoint = f"{ hostname } oidc/v1/token"
212- logger .info (
213- f"Using default token endpoint after error: { self .token_endpoint } "
214- )
187+ logger .info (f"Using default token endpoint after error: { self .token_endpoint } " )
215188
216- def _extract_token_info_from_header (
217- self , headers : Dict [str , str ]
218- ) -> Tuple [str , str ]:
189+ def _format_hostname (self , hostname : str ) -> str :
190+ """Format hostname to ensure it has proper https:// prefix and trailing slash."""
191+ if not hostname .startswith ("https://" ):
192+ hostname = f"https://{ hostname } "
193+ if not hostname .endswith ("/" ):
194+ hostname = f"{ hostname } /"
195+ return hostname
196+
197+ def _extract_token_info_from_header (self , headers : Dict [str , str ]) -> Tuple [str , str ]:
219198 """Extract token type and token value from authorization header."""
220199 auth_header = headers .get ("Authorization" )
221200 if not auth_header :
@@ -286,10 +265,6 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]:
286265 Attempt to refresh an expired token by first getting a fresh external token
287266 and then exchanging it for a new Databricks token.
288267
289- This implementation follows the JDBC driver approach by first requesting
290- a fresh token from the underlying credentials provider before performing
291- the token exchange.
292-
293268 Args:
294269 access_token: The original external access token (will be replaced)
295270 token_type: The token type (Bearer, etc.)
@@ -300,7 +275,7 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]:
300275 try :
301276 logger .info ("Refreshing expired token by getting a new external token" )
302277
303- # ENHANCEMENT: Get a fresh token from the underlying credentials provider
278+ # Get a fresh token from the underlying credentials provider
304279 # instead of reusing the same access_token
305280 fresh_headers = self .credentials_provider ()()
306281
@@ -333,20 +308,14 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]:
333308
334309 # Create new headers with the refreshed token
335310 headers = dict (fresh_headers ) # Use the fresh headers as base
336- headers [
337- "Authorization"
338- ] = f"{ refreshed_token .token_type } { refreshed_token .access_token } "
311+ headers ["Authorization" ] = f"{ refreshed_token .token_type } { refreshed_token .access_token } "
339312 return headers
340313 except Exception as e :
341- logger .error (
342- f"Token refresh failed, falling back to original token: { str (e )} "
343- )
314+ logger .error (f"Token refresh failed, falling back to original token: { str (e )} " )
344315 # If refresh fails, fall back to the original headers
345316 return self .external_provider_headers
346317
347- def _try_token_exchange_or_fallback (
348- self , access_token : str , token_type : str
349- ) -> Dict [str , str ]:
318+ def _try_token_exchange_or_fallback (self , access_token : str , token_type : str ) -> Dict [str , str ]:
350319 """Try to exchange the token or fall back to the original token."""
351320 try :
352321 # Parse the token to get claims for IdP-specific adjustments
@@ -362,14 +331,10 @@ def _try_token_exchange_or_fallback(
362331
363332 # Create new headers with the exchanged token
364333 headers = dict (self .external_provider_headers )
365- headers [
366- "Authorization"
367- ] = f"{ exchanged_token .token_type } { exchanged_token .access_token } "
334+ headers ["Authorization" ] = f"{ exchanged_token .token_type } { exchanged_token .access_token } "
368335 return headers
369336 except Exception as e :
370- logger .error (
371- f"Token exchange failed, falling back to using external token: { str (e )} "
372- )
337+ logger .error (f"Token exchange failed, falling back to using external token: { str (e )} " )
373338 # Fall back to original headers
374339 return self .external_provider_headers
375340
@@ -431,14 +396,10 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token
431396 try :
432397 # Calculate expiry by adding expires_in seconds to current time
433398 expires_in_seconds = int (resp_data ["expires_in" ])
434- token .expiry = datetime .now (tz = timezone .utc ) + timedelta (
435- seconds = expires_in_seconds
436- )
399+ token .expiry = datetime .now (tz = timezone .utc ) + timedelta (seconds = expires_in_seconds )
437400 logger .debug (f"Token expiry set from expires_in: { token .expiry } " )
438401 except (ValueError , TypeError ) as e :
439- logger .warning (
440- f"Could not parse expires_in from response: { str (e )} "
441- )
402+ logger .warning (f"Could not parse expires_in from response: { str (e )} " )
442403
443404 # If expires_in wasn't available, try to parse expiry from the token JWT
444405 if token .expiry == datetime .now (tz = timezone .utc ):
@@ -447,9 +408,7 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token
447408 exp_time = token_claims .get ("exp" )
448409 if exp_time :
449410 token .expiry = datetime .fromtimestamp (exp_time , tz = timezone .utc )
450- logger .debug (
451- f"Token expiry set from JWT exp claim: { token .expiry } "
452- )
411+ logger .debug (f"Token expiry set from JWT exp claim: { token .expiry } " )
453412 except Exception as e :
454413 logger .warning (f"Could not parse expiry from token: { str (e )} " )
455414
0 commit comments