Skip to content

Commit cbf45b6

Browse files
Niall Egansusodapop
authored andcommitted
Fetch Arrow methods
Adds initial focus on providing Arrow dataframes rather than deserialising completely to Python types. This PR adds fetch_arrow methods. Also, to make the client implementation simpler, the constructor will now always fetch some results from the server if it is not passed an initial result set. (Also some minor renaming from last PR, + changed fetchone a little to return None if there are no results as defined in PEP249) Note: this still doesn't work correctly for large results with multiple Arrow batches. I've added that functionality in a follow-up PR (where I've also added tests) Modified the `test_fetch` to use the new `ArrowQueue` class Author: Niall Egan <niall.egan@databricks.com>
1 parent 2d65181 commit cbf45b6

File tree

4 files changed

+226
-144
lines changed

4 files changed

+226
-144
lines changed

cmdexec/clients/python/command_exec_client.py

Lines changed: 80 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ def _response_to_result_set(self, execute_command_response, status):
7373
command_id = execute_command_response.command_id
7474
arrow_results = execute_command_response.results.arrow_ipc_stream
7575
has_been_closed_server_side = execute_command_response.closed
76-
number_of_valid_rows = execute_command_response.results.number_of_valid_rows
76+
num_valid_rows = execute_command_response.results.num_valid_rows
7777

7878
return ResultSet(self.connection, command_id, status, has_been_closed_server_side,
79-
arrow_results, number_of_valid_rows, self.buffer_size_rows)
79+
arrow_results, num_valid_rows, self.buffer_size_rows)
8080

