66from typing import Any , Dict , Tuple , List , Optional , Union , TYPE_CHECKING , Set
77
88from databricks .sql .backend .sea .models .base import ResultManifest
9+ from databricks .sql .backend .sea .models .responses import GetStatementResponse
910from databricks .sql .backend .sea .utils .constants import (
1011 ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP ,
1112 ResultFormat ,
@@ -323,7 +324,7 @@ def _extract_description_from_manifest(
323324 return columns
324325
325326 def _results_message_to_execute_response (
326- self , response : ExecuteStatementResponse
327+ self , response : Union [ ExecuteStatementResponse , GetStatementResponse ]
327328 ) -> ExecuteResponse :
328329 """
329330 Convert a SEA response to an ExecuteResponse and extract result data.
@@ -358,7 +359,9 @@ def _results_message_to_execute_response(
358359 return execute_response
359360
360361 def _response_to_result_set (
361- self , response : ExecuteStatementResponse , cursor : Cursor
362+ self ,
363+ response : Union [ExecuteStatementResponse , GetStatementResponse ],
364+ cursor : Cursor ,
362365 ) -> SeaResultSet :
363366 """
364367 Convert a SEA response to a SeaResultSet.
@@ -399,22 +402,24 @@ def _check_command_not_in_failed_or_closed_state(
399402
400403 def _wait_until_command_done (
401404 self , response : ExecuteStatementResponse
402- ) -> ExecuteStatementResponse :
405+ ) -> Union [ ExecuteStatementResponse , GetStatementResponse ] :
403406 """
404407 Wait until a command is done.
405408 """
406409
407- state = response .status .state
408- command_id = CommandId .from_sea_statement_id (response .statement_id )
410+ final_response : Union [ExecuteStatementResponse , GetStatementResponse ] = response
411+
412+ state = final_response .status .state
413+ command_id = CommandId .from_sea_statement_id (final_response .statement_id )
409414
410415 while state in [CommandState .PENDING , CommandState .RUNNING ]:
411416 time .sleep (self .POLL_INTERVAL_SECONDS )
412- response = self ._poll_query (command_id )
413- state = response .status .state
417+ final_response = self ._poll_query (command_id )
418+ state = final_response .status .state
414419
415420 self ._check_command_not_in_failed_or_closed_state (state , command_id )
416421
417- return response
422+ return final_response
418423
419424 def execute_command (
420425 self ,
@@ -516,12 +521,11 @@ def execute_command(
516521 if async_op :
517522 return None
518523
519- if response . status . state == CommandState . SUCCEEDED :
520- # if the response succeeded within the wait_timeout, return the results immediately
521- return self ._response_to_result_set (response , cursor )
524+ final_response : Union [ ExecuteStatementResponse , GetStatementResponse ] = response
525+ if response . status . state != CommandState . SUCCEEDED :
526+ final_response = self ._wait_until_command_done (response )
522527
523- response = self ._wait_until_command_done (response )
524- return self ._response_to_result_set (response , cursor )
528+ return self ._response_to_result_set (final_response , cursor )
525529
526530 def cancel_command (self , command_id : CommandId ) -> None :
527531 """
@@ -573,7 +577,7 @@ def close_command(self, command_id: CommandId) -> None:
573577 data = request .to_dict (),
574578 )
575579
576- def _poll_query (self , command_id : CommandId ) -> ExecuteStatementResponse :
580+ def _poll_query (self , command_id : CommandId ) -> GetStatementResponse :
577581 """
578582 Poll for the current command info.
579583 """
@@ -591,7 +595,7 @@ def _poll_query(self, command_id: CommandId) -> ExecuteStatementResponse:
591595 path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
592596 data = request .to_dict (),
593597 )
594- response = ExecuteStatementResponse .from_dict (response_data )
598+ response = GetStatementResponse .from_dict (response_data )
595599
596600 return response
597601
0 commit comments