Skip to content

Commit bf0a2f6

Browse files
committed
Merge remote-tracking branch 'target/main' into close-conn
2 parents 4e66230 + d3df719 commit bf0a2f6

File tree

11 files changed

+118
-110
lines changed

11 files changed

+118
-110
lines changed

src/databricks/sql/auth/authenticators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def __init__(
199199
self.azure_client_secret = azure_client_secret
200200
self.azure_workspace_resource_id = azure_workspace_resource_id
201201
self.azure_tenant_id = azure_tenant_id or get_azure_tenant_id_from_host(
202-
hostname
202+
hostname, http_client
203203
)
204204
self._http_client = http_client
205205

src/databricks/sql/auth/common.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,18 +115,14 @@ def get_azure_tenant_id_from_host(host: str, http_client) -> str:
115115
login_url = f"{host}/aad/auth"
116116
logger.debug("Loading tenant ID from %s", login_url)
117117

118-
with http_client.request_context(
119-
HttpMethod.GET, login_url, allow_redirects=False
120-
) as resp:
121-
if resp.status // 100 != 3:
118+
with http_client.request_context(HttpMethod.GET, login_url) as resp:
119+
entra_id_endpoint = resp.retries.history[-1].redirect_location
120+
if entra_id_endpoint is None:
122121
raise ValueError(
123-
f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status}"
122+
f"No Location header in response from {login_url}: {entra_id_endpoint}"
124123
)
125-
entra_id_endpoint = dict(resp.headers).get("Location")
126-
if entra_id_endpoint is None:
127-
raise ValueError(f"No Location header in response from {login_url}")
128124

129-
# The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
125+
# The final redirect URL has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
130126
# The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
131127
url = urlparse(entra_id_endpoint)
132128
path_segments = url.path.split("/")

src/databricks/sql/auth/oauth.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -337,17 +337,17 @@ def refresh(self) -> Token:
337337
}
338338
)
339339

340-
with self._http_client.execute(
341-
method=HttpMethod.POST, url=self.token_url, headers=headers, data=data
342-
) as response:
343-
if response.status_code == 200:
344-
oauth_response = OAuthResponse(**response.json())
345-
return Token(
346-
oauth_response.access_token,
347-
oauth_response.token_type,
348-
oauth_response.refresh_token,
349-
)
350-
else:
351-
raise Exception(
352-
f"Failed to get token: {response.status_code} {response.text}"
353-
)
340+
response = self._http_client.request(
341+
method=HttpMethod.POST, url=self.token_url, headers=headers, body=data
342+
)
343+
if response.status == 200:
344+
oauth_response = OAuthResponse(**json.loads(response.data.decode("utf-8")))
345+
return Token(
346+
oauth_response.access_token,
347+
oauth_response.token_type,
348+
oauth_response.refresh_token,
349+
)
350+
else:
351+
raise Exception(
352+
f"Failed to get token: {response.status} {response.data.decode('utf-8')}"
353+
)

src/databricks/sql/auth/retry.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,14 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]:
355355
logger.info(f"Received status code {status_code} for {method} request")
356356

357357
# Request succeeded. Don't retry.
358-
if status_code // 100 == 2:
359-
return False, "2xx codes are not retried"
358+
if status_code // 100 <= 3:
359+
return False, "2xx/3xx codes are not retried"
360+
361+
if status_code == 400:
362+
return (
363+
False,
364+
"Received 400 - BAD_REQUEST. Please check the request parameters.",
365+
)
360366

361367
if status_code == 401:
362368
return (

src/databricks/sql/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def read(self) -> Optional[OAuthToken]:
279279
host_url=server_hostname,
280280
http_path=http_path,
281281
port=kwargs.get("_port", 443),
282-
http_client=self.http_client,
282+
client_context=client_context,
283283
user_agent=self.session.useragent_header
284284
if hasattr(self, "session")
285285
else None,
@@ -301,7 +301,7 @@ def read(self) -> Optional[OAuthToken]:
301301
auth_provider=self.session.auth_provider,
302302
host_url=self.session.host,
303303
batch_size=self.telemetry_batch_size,
304-
http_client=self.http_client,
304+
client_context=client_context,
305305
)
306306

