Skip to content

Commit aba27ab

Browse files
account for new exec resp
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 40f15a7 commit aba27ab

File tree

7 files changed

+96
-486
lines changed

7 files changed

+96
-486
lines changed

src/databricks/sql/backend/filters.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
)
1919

2020
if TYPE_CHECKING:
21-
from databricks.sql.result_set import ResultSet, SeaResultSet
21+
from databricks.sql.result_set import ResultSet
22+
23+
from databricks.sql.result_set import SeaResultSet
2224

2325
logger = logging.getLogger(__name__)
2426

src/databricks/sql/backend/thrift_backend.py

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,20 @@
55
import time
66
import uuid
77
import threading
8-
from typing import List, Optional, Union, Any, TYPE_CHECKING
8+
from typing import List, Union, Any, TYPE_CHECKING
99

1010
if TYPE_CHECKING:
1111
from databricks.sql.client import Cursor
12-
from databricks.sql.result_set import ResultSet, ThriftResultSet
1312

1413
from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState
1514
from databricks.sql.backend.types import (
1615
CommandState,
1716
SessionId,
1817
CommandId,
1918
BackendType,
19+
guid_to_hex_id,
20+
ExecuteResponse,
2021
)
21-
from databricks.sql.backend.utils import guid_to_hex_id
2222

2323
try:
2424
import pyarrow
@@ -42,7 +42,7 @@
4242
)
4343

4444
from databricks.sql.utils import (
45-
ExecuteResponse,
45+
ResultSetQueueFactory,
4646
_bound,
4747
RequestErrorInfo,
4848
NoRetryReason,
@@ -53,6 +53,7 @@
5353
)
5454
from databricks.sql.types import SSLOptions
5555
from databricks.sql.backend.databricks_client import DatabricksClient
56+
from databricks.sql.result_set import ResultSet, ThriftResultSet
5657

5758
logger = logging.getLogger(__name__)
5859

