Skip to content

Commit 496d7f7

Browse files
committed
Revert "formatting (black) - fix some closures"
This reverts commit 67020f1.
1 parent 67020f1 commit 496d7f7

File tree

13 files changed

+60
-101
lines changed

13 files changed

+60
-101
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def __init__(
232232
try:
233233
self._transport.open()
234234
except:
235-
self._transport.close()
235+
self._transport.release_connection()
236236
raise
237237

238238
self._request_lock = threading.RLock()
@@ -607,7 +607,7 @@ def open_session(self, session_configuration, catalog, schema) -> SessionId:
607607
self._session_id_hex = session_id.hex_guid
608608
return session_id
609609
except:
610-
self._transport.close()
610+
self._transport.release_connection()
611611
raise
612612

613613
def close_session(self, session_id: SessionId) -> None:

src/databricks/sql/client.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,6 @@ def read(self) -> Optional[OAuthToken]:
284284
if hasattr(self, "session")
285285
else None,
286286
)
287-
if self.http_client:
288-
self.http_client.close()
289287
raise e
290288

291289
self.use_inline_params = self._set_use_inline_params_with_warning(

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,6 @@ def close(self):
359359
"""Flush remaining events before closing"""
360360
logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex)
361361
self._flush()
362-
self._http_client.close()
363362

364363

365364
class TelemetryClientFactory:
@@ -461,6 +460,7 @@ def initialize_telemetry_client(
461460
):
462461
"""Initialize a telemetry client for a specific connection if telemetry is enabled"""
463462
try:
463+
464464
with TelemetryClientFactory._lock:
465465
TelemetryClientFactory._initialize()
466466

tests/e2e/common/staging_ingestion_tests.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def test_staging_ingestion_life_cycle(self, ingestion_user):
8080

8181
# GET after REMOVE should fail
8282

83-
with pytest.raises(Error, match="too many 404 error responses"):
83+
with pytest.raises(
84+
Error, match="too many 404 error responses"
85+
):
8486
cursor = conn.cursor()
8587
query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv' TO '{new_temp_path}'"
8688
cursor.execute(query)

tests/e2e/common/uc_volume_tests.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def test_uc_volume_life_cycle(self, catalog, schema):
8080

8181
# GET after REMOVE should fail
8282

83-
with pytest.raises(Error, match="too many 404 error responses"):
83+
with pytest.raises(
84+
Error, match="too many 404 error responses"
85+
):
8486
cursor = conn.cursor()
8587
query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'"
8688
cursor.execute(query)

tests/e2e/test_concurrent_telemetry.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,9 @@ def execute_query_worker(thread_id):
122122
response = future.result()
123123
# Check status using urllib3 method (response.status instead of response.raise_for_status())
124124
if response.status >= 400:
125-
raise Exception(
126-
f"HTTP {response.status}: {getattr(response, 'reason', 'Unknown')}"
127-
)
125+
raise Exception(f"HTTP {response.status}: {getattr(response, 'reason', 'Unknown')}")
128126
# Parse JSON using urllib3 method (response.data.decode() instead of response.json())
129-
response_data = (
130-
json.loads(response.data.decode()) if response.data else {}
131-
)
127+
response_data = json.loads(response.data.decode()) if response.data else {}
132128
captured_responses.append(response_data)
133129
except Exception as e:
134130
captured_exceptions.append(e)

tests/e2e/test_driver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@
6464
for name in test_loader.getTestCaseNames(DecimalTestsMixin):
6565
if name.startswith("test_"):
6666
fn = getattr(DecimalTestsMixin, name)
67-
decorated = skipUnless(
68-
pysql_supports_arrow(), "Decimal tests need arrow support"
69-
)(fn)
67+
decorated = skipUnless(pysql_supports_arrow(), "Decimal tests need arrow support")(
68+
fn
69+
)
7070
setattr(DecimalTestsMixin, name, decorated)
7171

7272

