Skip to content

Commit 2a1f719

Browse files
more fixes
Signed-off-by: Vikrant Puppala <vikrant.puppala@databricks.com>
1 parent cba3da7 commit 2a1f719

File tree

10 files changed

+121
-55
lines changed

10 files changed

+121
-55
lines changed

src/databricks/sql/auth/auth.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def get_auth_provider(cfg: ClientContext, http_client):
1919
cfg.hostname,
2020
cfg.azure_client_id,
2121
cfg.azure_client_secret,
22+
http_client,
2223
cfg.azure_tenant_id,
2324
cfg.azure_workspace_resource_id,
2425
)
@@ -34,8 +35,8 @@ def get_auth_provider(cfg: ClientContext, http_client):
3435
cfg.oauth_redirect_port_range,
3536
cfg.oauth_client_id,
3637
cfg.oauth_scopes,
38+
http_client,
3739
cfg.auth_type,
38-
http_client=http_client,
3940
)
4041
elif cfg.access_token is not None:
4142
return AccessTokenAuthProvider(cfg.access_token)
@@ -54,7 +55,8 @@ def get_auth_provider(cfg: ClientContext, http_client):
5455
cfg.oauth_redirect_port_range,
5556
cfg.oauth_client_id,
5657
cfg.oauth_scopes,
57-
http_client=http_client,
58+
http_client,
59+
cfg.auth_type or "databricks-oauth",
5860
)
5961
else:
6062
raise RuntimeError("No valid authentication settings!")

src/databricks/sql/auth/authenticators.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def __init__(
190190
hostname,
191191
azure_client_id,
192192
azure_client_secret,
193+
http_client,
193194
azure_tenant_id=None,
194195
azure_workspace_resource_id=None,
195196
):
@@ -200,6 +201,7 @@ def __init__(
200201
self.azure_tenant_id = azure_tenant_id or get_azure_tenant_id_from_host(
201202
hostname
202203
)
204+
self._http_client = http_client
203205

204206
def auth_type(self) -> str:
205207
return AuthType.AZURE_SP_M2M.value
@@ -209,6 +211,7 @@ def get_token_source(self, resource: str) -> RefreshableTokenSource:
209211
token_url=f"{self.AZURE_AAD_ENDPOINT}/{self.azure_tenant_id}/{self.AZURE_TOKEN_ENDPOINT}",
210212
client_id=self.azure_client_id,
211213
client_secret=self.azure_client_secret,
214+
http_client=self._http_client,
212215
extra_params={"resource": resource},
213216
)
214217