@@ -351,7 +352,6 @@ def make_request(self, method, request, retryable=True):
351352
Will stop retry attempts if total elapsed time + next retry delay would exceed
352353
_retry_stop_after_attempts_duration.
353354
"""
354-
355355
# basic strategy: build range iterator rep'ing number of available
356356
# retries. bounds can be computed from there. iterate over it with
357357
# retries until success or final failure achieved.
@@ -797,23 +797,24 @@ def _results_message_to_execute_response(self, resp, operation_state):
797797

798798
command_id = CommandId.from_thrift_handle(resp.operationHandle)
799799

800+
status = CommandState.from_thrift_state(operation_state)
801+
if status is None:
802+
raise ValueError(f"Invalid operation state: {operation_state}")
803+
800804
return ExecuteResponse(
801-
arrow_queue=arrow_queue_opt,
802-
status=CommandState.from_thrift_state(operation_state),
803-
has_been_closed_server_side=has_been_closed_server_side,
805+
command_id=command_id,
806+
status=status,
807+
description=description,
804808
has_more_rows=has_more_rows,
809+
results_queue=arrow_queue_opt,
810+
has_been_closed_server_side=has_been_closed_server_side,
805811
lz4_compressed=lz4_compressed,
806812
is_staging_operation=is_staging_operation,
807-
command_id=command_id,
808-
description=description,
809-
arrow_schema_bytes=schema_bytes,
810-
)
813+
), schema_bytes
811814

812815
def get_execution_result(
813816
self, command_id: CommandId, cursor: "Cursor"
814817
) -> "ResultSet":
815-
from databricks.sql.result_set import ThriftResultSet
816-
817818
thrift_handle = command_id.to_thrift_handle()
818819
if not thrift_handle:
819820
raise ValueError("Not a valid Thrift command ID")
@@ -863,15 +864,14 @@ def get_execution_result(
863864
)
864865

865866
execute_response = ExecuteResponse(
866-
arrow_queue=queue,
867-
status=CommandState.from_thrift_state(resp.status),
868-
has_been_closed_server_side=False,
867+
command_id=command_id,
868+
status=resp.status,
869+
description=description,
869870
has_more_rows=has_more_rows,
871+
results_queue=queue,
872+
has_been_closed_server_side=False,
870873
lz4_compressed=lz4_compressed,
871874
is_staging_operation=is_staging_operation,
872-
command_id=command_id,
873-
description=description,
874-
arrow_schema_bytes=schema_bytes,
875875
)
876876

877877
return ThriftResultSet(
@@ -881,6 +881,7 @@ def get_execution_result(
881881
buffer_size_bytes=cursor.buffer_size_bytes,
882882
arraysize=cursor.arraysize,
883883
use_cloud_fetch=cursor.connection.use_cloud_fetch,
884+
arrow_schema_bytes=schema_bytes
884885
)
885886

886887
def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
@@ -909,10 +910,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
909910
poll_resp = self._poll_for_status(thrift_handle)
910911
operation_state = poll_resp.operationState
911912
self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp)
912-
state = CommandState.from_thrift_state(operation_state)
913-
if state is None:
914-
raise ValueError(f"Unknown command state: {operation_state}")
915-
return state
913+
return CommandState.from_thrift_state(operation_state)
916914

917915
@staticmethod
918916
def _check_direct_results_for_error(t_spark_direct_results):
@@ -947,8 +945,6 @@ def execute_command(
947945
async_op=False,
948946
enforce_embedded_schema_correctness=False,
949947
) -> Union["ResultSet", None]:
950-
from databricks.sql.result_set import ThriftResultSet
951-
952948
thrift_handle = session_id.to_thrift_handle()
953949
if not thrift_handle:
954950
raise ValueError("Not a valid Thrift session ID")
@@ -995,7 +991,7 @@ def execute_command(
995991
self._handle_execute_response_async(resp, cursor)
996992
return None
997993
else:
998-
execute_response = self._handle_execute_response(resp, cursor)
994+
execute_response, arrow_schema_bytes = self._handle_execute_response(resp, cursor)
999995

1000996
return ThriftResultSet(
1001997
connection=cursor.connection,
@@ -1004,6 +1000,7 @@ def execute_command(
10041000
buffer_size_bytes=max_bytes,
10051001
arraysize=max_rows,
10061002
use_cloud_fetch=use_cloud_fetch,
1003+
arrow_schema_bytes=arrow_schema_bytes
10071004
)
10081005

10091006
def get_catalogs(
@@ -1013,8 +1010,6 @@ def get_catalogs(
10131010
max_bytes: int,
10141011
cursor: "Cursor",
10151012
) -> "ResultSet":
1016-
from databricks.sql.result_set import ThriftResultSet
1017-
10181013
thrift_handle = session_id.to_thrift_handle()
10191014
if not thrift_handle:
10201015
raise ValueError("Not a valid Thrift session ID")
@@ -1027,7 +1022,7 @@ def get_catalogs(
10271022
)
10281023
resp = self.make_request(self._client.GetCatalogs, req)
10291024

1030-
execute_response = self._handle_execute_response(resp, cursor)
1025+
execute_response, arrow_schema_bytes = self._handle_execute_response(resp, cursor)
10311026

10321027
return ThriftResultSet(
10331028
connection=cursor.connection,
@@ -1036,6 +1031,7 @@ def get_catalogs(
10361031
buffer_size_bytes=max_bytes,
10371032
arraysize=max_rows,
10381033
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1034+
arrow_schema_bytes=arrow_schema_bytes
10391035
)
10401036

10411037
def get_schemas(
@@ -1047,8 +1043,6 @@ def get_schemas(
10471043
catalog_name=None,
10481044
schema_name=None,
10491045
) -> "ResultSet":
1050-
from databricks.sql.result_set import ThriftResultSet
1051-
10521046
thrift_handle = session_id.to_thrift_handle()
10531047
if not thrift_handle:
10541048
raise ValueError("Not a valid Thrift session ID")
@@ -1063,7 +1057,7 @@ def get_schemas(
10631057
)
10641058
resp = self.make_request(self._client.GetSchemas, req)
10651059

1066-
execute_response = self._handle_execute_response(resp, cursor)
1060+
execute_response, arrow_schema_bytes = self._handle_execute_response(resp, cursor)
10671061

10681062
return ThriftResultSet(
10691063
connection=cursor.connection,
@@ -1072,6 +1066,7 @@ def get_schemas(
10721066
buffer_size_bytes=max_bytes,
10731067
arraysize=max_rows,
10741068
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1069+
arrow_schema_bytes=arrow_schema_bytes
10751070
)
10761071

10771072
def get_tables(
@@ -1085,8 +1080,6 @@ def get_tables(
10851080
table_name=None,
10861081
table_types=None,
10871082
) -> "ResultSet":
1088-
from databricks.sql.result_set import ThriftResultSet
1089-
10901083
thrift_handle = session_id.to_thrift_handle()
10911084
if not thrift_handle:
10921085
raise ValueError("Not a valid Thrift session ID")
@@ -1103,7 +1096,7 @@ def get_tables(
11031096
)
11041097
resp = self.make_request(self._client.GetTables, req)
11051098

1106-
execute_response = self._handle_execute_response(resp, cursor)
1099+
execute_response, arrow_schema_bytes = self._handle_execute_response(resp, cursor)
11071100

11081101
return ThriftResultSet(
11091102
connection=cursor.connection,
@@ -1112,6 +1105,7 @@ def get_tables(
11121105
buffer_size_bytes=max_bytes,
11131106
arraysize=max_rows,
11141107
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1108+
arrow_schema_bytes=arrow_schema_bytes
11151109
)
11161110

11171111
def get_columns(
@@ -1125,8 +1119,6 @@ def get_columns(
11251119
table_name=None,
11261120
column_name=None,
11271121
) -> "ResultSet":
1128-
from databricks.sql.result_set import ThriftResultSet
1129-
11301122
thrift_handle = session_id.to_thrift_handle()
11311123
if not thrift_handle:
11321124
raise ValueError("Not a valid Thrift session ID")
@@ -1143,7 +1135,7 @@ def get_columns(
11431135
)
11441136
resp = self.make_request(self._client.GetColumns, req)
11451137

1146-
execute_response = self._handle_execute_response(resp, cursor)
1138+
execute_response, arrow_schema_bytes = self._handle_execute_response(resp, cursor)
11471139

11481140
return ThriftResultSet(
11491141
connection=cursor.connection,
@@ -1152,6 +1144,7 @@ def get_columns(
11521144
buffer_size_bytes=max_bytes,
11531145
arraysize=max_rows,
11541146
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1147+
arrow_schema_bytes=arrow_schema_bytes
11551148
)
11561149

11571150
def _handle_execute_response(self, resp, cursor):
@@ -1165,11 +1158,11 @@ def _handle_execute_response(self, resp, cursor):
11651158
resp.directResults and resp.directResults.operationStatus,
11661159
)
11671160

1168-
execute_response = self._results_message_to_execute_response(
1161+
execute_response, arrow_schema_bytes = self._results_message_to_execute_response(
11691162
resp, final_operation_state
11701163
)
1171-
execute_response = execute_response._replace(command_id=command_id)
1172-
return execute_response
1164+
execute_response.command_id = command_id
1165+
return execute_response, arrow_schema_bytes
11731166

11741167
def _handle_execute_response_async(self, resp, cursor):
11751168
command_id = CommandId.from_thrift_handle(resp.operationHandle)
@@ -1230,7 +1223,7 @@ def cancel_command(self, command_id: CommandId) -> None:
12301223
if not thrift_handle:
12311224
raise ValueError("Not a valid Thrift command ID")
12321225

1233-
logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid)))
1226+
logger.debug("Cancelling command {}".format(command_id.guid))
12341227
req = ttypes.TCancelOperationReq(thrift_handle)
12351228
self.make_request(self._client.CancelOperation, req)
12361229

tests/unit/test_fetches.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,13 @@ def make_dummy_result_set_from_initial_results(initial_results):
4343
rs = ThriftResultSet(
4444
connection=Mock(),
4545
execute_response=ExecuteResponse(
46+
command_id=None,
4647
status=None,
4748
has_been_closed_server_side=True,
4849
has_more_rows=False,
4950
description=Mock(),
5051
lz4_compressed=Mock(),
51-
command_id=None,
5252
results_queue=arrow_queue,
53-
arrow_schema_bytes=schema.serialize().to_pybytes(),
5453
is_staging_operation=False,
5554
),
5655
thrift_client=None,
@@ -89,6 +88,7 @@ def fetch_results(
8988
rs = ThriftResultSet(
9089
connection=Mock(),
9190
execute_response=ExecuteResponse(
91+
command_id=None,
9292
status=None,
9393
has_been_closed_server_side=False,
9494
has_more_rows=True,
@@ -97,9 +97,7 @@ def fetch_results(
9797
for col_id in range(num_cols)
9898
],
9999
lz4_compressed=Mock(),
100-
command_id=None,
101100
results_queue=None,
102-
arrow_schema_bytes=None,
103101
is_staging_operation=False,
104102
),
105103
thrift_client=mock_thrift_backend,

0 commit comments

Comments
 (0)