Skip to content

Commit 84d0d4e

Browse files
Niall Egansusodapop
authored andcommitted
Support and tests for large queries
Added support for large queries: - Still fetching results if there are more even if the fast path is taken - Fixed a bug: we should actually always specify the offset we want to fetch from - Setting the gRPC message size properly New DriverLocal test for large queries Author: Niall Egan <niall.egan@databricks.com>
1 parent 15e9ced commit 84d0d4e

File tree

3 files changed

+57
-29
lines changed

3 files changed

+57
-29
lines changed

cmdexec/clients/python/command_exec_client.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
logger = logging.getLogger(__name__)
1515

16-
DEFAULT_BUFFER_SIZE_ROWS = 1000
16+
DEFAULT_RESULT_BUFFER_SIZE_BYTES = 10485760
1717

1818

1919
def connect(**kwargs):
@@ -48,10 +48,11 @@ def __enter__(self):
4848
def __exit__(self, exc_type, exc_value, traceback):
4949
self.close()
5050

51-
def cursor(self, buffer_size_rows=DEFAULT_BUFFER_SIZE_ROWS):
51+
def cursor(self, buffer_size_bytes=DEFAULT_RESULT_BUFFER_SIZE_BYTES):
5252
if not self.open:
5353
raise Error("Cannot create cursor from closed connection")
54-
cursor = Cursor(self, buffer_size_rows)
54+
55+
cursor = Cursor(self, buffer_size_bytes)
5556
self._cursors.append(cursor)
5657
return cursor
5758

@@ -65,12 +66,12 @@ def close(self):
6566

6667

6768
class Cursor:
68-
def __init__(self, connection, buffer_size_rows=DEFAULT_BUFFER_SIZE_ROWS):
69+
def __init__(self, connection, result_buffer_size_bytes=DEFAULT_RESULT_BUFFER_SIZE_BYTES):
6970
self.connection = connection
7071
self.description = None
7172
self.rowcount = -1
7273
self.arraysize = 1
73-
self.buffer_size_rows = buffer_size_rows
74+
self.buffer_size_bytes = result_buffer_size_bytes
7475
self.active_result_set = None
7576
# Note that Cursor closed => active result set closed, but not vice versa
7677
self.open = True
@@ -92,10 +93,11 @@ def _response_to_result_set(self, execute_command_response, status):
9293
command_id = execute_command_response.command_id
9394
arrow_results = execute_command_response.results.arrow_ipc_stream
9495
has_been_closed_server_side = execute_command_response.closed
96+
has_more_rows = execute_command_response.results.has_more_rows
9597
num_valid_rows = execute_command_response.results.num_valid_rows
9698

9799
return ResultSet(self.connection, command_id, status, has_been_closed_server_side,
98-
arrow_results, num_valid_rows, self.buffer_size_rows)
100+
has_more_rows, arrow_results, num_valid_rows, self.buffer_size_bytes)
99101

100102
def _close_and_clear_active_result_set(self):
101103
try:
@@ -127,15 +129,13 @@ def _poll_for_state(self, command_id):
127129

128130
def _wait_until_command_done(self, command_id, initial_status):
129131
status = initial_status
130-
print("initial status: %s" % status)
131132
while status in [command_pb2.PENDING, command_pb2.RUNNING]:
132133
resp = self._poll_for_state(command_id)
133134
status = resp.status.state
134135
self._check_response_for_error(resp, command_id)
135-
print("status is: %s" % status)
136136

137137
# TODO: Remove this sleep once we have long-polling on the server (SC-77653)
138-
time.sleep(1)
138+
time.sleep(0.1)
139139
return status
140140

141141
def execute(self, operation, query_params=None, metadata=None):
@@ -197,23 +197,25 @@ def __init__(self,
197197
command_id,
198198
status,
199199
has_been_closed_server_side,
200+
has_more_rows,
200201
arrow_ipc_stream=None,
201202
num_valid_rows=None,
202-
buffer_size_rows=DEFAULT_BUFFER_SIZE_ROWS):
203+
result_buffer_size_bytes=DEFAULT_RESULT_BUFFER_SIZE_BYTES):
203204
self.connection = connection
204205
self.command_id = command_id
205206
self.status = status
206207
self.has_been_closed_server_side = has_been_closed_server_side
207-
self.buffer_size_rows = buffer_size_rows
208+
self.has_more_rows = has_more_rows
209+
self.buffer_size_bytes = result_buffer_size_bytes
210+
self._row_index = 0
208211

