55import time
66import uuid
77import threading
8- from typing import List , Optional , Union , Any , TYPE_CHECKING
8+ from typing import List , Union , Any , TYPE_CHECKING
99
1010if TYPE_CHECKING :
1111 from databricks .sql .client import Cursor
12- from databricks .sql .result_set import ResultSet , ThriftResultSet
1312
1413from databricks .sql .thrift_api .TCLIService .ttypes import TOperationState
1514from databricks .sql .backend .types import (
1615 CommandState ,
1716 SessionId ,
1817 CommandId ,
1918 BackendType ,
19+ guid_to_hex_id ,
20+ ExecuteResponse ,
2021)
21- from databricks .sql .backend .utils import guid_to_hex_id
2222
2323try :
2424 import pyarrow
4242)
4343
4444from databricks .sql .utils import (
45- ExecuteResponse ,
45+ ResultSetQueueFactory ,
4646 _bound ,
4747 RequestErrorInfo ,
4848 NoRetryReason ,
5353)
5454from databricks .sql .types import SSLOptions
5555from databricks .sql .backend .databricks_client import DatabricksClient
56+ from databricks .sql .result_set import ResultSet , ThriftResultSet
5657
5758logger = logging .getLogger (__name__ )
5859
@@ -351,7 +352,6 @@ def make_request(self, method, request, retryable=True):
351352 Will stop retry attempts if total elapsed time + next retry delay would exceed
352353 _retry_stop_after_attempts_duration.
353354 """
354-
355355 # basic strategy: build range iterator rep'ing number of available
356356 # retries. bounds can be computed from there. iterate over it with
357357 # retries until success or final failure achieved.
@@ -797,23 +797,24 @@ def _results_message_to_execute_response(self, resp, operation_state):
797797
798798 command_id = CommandId .from_thrift_handle (resp .operationHandle )
799799
800+ status = CommandState .from_thrift_state (operation_state )
801+ if status is None :
802+ raise ValueError (f"Invalid operation state: { operation_state } " )
803+
800804 return ExecuteResponse (
801- arrow_queue = arrow_queue_opt ,
802- status = CommandState . from_thrift_state ( operation_state ) ,
803- has_been_closed_server_side = has_been_closed_server_side ,
805+ command_id = command_id ,
806+ status = status ,
807+ description = description ,
804808 has_more_rows = has_more_rows ,
809+ results_queue = arrow_queue_opt ,
810+ has_been_closed_server_side = has_been_closed_server_side ,
805811 lz4_compressed = lz4_compressed ,
806812 is_staging_operation = is_staging_operation ,
807- command_id = command_id ,
808- description = description ,
809- arrow_schema_bytes = schema_bytes ,
810- )
813+ ), schema_bytes
811814
812815 def get_execution_result (
813816 self , command_id : CommandId , cursor : "Cursor"
814817 ) -> "ResultSet" :
815- from databricks .sql .result_set import ThriftResultSet
816-
817818 thrift_handle = command_id .to_thrift_handle ()
818819 if not thrift_handle :
819820 raise ValueError ("Not a valid Thrift command ID" )
@@ -863,15 +864,14 @@ def get_execution_result(
863864 )
864865
865866 execute_response = ExecuteResponse (
866- arrow_queue = queue ,
867- status = CommandState . from_thrift_state ( resp .status ) ,
868- has_been_closed_server_side = False ,
867+ command_id = command_id ,
868+ status = resp .status ,
869+ description = description ,
869870 has_more_rows = has_more_rows ,
871+ results_queue = queue ,
872+ has_been_closed_server_side = False ,
870873 lz4_compressed = lz4_compressed ,
871874 is_staging_operation = is_staging_operation ,
872- command_id = command_id ,
873- description = description ,
874- arrow_schema_bytes = schema_bytes ,
875875 )
876876
877877 return ThriftResultSet (
@@ -881,6 +881,7 @@ def get_execution_result(
881881 buffer_size_bytes = cursor .buffer_size_bytes ,
882882 arraysize = cursor .arraysize ,
883883 use_cloud_fetch = cursor .connection .use_cloud_fetch ,
884+ arrow_schema_bytes = schema_bytes
884885 )
885886
886887 def _wait_until_command_done (self , op_handle , initial_operation_status_resp ):
@@ -909,10 +910,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
909910 poll_resp = self ._poll_for_status (thrift_handle )
910911 operation_state = poll_resp .operationState
911912 self ._check_command_not_in_error_or_closed_state (thrift_handle , poll_resp )
912- state = CommandState .from_thrift_state (operation_state )
913- if state is None :
914- raise ValueError (f"Unknown command state: { operation_state } " )
915- return state
913+ return CommandState .from_thrift_state (operation_state )
916914
917915 @staticmethod
918916 def _check_direct_results_for_error (t_spark_direct_results ):
@@ -947,8 +945,6 @@ def execute_command(
947945 async_op = False ,
948946 enforce_embedded_schema_correctness = False ,
949947 ) -> Union ["ResultSet" , None ]:
950- from databricks .sql .result_set import ThriftResultSet
951-
952948 thrift_handle = session_id .to_thrift_handle ()
953949 if not thrift_handle :
954950 raise ValueError ("Not a valid Thrift session ID" )
@@ -995,7 +991,7 @@ def execute_command(
995991 self ._handle_execute_response_async (resp , cursor )
996992 return None
997993 else :
998- execute_response = self ._handle_execute_response (resp , cursor )
994+ execute_response , arrow_schema_bytes = self ._handle_execute_response (resp , cursor )
999995
1000996 return ThriftResultSet (
1001997 connection = cursor .connection ,
@@ -1004,6 +1000,7 @@ def execute_command(
10041000 buffer_size_bytes = max_bytes ,
10051001 arraysize = max_rows ,
10061002 use_cloud_fetch = use_cloud_fetch ,
1003+ arrow_schema_bytes = arrow_schema_bytes
10071004 )
10081005
10091006 def get_catalogs (
@@ -1013,8 +1010,6 @@ def get_catalogs(
10131010 max_bytes : int ,
10141011 cursor : "Cursor" ,
10151012 ) -> "ResultSet" :
1016- from databricks .sql .result_set import ThriftResultSet
1017-
10181013 thrift_handle = session_id .to_thrift_handle ()
10191014 if not thrift_handle :
10201015 raise ValueError ("Not a valid Thrift session ID" )
@@ -1027,7 +1022,7 @@ def get_catalogs(
10271022 )
10281023 resp = self .make_request (self ._client .GetCatalogs , req )
10291024
1030- execute_response = self ._handle_execute_response (resp , cursor )
1025+ execute_response , arrow_schema_bytes = self ._handle_execute_response (resp , cursor )
10311026
10321027 return ThriftResultSet (
10331028 connection = cursor .connection ,
@@ -1036,6 +1031,7 @@ def get_catalogs(
10361031 buffer_size_bytes = max_bytes ,
10371032 arraysize = max_rows ,
10381033 use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1034+ arrow_schema_bytes = arrow_schema_bytes
10391035 )
10401036
10411037 def get_schemas (
@@ -1047,8 +1043,6 @@ def get_schemas(
10471043 catalog_name = None ,
10481044 schema_name = None ,
10491045 ) -> "ResultSet" :
1050- from databricks .sql .result_set import ThriftResultSet
1051-
10521046 thrift_handle = session_id .to_thrift_handle ()
10531047 if not thrift_handle :
10541048 raise ValueError ("Not a valid Thrift session ID" )
@@ -1063,7 +1057,7 @@ def get_schemas(
10631057 )
10641058 resp = self .make_request (self ._client .GetSchemas , req )
10651059
1066- execute_response = self ._handle_execute_response (resp , cursor )
1060+ execute_response , arrow_schema_bytes = self ._handle_execute_response (resp , cursor )
10671061
10681062 return ThriftResultSet (
10691063 connection = cursor .connection ,
@@ -1072,6 +1066,7 @@ def get_schemas(
10721066 buffer_size_bytes = max_bytes ,
10731067 arraysize = max_rows ,
10741068 use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1069+ arrow_schema_bytes = arrow_schema_bytes
10751070 )
10761071
10771072 def get_tables (
@@ -1085,8 +1080,6 @@ def get_tables(
10851080 table_name = None ,
10861081 table_types = None ,
10871082 ) -> "ResultSet" :
1088- from databricks .sql .result_set import ThriftResultSet
1089-
10901083 thrift_handle = session_id .to_thrift_handle ()
10911084 if not thrift_handle :
10921085 raise ValueError ("Not a valid Thrift session ID" )
@@ -1103,7 +1096,7 @@ def get_tables(
11031096 )
11041097 resp = self .make_request (self ._client .GetTables , req )
11051098
1106- execute_response = self ._handle_execute_response (resp , cursor )
1099+ execute_response , arrow_schema_bytes = self ._handle_execute_response (resp , cursor )
11071100
11081101 return ThriftResultSet (
11091102 connection = cursor .connection ,
@@ -1112,6 +1105,7 @@ def get_tables(
11121105 buffer_size_bytes = max_bytes ,
11131106 arraysize = max_rows ,
11141107 use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1108+ arrow_schema_bytes = arrow_schema_bytes
11151109 )
11161110
11171111 def get_columns (
@@ -1125,8 +1119,6 @@ def get_columns(
11251119 table_name = None ,
11261120 column_name = None ,
11271121 ) -> "ResultSet" :
1128- from databricks .sql .result_set import ThriftResultSet
1129-
11301122 thrift_handle = session_id .to_thrift_handle ()
11311123 if not thrift_handle :
11321124 raise ValueError ("Not a valid Thrift session ID" )
@@ -1143,7 +1135,7 @@ def get_columns(
11431135 )
11441136 resp = self .make_request (self ._client .GetColumns , req )
11451137
1146- execute_response = self ._handle_execute_response (resp , cursor )
1138+ execute_response , arrow_schema_bytes = self ._handle_execute_response (resp , cursor )
11471139
11481140 return ThriftResultSet (
11491141 connection = cursor .connection ,
@@ -1152,6 +1144,7 @@ def get_columns(
11521144 buffer_size_bytes = max_bytes ,
11531145 arraysize = max_rows ,
11541146 use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1147+ arrow_schema_bytes = arrow_schema_bytes
11551148 )
11561149
11571150 def _handle_execute_response (self , resp , cursor ):
@@ -1165,11 +1158,11 @@ def _handle_execute_response(self, resp, cursor):
11651158 resp .directResults and resp .directResults .operationStatus ,
11661159 )
11671160
1168- execute_response = self ._results_message_to_execute_response (
1161+ execute_response , arrow_schema_bytes = self ._results_message_to_execute_response (
11691162 resp , final_operation_state
11701163 )
1171- execute_response = execute_response . _replace ( command_id = command_id )
1172- return execute_response
1164+ execute_response . command_id = command_id
1165+ return execute_response , arrow_schema_bytes
11731166
11741167 def _handle_execute_response_async (self , resp , cursor ):
11751168 command_id = CommandId .from_thrift_handle (resp .operationHandle )
@@ -1230,7 +1223,7 @@ def cancel_command(self, command_id: CommandId) -> None:
12301223 if not thrift_handle :
12311224 raise ValueError ("Not a valid Thrift command ID" )
12321225
1233- logger .debug ("Cancelling command {}" .format (guid_to_hex_id ( command_id .guid ) ))
1226+ logger .debug ("Cancelling command {}" .format (command_id .guid ))
12341227 req = ttypes .TCancelOperationReq (thrift_handle )
12351228 self .make_request (self ._client .CancelOperation , req )
12361229
0 commit comments