1313
1414logger = logging .getLogger (__name__ )
1515
16- DEFAULT_BUFFER_SIZE_ROWS = 1000
16+ DEFAULT_RESULT_BUFFER_SIZE_BYTES = 10485760
1717
1818
1919def connect (** kwargs ):
@@ -48,10 +48,11 @@ def __enter__(self):
4848 def __exit__ (self , exc_type , exc_value , traceback ):
4949 self .close ()
5050
51- def cursor (self , buffer_size_rows = DEFAULT_BUFFER_SIZE_ROWS ):
51+ def cursor (self , buffer_size_bytes = DEFAULT_RESULT_BUFFER_SIZE_BYTES ):
5252 if not self .open :
5353 raise Error ("Cannot create cursor from closed connection" )
54- cursor = Cursor (self , buffer_size_rows )
54+
55+ cursor = Cursor (self , buffer_size_bytes )
5556 self ._cursors .append (cursor )
5657 return cursor
5758
@@ -65,12 +66,12 @@ def close(self):
6566
6667
6768class Cursor :
68- def __init__ (self , connection , buffer_size_rows = DEFAULT_BUFFER_SIZE_ROWS ):
69+ def __init__ (self , connection , result_buffer_size_bytes = DEFAULT_RESULT_BUFFER_SIZE_BYTES ):
6970 self .connection = connection
7071 self .description = None
7172 self .rowcount = - 1
7273 self .arraysize = 1
73- self .buffer_size_rows = buffer_size_rows
74+ self .buffer_size_bytes = result_buffer_size_bytes
7475 self .active_result_set = None
7576 # Note that Cursor closed => active result set closed, but not vice versa
7677 self .open = True
@@ -92,10 +93,11 @@ def _response_to_result_set(self, execute_command_response, status):
9293 command_id = execute_command_response .command_id
9394 arrow_results = execute_command_response .results .arrow_ipc_stream
9495 has_been_closed_server_side = execute_command_response .closed
96+ has_more_rows = execute_command_response .results .has_more_rows
9597 num_valid_rows = execute_command_response .results .num_valid_rows
9698
9799 return ResultSet (self .connection , command_id , status , has_been_closed_server_side ,
98- arrow_results , num_valid_rows , self .buffer_size_rows )
100+ has_more_rows , arrow_results , num_valid_rows , self .buffer_size_bytes )
99101
100102 def _close_and_clear_active_result_set (self ):
101103 try :
@@ -127,15 +129,13 @@ def _poll_for_state(self, command_id):
127129
128130 def _wait_until_command_done (self , command_id , initial_status ):
129131 status = initial_status
130- print ("initial status: %s" % status )
131132 while status in [command_pb2 .PENDING , command_pb2 .RUNNING ]:
132133 resp = self ._poll_for_state (command_id )
133134 status = resp .status .state
134135 self ._check_response_for_error (resp , command_id )
135- print ("status is: %s" % status )
136136
137137 # TODO: Remove this sleep once we have long-polling on the server (SC-77653)
138- time .sleep (1 )
138+ time .sleep (0. 1 )
139139 return status
140140
141141 def execute (self , operation , query_params = None , metadata = None ):
@@ -197,23 +197,25 @@ def __init__(self,
197197 command_id ,
198198 status ,
199199 has_been_closed_server_side ,
200+ has_more_rows ,
200201 arrow_ipc_stream = None ,
201202 num_valid_rows = None ,
202- buffer_size_rows = DEFAULT_BUFFER_SIZE_ROWS ):
203+ result_buffer_size_bytes = DEFAULT_RESULT_BUFFER_SIZE_BYTES ):
203204 self .connection = connection
204205 self .command_id = command_id
205206 self .status = status
206207 self .has_been_closed_server_side = has_been_closed_server_side
207- self .buffer_size_rows = buffer_size_rows
208+ self .has_more_rows = has_more_rows
209+ self .buffer_size_bytes = result_buffer_size_bytes
210+ self ._row_index = 0
208211
209212 assert (self .status not in [command_pb2 .PENDING , command_pb2 .RUNNING ])
210213
211214 if arrow_ipc_stream :
212215 # In the case we are passed in an initial result set, the server has taken the
213216 # fast path and has no more rows to send
214217 self .results = ArrowQueue (
215- pyarrow .ipc .open_stream (arrow_ipc_stream ).read_all (), num_valid_rows )
216- self .has_more_rows = False
218+ pyarrow .ipc .open_stream (arrow_ipc_stream ).read_all (), num_valid_rows , 0 )
217219 else :
218220 # In this case, there are results waiting on the server so we fetch now for simplicity
219221 self ._fill_results_buffer ()
@@ -230,15 +232,17 @@ def _fetch_and_deserialize_results(self):
230232 fetch_results_request = command_pb2 .FetchCommandResultsRequest (
231233 id = self .command_id ,
232234 options = command_pb2 .CommandResultOptions (
233- max_rows = self .buffer_size_rows ,
235+ max_bytes = self .buffer_size_bytes ,
236+ row_offset = self ._row_index ,
234237 include_metadata = True ,
235238 ))
236239
237240 result_message = self .connection .base_client .make_request (
238241 self .connection .base_client .stub .FetchCommandResults , fetch_results_request ).results
239242 num_valid_rows = result_message .num_valid_rows
240243 arrow_table = pyarrow .ipc .open_stream (result_message .arrow_ipc_stream ).read_all ()
241- results = ArrowQueue (arrow_table , num_valid_rows )
244+ results = ArrowQueue (arrow_table , num_valid_rows ,
245+ self ._row_index - result_message .start_row_offset )
242246 return results , result_message .has_more_rows
243247
244248 def _fill_results_buffer (self ):
@@ -268,23 +272,29 @@ def fetchmany_arrow(self, n_rows):
268272 raise ValueError ("n_rows argument for fetchmany is %s but must be >= 0" , n_rows )
269273 results = self .results .next_n_rows (n_rows )
270274 n_remaining_rows = n_rows - results .num_rows
275+ self ._row_index += results .num_rows
271276
272277 while n_remaining_rows > 0 and not self .has_been_closed_server_side and self .has_more_rows :
273278 self ._fill_results_buffer ()
274279 partial_results = self .results .next_n_rows (n_remaining_rows )
275280 results = pyarrow .concat_tables ([results , partial_results ])
276281 n_remaining_rows -= partial_results .num_rows
282+ self ._row_index += partial_results .num_rows
277283
278284 return results
279285
280286 def fetchall_arrow (self ):
281287 """
282288 Fetch all (remaining) rows of a query result, returning them as a PyArrow table.
283289 """
284- results = self .fetchmany_arrow (self .buffer_size_rows )
285- while self .has_more_rows :
286- # TODO: What's the optimal sequence of sizes to fetch?
287- results = pyarrow .concat_tables ([results , self .fetchmany_arrow (self .buffer_size_rows )])
290+ results = self .results .remaining_rows ()
291+ self ._row_index += results .num_rows
292+
293+ while not self .has_been_closed_server_side and self .has_more_rows :
294+ self ._fill_results_buffer ()
295+ partial_results = self .results .remaining_rows ()
296+ results = pyarrow .concat_tables ([results , partial_results ])
297+ self ._row_index += partial_results .num_rows
288298
289299 return results
290300
@@ -293,6 +303,7 @@ def fetchone(self):
293303 Fetch the next row of a query result set, returning a single sequence,
294304 or None when no more data is available.
295305 """
306+ self ._row_index += 1
296307 res = self ._convert_arrow_table (self .fetchmany_arrow (1 ))
297308 if len (res ) > 0 :
298309 return res [0 ]
@@ -329,8 +340,8 @@ def close(self):
329340
330341
331342class ArrowQueue :
332- def __init__ (self , arrow_table , n_valid_rows ):
333- self .cur_row_index = 0
343+ def __init__ (self , arrow_table , n_valid_rows , start_row_index ):
344+ self .cur_row_index = start_row_index
334345 self .arrow_table = arrow_table
335346 self .n_valid_rows = n_valid_rows
336347
@@ -343,6 +354,11 @@ def next_n_rows(self, num_rows):
343354 self .cur_row_index += slice .num_rows
344355 return slice
345356
357+ def remaining_rows (self ):
358+ slice = self .arrow_table .slice (self .cur_row_index , self .n_valid_rows - self .cur_row_index )
359+ self .cur_row_index += slice .num_rows
360+ return slice
361+
346362
347363class CmdExecBaseHttpClient :
348364 """
@@ -352,7 +368,8 @@ class CmdExecBaseHttpClient:
352368 def __init__ (self , host : str , port : int , http_headers : List [Tuple [str , str ]]):
353369 self .host_url = host + ":" + str (port )
354370 self .http_headers = [(k .lower (), v ) for (k , v ) in http_headers ]
355- self .channel = grpc .insecure_channel (self .host_url )
371+ self .channel = grpc .insecure_channel (
372+ self .host_url , options = [('grpc.max_receive_message_length' , - 1 )])
356373 self .stub = SqlCommandServiceStub (self .channel )
357374
358375 def make_request (self , method , request ):
0 commit comments