307307
self._telemetry_client = TelemetryClientFactory.get_telemetry_client(

src/databricks/sql/common/unified_http_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,9 @@ def request_context(
154154
Yields:
155155
urllib3.HTTPResponse: The HTTP response object
156156
"""
157-
logger.debug("Making %s request to %s", method, url)
157+
logger.debug(
158+
"Making %s request to %s", method, urllib.parse.urlparse(url).netloc
159+
)
158160

159161
request_headers = self._prepare_headers(headers)
160162

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def __init__(
172172
host_url,
173173
executor,
174174
batch_size,
175-
http_client,
175+
client_context,
176176
):
177177
logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex)
178178
self._telemetry_enabled = telemetry_enabled
@@ -186,8 +186,8 @@ def __init__(
186186
self._host_url = host_url
187187
self._executor = executor
188188

189-
# Use the provided HTTP client directly
190-
self._http_client = http_client
189+
# Create own HTTP client from client context
190+
self._http_client = UnifiedHttpClient(client_context)
191191

192192
def _export_event(self, event):
193193
"""Add an event to the batch queue and flush if batch is full"""
@@ -456,7 +456,7 @@ def initialize_telemetry_client(
456456
auth_provider,
457457
host_url,
458458
batch_size,
459-
http_client,
459+
client_context,
460460
):
461461
"""Initialize a telemetry client for a specific connection if telemetry is enabled"""
462462
try:
@@ -479,7 +479,7 @@ def initialize_telemetry_client(
479479
host_url=host_url,
480480
executor=TelemetryClientFactory._executor,
481481
batch_size=batch_size,
482-
http_client=http_client,
482+
client_context=client_context,
483483
)
484484
else:
485485
TelemetryClientFactory._clients[
@@ -532,10 +532,10 @@ def connection_failure_log(
532532
host_url: str,
533533
http_path: str,
534534
port: int,
535-
http_client,
535+
client_context,
536536
user_agent: Optional[str] = None,
537537
):
538-
"""Send error telemetry when connection creation fails, using existing HTTP client"""
538+
"""Send error telemetry when connection creation fails, using provided client context"""
539539

540540
UNAUTH_DUMMY_SESSION_ID = "unauth_session_id"
541541

@@ -545,7 +545,7 @@ def connection_failure_log(
545545
auth_provider=None,
546546
host_url=host_url,
547547
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
548-
http_client=http_client,
548+
client_context=client_context,
549549
)
550550

551551
telemetry_client = TelemetryClientFactory.get_telemetry_client(

tests/e2e/common/retry_test_mixins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def test_retry_dangerous_codes(self, extra_params):
346346

347347
# These http codes are not retried by default
348348
# For some applications, idempotency is not important so we give users a way to force retries anyway
349-
DANGEROUS_CODES = [502, 504, 400]
349+
DANGEROUS_CODES = [502, 504]
350350

351351
additional_settings = {
352352
"_retry_dangerous_codes": DANGEROUS_CODES,

tests/e2e/test_concurrent_telemetry.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import time
55
from unittest.mock import patch
66
import pytest
7+
import json
78

89
from databricks.sql.telemetry.models.enums import StatementType
910
from databricks.sql.telemetry.telemetry_client import (
@@ -119,8 +120,12 @@ def execute_query_worker(thread_id):
119120
for future in done:
120121
try:
121122
response = future.result()
122-
response.raise_for_status()
123-
captured_responses.append(response.json())
123+
# Check status using urllib3 method (response.status instead of response.raise_for_status())
124+
if response.status >= 400:
125+
raise Exception(f"HTTP {response.status}: {getattr(response, 'reason', 'Unknown')}")
126+
# Parse JSON using urllib3 method (response.data.decode() instead of response.json())
127+
response_data = json.loads(response.data.decode()) if response.data else {}
128+
captured_responses.append(response_data)
124129
except Exception as e:
125130
captured_exceptions.append(e)
126131

tests/unit/test_auth.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -263,15 +263,11 @@ def test_get_token_success(self, token_source, http_response):
263263
with patch.object(token_source, "_http_client", mock_http_client):
264264
# Create a mock response with the expected format
265265
mock_response = MagicMock()
266-
mock_response.status_code = 200
267-
mock_response.json.return_value = {
268-
"access_token": "abc123",
269-
"token_type": "Bearer",
270-
"refresh_token": None,
271-
}
272-
# Mock the context manager (execute returns context manager)
273-
mock_http_client.execute.return_value.__enter__.return_value = mock_response
274-
mock_http_client.execute.return_value.__exit__.return_value = None
266+
mock_response.status = 200
267+
mock_response.data.decode.return_value = '{"access_token": "abc123", "token_type": "Bearer", "refresh_token": null}'
268+
269+
# Mock the request method to return the response directly
270+
mock_http_client.request.return_value = mock_response
275271

276272
token = token_source.get_token()
277273

@@ -287,12 +283,11 @@ def test_get_token_failure(self, token_source, http_response):
287283
with patch.object(token_source, "_http_client", mock_http_client):
288284
# Create a mock response with error
289285
mock_response = MagicMock()
290-
mock_response.status_code = 400
291-
mock_response.text = "Bad Request"
292-
mock_response.json.return_value = {"error": "invalid_client"}
293-
# Mock the context manager (execute returns context manager)
294-
mock_http_client.execute.return_value.__enter__.return_value = mock_response
295-
mock_http_client.execute.return_value.__exit__.return_value = None
286+
mock_response.status = 400
287+
mock_response.data.decode.return_value = "Bad Request"
288+
289+
# Mock the request method to return the response directly
290+
mock_http_client.request.return_value = mock_response
296291

297292
with pytest.raises(Exception) as e:
298293
token_source.get_token()

0 commit comments

Comments
 (0)