Skip to content

Commit 928e128

Browse files
committed
chunk download latency
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 6d4701f commit 928e128

File tree

7 files changed

+105
-25
lines changed

7 files changed

+105
-25
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import time
77
import threading
88
from typing import List, Optional, Union, Any, TYPE_CHECKING
9+
from uuid import UUID
910

1011
from databricks.sql.result_set import ThriftResultSet
1112

@@ -1021,7 +1022,7 @@ def execute_command(
10211022
self._handle_execute_response_async(resp, cursor)
10221023
return None
10231024
else:
1024-
execute_response, is_direct_results = self._handle_execute_response(
1025+
execute_response, is_direct_results, statement_id = self._handle_execute_response(
10251026
resp, cursor
10261027
)
10271028

@@ -1040,6 +1041,8 @@ def execute_command(
10401041
max_download_threads=self.max_download_threads,
10411042
ssl_options=self._ssl_options,
10421043
is_direct_results=is_direct_results,
1044+
session_id_hex=self._session_id_hex,
1045+
statement_id=statement_id,
10431046
)
10441047

10451048
def get_catalogs(
@@ -1061,7 +1064,7 @@ def get_catalogs(
10611064
)
10621065
resp = self.make_request(self._client.GetCatalogs, req)
10631066

1064-
execute_response, is_direct_results = self._handle_execute_response(
1067+
execute_response, is_direct_results, _ = self._handle_execute_response(
10651068
resp, cursor
10661069
)
10671070

@@ -1107,7 +1110,7 @@ def get_schemas(
11071110
)
11081111
resp = self.make_request(self._client.GetSchemas, req)
11091112

1110-
execute_response, is_direct_results = self._handle_execute_response(
1113+
execute_response, is_direct_results, _ = self._handle_execute_response(
11111114
resp, cursor
11121115
)
11131116

@@ -1157,7 +1160,7 @@ def get_tables(
11571160
)
11581161
resp = self.make_request(self._client.GetTables, req)
11591162

1160-
execute_response, is_direct_results = self._handle_execute_response(
1163+
execute_response, is_direct_results, _ = self._handle_execute_response(
11611164
resp, cursor
11621165
)
11631166

@@ -1207,7 +1210,7 @@ def get_columns(
12071210
)
12081211
resp = self.make_request(self._client.GetColumns, req)
12091212

1210-
execute_response, is_direct_results = self._handle_execute_response(
1213+
execute_response, is_direct_results, _ = self._handle_execute_response(
12111214
resp, cursor
12121215
)
12131216

@@ -1241,7 +1244,11 @@ def _handle_execute_response(self, resp, cursor):
12411244
resp.directResults and resp.directResults.operationStatus,
12421245
)
12431246

1244-
return self._results_message_to_execute_response(resp, final_operation_state)
1247+
execute_response, is_direct_results = self._results_message_to_execute_response(
1248+
resp, final_operation_state
1249+
)
1250+
1251+
return execute_response, is_direct_results, cursor.active_command_id.to_hex_guid()
12451252

12461253
def _handle_execute_response_async(self, resp, cursor):
12471254
command_id = CommandId.from_thrift_handle(resp.operationHandle)
@@ -1261,6 +1268,7 @@ def fetch_results(
12611268
arrow_schema_bytes,
12621269
description,
12631270
use_cloud_fetch=True,
1271+
statement_id=None,
12641272
):
12651273
thrift_handle = command_id.to_thrift_handle()
12661274
if not thrift_handle:
@@ -1297,6 +1305,8 @@ def fetch_results(
12971305
lz4_compressed=lz4_compressed,
12981306
description=description,
12991307
ssl_options=self._ssl_options,
1308+
session_id_hex=self._session_id_hex,
1309+
statement_id=statement_id
13001310
)
13011311

13021312
return queue, resp.hasMoreRows

src/databricks/sql/cloudfetch/download_manager.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22

33
from concurrent.futures import ThreadPoolExecutor, Future
4-
from typing import List, Union
4+
from typing import List, Union, Optional, Tuple
55

66
from databricks.sql.cloudfetch.downloader import (
77
ResultSetDownloadHandler,
@@ -22,24 +22,28 @@ def __init__(
2222
max_download_threads: int,
2323
lz4_compressed: bool,
2424
ssl_options: SSLOptions,
25+
session_id_hex: Optional[str] = None,
26+
statement_id: Optional[str] = None,
2527
):
26-
self._pending_links: List[TSparkArrowResultLink] = []
27-
for link in links:
28+
self._pending_links: List[Tuple[int, TSparkArrowResultLink]] = []
29+
for i, link in enumerate(links):
2830
if link.rowCount <= 0:
2931
continue
3032
logger.debug(
31-
"ResultFileDownloadManager: adding file link, start offset {}, row count: {}".format(
32-
link.startRowOffset, link.rowCount
33+
"ResultFileDownloadManager: adding file link, chunk id {}, start offset {}, row count: {}".format(
34+
i, link.startRowOffset, link.rowCount
3335
)
3436
)
35-
self._pending_links.append(link)
37+
self._pending_links.append((i, link))
3638

3739
self._download_tasks: List[Future[DownloadedFile]] = []
3840
self._max_download_threads: int = max_download_threads
3941
self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)
4042

4143
self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
4244
self._ssl_options = ssl_options
45+
self.session_id_hex = session_id_hex
46+
self.statement_id = statement_id
4347

4448
def get_next_downloaded_file(
4549
self, next_row_offset: int
@@ -89,14 +93,17 @@ def _schedule_downloads(self):
8993
while (len(self._download_tasks) < self._max_download_threads) and (
9094
len(self._pending_links) > 0
9195
):
92-
link = self._pending_links.pop(0)
96+
chunk_id, link = self._pending_links.pop(0)
9397
logger.debug(
94-
"- start: {}, row count: {}".format(link.startRowOffset, link.rowCount)
98+
"- chunk: {}, start: {}, row count: {}".format(chunk_id, link.startRowOffset, link.rowCount)
9599
)
96100
handler = ResultSetDownloadHandler(
97101
settings=self._downloadable_result_settings,
98102
link=link,
99103
ssl_options=self._ssl_options,
104+
chunk_id=chunk_id,
105+
session_id_hex=self.session_id_hex,
106+
statement_id=self.statement_id
100107
)
101108
task = self._thread_pool.submit(handler.run)
102109
self._download_tasks.append(task)

src/databricks/sql/cloudfetch/downloader.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from dataclasses import dataclass
3+
from typing import Optional
34

45
import requests
56
from requests.adapters import HTTPAdapter, Retry
@@ -9,6 +10,7 @@
910
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
1011
from databricks.sql.exc import Error
1112
from databricks.sql.types import SSLOptions
13+
from databricks.sql.telemetry.latency_logger import log_latency
1214

1315
logger = logging.getLogger(__name__)
1416

@@ -66,11 +68,18 @@ def __init__(
6668
settings: DownloadableResultSettings,
6769
link: TSparkArrowResultLink,
6870
ssl_options: SSLOptions,
71+
chunk_id: int,
72+
session_id_hex: Optional[str] = None,
73+
statement_id: Optional[str] = None,
6974
):
7075
self.settings = settings
7176
self.link = link
7277
self._ssl_options = ssl_options
78+
self.chunk_id = chunk_id
79+
self.session_id_hex = session_id_hex
80+
self.statement_id = statement_id
7381

82+
@log_latency()
7483
def run(self) -> DownloadedFile:
7584
"""
7685
Download the file described in the cloud fetch link.
@@ -80,8 +89,8 @@ def run(self) -> DownloadedFile:
8089
"""
8190

8291
logger.debug(
83-
"ResultSetDownloadHandler: starting file download, offset {}, row count {}".format(
84-
self.link.startRowOffset, self.link.rowCount
92+
"ResultSetDownloadHandler: starting file download, chunk id {}, offset {}, row count {}".format(
93+
self.chunk_id, self.link.startRowOffset, self.link.rowCount
8594
)
8695
)
8796

src/databricks/sql/result_set.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ def __init__(
198198
max_download_threads: int = 10,
199199
ssl_options=None,
200200
is_direct_results: bool = True,
201+
session_id_hex: Optional[str] = None,
202+
statement_id: Optional[str] = None,
201203
):
202204
"""
203205
Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient.
@@ -233,6 +235,8 @@ def __init__(
233235
lz4_compressed=execute_response.lz4_compressed,
234236
description=execute_response.description,
235237
ssl_options=ssl_options,
238+
session_id_hex=session_id_hex,
239+
statement_id=statement_id,
236240
)
237241

238242
# Call parent constructor with common attributes

src/databricks/sql/telemetry/latency_logger.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
SqlExecutionEvent,
88
)
99
from databricks.sql.telemetry.models.enums import ExecutionResultFormat, StatementType
10-
from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue
1110
from uuid import UUID
1211

1312
logger = logging.getLogger(__name__)
@@ -42,6 +41,9 @@ def get_execution_result(self):
4241
def get_retry_count(self):
4342
pass
4443

44+
def get_chunk_id(self):
45+
pass
46+
4547

4648
class CursorExtractor(TelemetryExtractor):
4749
"""
@@ -63,7 +65,8 @@ def get_is_compressed(self) -> bool:
6365
def get_execution_result(self) -> ExecutionResultFormat:
6466
if self.active_result_set is None:
6567
return ExecutionResultFormat.FORMAT_UNSPECIFIED
66-
68+
69+
from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue
6770
if isinstance(self.active_result_set.results, ColumnQueue):
6871
return ExecutionResultFormat.COLUMNAR_INLINE
6972
elif isinstance(self.active_result_set.results, CloudFetchQueue):
@@ -74,11 +77,14 @@ def get_execution_result(self) -> ExecutionResultFormat:
7477

7578
def get_retry_count(self) -> int:
7679
if (
77-
hasattr(self.thrift_backend, "retry_policy")
78-
and self.thrift_backend.retry_policy
80+
hasattr(self.backend, "retry_policy")
81+
and self.backend.retry_policy
7982
):
80-
return len(self.thrift_backend.retry_policy.history)
83+
return len(self.backend.retry_policy.history)
8184
return 0
85+
86+
def get_chunk_id(self):
87+
return None
8288

8389

8490
class ResultSetExtractor(TelemetryExtractor):
@@ -101,6 +107,7 @@ def get_is_compressed(self) -> bool:
101107
return self.lz4_compressed
102108

103109
def get_execution_result(self) -> ExecutionResultFormat:
110+
from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue
104111
if isinstance(self.results, ColumnQueue):
105112
return ExecutionResultFormat.COLUMNAR_INLINE
106113
elif isinstance(self.results, CloudFetchQueue):
@@ -116,7 +123,34 @@ def get_retry_count(self) -> int:
116123
):
117124
return len(self.thrift_backend.retry_policy.history)
118125
return 0
126+
127+
def get_chunk_id(self):
128+
return None
129+
130+
131+
class ResultSetDownloadHandlerExtractor(TelemetryExtractor):
132+
"""
133+
Telemetry extractor specialized for ResultSetDownloadHandler objects.
134+
"""
135+
def get_session_id_hex(self) -> Optional[str]:
136+
return self._obj.session_id_hex
137+
138+
def get_statement_id(self) -> Optional[str]:
139+
return self._obj.statement_id
140+
141+
def get_is_compressed(self) -> bool:
142+
return self._obj.settings.is_lz4_compressed
143+
144+
def get_execution_result(self) -> ExecutionResultFormat:
145+
return ExecutionResultFormat.EXTERNAL_LINKS
146+
147+
def get_retry_count(self) -> Optional[int]:
148+
# standard requests and urllib3 libraries don't expose retry count
149+
return None
119150

151+
def get_chunk_id(self) -> Optional[int]:
152+
return self._obj.chunk_id
153+
120154

121155
def get_extractor(obj):
122156
"""
@@ -133,12 +167,15 @@ def get_extractor(obj):
133167
TelemetryExtractor: A specialized extractor instance:
134168
- CursorExtractor for Cursor objects
135169
- ResultSetExtractor for ResultSet objects
170+
- ResultSetDownloadHandlerExtractor for ResultSetDownloadHandler objects
136171
- None for all other objects
137172
"""
138173
if obj.__class__.__name__ == "Cursor":
139174
return CursorExtractor(obj)
140175
elif obj.__class__.__name__ == "ResultSet":
141176
return ResultSetExtractor(obj)
177+
elif obj.__class__.__name__=="ResultSetDownloadHandler":
178+
return ResultSetDownloadHandlerExtractor(obj)
142179
else:
143180
logger.debug("No extractor found for %s", obj.__class__.__name__)
144181
return None
@@ -196,6 +233,7 @@ def _safe_call(func_to_call):
196233
duration_ms = int((end_time - start_time) * 1000)
197234

198235
extractor = get_extractor(self)
236+
print("function name", func.__name__, "latency", duration_ms, "session_id_hex", extractor.get_session_id_hex(), "statement_id", extractor.get_statement_id(), flush=True)
199237

200238
if extractor is not None:
201239
session_id_hex = _safe_call(extractor.get_session_id_hex)
@@ -205,7 +243,8 @@ def _safe_call(func_to_call):
205243
statement_type=statement_type,
206244
is_compressed=_safe_call(extractor.get_is_compressed),
207245
execution_result=_safe_call(extractor.get_execution_result),
208-
retry_count=_safe_call(extractor.get_retry_count),
246+
retry_count=extractor.get_retry_count(),
247+
chunk_id=_safe_call(extractor.get_chunk_id),
209248
)
210249

211250
telemetry_client = TelemetryClientFactory.get_telemetry_client(

src/databricks/sql/telemetry/models/event.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,14 @@ class SqlExecutionEvent(JsonSerializableMixin):
122122
is_compressed (bool): Whether the result is compressed
123123
execution_result (ExecutionResultFormat): Format of the execution result
124124
retry_count (int): Number of retry attempts made
125+
chunk_id (int): ID of the chunk if applicable
125126
"""
126127

127128
statement_type: StatementType
128129
is_compressed: bool
129130
execution_result: ExecutionResultFormat
130-
retry_count: int
131-
131+
retry_count: Optional[int]
132+
chunk_id: Optional[int]
132133

133134
@dataclass
134135
class TelemetryEvent(JsonSerializableMixin):

0 commit comments

Comments
 (0)