Skip to content

Commit ae5ee50

Browse files
committed
address comment
1 parent 02d16d2 commit ae5ee50

File tree

5 files changed

+498
-454
lines changed

5 files changed

+498
-454
lines changed

src/databricks/sql/auth/auth.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
AzureServicePrincipalCredentialProvider,
99
)
1010
from databricks.sql.auth.common import AuthType, ClientContext
11-
from databricks.sql.auth.token_federation import TokenFederationProvider, ExternalTokenProvider
11+
from databricks.sql.auth.token_federation import TokenFederationProvider
1212

1313

1414
def get_auth_provider(cfg: ClientContext, http_client):
1515
# Determine the base auth provider
1616
base_provider = None
17-
17+
1818
if cfg.credentials_provider:
1919
base_provider = ExternalAuthProvider(cfg.credentials_provider)
2020
elif cfg.auth_type == AuthType.AZURE_SP_M2M.value:
@@ -64,16 +64,16 @@ def get_auth_provider(cfg: ClientContext, http_client):
6464
)
6565
else:
6666
raise RuntimeError("No valid authentication settings!")
67-
68-
# Wrap with token federation if enabled
69-
if cfg.enable_token_federation and base_provider:
67+
68+
# Always wrap with token federation (falls back gracefully if not needed)
69+
if base_provider:
7070
return TokenFederationProvider(
7171
hostname=cfg.hostname,
7272
external_provider=base_provider,
7373
http_client=http_client,
7474
identity_federation_client_id=cfg.identity_federation_client_id,
7575
)
76-
76+
7777
return base_provider
7878

7979

@@ -129,8 +129,6 @@ def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs)
129129
else redirect_port_range,
130130
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
131131
credentials_provider=kwargs.get("credentials_provider"),
132-
# Token federation parameters
133-
enable_token_federation=kwargs.get("enable_token_federation", False),
134132
identity_federation_client_id=kwargs.get("identity_federation_client_id"),
135133
)
136134
return get_auth_provider(cfg, http_client)
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import logging
2+
import jwt
3+
from datetime import datetime, timedelta
4+
from typing import Optional, Dict, Tuple
5+
from urllib.parse import urlparse
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
def parse_hostname(hostname: str) -> str:
11+
"""
12+
Normalize the hostname to include scheme and trailing slash.
13+
14+
Args:
15+
hostname: The hostname to normalize
16+
17+
Returns:
18+
Normalized hostname with scheme and trailing slash
19+
"""
20+
if not hostname.startswith("http://") and not hostname.startswith("https://"):
21+
hostname = f"https://{hostname}"
22+
if not hostname.endswith("/"):
23+
hostname = f"{hostname}/"
24+
return hostname
25+
26+
27+
def decode_token(access_token: str) -> Optional[Dict]:
28+
"""
29+
Decode a JWT token without verification to extract claims.
30+
31+
Args:
32+
access_token: The JWT access token to decode
33+
34+
Returns:
35+
Decoded token claims or None if decoding fails
36+
"""
37+
try:
38+
return jwt.decode(access_token, options={"verify_signature": False})
39+
except Exception as e:
40+
logger.debug("Failed to decode JWT token: %s", e)
41+
return None
42+
43+
44+
def is_same_host(url1: str, url2: str) -> bool:
45+
"""
46+
Check if two URLs have the same host.
47+
48+
Args:
49+
url1: First URL
50+
url2: Second URL
51+
52+
Returns:
53+
True if hosts are the same, False otherwise
54+
"""
55+
try:
56+
host1 = urlparse(url1).netloc
57+
host2 = urlparse(url2).netloc
58+
# Handle port differences (e.g., example.com vs example.com:443)
59+
host1_without_port = host1.split(":")[0]
60+
host2_without_port = host2.split(":")[0]
61+
return host1_without_port == host2_without_port
62+
except Exception as e:
63+
logger.debug("Failed to parse URLs: %s", e)
64+
return False
65+
66+
67+
class Token:
68+
"""
69+
Represents an OAuth token with expiration management.
70+
"""
71+
72+
def __init__(self, access_token: str, token_type: str = "Bearer"):
73+
"""
74+
Initialize a token.
75+
76+
Args:
77+
access_token: The access token string
78+
token_type: The token type (default: Bearer)
79+
"""
80+
self.access_token = access_token
81+
self.token_type = token_type
82+
self.expiry_time = self._calculate_expiry()
83+
84+
def _calculate_expiry(self) -> datetime:
85+
"""
86+
Calculate the token expiry time from JWT claims.
87+
88+
Returns:
89+
The token expiry datetime
90+
"""
91+
decoded = decode_token(self.access_token)
92+
if decoded and "exp" in decoded:
93+
# Use JWT exp claim with 1 minute buffer
94+
return datetime.fromtimestamp(decoded["exp"]) - timedelta(minutes=1)
95+
# Default to 1 hour if no expiry info
96+
return datetime.now() + timedelta(hours=1)
97+
98+
def is_expired(self) -> bool:
99+
"""
100+
Check if the token is expired.
101+
102+
Returns:
103+
True if token is expired, False otherwise
104+
"""
105+
return datetime.now() >= self.expiry_time
106+
107+
def to_dict(self) -> Dict[str, str]:
108+
"""
109+
Convert token to dictionary format.
110+
111+
Returns:
112+
Dictionary with access_token and token_type
113+
"""
114+
return {
115+
"access_token": self.access_token,
116+
"token_type": self.token_type,
117+
}

src/databricks/sql/auth/common.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ def __init__(
3737
tls_client_cert_file: Optional[str] = None,
3838
oauth_persistence=None,
3939
credentials_provider=None,
40-
# Token federation parameters
41-
enable_token_federation: bool = False,
4240
identity_federation_client_id: Optional[str] = None,
4341
# HTTP client configuration parameters
4442
ssl_options=None, # SSLOptions type
@@ -68,8 +66,6 @@ def __init__(
6866
self.tls_client_cert_file = tls_client_cert_file
6967
self.oauth_persistence = oauth_persistence
7068
self.credentials_provider = credentials_provider
71-
# Token federation
72-
self.enable_token_federation = enable_token_federation
7369
self.identity_federation_client_id = identity_federation_client_id
7470

7571
# HTTP client configuration

0 commit comments

Comments
 (0)