Skip to content

Commit 34413f3

Browse files
committed
fmt
1 parent aa2d1b9 commit 34413f3

File tree

1 file changed

+27
-68
lines changed

1 file changed

+27
-68
lines changed

src/databricks/sql/auth/token_federation.py

Lines changed: 27 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
import base64
22
import json
33
import logging
4-
import urllib.parse
54
from datetime import datetime, timezone, timedelta
6-
from typing import Dict, Optional, Any, Tuple, List, Union
5+
from typing import Dict, Optional, Any, Tuple
76
from urllib.parse import urlparse
87

98
import requests
109
from requests.exceptions import RequestException
1110

1211
from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory
1312
from databricks.sql.auth.endpoint import (
14-
get_databricks_oidc_url,
1513
get_oauth_endpoints,
1614
infer_cloud_from_host,
1715
)
@@ -25,11 +23,7 @@
2523
"return_original_token_if_authenticated": "true",
2624
}
2725

28-
# Special client IDs for different IdPs
29-
AZURE_AD_MULTI_TENANT_APP_ID = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"
30-
31-
# Buffer time in seconds before token expiry to trigger a refresh (5 minutes)
32-
TOKEN_REFRESH_BUFFER_SECONDS = 300
26+
TOKEN_REFRESH_BUFFER_SECONDS = 10
3327

3428

3529
class Token:
@@ -85,7 +79,6 @@ def __init__(
8579
self.hostname = hostname
8680
self.identity_federation_client_id = identity_federation_client_id
8781
self.external_provider_headers: Dict[str, str] = {}
88-
self.token = None
8982
self.token_endpoint: Optional[str] = None
9083
self.idp_endpoints = None
9184
self.openid_config = None
@@ -123,9 +116,7 @@ def get_headers() -> Dict[str, str]:
123116
self.external_provider_headers = header_factory()
124117

125118
# Extract the token from the headers
126-
token_info = self._extract_token_info_from_header(
127-
self.external_provider_headers
128-
)
119+
token_info = self._extract_token_info_from_header(self.external_provider_headers)
129120
token_type, access_token = token_info
130121

131122
try:
@@ -148,10 +139,7 @@ def get_headers() -> Dict[str, str]:
148139
return self.external_provider_headers
149140
else:
150141
# Token is from a different host, need to exchange
151-
return self._try_token_exchange_or_fallback(
152-
access_token, token_type
153-
)
154-
142+
return self._try_token_exchange_or_fallback(access_token, token_type)
155143
except Exception as e:
156144
logger.error(f"Failed to process token: {str(e)}")
157145
# Fall back to original headers in case of error
@@ -171,10 +159,8 @@ def _init_oidc_discovery(self):
171159

172160
if self.idp_endpoints:
173161
# Get the OpenID configuration URL
174-
openid_config_url = self.idp_endpoints.get_openid_config_url(
175-
self.hostname
176-
)
177-
162+
openid_config_url = self.idp_endpoints.get_openid_config_url(self.hostname)
163+
178164
# Fetch the OpenID configuration
179165
response = requests.get(openid_config_url)
180166
if response.status_code == 200:
@@ -189,33 +175,26 @@ def _init_oidc_discovery(self):
189175

190176
# Fallback to default token endpoint if discovery fails
191177
if not self.token_endpoint:
192-
# Make sure hostname has proper format with https:// prefix and trailing slash
193-
hostname = self.hostname
194-
if not hostname.startswith("https://"):
195-
hostname = f"https://{hostname}"
196-
if not hostname.endswith("/"):
197-
hostname = f"{hostname}/"
178+
hostname = self._format_hostname(self.hostname)
198179
self.token_endpoint = f"{hostname}oidc/v1/token"
199180
logger.info(f"Using default token endpoint: {self.token_endpoint}")
200-
201181
except Exception as e:
202182
logger.warning(
203183
f"OIDC discovery failed: {str(e)}. Using default token endpoint."
204184
)
205-
# Make sure hostname has proper format with https:// prefix and trailing slash
206-
hostname = self.hostname
207-
if not hostname.startswith("https://"):
208-
hostname = f"https://{hostname}"
209-
if not hostname.endswith("/"):
210-
hostname = f"{hostname}/"
185+
hostname = self._format_hostname(self.hostname)
211186
self.token_endpoint = f"{hostname}oidc/v1/token"
212-
logger.info(
213-
f"Using default token endpoint after error: {self.token_endpoint}"
214-
)
187+
logger.info(f"Using default token endpoint after error: {self.token_endpoint}")
215188

216-
def _extract_token_info_from_header(
217-
self, headers: Dict[str, str]
218-
) -> Tuple[str, str]:
189+
def _format_hostname(self, hostname: str) -> str:
190+
"""Format hostname to ensure it has proper https:// prefix and trailing slash."""
191+
if not hostname.startswith("https://"):
192+
hostname = f"https://{hostname}"
193+
if not hostname.endswith("/"):
194+
hostname = f"{hostname}/"
195+
return hostname
196+
197+
def _extract_token_info_from_header(self, headers: Dict[str, str]) -> Tuple[str, str]:
219198
"""Extract token type and token value from authorization header."""
220199
auth_header = headers.get("Authorization")
221200
if not auth_header:
@@ -286,10 +265,6 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]:
286265
Attempt to refresh an expired token by first getting a fresh external token
287266
and then exchanging it for a new Databricks token.
288267
289-
This implementation follows the JDBC driver approach by first requesting
290-
a fresh token from the underlying credentials provider before performing
291-
the token exchange.
292-
293268
Args:
294269
access_token: The original external access token (will be replaced)
295270
token_type: The token type (Bearer, etc.)
@@ -300,7 +275,7 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]:
300275
try:
301276
logger.info("Refreshing expired token by getting a new external token")
302277

