Skip to content

Commit 6d122c4

Browse files
committed
remove defaults, fix chunk id
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent e79c325 commit 6d122c4

File tree

13 files changed

+236
-114
lines changed

13 files changed

+236
-114
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
if TYPE_CHECKING:
1515
from databricks.sql.client import Cursor
1616
from databricks.sql.result_set import ResultSet
17+
from databricks.sql.telemetry.models.event import StatementType
1718

1819
from databricks.sql.backend.types import (
1920
CommandState,
@@ -832,7 +833,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
832833
return execute_response, is_direct_results
833834

834835
def get_execution_result(
835-
self, command_id: CommandId, cursor: "Cursor"
836+
self, command_id: CommandId, cursor: "Cursor", statement_type: StatementType
836837
) -> "ResultSet":
837838
thrift_handle = command_id.to_thrift_handle()
838839
if not thrift_handle:
@@ -900,6 +901,8 @@ def get_execution_result(
900901
max_download_threads=self.max_download_threads,
901902
ssl_options=self._ssl_options,
902903
is_direct_results=is_direct_results,
904+
session_id_hex=self._session_id_hex,
905+
statement_type=statement_type,
903906
)
904907

905908
def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
@@ -965,6 +968,7 @@ def execute_command(
965968
max_bytes: int,
966969
lz4_compression: bool,
967970
cursor: Cursor,
971+
statement_type: StatementType,
968972
use_cloud_fetch=True,
969973
parameters=[],
970974
async_op=False,
@@ -1018,11 +1022,9 @@ def execute_command(
10181022
self._handle_execute_response_async(resp, cursor)
10191023
return None
10201024
else:
1021-
(
1022-
execute_response,
1023-
is_direct_results,
1024-
statement_id,
1025-
) = self._handle_execute_response(resp, cursor)
1025+
execute_response, is_direct_results = self._handle_execute_response(
1026+
resp, cursor
1027+
)
10261028

10271029
t_row_set = None
10281030
if resp.directResults and resp.directResults.resultSet:
@@ -1040,7 +1042,7 @@ def execute_command(
10401042
ssl_options=self._ssl_options,
10411043
is_direct_results=is_direct_results,
10421044
session_id_hex=self._session_id_hex,
1043-
statement_id=statement_id,
1045+
statement_type=statement_type,
10441046
)
10451047

10461048
def get_catalogs(
@@ -1049,6 +1051,7 @@ def get_catalogs(
10491051
max_rows: int,
10501052
max_bytes: int,
10511053
cursor: "Cursor",
1054+
statement_type: StatementType,
10521055
) -> "ResultSet":
10531056
thrift_handle = session_id.to_thrift_handle()
10541057
if not thrift_handle:
@@ -1062,11 +1065,9 @@ def get_catalogs(
10621065
)
10631066
resp = self.make_request(self._client.GetCatalogs, req)
10641067

1065-
(
1066-
execute_response,
1067-
is_direct_results,
1068-
statement_id,
1069-
) = self._handle_execute_response(resp, cursor)
1068+
execute_response, is_direct_results = self._handle_execute_response(
1069+
resp, cursor
1070+
)
10701071

10711072
t_row_set = None
10721073
if resp.directResults and resp.directResults.resultSet:
@@ -1084,7 +1085,7 @@ def get_catalogs(
10841085
ssl_options=self._ssl_options,
10851086
is_direct_results=is_direct_results,
10861087
session_id_hex=self._session_id_hex,
1087-
statement_id=statement_id,
1088+
statement_id=statement_type,
10881089
)
10891090

10901091
def get_schemas(
@@ -1093,6 +1094,7 @@ def get_schemas(
10931094
max_rows: int,
10941095
max_bytes: int,
10951096
cursor: Cursor,
1097+
statement_type: StatementType,
10961098
catalog_name=None,
10971099
schema_name=None,
10981100
) -> "ResultSet":
@@ -1112,11 +1114,9 @@ def get_schemas(
11121114
)
11131115
resp = self.make_request(self._client.GetSchemas, req)
11141116

1115-
(
1116-
execute_response,
1117-
is_direct_results,
1118-
statement_id,
1119-
) = self._handle_execute_response(resp, cursor)
1117+
execute_response, is_direct_results = self._handle_execute_response(
1118+
resp, cursor
1119+
)
11201120

11211121
t_row_set = None
11221122
if resp.directResults and resp.directResults.resultSet:
@@ -1134,7 +1134,7 @@ def get_schemas(
11341134
ssl_options=self._ssl_options,
11351135
is_direct_results=is_direct_results,
11361136
session_id_hex=self._session_id_hex,
1137-
statement_id=statement_id,
1137+
statement_type=statement_type,
11381138
)
11391139

11401140
def get_tables(
@@ -1143,6 +1143,7 @@ def get_tables(
11431143
max_rows: int,
11441144
max_bytes: int,
11451145
cursor: Cursor,
1146+
statement_type: StatementType,
11461147
catalog_name=None,
11471148
schema_name=None,
11481149
table_name=None,
@@ -1166,11 +1167,9 @@ def get_tables(
11661167
)
11671168
resp = self.make_request(self._client.GetTables, req)
11681169

1169-
(
1170-
execute_response,
1171-
is_direct_results,
1172-
statement_id,
1173-
) = self._handle_execute_response(resp, cursor)
1170+
execute_response, is_direct_results = self._handle_execute_response(
1171+
resp, cursor
1172+
)
11741173

11751174
t_row_set = None
11761175
if resp.directResults and resp.directResults.resultSet:
@@ -1188,7 +1187,7 @@ def get_tables(
11881187
ssl_options=self._ssl_options,
11891188
is_direct_results=is_direct_results,
11901189
session_id_hex=self._session_id_hex,
1191-
statement_id=statement_id,
1190+
statement_type=statement_type,
11921191
)
11931192

11941193
def get_columns(
@@ -1197,6 +1196,7 @@ def get_columns(
11971196
max_rows: int,
11981197
max_bytes: int,
11991198
cursor: Cursor,
1199+
statement_type: StatementType,
12001200
catalog_name=None,
12011201
schema_name=None,
12021202
table_name=None,
@@ -1220,11 +1220,9 @@ def get_columns(
12201220
)
12211221
resp = self.make_request(self._client.GetColumns, req)
12221222

1223-
(
1224-
execute_response,
1225-
is_direct_results,
1226-
statement_id,
1227-
) = self._handle_execute_response(resp, cursor)
1223+
execute_response, is_direct_results = self._handle_execute_response(
1224+
resp, cursor
1225+
)
12281226

12291227
t_row_set = None
12301228
if resp.directResults and resp.directResults.resultSet:
@@ -1242,7 +1240,7 @@ def get_columns(
12421240
ssl_options=self._ssl_options,
12431241
is_direct_results=is_direct_results,
12441242
session_id_hex=self._session_id_hex,
1245-
statement_id=statement_id,
1243+
statement_type=statement_type,
12461244
)
12471245

12481246
def _handle_execute_response(self, resp, cursor):
@@ -1258,15 +1256,7 @@ def _handle_execute_response(self, resp, cursor):
12581256
resp.directResults and resp.directResults.operationStatus,
12591257
)
12601258

1261-
execute_response, is_direct_results = self._results_message_to_execute_response(
1262-
resp, final_operation_state
1263-
)
1264-
1265-
return (
1266-
execute_response,
1267-
is_direct_results,
1268-
cursor.active_command_id.to_hex_guid(),
1269-
)
1259+
return self._results_message_to_execute_response(resp, final_operation_state)
12701260

12711261
def _handle_execute_response_async(self, resp, cursor):
12721262
command_id = CommandId.from_thrift_handle(resp.operationHandle)
@@ -1285,8 +1275,9 @@ def fetch_results(
12851275
lz4_compressed: bool,
12861276
arrow_schema_bytes,
12871277
description,
1278+
statement_type,
1279+
chunk_id: int,
12881280
use_cloud_fetch=True,
1289-
statement_id=None,
12901281
):
12911282
thrift_handle = command_id.to_thrift_handle()
12921283
if not thrift_handle:
@@ -1324,10 +1315,16 @@ def fetch_results(
13241315
description=description,
13251316
ssl_options=self._ssl_options,
13261317
session_id_hex=self._session_id_hex,
1327-
statement_id=statement_id,
1318+
statement_id=command_id.to_hex_guid(),
1319+
statement_type=statement_type,
1320+
chunk_id=chunk_id,
13281321
)
13291322

1330-
return queue, resp.hasMoreRows
1323+
return (
1324+
queue,
1325+
resp.hasMoreRows,
1326+
len(resp.results.resultLinks) if resp.results.resultLinks else 0,
1327+
)
13311328

13321329
def cancel_command(self, command_id: CommandId) -> None:
13331330
thrift_handle = command_id.to_thrift_handle()

0 commit comments

Comments
 (0)