55import grpc
66import pyarrow
77
8- import cmdexec . clients . python . api . messages_pb2 as command_pb2
9- from cmdexec . clients . python . api . sql_cmd_service_pb2_grpc import SqlCommandServiceStub
10- from cmdexec . clients . python . errors import OperationalError , InterfaceError , DatabaseError , Error
8+ from . errors import OperationalError , InterfaceError , DatabaseError , Error
9+ from . api import messages_pb2
10+ from . api . sql_cmd_service_pb2_grpc import SqlCommandServiceStub
1111
1212import time
1313
1616DEFAULT_RESULT_BUFFER_SIZE_BYTES = 10485760
1717
1818
19- def connect (** kwargs ):
20- return Connection (** kwargs )
21-
22-
2319class Connection :
2420 def __init__ (self , ** kwargs ):
2521 try :
@@ -29,7 +25,7 @@ def __init__(self, **kwargs):
2925 raise InterfaceError ("Please include arguments HOST and PORT in kwargs for Connection" )
3026
3127 self .base_client = CmdExecBaseHttpClient (self .host , self .port , kwargs .get ("metadata" , []))
32- open_session_request = command_pb2 .OpenSessionRequest (
28+ open_session_request = messages_pb2 .OpenSessionRequest (
3329 configuration = {},
3430 client_session_id = None ,
3531 session_info_fields = None ,
@@ -57,7 +53,7 @@ def cursor(self, buffer_size_bytes=DEFAULT_RESULT_BUFFER_SIZE_BYTES):
5753 return cursor
5854
5955 def close (self ):
60- close_session_request = command_pb2 .CloseSessionRequest (id = self .session_id )
56+ close_session_request = messages_pb2 .CloseSessionRequest (id = self .session_id )
6157 self .base_client .make_request (self .base_client .stub .CloseSession , close_session_request )
6258 self .open = False
6359
@@ -112,14 +108,14 @@ def _check_not_closed(self):
112108
113109 def _check_response_for_error (self , resp , command_id ):
114110 status = resp .status .state
115- if status == command_pb2 .ERROR :
111+ if status == messages_pb2 .ERROR :
116112 raise DatabaseError (
117113 "Command %s failed with error message %s" % (command_id , resp .status .error_message ))
118- elif status == command_pb2 .CLOSED :
114+ elif status == messages_pb2 .CLOSED :
119115 raise DatabaseError ("Command %s closed before results could be fetched" % command_id )
120116
121117 def _poll_for_state (self , command_id ):
122- get_status_request = command_pb2 .GetCommandStatusRequest (id = command_id )
118+ get_status_request = messages_pb2 .GetCommandStatusRequest (id = command_id )
123119
124120 resp = self .connection .base_client .make_request (
125121 self .connection .base_client .stub .GetCommandStatus , get_status_request )
@@ -129,7 +125,7 @@ def _poll_for_state(self, command_id):
129125
130126 def _wait_until_command_done (self , command_id , initial_status ):
131127 status = initial_status
132- while status in [command_pb2 .PENDING , command_pb2 .RUNNING ]:
128+ while status in [messages_pb2 .PENDING , messages_pb2 .RUNNING ]:
133129 resp = self ._poll_for_state (command_id )
134130 status = resp .status .state
135131 self ._check_response_for_error (resp , command_id )
@@ -143,7 +139,7 @@ def execute(self, operation, query_params=None, metadata=None):
143139 self ._close_and_clear_active_result_set ()
144140
145141 # Execute the command
146- execute_command_request = command_pb2 .ExecuteCommandRequest (
142+ execute_command_request = messages_pb2 .ExecuteCommandRequest (
147143 session_id = self .connection .session_id ,
148144 client_command_id = None ,
149145 command = operation ,
@@ -209,7 +205,7 @@ def __init__(self,
209205 self .buffer_size_bytes = result_buffer_size_bytes
210206 self ._row_index = 0
211207
212- assert (self .status not in [command_pb2 .PENDING , command_pb2 .RUNNING ])
208+ assert (self .status not in [messages_pb2 .PENDING , messages_pb2 .RUNNING ])
213209
214210 if arrow_ipc_stream :
215211 # In the case we are passed in an initial result set, the server has taken the
@@ -229,9 +225,9 @@ def __iter__(self):
229225 break
230226
231227 def _fetch_and_deserialize_results (self ):
232- fetch_results_request = command_pb2 .FetchCommandResultsRequest (
228+ fetch_results_request = messages_pb2 .FetchCommandResultsRequest (
233229 id = self .command_id ,
234- options = command_pb2 .CommandResultOptions (
230+ options = messages_pb2 .CommandResultOptions (
235231 max_bytes = self .buffer_size_bytes ,
236232 row_offset = self ._row_index ,
237233 include_metadata = True ,
@@ -246,9 +242,9 @@ def _fetch_and_deserialize_results(self):
246242 return results , result_message .has_more_rows
247243
248244 def _fill_results_buffer (self ):
249- if self .status == command_pb2 .CLOSED :
245+ if self .status == messages_pb2 .CLOSED :
250246 raise Error ("Can't fetch results on closed command %s" % self .command_id )
251- elif self .status == command_pb2 .ERROR :
247+ elif self .status == messages_pb2 .ERROR :
252248 raise DatabaseError ("Command %s failed" % self .command_id )
253249 else :
254250 results , has_more_rows = self ._fetch_and_deserialize_results ()
@@ -329,14 +325,14 @@ def close(self):
329325 been closed on the server for some other reason, issue a request to the server to close it.
330326 """
331327 try :
332- if self .status != command_pb2 .CLOSED and not self .has_been_closed_server_side \
328+ if self .status != messages_pb2 .CLOSED and not self .has_been_closed_server_side \
333329 and self .connection .open :
334- close_command_request = command_pb2 .CloseCommandRequest (id = self .command_id )
330+ close_command_request = messages_pb2 .CloseCommandRequest (id = self .command_id )
335331 self .connection .base_client .make_request (
336332 self .connection .base_client .stub .CloseCommand , close_command_request )
337333 finally :
338334 self .has_been_closed_server_side = True
339- self .status = command_pb2 .CLOSED
335+ self .status = messages_pb2 .CLOSED
340336
341337
342338class ArrowQueue :
0 commit comments