209212
assert (self.status not in [command_pb2.PENDING, command_pb2.RUNNING])
210213

211214
if arrow_ipc_stream:
212215
# In the case we are passed in an initial result set, the server has taken the
213216
# fast path and has no more rows to send
214217
self.results = ArrowQueue(
215-
pyarrow.ipc.open_stream(arrow_ipc_stream).read_all(), num_valid_rows)
216-
self.has_more_rows = False
218+
pyarrow.ipc.open_stream(arrow_ipc_stream).read_all(), num_valid_rows, 0)
217219
else:
218220
# In this case, there are results waiting on the server so we fetch now for simplicity
219221
self._fill_results_buffer()
@@ -230,15 +232,17 @@ def _fetch_and_deserialize_results(self):
230232
fetch_results_request = command_pb2.FetchCommandResultsRequest(
231233
id=self.command_id,
232234
options=command_pb2.CommandResultOptions(
233-
max_rows=self.buffer_size_rows,
235+
max_bytes=self.buffer_size_bytes,
236+
row_offset=self._row_index,
234237
include_metadata=True,
235238
))
236239

237240
result_message = self.connection.base_client.make_request(
238241
self.connection.base_client.stub.FetchCommandResults, fetch_results_request).results
239242
num_valid_rows = result_message.num_valid_rows
240243
arrow_table = pyarrow.ipc.open_stream(result_message.arrow_ipc_stream).read_all()
241-
results = ArrowQueue(arrow_table, num_valid_rows)
244+
results = ArrowQueue(arrow_table, num_valid_rows,
245+
self._row_index - result_message.start_row_offset)
242246
return results, result_message.has_more_rows
243247

244248
def _fill_results_buffer(self):
@@ -268,23 +272,29 @@ def fetchmany_arrow(self, n_rows):
268272
raise ValueError("n_rows argument for fetchmany is %s but must be >= 0", n_rows)
269273
results = self.results.next_n_rows(n_rows)
270274
n_remaining_rows = n_rows - results.num_rows
275+
self._row_index += results.num_rows
271276

272277
while n_remaining_rows > 0 and not self.has_been_closed_server_side and self.has_more_rows:
273278
self._fill_results_buffer()
274279
partial_results = self.results.next_n_rows(n_remaining_rows)
275280
results = pyarrow.concat_tables([results, partial_results])
276281
n_remaining_rows -= partial_results.num_rows
282+
self._row_index += partial_results.num_rows
277283

278284
return results
279285

280286
def fetchall_arrow(self):
281287
"""
282288
Fetch all (remaining) rows of a query result, returning them as a PyArrow table.
283289
"""
284-
results = self.fetchmany_arrow(self.buffer_size_rows)
285-
while self.has_more_rows:
286-
# TODO: What's the optimal sequence of sizes to fetch?
287-
results = pyarrow.concat_tables([results, self.fetchmany_arrow(self.buffer_size_rows)])
290+
results = self.results.remaining_rows()
291+
self._row_index += results.num_rows
292+
293+
while not self.has_been_closed_server_side and self.has_more_rows:
294+
self._fill_results_buffer()
295+
partial_results = self.results.remaining_rows()
296+
results = pyarrow.concat_tables([results, partial_results])
297+
self._row_index += partial_results.num_rows
288298

289299
return results
290300

