22from typing import Dict , Tuple , List
33
44import grpc
5- from google . protobuf import message
5+ import pyarrow
66
77import cmdexec .clients .python .sql_command_service_pb2 as command_pb2
88from cmdexec .clients .python .sql_command_service_pb2_grpc import SqlCommandServiceStub
9- from cmdexec .clients .python .errors import OperationalError , InterfaceError
9+ from cmdexec .clients .python .errors import OperationalError , InterfaceError , DatabaseError , Error
10+
11+ import time
1012
1113logger = logging .getLogger (__name__ )
1214
@@ -23,31 +25,199 @@ def __init__(self, **kwargs):
2325 except KeyError :
2426 raise InterfaceError ("Please include arguments HOST and PORT in kwargs for Connection" )
2527
26- self ._base_client = CmdExecBaseHttpClient (self .host , self .port , kwargs .get ("metadata" , []))
28+ self .base_client = CmdExecBaseHttpClient (self .host , self .port , kwargs .get ("metadata" , []))
2729 open_session_request = command_pb2 .OpenSessionRequest (
2830 configuration = {},
2931 client_session_id = None ,
3032 session_info_fields = None ,
3133 )
3234
33- try :
34- resp = self ._base_client .make_request (self ._base_client .stub .OpenSession ,
35- open_session_request )
36- self .session_id = resp .id
37- logger .info ("Successfully opened session " + str (self .session_id .hex ()))
38- except grpc .RpcError as error :
39- raise OperationalError ("Error during database connection" , error )
35+ resp = self .base_client .make_request (self .base_client .stub .OpenSession ,
36+ open_session_request )
37+ self .session_id = resp .id
38+ self .open = True
39+ logger .info ("Successfully opened session " + str (self .session_id .hex ()))
40+ self ._cursors = []
4041
4142 def cursor (self ):
42- pass
43+ if not self .open :
44+ raise Error ("Cannot create cursor from closed connection" )
45+ cursor = Cursor (self )
46+ self ._cursors .append (cursor )
47+ return cursor
4348
4449 def close (self ):
50+ close_session_request = command_pb2 .CloseSessionRequest (id = self .session_id )
51+ self .base_client .make_request (self .base_client .stub .CloseSession , close_session_request )
52+ self .open = False
53+
54+ for cursor in self ._cursors :
55+ cursor .close ()
56+
57+
58+ class Cursor :
59+ def __init__ (self , connection ):
60+ self .connection = connection
61+ self .description = None
62+ self .rowcount = - 1
63+ self .arraysize = 1
64+
65+ self .active_result_set = None
66+ # Note that Cursor closed => active result set closed, but not vice versa
67+ self .open = True
68+
69+ def _response_to_result_set (self , execute_command_response , status ):
70+ command_id = execute_command_response .command_id
71+ arrow_results = execute_command_response .results .arrow_ipc_stream
72+ has_been_closed_server_side = execute_command_response .closed
73+
74+ return ResultSet (self .connection , command_id , status , has_been_closed_server_side ,
75+ arrow_results )
76+
77+ def _close_and_clear_active_result_set (self ):
4578 try :
46- close_session_request = command_pb2 .CloseSessionRequest (id = self .session_id )
47- self ._base_client .make_request (self ._base_client .stub .CloseSession ,
48- close_session_request )
49- except grpc .RpcError as error :
50- raise OperationalError ("Error during database connection close" , error )
79+ if self .active_result_set :
80+ self .active_result_set .close ()
81+ finally :
82+ self .active_result_set = None
83+
84+ def _check_not_closed (self ):
85+ if not self .open :
86+ raise Error ("Attempting operation on closed cursor" )
87+
88+ def _check_response_for_error (self , resp , command_id ):
89+ status = resp .status .state
90+ if status == command_pb2 .ERROR :
91+ raise DatabaseError (
92+ "Command %s failed with error message %s" % (command_id , resp .status .error_message ))
93+ elif status == command_pb2 .CLOSED :
94+ raise DatabaseError ("Command %s closed before results could be fetched" % command_id )
95+
96+ def _poll_for_state (self , command_id ):
97+ get_status_request = command_pb2 .GetCommandStatusRequest (id = command_id )
98+
99+ resp = self .connection .base_client .make_request (
100+ self .connection .base_client .stub .GetCommandStatus , get_status_request )
101+
102+ logger .info ("Status for command %s is: %s" % (command_id , resp .status ))
103+ return resp
104+
105+ def _wait_until_command_done (self , command_id , initial_status ):
106+ status = initial_status
107+ print ("initial status: %s" % status )
108+ while status in [command_pb2 .PENDING , command_pb2 .RUNNING ]:
109+ resp = self ._poll_for_state (command_id )
110+ status = resp .status .state
111+ self ._check_response_for_error (resp , command_id )
112+ print ("status is: %s" % status )
113+
114+ # TODO: Remove this sleep once we have long-polling on the server (SC-77653)
115+ time .sleep (1 )
116+ return status
117+
118+ def execute (self , operation , query_params = None , metadata = None ):
119+ self ._check_not_closed ()
120+ self ._close_and_clear_active_result_set ()
121+
122+ # Execute the command
123+ execute_command_request = command_pb2 .ExecuteCommandRequest (
124+ session_id = self .connection .session_id ,
125+ client_command_id = None ,
126+ command = operation ,
127+ conf_overlay = None ,
128+ row_limit = None ,
129+ result_options = None ,
130+ )
131+
132+ execute_command_response = self .connection .base_client .make_request (
133+ self .connection .base_client .stub .ExecuteCommand , execute_command_request )
134+ initial_status = execute_command_response .status .state
135+ command_id = execute_command_response .command_id
136+
137+ self ._check_response_for_error (execute_command_response , command_id )
138+ final_status = self ._wait_until_command_done (command_id , initial_status )
139+ self .active_result_set = self ._response_to_result_set (execute_command_response ,
140+ final_status )
141+
142+ return self
143+
144+ def fetchall (self ):
145+ self ._check_not_closed ()
146+ if self .active_result_set :
147+ return self .active_result_set .fetchall ()
148+ else :
149+ raise Error ("There is no active result set" )
150+
151+ def close (self ):
152+ self .open = False
153+ if self .active_result_set :
154+ self ._close_and_clear_active_result_set ()
155+
156+
157+ class ResultSet :
158+ def __init__ (self ,
159+ connection ,
160+ command_id ,
161+ status ,
162+ has_been_closed_server_side ,
163+ arrow_ipc_stream = None ):
164+ self .connection = connection
165+ self .command_id = command_id
166+ self .status = status
167+ self .has_been_closed_server_side = has_been_closed_server_side
168+
169+ assert (self .status not in [command_pb2 .PENDING , command_pb2 .RUNNING ])
170+
171+ if arrow_ipc_stream :
172+ self .results = self ._deserialize_arrow_ipc_stream (arrow_ipc_stream )
173+ else :
174+ self .results = None
175+
176+ def _deserialize_arrow_ipc_stream (self , ipc_stream ):
177+ # TODO: Proper results deserialization, taking into account the logical schema (SC-77871)
178+ reader = pyarrow .ipc .open_stream (ipc_stream )
179+ return reader .read_all ().to_pandas ().values .tolist ()
180+
181+ def _fetch_results (self ):
182+ # TODO: Offsets, number of rows (SC-77872)
183+
184+ fetch_results_request = command_pb2 .FetchCommandResultsRequest (
185+ id = self .command_id ,
186+ options = command_pb2 .CommandResultOptions (
187+ max_rows = 1 ,
188+ include_metadata = True ,
189+ ))
190+
191+ return self .connection .base_client .make_request (
192+ self .connection .base_client .stub .FetchCommandResults , fetch_results_request ).results
193+
194+ def fetchall (self ):
195+ # TODO: Check that these results are in the right place (SC-77872)
196+ if self .status == command_pb2 .SUCCESS :
197+ return self .results
198+ elif self .status in [command_pb2 .PENDING , command_pb2 .RUNNING ]:
199+ # TODO: Pre-fetch results (SC-77868)
200+ result_message = self ._fetch_results ()
201+ self .results = self ._deserialize_arrow_ipc_stream (result_message .arrow_ipc_stream )
202+ return self .results
203+ elif self .status == command_pb2 .CLOSED :
204+ raise Error ("Can't fetch results on closed command %s" % self .command_id )
205+ elif self .status == command_pb2 .ERROR :
206+ raise DatabaseError ("Command %s failed" % self .command_id )
207+ else :
208+ raise Error (
209+ "Command %s is in an unrecognised state: %s" % (self .command_id , self .status ))
210+
211+ def close (self ):
212+ try :
213+ if self .status != command_pb2 .CLOSED and not self .has_been_closed_server_side \
214+ and not self .connection .closed :
215+ close_command_request = command_pb2 .CloseCommandRequest (id = self .command_id )
216+ self .connection .base_client .make_request (
217+ self .connection .base_client .stub .CloseCommand , close_command_request )
218+ finally :
219+ self .has_been_closed_server_side = True
220+ self .status = command_pb2 .CLOSED
51221
52222
53223class CmdExecBaseHttpClient :
@@ -68,4 +238,4 @@ def make_request(self, method, request):
68238 return response
69239 except grpc .RpcError as error :
70240 logger .error ("Received error during gRPC request: %s" , error )
71- raise error
241+ raise OperationalError ( "Error during gRPC request" , error )
0 commit comments