@@ -73,10 +73,10 @@ def _response_to_result_set(self, execute_command_response, status):
7373 command_id = execute_command_response .command_id
7474 arrow_results = execute_command_response .results .arrow_ipc_stream
7575 has_been_closed_server_side = execute_command_response .closed
76- number_of_valid_rows = execute_command_response .results .number_of_valid_rows
76+ num_valid_rows = execute_command_response .results .num_valid_rows
7777
7878 return ResultSet (self .connection , command_id , status , has_been_closed_server_side ,
79- arrow_results , number_of_valid_rows , self .buffer_size_rows )
79+ arrow_results , num_valid_rows , self .buffer_size_rows )
8080
8181 def _close_and_clear_active_result_set (self ):
8282 try :
@@ -179,7 +179,7 @@ def __init__(self,
179179 status ,
180180 has_been_closed_server_side ,
181181 arrow_ipc_stream = None ,
182- number_of_valid_rows = None ,
182+ num_valid_rows = None ,
183183 buffer_size_rows = DEFAULT_BUFFER_SIZE_ROWS ):
184184 self .connection = connection
185185 self .command_id = command_id
@@ -190,21 +190,14 @@ def __init__(self,
190190 assert (self .status not in [command_pb2 .PENDING , command_pb2 .RUNNING ])
191191
192192 if arrow_ipc_stream :
193- self .results = deque (
194- self ._deserialize_arrow_ipc_stream (arrow_ipc_stream )[:number_of_valid_rows ])
193+ # In the case we are passed in an initial result set, the server has taken the
194+ # fast path and has no more rows to send
195+ self .results = ArrowQueue (
196+ pyarrow .ipc .open_stream (arrow_ipc_stream ).read_all (), num_valid_rows )
195197 self .has_more_rows = False
196198 else :
197- self .results = deque ()
198- self .has_more_rows = True
199-
200- def _deserialize_arrow_ipc_stream (self , ipc_stream ):
201- # TODO: Proper results deserialization, taking into account the logical schema, convert
202- # via pd df for efficiency (SC-77871)
203- pyarrow_table = pyarrow .ipc .open_stream (ipc_stream ).read_all ()
204- dict_repr = pyarrow_table .to_pydict ()
205- n_rows , n_cols = pyarrow_table .shape
206- list_repr = [[col [i ] for col in dict_repr .values ()] for i in range (n_rows )]
207- return list_repr
199+ # In this case, there are results waiting on the server so we fetch now for simplicity
200+ self ._fill_results_buffer ()
208201
209202 def _fetch_and_deserialize_results (self ):
210203 fetch_results_request = command_pb2 .FetchCommandResultsRequest (
@@ -216,11 +209,9 @@ def _fetch_and_deserialize_results(self):
216209
217210 result_message = self .connection .base_client .make_request (
218211 self .connection .base_client .stub .FetchCommandResults , fetch_results_request ).results
219- number_of_valid_rows = result_message .number_of_valid_rows
220- # TODO: Make efficient with less copying (https://databricks.atlassian.net/browse/SC-77868)
221- results = deque (
222- self ._deserialize_arrow_ipc_stream (
223- result_message .arrow_ipc_stream )[:number_of_valid_rows ])
212+ num_valid_rows = result_message .num_valid_rows
213+ arrow_table = pyarrow .ipc .open_stream (result_message .arrow_ipc_stream ).read_all ()
214+ results = ArrowQueue (arrow_table , num_valid_rows )
224215 return results , result_message .has_more_rows
225216
226217 def _fill_results_buffer (self ):
@@ -231,49 +222,74 @@ def _fill_results_buffer(self):
231222 else :
232223 results , has_more_rows = self ._fetch_and_deserialize_results ()
233224 self .results = results
234- if not has_more_rows :
235- self .has_more_rows = False
236-
237- def _take_n_from_deque (self , deque , n ):
238- arr = []
239- for _ in range (n ):
240- try :
241- arr .append (deque .popleft ())
242- except IndexError :
243- break
244- return arr
225+ self .has_more_rows = has_more_rows
245226
246- def fetchmany (self , n_rows ):
227+ @staticmethod
228+ def _convert_arrow_table (table ):
229+ dict_repr = table .to_pydict ()
230+ n_rows , n_cols = table .shape
231+ list_repr = [[col [i ] for col in dict_repr .values ()] for i in range (n_rows )]
232+ return list_repr
233+
234+ def fetchmany_arrow (self , n_rows ):
235+ """
236+ Fetch the next set of rows of a query result, returning a PyArrow table.
237+ An empty sequence is returned when no more rows are available.
238+ """
247239 # TODO: Make efficient with less copying
248240 if n_rows < 0 :
249241 raise ValueError ("n_rows argument for fetchmany is %s but must be >= 0" , n_rows )
250- results = self ._take_n_from_deque ( self . results , n_rows )
251- n_remaining_rows = n_rows - len ( results )
242+ results = self .results . next_n_rows ( n_rows )
243+ n_remaining_rows = n_rows - results . num_rows
252244
253245 while n_remaining_rows > 0 and not self .has_been_closed_server_side and self .has_more_rows :
254246 self ._fill_results_buffer ()
255- partial_results = self ._take_n_from_deque (self .results , n_remaining_rows )
256- results += partial_results
257- n_remaining_rows -= len (partial_results )
247+ partial_results = self .results .next_n_rows (n_remaining_rows )
248+ results = pyarrow .concat_tables ([results , partial_results ])
249+ n_remaining_rows -= partial_results .num_rows
250+
251+ return results
252+
253+ def fetchall_arrow (self ):
254+ """
255+ Fetch all (remaining) rows of a query result, returning them as a PyArrow table.
256+ """
257+ results = self .fetchmany_arrow (self .buffer_size_rows )
258+ while self .has_more_rows :
259+ # TODO: What's the optimal sequence of sizes to fetch?
260+ results = pyarrow .concat_tables ([results , self .fetchmany_arrow (self .buffer_size_rows )])
258261
259262 return results
260263
261264 def fetchone (self ):
262- return self .fetchmany (1 )
265+ """
266+ Fetch the next row of a query result set, returning a single sequence,
267+ or None when no more data is available.
268+ """
269+ res = self ._convert_arrow_table (self .fetchmany_arrow (1 ))
270+ if len (res ) > 0 :
271+ return res [0 ]
272+ else :
273+ return None
263274
264275 def fetchall (self ):
265- results = []
266- while True :
267- partial_results = self .fetchmany (self .buffer_size_rows )
268- # TODO: What's the optimal sequence of sizes to fetch?
269- results += partial_results
270-
271- if len (partial_results ) == 0 :
272- break
276+ """
277+ Fetch all (remaining) rows of a query result, returning them as a list of lists.
278+ """
279+ return self ._convert_arrow_table (self .fetchall_arrow ())
273280
274- return results
281+ def fetchmany (self , n_rows ):
282+ """
283+ Fetch the next set of rows of a query result, returning a list of lists.
284+ An empty sequence is returned when no more rows are available.
285+ """
286+ return self ._convert_arrow_table (self .fetchmany_arrow (n_rows ))
275287
276288 def close (self ):
289+ """
290+ Close the cursor. If the connection has not been closed, and the cursor has not already
291+ been closed on the server for some other reason, issue a request to the server to close it.
292+ """
277293 try :
278294 if self .status != command_pb2 .CLOSED and not self .has_been_closed_server_side \
279295 and self .connection .open :
@@ -285,6 +301,22 @@ def close(self):
285301 self .status = command_pb2 .CLOSED
286302
287303
304+ class ArrowQueue :
305+ def __init__ (self , arrow_table , n_valid_rows ):
306+ self .cur_row_index = 0
307+ self .arrow_table = arrow_table
308+ self .n_valid_rows = n_valid_rows
309+
310+ def next_n_rows (self , num_rows ):
311+ """
312+ Get upto the next n rows of the Arrow dataframe
313+ """
314+ length = min (num_rows , self .n_valid_rows - self .cur_row_index )
315+ slice = self .arrow_table .slice (self .cur_row_index , length )
316+ self .cur_row_index += slice .num_rows
317+ return slice
318+
319+
288320class CmdExecBaseHttpClient :
289321 """
290322 A thin wrapper around a gRPC channel that takes cares of headers etc.
0 commit comments