11import logging
22from typing import Dict , Tuple , List
3+ from collections import deque
34
45import grpc
56import pyarrow
1213
1314logger = logging .getLogger (__name__ )
1415
16+ DEFAULT_BUFFER_SIZE_ROWS = 1000
17+
1518
1619def connect (** kwargs ):
1720 return Connection (** kwargs )
@@ -39,10 +42,10 @@ def __init__(self, **kwargs):
3942 logger .info ("Successfully opened session " + str (self .session_id .hex ()))
4043 self ._cursors = []
4144
42- def cursor (self ):
45+ def cursor (self , buffer_size_rows = DEFAULT_BUFFER_SIZE_ROWS ):
4346 if not self .open :
4447 raise Error ("Cannot create cursor from closed connection" )
45- cursor = Cursor (self )
48+ cursor = Cursor (self , buffer_size_rows )
4649 self ._cursors .append (cursor )
4750 return cursor
4851
@@ -56,12 +59,12 @@ def close(self):
5659
5760
5861class Cursor :
59- def __init__ (self , connection ):
62+ def __init__ (self , connection , buffer_size_rows = DEFAULT_BUFFER_SIZE_ROWS ):
6063 self .connection = connection
6164 self .description = None
6265 self .rowcount = - 1
6366 self .arraysize = 1
64-
67+ self . buffer_size_rows = buffer_size_rows
6568 self .active_result_set = None
6669 # Note that Cursor closed => active result set closed, but not vice versa
6770 self .open = True
@@ -70,9 +73,10 @@ def _response_to_result_set(self, execute_command_response, status):
7073 command_id = execute_command_response .command_id
7174 arrow_results = execute_command_response .results .arrow_ipc_stream
7275 has_been_closed_server_side = execute_command_response .closed
76+ number_of_valid_rows = execute_command_response .results .number_of_valid_rows
7377
7478 return ResultSet (self .connection , command_id , status , has_been_closed_server_side ,
75- arrow_results )
79+ arrow_results , number_of_valid_rows , self . buffer_size_rows )
7680
7781 def _close_and_clear_active_result_set (self ):
7882 try :
@@ -148,6 +152,20 @@ def fetchall(self):
148152 else :
149153 raise Error ("There is no active result set" )
150154
155+ def fetchone (self ):
156+ self ._check_not_closed ()
157+ if self .active_result_set :
158+ return self .active_result_set .fetchone ()
159+ else :
160+ raise Error ("There is no active result set" )
161+
162+ def fetchmany (self , n_rows ):
163+ self ._check_not_closed ()
164+ if self .active_result_set :
165+ return self .active_result_set .fetchmany (n_rows )
166+ else :
167+ raise Error ("There is no active result set" )
168+
151169 def close (self ):
152170 self .open = False
153171 if self .active_result_set :
@@ -160,18 +178,24 @@ def __init__(self,
160178 command_id ,
161179 status ,
162180 has_been_closed_server_side ,
163- arrow_ipc_stream = None ):
181+ arrow_ipc_stream = None ,
182+ number_of_valid_rows = None ,
183+ buffer_size_rows = DEFAULT_BUFFER_SIZE_ROWS ):
164184 self .connection = connection
165185 self .command_id = command_id
166186 self .status = status
167187 self .has_been_closed_server_side = has_been_closed_server_side
188+ self .buffer_size_rows = buffer_size_rows
168189
169190 assert (self .status not in [command_pb2 .PENDING , command_pb2 .RUNNING ])
170191
171192 if arrow_ipc_stream :
172- self .results = self ._deserialize_arrow_ipc_stream (arrow_ipc_stream )
193+ self .results = deque (
194+ self ._deserialize_arrow_ipc_stream (arrow_ipc_stream )[:number_of_valid_rows ])
195+ self .has_more_rows = False
173196 else :
174- self .results = None
197+ self .results = deque ()
198+ self .has_more_rows = True
175199
176200 def _deserialize_arrow_ipc_stream (self , ipc_stream ):
177201 # TODO: Proper results deserialization, taking into account the logical schema, convert
@@ -182,40 +206,77 @@ def _deserialize_arrow_ipc_stream(self, ipc_stream):
182206 list_repr = [[col [i ] for col in dict_repr .values ()] for i in range (n_rows )]
183207 return list_repr
184208
185- def _fetch_results (self ):
186- # TODO: Offsets, number of rows (SC-77872)
187-
209+ def _fetch_and_deserialize_results (self ):
188210 fetch_results_request = command_pb2 .FetchCommandResultsRequest (
189211 id = self .command_id ,
190212 options = command_pb2 .CommandResultOptions (
191- max_rows = 1 ,
213+ max_rows = self . buffer_size_rows ,
192214 include_metadata = True ,
193215 ))
194216
195- return self .connection .base_client .make_request (
217+ result_message = self .connection .base_client .make_request (
196218 self .connection .base_client .stub .FetchCommandResults , fetch_results_request ).results
197-
198- def fetchall (self ):
199- # TODO: Check that these results are in the right place (SC-77872)
200- if self .status == command_pb2 .SUCCESS :
201- return self .results
202- elif self .status in [command_pb2 .PENDING , command_pb2 .RUNNING ]:
203- # TODO: Pre-fetch results (SC-77868)
204- result_message = self ._fetch_results ()
205- self .results = self ._deserialize_arrow_ipc_stream (result_message .arrow_ipc_stream )
206- return self .results
207- elif self .status == command_pb2 .CLOSED :
219+ number_of_valid_rows = result_message .number_of_valid_rows
220+ # TODO: Make efficient with less copying (https://databricks.atlassian.net/browse/SC-77868)
221+ results = deque (
222+ self ._deserialize_arrow_ipc_stream (
223+ result_message .arrow_ipc_stream )[:number_of_valid_rows ])
224+ return results , result_message .has_more_rows
225+
226+ def _fill_results_buffer (self ):
227+ if self .status == command_pb2 .CLOSED :
208228 raise Error ("Can't fetch results on closed command %s" % self .command_id )
209229 elif self .status == command_pb2 .ERROR :
210230 raise DatabaseError ("Command %s failed" % self .command_id )
211231 else :
212- raise Error (
213- "Command %s is in an unrecognised state: %s" % (self .command_id , self .status ))
232+ results , has_more_rows = self ._fetch_and_deserialize_results ()
233+ self .results = results
234+ if not has_more_rows :
235+ self .has_more_rows = False
236+
237+ def _take_n_from_deque (self , deque , n ):
238+ arr = []
239+ for _ in range (n ):
240+ try :
241+ arr .append (deque .popleft ())
242+ except IndexError :
243+ break
244+ return arr
245+
246+ def fetchmany (self , n_rows ):
247+ # TODO: Make efficient with less copying
248+ if n_rows < 0 :
249+ raise ValueError ("n_rows argument for fetchmany is %s but must be >= 0" , n_rows )
250+ results = self ._take_n_from_deque (self .results , n_rows )
251+ n_remaining_rows = n_rows - len (results )
252+
253+ while n_remaining_rows > 0 and not self .has_been_closed_server_side and self .has_more_rows :
254+ self ._fill_results_buffer ()
255+ partial_results = self ._take_n_from_deque (self .results , n_remaining_rows )
256+ results += partial_results
257+ n_remaining_rows -= len (partial_results )
258+
259+ return results
260+
261+ def fetchone (self ):
262+ return self .fetchmany (1 )
263+
264+ def fetchall (self ):
265+ results = []
266+ while True :
267+ partial_results = self .fetchmany (self .buffer_size_rows )
268+ # TODO: What's the optimal sequence of sizes to fetch?
269+ results += partial_results
270+
271+ if len (partial_results ) == 0 :
272+ break
273+
274+ return results
214275
215276 def close (self ):
216277 try :
217278 if self .status != command_pb2 .CLOSED and not self .has_been_closed_server_side \
218- and not self .connection .closed :
279+ and self .connection .open :
219280 close_command_request = command_pb2 .CloseCommandRequest (id = self .command_id )
220281 self .connection .base_client .make_request (
221282 self .connection .base_client .stub .CloseCommand , close_command_request )
0 commit comments