tests/unit/test_auth.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,7 @@ def test_get_python_sql_connector_auth_provider_access_token(self):
145145
hostname = "moderakh-test.cloud.databricks.com"
146146
kwargs = {"access_token": "dpi123"}
147147
mock_http_client = MagicMock()
148-
auth_provider = get_python_sql_connector_auth_provider(
149-
hostname, mock_http_client, **kwargs
150-
)
148+
auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs)
151149
self.assertTrue(type(auth_provider).__name__, "AccessTokenAuthProvider")
152150

153151
headers = {}
@@ -165,9 +163,7 @@ def __call__(self, *args, **kwargs) -> HeaderFactory:
165163
hostname = "moderakh-test.cloud.databricks.com"
166164
kwargs = {"credentials_provider": MyProvider()}
167165
mock_http_client = MagicMock()
168-
auth_provider = get_python_sql_connector_auth_provider(
169-
hostname, mock_http_client, **kwargs
170-
)
166+
auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs)
171167
self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider")
172168

173169
headers = {}
@@ -183,9 +179,7 @@ def test_get_python_sql_connector_auth_provider_noop(self):
183179
"_use_cert_as_auth": use_cert_as_auth,
184180
}
185181
mock_http_client = MagicMock()
186-
auth_provider = get_python_sql_connector_auth_provider(
187-
hostname, mock_http_client, **kwargs
188-
)
182+
auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs)
189183
self.assertTrue(type(auth_provider).__name__, "CredentialProvider")
190184

191185
def test_get_python_sql_connector_basic_auth(self):
@@ -195,9 +189,7 @@ def test_get_python_sql_connector_basic_auth(self):
195189
}
196190
mock_http_client = MagicMock()
197191
with self.assertRaises(ValueError) as e:
198-
get_python_sql_connector_auth_provider(
199-
"foo.cloud.databricks.com", mock_http_client, **kwargs
200-
)
192+
get_python_sql_connector_auth_provider("foo.cloud.databricks.com", mock_http_client, **kwargs)
201193
self.assertIn(
202194
"Username/password authentication is no longer supported", str(e.exception)
203195
)
@@ -206,9 +198,7 @@ def test_get_python_sql_connector_basic_auth(self):
206198
def test_get_python_sql_connector_default_auth(self, mock__initial_get_token):
207199
hostname = "foo.cloud.databricks.com"
208200
mock_http_client = MagicMock()
209-
auth_provider = get_python_sql_connector_auth_provider(
210-
hostname, mock_http_client
211-
)
201+
auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client)
212202
self.assertTrue(type(auth_provider).__name__, "DatabricksOAuthProvider")
213203
self.assertTrue(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID)
214204