8181
def _close_and_clear_active_result_set(self):
8282
try:
@@ -179,7 +179,7 @@ def __init__(self,
179179
status,
180180
has_been_closed_server_side,
181181
arrow_ipc_stream=None,
182-
number_of_valid_rows=None,
182+
num_valid_rows=None,
183183
buffer_size_rows=DEFAULT_BUFFER_SIZE_ROWS):
184184
self.connection = connection
185185
self.command_id = command_id
@@ -190,21 +190,14 @@ def __init__(self,
190190
assert (self.status not in [command_pb2.PENDING, command_pb2.RUNNING])
191191

192192
if arrow_ipc_stream:
193-
self.results = deque(
194-
self._deserialize_arrow_ipc_stream(arrow_ipc_stream)[:number_of_valid_rows])
193+
# In the case we are passed in an initial result set, the server has taken the
194+
# fast path and has no more rows to send
195+
self.results = ArrowQueue(
196+
pyarrow.ipc.open_stream(arrow_ipc_stream).read_all(), num_valid_rows)
195197
self.has_more_rows = False
196198
else:
197-
self.results = deque()
198-
self.has_more_rows = True
199-
200-
def _deserialize_arrow_ipc_stream(self, ipc_stream):
201-
# TODO: Proper results deserialization, taking into account the logical schema, convert
202-
# via pd df for efficiency (SC-77871)
203-
pyarrow_table = pyarrow.ipc.open_stream(ipc_stream).read_all()
204-
dict_repr = pyarrow_table.to_pydict()
205-
n_rows, n_cols = pyarrow_table.shape
206-
list_repr = [[col[i] for col in dict_repr.values()] for i in range(n_rows)]
207-
return list_repr
199+
# In this case, there are results waiting on the server so we fetch now for simplicity
200+
self._fill_results_buffer()
208201

209202
def _fetch_and_deserialize_results(self):
210203
fetch_results_request = command_pb2.FetchCommandResultsRequest(
@@ -216,11 +209,9 @@ def _fetch_and_deserialize_results(self):
216209

217210
result_message = self.connection.base_client.make_request(
218211
self.connection.base_client.stub.FetchCommandResults, fetch_results_request).results
219-
number_of_valid_rows = result_message.number_of_valid_rows
220-
# TODO: Make efficient with less copying (https://databricks.atlassian.net/browse/SC-77868)
221-
results = deque(
222-
self._deserialize_arrow_ipc_stream(
223-
result_message.arrow_ipc_stream)[:number_of_valid_rows])
212+
num_valid_rows = result_message.num_valid_rows
213+
arrow_table = pyarrow.ipc.open_stream(result_message.arrow_ipc_stream).read_all()
214+
results = ArrowQueue(arrow_table, num_valid_rows)
224215
return results, result_message.has_more_rows
225216

226217
def _fill_results_buffer(self):
@@ -231,49 +222,74 @@ def _fill_results_buffer(self):
231222
else:
232223
results, has_more_rows = self._fetch_and_deserialize_results()
233224
self.results = results
234-
if not has_more_rows:
235-
self.has_more_rows = False
236-
237-
def _take_n_from_deque(self, deque, n):
238-
arr = []
239-
for _ in range(n):
240-
try:
241-
arr.append(deque.popleft())
242-
except IndexError:
243-
break
244-
return arr
225+
self.has_more_rows = has_more_rows
245226

246-
def fetchmany(self, n_rows):
227+
@staticmethod
228+
def _convert_arrow_table(table):
229+
dict_repr = table.to_pydict()
230+
n_rows, n_cols = table.shape
231+
list_repr = [[col[i] for col in dict_repr.values()] for i in range(n_rows)]
232+
return list_repr
233+
234+
def fetchmany_arrow(self, n_rows):
235+
"""
236+
Fetch the next set of rows of a query result, returning a PyArrow table.
237+
An empty sequence is returned when no more rows are available.
238+
"""
247239
# TODO: Make efficient with less copying
248240
if n_rows < 0:
249241
raise ValueError("n_rows argument for fetchmany is %s but must be >= 0", n_rows)
250-
results = self._take_n_from_deque(self.results, n_rows)
251-
n_remaining_rows = n_rows - len(results)
242+
results = self.results.next_n_rows(n_rows)
243+
n_remaining_rows = n_rows - results.num_rows
252244

253245
while n_remaining_rows > 0 and not self.has_been_closed_server_side and self.has_more_rows:
254246
self._fill_results_buffer()
255-
partial_results = self._take_n_from_deque(self.results, n_remaining_rows)
256-
results += partial_results
257-
n_remaining_rows -= len(partial_results)
247+
partial_results = self.results.next_n_rows(n_remaining_rows)
248+
results = pyarrow.concat_tables([results, partial_results])
249+
n_remaining_rows -= partial_results.num_rows
250+
251+
return results
252+
253+
def fetchall_arrow(self):
254+
"""
255+
Fetch all (remaining) rows of a query result, returning them as a PyArrow table.
256+
"""
257+
results = self.fetchmany_arrow(self.buffer_size_rows)
258+
while self.has_more_rows:
259+
# TODO: What's the optimal sequence of sizes to fetch?
260+
results = pyarrow.concat_tables([results, self.fetchmany_arrow(self.buffer_size_rows)])
258261

259262
return results
260263

261264
def fetchone(self):
262-
return self.fetchmany(1)
265+
"""
266+
Fetch the next row of a query result set, returning a single sequence,
267+
or None when no more data is available.
268+
"""
269+
res = self._convert_arrow_table(self.fetchmany_arrow(1))
270+
if len(res) > 0:
271+
return res[0]
272+
else:
273+
return None
263274

264275
def fetchall(self):
265-
results = []
266-
while True:
267-
partial_results = self.fetchmany(self.buffer_size_rows)
268-
# TODO: What's the optimal sequence of sizes to fetch?
269-
results += partial_results
270-
271-
if len(partial_results) == 0:
272-
break
276+
"""
277+
Fetch all (remaining) rows of a query result, returning them as a list of lists.
278+
"""
279+
return self._convert_arrow_table(self.fetchall_arrow())
273280

274-
return results
281+
def fetchmany(self, n_rows):
282+
"""
283+
Fetch the next set of rows of a query result, returning a list of lists.
284+
An empty sequence is returned when no more rows are available.
285+
"""
286+
return self._convert_arrow_table(self.fetchmany_arrow(n_rows))
275287

276288
def close(self):
289+
"""
290+
Close the cursor. If the connection has not been closed, and the cursor has not already
291+
been closed on the server for some other reason, issue a request to the server to close it.
292+
"""
277293
try:
278294
if self.status != command_pb2.CLOSED and not self.has_been_closed_server_side \
279295
and self.connection.open:
@@ -285,6 +301,22 @@ def close(self):
285301
self.status = command_pb2.CLOSED
286302

287303

304+
class ArrowQueue:
305+
def __init__(self, arrow_table, n_valid_rows):
306+
self.cur_row_index = 0
307+
self.arrow_table = arrow_table
308+
self.n_valid_rows = n_valid_rows
309+
310+
def next_n_rows(self, num_rows):
311+
"""
312+
Get upto the next n rows of the Arrow dataframe
313+
"""
314+
length = min(num_rows, self.n_valid_rows - self.cur_row_index)
315+
slice = self.arrow_table.slice(self.cur_row_index, length)
316+
self.cur_row_index += slice.num_rows
317+
return slice
318+
319+
288320
class CmdExecBaseHttpClient:
289321
"""
290322
A thin wrapper around a gRPC channel that takes cares of headers etc.

0 commit comments

Comments
 (0)