src/databricks/sql/auth/oauth.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,7 @@ def __send_auth_code_token_request(
190190
data = f"{token_request_body}&code_verifier={verifier}"
191191
return self.__send_token_request(token_request_url, data)
192192

193-
@staticmethod
194-
def __send_token_request(token_request_url, data):
193+
def __send_token_request(self, token_request_url, data):
195194
headers = {
196195
"Accept": "application/json",
197196
"Content-Type": "application/x-www-form-urlencoded",
@@ -210,7 +209,7 @@ def __send_refresh_token_request(self, hostname, refresh_token):
210209
token_request_body = client.prepare_refresh_body(
211210
refresh_token=refresh_token, client_id=client.client_id
212211
)
213-
return OAuthManager.__send_token_request(token_request_url, token_request_body)
212+
return self.__send_token_request(token_request_url, token_request_body)
214213

215214
@staticmethod
216215
def __get_tokens_from_response(oauth_response):

src/databricks/sql/backend/thrift_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import List, Optional, Union, Any, TYPE_CHECKING
99
from uuid import UUID
1010

11+
from databricks.sql.common.unified_http_client import UnifiedHttpClient
1112
from databricks.sql.result_set import ThriftResultSet
1213
from databricks.sql.telemetry.models.event import StatementType
1314

@@ -105,7 +106,7 @@ def __init__(
105106
http_headers,
106107
auth_provider: AuthProvider,
107108
ssl_options: SSLOptions,
108-
http_client,
109+
http_client: UnifiedHttpClient,
109110
**kwargs,
110111
):
111112
# Internal arguments in **kwargs:

src/databricks/sql/client.py

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -799,12 +799,6 @@ def _handle_staging_put(
799799
r = self.connection.session.http_client.request(
800800
"PUT", presigned_url, body=fh.read(), headers=headers
801801
)
802-
# Add compatibility attributes for urllib3 response
803-
r.status_code = r.status
804-
if hasattr(r, "data"):
805-
r.content = r.data
806-
r.ok = r.status < 400
807-
r.text = r.data.decode() if r.data else ""
808802

809803
# fmt: off
810804
# HTTP status codes
@@ -814,13 +808,15 @@ def _handle_staging_put(
814808
NO_CONTENT = 204
815809
# fmt: on
816810

817-
if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]:
811+
if r.status not in [OK, CREATED, NO_CONTENT, ACCEPTED]:
812+
# Decode response data for error message
813+
error_text = r.data.decode() if r.data else ""
818814
raise OperationalError(
819-
f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}",
815+
f"Staging operation over HTTP was unsuccessful: {r.status}-{error_text}",
820816
session_id_hex=self.connection.get_session_id_hex(),
821817
)
822818

823-
if r.status_code == ACCEPTED:
819+
if r.status == ACCEPTED:
824820
logger.debug(
825821
f"Response code {ACCEPTED} from server indicates ingestion command was accepted "
826822
+ "but not yet applied on the server. It's possible this command may fail later."
@@ -844,23 +840,19 @@ def _handle_staging_get(
844840
r = self.connection.session.http_client.request(
845841
"GET", presigned_url, headers=headers
846842
)
847-
# Add compatibility attributes for urllib3 response
848-
r.status_code = r.status
849-
if hasattr(r, "data"):
850-
r.content = r.data
851-
r.ok = r.status < 400
852-
r.text = r.data.decode() if r.data else ""
853843

854844
# response.ok verifies the status code is not between 400-600.
855845
# Any 2xx or 3xx will evaluate r.ok == True
856-
if not r.ok:
846+
if r.status >= 400:
847+
# Decode response data for error message
848+
error_text = r.data.decode() if r.data else ""
857849
raise OperationalError(
858-
f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}",
850+
f"Staging operation over HTTP was unsuccessful: {r.status}-{error_text}",
859851
session_id_hex=self.connection.get_session_id_hex(),
860852
)
861853

862854
with open(local_file, "wb") as fp:
863-
fp.write(r.content)
855+
fp.write(r.data)
864856

865857
@log_latency(StatementType.SQL)
866858
def _handle_staging_remove(
@@ -871,16 +863,12 @@ def _handle_staging_remove(
871863
r = self.connection.session.http_client.request(
872864
"DELETE", presigned_url, headers=headers
873865
)
874-
# Add compatibility attributes for urllib3 response
875-
r.status_code = r.status
876-
if hasattr(r, "data"):
877-
r.content = r.data
878-
r.ok = r.status < 400
879-
r.text = r.data.decode() if r.data else ""
880-
881-
if not r.ok:
866+
867+
if r.status >= 400:
868+
# Decode response data for error message
869+
error_text = r.data.decode() if r.data else ""
882870
raise OperationalError(
883-
f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}",
871+
f"Staging operation over HTTP was unsuccessful: {r.status}-{error_text}",
884872
session_id_hex=self.connection.get_session_id_hex(),
885873
)
886874

src/databricks/sql/common/feature_flag.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,11 @@ def _refresh_flags(self):
113113
response = self._http_client.request(
114114
"GET", self._feature_flag_endpoint, headers=headers, timeout=30
115115
)
116-
# Add compatibility attributes for urllib3 response
117-
response.status_code = response.status
118-
response.json = lambda: json.loads(response.data.decode())
119116

120-
if response.status_code == 200:
121-
ff_response = FeatureFlagsResponse.from_dict(response.json())
117+
if response.status == 200:
118+
# Parse JSON response from urllib3 response data
119+
response_data = json.loads(response.data.decode())
120+
ff_response = FeatureFlagsResponse.from_dict(response_data)
122121
self._update_cache_from_response(ff_response)
123122
else:
124123
# On failure, initialize with an empty dictionary to prevent re-blocking.

tests/unit/test_auth.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ def credential_provider(self):
306306
hostname="hostname",
307307
azure_client_id="client_id",
308308
azure_client_secret="client_secret",
309+
http_client=MagicMock(),
309310
azure_tenant_id="tenant_id",
310311
)
311312

tests/unit/test_retry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_sleep__no_retry_after(self, t_mock, retry_policy, error_history):
3434
retry_policy.history = [error_history, error_history]
3535
retry_policy.sleep(HTTPResponse(status=503))
3636

37-
expected_backoff_time = min(
37+
expected_backoff_time = max(
3838
self.calculate_backoff_time(
3939
0, retry_policy.delay_min, retry_policy.delay_max
4040
),
@@ -57,7 +57,7 @@ def test_sleep__no_retry_after_header__multiple_retries(self, t_mock, retry_poli
5757
expected_backoff_times = []
5858
for attempt in range(num_attempts):
5959
expected_backoff_times.append(
60-
min(
60+
max(
6161
self.calculate_backoff_time(
6262
attempt, retry_policy.delay_min, retry_policy.delay_max
6363
),

tests/unit/test_telemetry.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,17 @@
1919

2020

2121
@pytest.fixture
22-
def mock_telemetry_client():
22+
@patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton")
23+
def mock_telemetry_client(mock_singleton_class):
2324
"""Create a mock telemetry client for testing."""
2425
session_id = str(uuid.uuid4())
2526
auth_provider = AccessTokenAuthProvider("test-token")
2627
executor = MagicMock()
27-
mock_http_client = MagicMock()
28+
mock_client_context = MagicMock()
29+
30+
# Mock the singleton to return a mock HTTP client
31+
mock_singleton = mock_singleton_class.return_value
32+
mock_singleton.get_http_client.return_value = MagicMock()
2833

2934
return TelemetryClient(
3035
telemetry_enabled=True,
@@ -33,7 +38,7 @@ def mock_telemetry_client():
3338
host_url="test-host.com",
3439
executor=executor,
3540
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
36-
http_client=mock_http_client,
41+
client_context=mock_client_context,
3742
)
3843

3944

@@ -212,11 +217,16 @@ def telemetry_system_reset(self):
212217
TelemetryClientFactory._executor = None
213218
TelemetryClientFactory._initialized = False
214219

215-
def test_client_lifecycle_flow(self):
220+
@patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton")
221+
def test_client_lifecycle_flow(self, mock_singleton_class):
216222
"""Test complete client lifecycle: initialize -> use -> close."""
217223
session_id_hex = "test-session"
218224
auth_provider = AccessTokenAuthProvider("token")
219-
mock_http_client = MagicMock()
225+
mock_client_context = MagicMock()
226+
227+
# Mock the singleton to return a mock HTTP client
228+
mock_singleton = mock_singleton_class.return_value
229+
mock_singleton.get_http_client.return_value = MagicMock()
220230

221231
# Initialize enabled client
222232
TelemetryClientFactory.initialize_telemetry_client(
@@ -225,7 +235,7 @@ def test_client_lifecycle_flow(self):
225235
auth_provider=auth_provider,
226236
host_url="test-host.com",
227237
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
228-
http_client=mock_http_client,
238+
client_context=mock_client_context,
229239
)
230240

231241
client = TelemetryClientFactory.get_telemetry_client(session_id_hex)
@@ -241,27 +251,37 @@ def test_client_lifecycle_flow(self):
241251
client = TelemetryClientFactory.get_telemetry_client(session_id_hex)
242252
assert isinstance(client, NoopTelemetryClient)
243253

244-
def test_disabled_telemetry_flow(self):
254+
@patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton")
255+
def test_disabled_telemetry_flow(self, mock_singleton_class):
245256
"""Test that disabled telemetry uses NoopTelemetryClient."""
246257
session_id_hex = "test-session"
247-
mock_http_client = MagicMock()
258+
mock_client_context = MagicMock()
259+
260+
# Mock the singleton to return a mock HTTP client
261+
mock_singleton = mock_singleton_class.return_value
262+
mock_singleton.get_http_client.return_value = MagicMock()
248263

249264
TelemetryClientFactory.initialize_telemetry_client(
250265
telemetry_enabled=False,
251266
session_id_hex=session_id_hex,
252267
auth_provider=None,
253268
host_url="test-host.com",
254269
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
255-
http_client=mock_http_client,
270+
client_context=mock_client_context,
256271
)
257272

258273
client = TelemetryClientFactory.get_telemetry_client(session_id_hex)
259274
assert isinstance(client, NoopTelemetryClient)
260275

261-
def test_factory_error_handling(self):
276+
@patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton")
277+
def test_factory_error_handling(self, mock_singleton_class):
262278
"""Test that factory errors fall back to NoopTelemetryClient."""
263279
session_id = "test-session"
264-
mock_http_client = MagicMock()
280+
mock_client_context = MagicMock()
281+
282+
# Mock the singleton to return a mock HTTP client
283+
mock_singleton = mock_singleton_class.return_value
284+
mock_singleton.get_http_client.return_value = MagicMock()
265285

266286
# Simulate initialization error
267287
with patch(
@@ -274,18 +294,23 @@ def test_factory_error_handling(self):
274294
auth_provider=AccessTokenAuthProvider("token"),
275295
host_url="test-host.com",
276296
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
277-
http_client=mock_http_client,
297+
client_context=mock_client_context,
278298
)
279299

280300
# Should fall back to NoopTelemetryClient
281301
client = TelemetryClientFactory.get_telemetry_client(session_id)
282302
assert isinstance(client, NoopTelemetryClient)
283303

284-
def test_factory_shutdown_flow(self):
304+
@patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton")
305+
def test_factory_shutdown_flow(self, mock_singleton_class):
285306
"""Test factory shutdown when last client is removed."""
286307
session1 = "session-1"
287308
session2 = "session-2"
288-
mock_http_client = MagicMock()
309+
mock_client_context = MagicMock()
310+
311+
# Mock the singleton to return a mock HTTP client
312+
mock_singleton = mock_singleton_class.return_value
313+
mock_singleton.get_http_client.return_value = MagicMock()
289314

290315
# Initialize multiple clients
291316
for session in [session1, session2]:
@@ -295,7 +320,7 @@ def test_factory_shutdown_flow(self):
295320
auth_provider=AccessTokenAuthProvider("token"),
296321
host_url="test-host.com",
297322
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
298-
http_client=mock_http_client,
323+
client_context=mock_client_context,
299324
)
300325

301326
# Factory should be initialized

0 commit comments

Comments
 (0)