@@ -269,16 +259,16 @@ def test_no_token_refresh__when_token_is_not_expired(
269259

270260
def test_get_token_success(self, token_source, http_response):
271261
mock_http_client = MagicMock()
272-
262+
273263
with patch.object(token_source, "_http_client", mock_http_client):
274264
# Create a mock response with the expected format
275265
mock_response = MagicMock()
276266
mock_response.status = 200
277267
mock_response.data.decode.return_value = '{"access_token": "abc123", "token_type": "Bearer", "refresh_token": null}'
278-
268+
279269
# Mock the request method to return the response directly
280270
mock_http_client.request.return_value = mock_response
281-
271+
282272
token = token_source.get_token()
283273

284274
# Assert
@@ -289,16 +279,16 @@ def test_get_token_success(self, token_source, http_response):
289279

290280
def test_get_token_failure(self, token_source, http_response):
291281
mock_http_client = MagicMock()
292-
282+
293283
with patch.object(token_source, "_http_client", mock_http_client):
294284
# Create a mock response with error
295285
mock_response = MagicMock()
296286
mock_response.status = 400
297287
mock_response.data.decode.return_value = "Bad Request"
298-
288+
299289
# Mock the request method to return the response directly
300290
mock_http_client.request.return_value = mock_response
301-
291+
302292
with pytest.raises(Exception) as e:
303293
token_source.get_token()
304294
assert "Failed to get token: 400" in str(e.value)

tests/unit/test_cloud_fetch_queue.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,22 @@
1313

1414
@pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed")
1515
class CloudFetchQueueSuite(unittest.TestCase):
16-
def create_queue(
17-
self, schema_bytes=None, result_links=None, description=None, **kwargs
18-
):
16+
def create_queue(self, schema_bytes=None, result_links=None, description=None, **kwargs):
1917
"""Helper method to create ThriftCloudFetchQueue with sensible defaults"""
2018
# Set up defaults for commonly used parameters
2119
defaults = {
22-
"max_download_threads": 10,
23-
"ssl_options": SSLOptions(),
24-
"session_id_hex": Mock(),
25-
"statement_id": Mock(),
26-
"chunk_id": 0,
27-
"start_row_offset": 0,
28-
"lz4_compressed": True,
20+
'max_download_threads': 10,
21+
'ssl_options': SSLOptions(),
22+
'session_id_hex': Mock(),
23+
'statement_id': Mock(),
24+
'chunk_id': 0,
25+
'start_row_offset': 0,
26+
'lz4_compressed': True,
2927
}
30-
28+
3129
# Override defaults with any provided kwargs
3230
defaults.update(kwargs)
33-
31+
3432
mock_http_client = MagicMock()
3533
return utils.ThriftCloudFetchQueue(
3634
schema_bytes=schema_bytes or MagicMock(),
@@ -200,12 +198,7 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table):
200198
def test_next_n_rows_empty_table(self, mock_create_next_table):
201199
schema_bytes = self.get_schema_bytes()
202200
# Create description that matches the 4-column schema
203-
description = [
204-
("col0", "uint32"),
205-
("col1", "uint32"),
206-
("col2", "uint32"),
207-
("col3", "uint32"),
208-
]
201+
description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")]
209202
queue = self.create_queue(schema_bytes=schema_bytes, description=description)
210203
assert queue.table is None
211204

@@ -284,12 +277,7 @@ def test_remaining_rows_multiple_tables_fully_returned(
284277
def test_remaining_rows_empty_table(self, mock_create_next_table):
285278
schema_bytes = self.get_schema_bytes()
286279
# Create description that matches the 4-column schema
287-
description = [
288-
("col0", "uint32"),
289-
("col1", "uint32"),
290-
("col2", "uint32"),
291-
("col3", "uint32"),
292-
]
280+
description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")]
293281
queue = self.create_queue(schema_bytes=schema_bytes, description=description)
294282
assert queue.table is None
295283

tests/unit/test_downloader.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def test_run_uncompressed_successful(self, mock_time):
131131
self._setup_mock_http_response(mock_http_client, status=200, data=file_bytes)
132132

133133
# Patch the log metrics method to avoid division by zero
134-
with patch.object(downloader.ResultSetDownloadHandler, "_log_download_metrics"):
134+
with patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'):
135135
d = downloader.ResultSetDownloadHandler(
136136
settings,
137137
result_link,
@@ -160,16 +160,11 @@ def test_run_compressed_successful(self, mock_time):
160160
result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=xyz789"
161161

162162
# Setup mock HTTP response using helper method
163-
self._setup_mock_http_response(
164-
mock_http_client, status=200, data=compressed_bytes
165-
)
163+
self._setup_mock_http_response(mock_http_client, status=200, data=compressed_bytes)
166164

167165
# Mock the decompression method and log metrics to avoid issues
168-
with patch.object(
169-
downloader.ResultSetDownloadHandler,
170-
"_decompress_data",
171-
return_value=file_bytes,
172-
), patch.object(downloader.ResultSetDownloadHandler, "_log_download_metrics"):
166+
with patch.object(downloader.ResultSetDownloadHandler, '_decompress_data', return_value=file_bytes), \
167+
patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'):
173168
d = downloader.ResultSetDownloadHandler(
174169
settings,
175170
result_link,

0 commit comments

Comments
 (0)