Skip to content

Commit 2201765

Browse files
Merge branch 'main' into sea-http-client
2 parents 20c705f + 0a7a6ab commit 2201765

File tree

12 files changed

+236
-69
lines changed

12 files changed

+236
-69
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
790790
direct_results = resp.directResults
791791
has_been_closed_server_side = direct_results and direct_results.closeOperation
792792

793-
is_direct_results = (
793+
has_more_rows = (
794794
(not direct_results)
795795
or (not direct_results.resultSet)
796796
or direct_results.resultSet.hasMoreRows
@@ -831,7 +831,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
831831
result_format=t_result_set_metadata_resp.resultFormat,
832832
)
833833

834-
return execute_response, is_direct_results
834+
return execute_response, has_more_rows
835835

836836
def get_execution_result(
837837
self, command_id: CommandId, cursor: Cursor
@@ -876,7 +876,7 @@ def get_execution_result(
876876

877877
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
878878
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
879-
is_direct_results = resp.hasMoreRows
879+
has_more_rows = resp.hasMoreRows
880880

881881
status = CommandState.from_thrift_state(resp.status) or CommandState.RUNNING
882882

@@ -902,7 +902,7 @@ def get_execution_result(
902902
t_row_set=resp.results,
903903
max_download_threads=self.max_download_threads,
904904
ssl_options=self._ssl_options,
905-
is_direct_results=is_direct_results,
905+
has_more_rows=has_more_rows,
906906
)
907907

908908
def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
@@ -1021,7 +1021,7 @@ def execute_command(
10211021
self._handle_execute_response_async(resp, cursor)
10221022
return None
10231023
else:
1024-
execute_response, is_direct_results = self._handle_execute_response(
1024+
execute_response, has_more_rows = self._handle_execute_response(
10251025
resp, cursor
10261026
)
10271027

@@ -1039,8 +1039,7 @@ def execute_command(
10391039
t_row_set=t_row_set,
10401040
max_download_threads=self.max_download_threads,
10411041
ssl_options=self._ssl_options,
1042-
is_direct_results=is_direct_results,
1043-
session_id_hex=self._session_id_hex,
1042+
has_more_rows=has_more_rows,
10441043
)
10451044

10461045
def get_catalogs(
@@ -1062,9 +1061,7 @@ def get_catalogs(
10621061
)
10631062
resp = self.make_request(self._client.GetCatalogs, req)
10641063

1065-
execute_response, is_direct_results = self._handle_execute_response(
1066-
resp, cursor
1067-
)
1064+
execute_response, has_more_rows = self._handle_execute_response(resp, cursor)
10681065

10691066
t_row_set = None
10701067
if resp.directResults and resp.directResults.resultSet:
@@ -1080,8 +1077,7 @@ def get_catalogs(
10801077
t_row_set=t_row_set,
10811078
max_download_threads=self.max_download_threads,
10821079
ssl_options=self._ssl_options,
1083-
is_direct_results=is_direct_results,
1084-
session_id_hex=self._session_id_hex,
1080+
has_more_rows=has_more_rows,
10851081
)
10861082

10871083
def get_schemas(
@@ -1109,9 +1105,7 @@ def get_schemas(
11091105
)
11101106
resp = self.make_request(self._client.GetSchemas, req)
11111107

1112-
execute_response, is_direct_results = self._handle_execute_response(
1113-
resp, cursor
1114-
)
1108+
execute_response, has_more_rows = self._handle_execute_response(resp, cursor)
11151109

11161110
t_row_set = None
11171111
if resp.directResults and resp.directResults.resultSet:
@@ -1127,8 +1121,7 @@ def get_schemas(
11271121
t_row_set=t_row_set,
11281122
max_download_threads=self.max_download_threads,
11291123
ssl_options=self._ssl_options,
1130-
is_direct_results=is_direct_results,
1131-
session_id_hex=self._session_id_hex,
1124+
has_more_rows=has_more_rows,
11321125
)
11331126

11341127
def get_tables(
@@ -1160,9 +1153,7 @@ def get_tables(
11601153
)
11611154
resp = self.make_request(self._client.GetTables, req)
11621155

1163-
execute_response, is_direct_results = self._handle_execute_response(
1164-
resp, cursor
1165-
)
1156+
execute_response, has_more_rows = self._handle_execute_response(resp, cursor)
11661157

11671158
t_row_set = None
11681159
if resp.directResults and resp.directResults.resultSet:
@@ -1178,8 +1169,7 @@ def get_tables(
11781169
t_row_set=t_row_set,
11791170
max_download_threads=self.max_download_threads,
11801171
ssl_options=self._ssl_options,
1181-
is_direct_results=is_direct_results,
1182-
session_id_hex=self._session_id_hex,
1172+
has_more_rows=has_more_rows,
11831173
)
11841174

11851175
def get_columns(
@@ -1211,9 +1201,7 @@ def get_columns(
12111201
)
12121202
resp = self.make_request(self._client.GetColumns, req)
12131203

1214-
execute_response, is_direct_results = self._handle_execute_response(
1215-
resp, cursor
1216-
)
1204+
execute_response, has_more_rows = self._handle_execute_response(resp, cursor)
12171205

12181206
t_row_set = None
12191207
if resp.directResults and resp.directResults.resultSet:
@@ -1229,8 +1217,7 @@ def get_columns(
12291217
t_row_set=t_row_set,
12301218
max_download_threads=self.max_download_threads,
12311219
ssl_options=self._ssl_options,
1232-
is_direct_results=is_direct_results,
1233-
session_id_hex=self._session_id_hex,
1220+
has_more_rows=has_more_rows,
12341221
)
12351222

12361223
def _handle_execute_response(self, resp, cursor):

src/databricks/sql/common/http.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
import threading
66
from dataclasses import dataclass
77
from contextlib import contextmanager
8-
from typing import Generator
8+
from typing import Generator, Optional
99
import logging
10+
from requests.adapters import HTTPAdapter
11+
from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType
1012

1113
logger = logging.getLogger(__name__)
1214

@@ -81,3 +83,70 @@ def execute(
8183

8284
def close(self):
8385
self.session.close()
86+
87+
88+
class TelemetryHTTPAdapter(HTTPAdapter):
89+
"""
90+
Custom HTTP adapter to prepare our DatabricksRetryPolicy before each request.
91+
This ensures the retry timer is started and the command type is set correctly,
92+
allowing the policy to manage its state for the duration of the request retries.
93+
"""
94+
95+
def send(self, request, **kwargs):
96+
self.max_retries.command_type = CommandType.OTHER
97+
self.max_retries.start_retry_timer()
98+
return super().send(request, **kwargs)
99+
100+
101+
class TelemetryHttpClient: # TODO: Unify all the http clients in the PySQL Connector
102+
"""Singleton HTTP client for sending telemetry data."""
103+
104+
_instance: Optional["TelemetryHttpClient"] = None
105+
_lock = threading.Lock()
106+
107+
TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT = 3
108+
TELEMETRY_RETRY_DELAY_MIN = 1.0
109+
TELEMETRY_RETRY_DELAY_MAX = 10.0
110+
TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION = 30.0
111+
112+
def __init__(self):
113+
"""Initializes the session and mounts the custom retry adapter."""
114+
retry_policy = DatabricksRetryPolicy(
115+
delay_min=self.TELEMETRY_RETRY_DELAY_MIN,
116+
delay_max=self.TELEMETRY_RETRY_DELAY_MAX,
117+
stop_after_attempts_count=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT,
118+
stop_after_attempts_duration=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION,
119+
delay_default=1.0,
120+
force_dangerous_codes=[],
121+
)
122+
adapter = TelemetryHTTPAdapter(max_retries=retry_policy)
123+
self.session = requests.Session()
124+
self.session.mount("https://", adapter)
125+
self.session.mount("http://", adapter)
126+
127+
@classmethod
128+
def get_instance(cls) -> "TelemetryHttpClient":
129+
"""Get the singleton instance of the TelemetryHttpClient."""
130+
if cls._instance is None:
131+
with cls._lock:
132+
if cls._instance is None:
133+
logger.debug("Initializing singleton TelemetryHttpClient")
134+
cls._instance = TelemetryHttpClient()
135+
return cls._instance
136+
137+
def post(self, url: str, **kwargs) -> requests.Response:
138+
"""
139+
Executes a POST request using the configured session.
140+
141+
This is a blocking call intended to be run in a background thread.
142+
"""
143+
logger.debug("Executing telemetry POST request to: %s", url)
144+
return self.session.post(url, **kwargs)
145+
146+
def close(self):
147+
"""Closes the underlying requests.Session."""
148+
logger.debug("Closing TelemetryHttpClient session.")
149+
self.session.close()
150+
# Clear the instance to allow for re-initialization if needed
151+
with TelemetryHttpClient._lock:
152+
TelemetryHttpClient._instance = None

src/databricks/sql/exc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import logging
33

44
logger = logging.getLogger(__name__)
5-
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
6-
75

86
### PEP-249 Mandated ###
97
# https://peps.python.org/pep-0249/#exceptions
@@ -22,6 +20,8 @@ def __init__(
2220

2321
error_name = self.__class__.__name__
2422
if session_id_hex:
23+
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
24+
2525
telemetry_client = TelemetryClientFactory.get_telemetry_client(
2626
session_id_hex
2727
)

src/databricks/sql/result_set.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343
command_id: CommandId,
4444
status: CommandState,
4545
has_been_closed_server_side: bool = False,
46-
is_direct_results: bool = False,
46+
has_more_rows: bool = False,
4747
results_queue=None,
4848
description: List[Tuple] = [],
4949
is_staging_operation: bool = False,
@@ -61,7 +61,7 @@ def __init__(
6161
:param command_id: The command ID
6262
:param status: The command status
6363
:param has_been_closed_server_side: Whether the command has been closed on the server
64-
:param is_direct_results: Whether the command has more rows
64+
:param has_more_rows: Whether the command has more rows
6565
:param results_queue: The results queue
6666
:param description: column description of the results
6767
:param is_staging_operation: Whether the command is a staging operation
@@ -76,7 +76,7 @@ def __init__(
7676
self.command_id = command_id
7777
self.status = status
7878
self.has_been_closed_server_side = has_been_closed_server_side
79-
self.is_direct_results = is_direct_results
79+
self.has_more_rows = has_more_rows
8080
self.results = results_queue
8181
self._is_staging_operation = is_staging_operation
8282
self.lz4_compressed = lz4_compressed
@@ -170,7 +170,11 @@ def close(self) -> None:
170170
been closed on the server for some other reason, issue a request to the server to close it.
171171
"""
172172
try:
173-
self.results.close()
173+
if self.results is not None:
174+
self.results.close()
175+
else:
176+
logger.warning("result set close: queue not initialized")
177+
174178
if (
175179
self.status != CommandState.CLOSED
176180
and not self.has_been_closed_server_side
@@ -193,14 +197,13 @@ def __init__(
193197
connection: Connection,
194198
execute_response: ExecuteResponse,
195199
thrift_client: ThriftDatabricksClient,
196-
session_id_hex: Optional[str],
197200
buffer_size_bytes: int = 104857600,
198201
arraysize: int = 10000,
199202
use_cloud_fetch: bool = True,
200203
t_row_set=None,
201204
max_download_threads: int = 10,
202205
ssl_options=None,
203-
is_direct_results: bool = True,
206+
has_more_rows: bool = True,
204207
):
205208
"""
206209
Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient.
@@ -215,13 +218,13 @@ def __init__(
215218
:param t_row_set: The TRowSet containing result data (if available)
216219
:param max_download_threads: Maximum number of download threads for cloud fetch
217220
:param ssl_options: SSL options for cloud fetch
218-
:param is_direct_results: Whether there are more rows to fetch
221+
:param has_more_rows: Whether there are more rows to fetch
219222
"""
220-
self.num_downloaded_chunks = 0
223+
self.num_chunks = 0
221224

222225
# Initialize ThriftResultSet-specific attributes
223226
self._use_cloud_fetch = use_cloud_fetch
224-
self.is_direct_results = is_direct_results
227+
self.has_more_rows = has_more_rows
225228

226229
# Build the results queue if t_row_set is provided
227230
results_queue = None
@@ -237,12 +240,12 @@ def __init__(
237240
lz4_compressed=execute_response.lz4_compressed,
238241
description=execute_response.description,
239242
ssl_options=ssl_options,
240-
session_id_hex=session_id_hex,
243+
session_id_hex=connection.get_session_id_hex(),
241244
statement_id=execute_response.command_id.to_hex_guid(),
242-
chunk_id=self.num_downloaded_chunks,
245+
chunk_id=self.num_chunks,
243246
)
244247
if t_row_set.resultLinks:
245-
self.num_downloaded_chunks += len(t_row_set.resultLinks)
248+
self.num_chunks += len(t_row_set.resultLinks)
246249

247250
# Call parent constructor with common attributes
248251
super().__init__(
@@ -253,7 +256,7 @@ def __init__(
253256
command_id=execute_response.command_id,
254257
status=execute_response.status,
255258
has_been_closed_server_side=execute_response.has_been_closed_server_side,
256-
is_direct_results=is_direct_results,
259+
has_more_rows=has_more_rows,
257260
results_queue=results_queue,
258261
description=execute_response.description,
259262
is_staging_operation=execute_response.is_staging_operation,
@@ -266,7 +269,7 @@ def __init__(
266269
self._fill_results_buffer()
267270

268271
def _fill_results_buffer(self):
269-
results, is_direct_results, result_links_count = self.backend.fetch_results(
272+
results, has_more_rows, result_links_count = self.backend.fetch_results(
270273
command_id=self.command_id,
271274
max_rows=self.arraysize,
272275
max_bytes=self.buffer_size_bytes,
@@ -275,11 +278,11 @@ def _fill_results_buffer(self):
275278
arrow_schema_bytes=self._arrow_schema_bytes,
276279
description=self.description,
277280
use_cloud_fetch=self._use_cloud_fetch,
278-
chunk_id=self.num_downloaded_chunks,
281+
chunk_id=self.num_chunks,
279282
)
280283
self.results = results
281-
self.is_direct_results = is_direct_results
282-
self.num_downloaded_chunks += result_links_count
284+
self.has_more_rows = has_more_rows
285+
self.num_chunks += result_links_count
283286

284287
def _convert_columnar_table(self, table):
285288
column_names = [c[0] for c in self.description]
@@ -326,7 +329,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
326329
while (
327330
n_remaining_rows > 0
328331
and not self.has_been_closed_server_side
329-
and self.is_direct_results
332+
and self.has_more_rows
330333
):
331334
self._fill_results_buffer()
332335
partial_results = self.results.next_n_rows(n_remaining_rows)
@@ -351,7 +354,7 @@ def fetchmany_columnar(self, size: int):
351354
while (
352355
n_remaining_rows > 0
353356
and not self.has_been_closed_server_side
354-
and self.is_direct_results
357+
and self.has_more_rows
355358
):
356359
self._fill_results_buffer()
357360
partial_results = self.results.next_n_rows(n_remaining_rows)
@@ -366,7 +369,7 @@ def fetchall_arrow(self) -> "pyarrow.Table":
366369
results = self.results.remaining_rows()
367370
self._next_row_index += results.num_rows
368371
partial_result_chunks = [results]
369-
while not self.has_been_closed_server_side and self.is_direct_results:
372+
while not self.has_been_closed_server_side and self.has_more_rows:
370373
self._fill_results_buffer()
371374
partial_results = self.results.remaining_rows()
372375
if isinstance(results, ColumnTable) and isinstance(
@@ -392,7 +395,7 @@ def fetchall_columnar(self):
392395
results = self.results.remaining_rows()
393396
self._next_row_index += results.num_rows
394397

395-
while not self.has_been_closed_server_side and self.is_direct_results:
398+
while not self.has_been_closed_server_side and self.has_more_rows:
396399
self._fill_results_buffer()
397400
partial_results = self.results.remaining_rows()
398401
results = self.merge_columnar(results, partial_results)

0 commit comments

Comments
 (0)