@@ -293,6 +303,7 @@ def fetchone(self):
293303
Fetch the next row of a query result set, returning a single sequence,
294304
or None when no more data is available.
295305
"""
306+
self._row_index += 1
296307
res = self._convert_arrow_table(self.fetchmany_arrow(1))
297308
if len(res) > 0:
298309
return res[0]
@@ -329,8 +340,8 @@ def close(self):
329340

330341

331342
class ArrowQueue:
332-
def __init__(self, arrow_table, n_valid_rows):
333-
self.cur_row_index = 0
343+
def __init__(self, arrow_table, n_valid_rows, start_row_index):
344+
self.cur_row_index = start_row_index
334345
self.arrow_table = arrow_table
335346
self.n_valid_rows = n_valid_rows
336347

@@ -343,6 +354,11 @@ def next_n_rows(self, num_rows):
343354
self.cur_row_index += slice.num_rows
344355
return slice
345356

357+
def remaining_rows(self):
358+
slice = self.arrow_table.slice(self.cur_row_index, self.n_valid_rows - self.cur_row_index)
359+
self.cur_row_index += slice.num_rows
360+
return slice
361+
346362

347363
class CmdExecBaseHttpClient:
348364
"""
@@ -352,7 +368,8 @@ class CmdExecBaseHttpClient:
352368
def __init__(self, host: str, port: int, http_headers: List[Tuple[str, str]]):
353369
self.host_url = host + ":" + str(port)
354370
self.http_headers = [(k.lower(), v) for (k, v) in http_headers]
355-
self.channel = grpc.insecure_channel(self.host_url)
371+
self.channel = grpc.insecure_channel(
372+
self.host_url, options=[('grpc.max_receive_message_length', -1)])
356373
self.stub = SqlCommandServiceStub(self.channel)
357374

358375
def make_request(self, method, request):

cmdexec/clients/python/tests/test_fetches.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def make_arrow_ipc_stream(batch):
3232
@staticmethod
3333
def make_arrow_queue(batch):
3434
table = FetchTests.make_arrow_table(batch)
35-
queue = command_exec_client.ArrowQueue(table, len(batch))
35+
queue = command_exec_client.ArrowQueue(table, len(batch), 0)
3636
return queue
3737

3838
@staticmethod
@@ -45,7 +45,8 @@ def make_dummy_result_set_from_initial_results(initial_results):
4545
None,
4646
True,
4747
arrow_ipc_stream=arrow_ipc_stream,
48-
num_valid_rows=len(initial_results))
48+
num_valid_rows=len(initial_results),
49+
has_more_rows=False)
4950

5051
@staticmethod
5152
def make_dummy_result_set_from_batch_list(batch_list):
@@ -59,7 +60,7 @@ def _fetch_and_deserialize_results(self):
5960

6061
return results, batch_index < len(batch_list)
6162

62-
return SemiFakeResultSet(None, None, None, False)
63+
return SemiFakeResultSet(None, None, None, False, False)
6364

6465
def test_fetchmany_with_initial_results(self):
6566
# Fetch all in one go

cmdexec/clients/python/tests/tests.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(
100100
command_pb2.SUCCESS,
101101
False,
102102
arrow_ipc_stream=Mock(),
103-
num_valid_rows=0)
103+
num_valid_rows=0,
104+
has_more_rows=False)
104105
mock_connection.open = False
105106

106107
result_set.close()
@@ -115,9 +116,10 @@ def test_closing_result_set_hard_closes_commands(self, pyarrow_ipc_open_stream):
115116
mock_connection = Mock()
116117
mock_response = Mock()
117118
mock_response.id = b'\x22'
119+
mock_response.results.start_row_offset = 0
118120
mock_connection.base_client.make_request.return_value = mock_response
119-
result_set = command_exec_client.ResultSet(mock_connection, b'\x10', command_pb2.SUCCESS,
120-
False)
121+
result_set = command_exec_client.ResultSet(
122+
mock_connection, b'\x10', command_pb2.SUCCESS, False, has_more_rows=False)
121123
mock_connection.open = True
122124

123125
result_set.close()
@@ -171,7 +173,15 @@ def test_closed_cursor_doesnt_allow_operations(self):
171173

172174
@patch("pyarrow.ipc.open_stream")
173175
def test_negative_fetch_throws_exception(self, pyarrow_ipc_open_stream_mock):
174-
result_set = command_exec_client.ResultSet(Mock(), b'\x22', command_pb2.SUCCESS, Mock())
176+
mock_connection = Mock()
177+
mock_response = Mock()
178+
mock_response.id = b'\x22'
179+
mock_response.results.start_row_offset = 0
180+
mock_response.status.state = command_pb2.SUCCESS
181+
mock_connection.base_client.make_request.return_value = mock_response
182+
183+
result_set = command_exec_client.ResultSet(
184+
mock_connection, b'\x22', command_pb2.SUCCESS, Mock(), has_more_rows=False)
175185

176186
with self.assertRaises(ValueError) as e:
177187
result_set.fetchmany(-1)

0 commit comments

Comments
 (0)