303-
# ENHANCEMENT: Get a fresh token from the underlying credentials provider
278+
# Get a fresh token from the underlying credentials provider
304279
# instead of reusing the same access_token
305280
fresh_headers = self.credentials_provider()()
306281

@@ -333,20 +308,14 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]:
333308

334309
# Create new headers with the refreshed token
335310
headers = dict(fresh_headers) # Use the fresh headers as base
336-
headers[
337-
"Authorization"
338-
] = f"{refreshed_token.token_type} {refreshed_token.access_token}"
311+
headers["Authorization"] = f"{refreshed_token.token_type} {refreshed_token.access_token}"
339312
return headers
340313
except Exception as e:
341-
logger.error(
342-
f"Token refresh failed, falling back to original token: {str(e)}"
343-
)
314+
logger.error(f"Token refresh failed, falling back to original token: {str(e)}")
344315
# If refresh fails, fall back to the original headers
345316
return self.external_provider_headers
346317

347-
def _try_token_exchange_or_fallback(
348-
self, access_token: str, token_type: str
349-
) -> Dict[str, str]:
318+
def _try_token_exchange_or_fallback(self, access_token: str, token_type: str) -> Dict[str, str]:
350319
"""Try to exchange the token or fall back to the original token."""
351320
try:
352321
# Parse the token to get claims for IdP-specific adjustments
@@ -362,14 +331,10 @@ def _try_token_exchange_or_fallback(
362331

363332
# Create new headers with the exchanged token
364333
headers = dict(self.external_provider_headers)
365-
headers[
366-
"Authorization"
367-
] = f"{exchanged_token.token_type} {exchanged_token.access_token}"
334+
headers["Authorization"] = f"{exchanged_token.token_type} {exchanged_token.access_token}"
368335
return headers
369336
except Exception as e:
370-
logger.error(
371-
f"Token exchange failed, falling back to using external token: {str(e)}"
372-
)
337+
logger.error(f"Token exchange failed, falling back to using external token: {str(e)}")
373338
# Fall back to original headers
374339
return self.external_provider_headers
375340

@@ -431,14 +396,10 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token
431396
try:
432397
# Calculate expiry by adding expires_in seconds to current time
433398
expires_in_seconds = int(resp_data["expires_in"])
434-
token.expiry = datetime.now(tz=timezone.utc) + timedelta(
435-
seconds=expires_in_seconds
436-
)
399+
token.expiry = datetime.now(tz=timezone.utc) + timedelta(seconds=expires_in_seconds)
437400
logger.debug(f"Token expiry set from expires_in: {token.expiry}")
438401
except (ValueError, TypeError) as e:
439-
logger.warning(
440-
f"Could not parse expires_in from response: {str(e)}"
441-
)
402+
logger.warning(f"Could not parse expires_in from response: {str(e)}")
442403

443404
# If expires_in wasn't available, try to parse expiry from the token JWT
444405
if token.expiry == datetime.now(tz=timezone.utc):
@@ -447,9 +408,7 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token
447408
exp_time = token_claims.get("exp")
448409
if exp_time:
449410
token.expiry = datetime.fromtimestamp(exp_time, tz=timezone.utc)
450-
logger.debug(
451-
f"Token expiry set from JWT exp claim: {token.expiry}"
452-
)
411+
logger.debug(f"Token expiry set from JWT exp claim: {token.expiry}")
453412
except Exception as e:
454413
logger.warning(f"Could not parse expiry from token: {str(e)}")
455414

0 commit comments

Comments
 (0)