Skip to content

Commit fd2602c

Browse files
NiallEgansusodapop
authored andcommitted
Change result fetching to use fetch.next
Elsewhere I found that calling fetch all() after a create temp view resulted in the following stack trace: ``` Traceback (most recent call last): 2 File "<stdin>", line 1, in <module> 3 File "/Users/niallegan/opt/miniconda3/lib/python3.7/site-packages/databricks/sql/client.py", line 318, in fetchall 4 return self.active_result_set.fetchall() 5 File "/Users/niallegan/opt/miniconda3/lib/python3.7/site-packages/databricks/sql/client.py", line 522, in fetchall 6 return self._convert_arrow_table(self.fetchall_arrow()) 7 File "/Users/niallegan/opt/miniconda3/lib/python3.7/site-packages/databricks/sql/client.py", line 496, in fetchall_arrow 8 results = self.results.remaining_rows() 9 File "/Users/niallegan/opt/miniconda3/lib/python3.7/site-packages/databricks/sql/utils.py", line 27, in remaining_rows 10 slice = self.arrow_table.slice(self.cur_row_index, self.n_valid_rows - self.cur_row_index) 11 File "pyarrow/table.pxi", line 1125, in pyarrow.lib.Table.slice 12IndexError: Offset must be non-negative ``` The problem was that the startRowOffset was after what we requested. However, it turns out that client side we should be ignoring and not setting `startRowOffset` since Thrift uses FETCH_NEXT by default. This PR changes to make `FETCH_NEXT` explicit and stop tracking the row offset. I also took this opportunity to improve unit test coverage for dealing with the arrow batches and row counts. * Manually running SELECT * from temp view test against prod * New smoke test for SELECT * from temp view * Increase unit test coverage - Did you add usage logs or metrics? Please mention them here. - Create dashboards or monitoring notebooks? Please link them here. - See http://go/obs/user for docs on our observability tools.
1 parent ecd02d8 commit fd2602c

File tree

7 files changed

+91
-17
lines changed

7 files changed

+91
-17
lines changed

cmdexec/clients/python/src/databricks/sql/client.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -423,11 +423,11 @@ def __init__(self,
423423
self.has_been_closed_server_side = execute_response.has_been_closed_server_side
424424
self.has_more_rows = execute_response.has_more_rows
425425
self.buffer_size_bytes = result_buffer_size_bytes
426-
self._row_index = 0
427426
self.arraysize = arraysize
428427
self.thrift_backend = thrift_backend
429428
self.description = execute_response.description
430429
self._arrow_schema = execute_response.arrow_schema
430+
self._next_row_index = 0
431431

432432
if execute_response.arrow_queue:
433433
# In this case the server has taken the fast path and returned an initial batch of
@@ -447,8 +447,12 @@ def __iter__(self):
447447

448448
def _fill_results_buffer(self):
449449
results, has_more_rows = self.thrift_backend.fetch_results(
450-
self.command_id, self.arraysize, self.buffer_size_bytes, self._row_index,
451-
self._arrow_schema, self.description)
450+
op_handle=self.command_id,
451+
max_rows=self.arraysize,
452+
max_bytes=self.buffer_size_bytes,
453+
expected_row_start_offset=self._next_row_index,
454+
arrow_schema=self._arrow_schema,
455+
description=self.description)
452456
self.results = results
453457
self.has_more_rows = has_more_rows
454458

@@ -468,27 +472,27 @@ def fetchmany_arrow(self, n_rows: int) -> pyarrow.Table:
468472
raise ValueError("n_rows argument for fetchmany is %s but must be >= 0", n_rows)
469473
results = self.results.next_n_rows(n_rows)
470474
n_remaining_rows = n_rows - results.num_rows
471-
self._row_index += results.num_rows
475+
self._next_row_index += results.num_rows
472476

473477
while n_remaining_rows > 0 and not self.has_been_closed_server_side and self.has_more_rows:
474478
self._fill_results_buffer()
475479
partial_results = self.results.next_n_rows(n_remaining_rows)
476480
results = pyarrow.concat_tables([results, partial_results])
477481
n_remaining_rows -= partial_results.num_rows
478-
self._row_index += partial_results.num_rows
482+
self._next_row_index += partial_results.num_rows
479483

480484
return results
481485

482486
def fetchall_arrow(self) -> pyarrow.Table:
483487
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
484488
results = self.results.remaining_rows()
485-
self._row_index += results.num_rows
489+
self._next_row_index += results.num_rows
486490

487491
while not self.has_been_closed_server_side and self.has_more_rows:
488492
self._fill_results_buffer()
489493
partial_results = self.results.remaining_rows()
490494
results = pyarrow.concat_tables([results, partial_results])
491-
self._row_index += partial_results.num_rows
495+
self._next_row_index += partial_results.num_rows
492496

493497
return results
494498

cmdexec/clients/python/src/databricks/sql/thrift_backend.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,8 @@ def _handle_execute_response(self, resp, cursor):
639639

640640
return self._results_message_to_execute_response(resp, final_operation_state)
641641

642-
def fetch_results(self, op_handle, max_rows, max_bytes, row_offset, arrow_schema, description):
642+
def fetch_results(self, op_handle, max_rows, max_bytes, expected_row_start_offset, arrow_schema,
643+
description):
643644
assert (op_handle is not None)
644645

645646
req = ttypes.TFetchResultsReq(
@@ -651,12 +652,14 @@ def fetch_results(self, op_handle, max_rows, max_bytes, row_offset, arrow_schema
651652
),
652653
maxRows=max_rows,
653654
maxBytes=max_bytes,
654-
startRowOffset=row_offset,
655-
)
655+
orientation=ttypes.TFetchOrientation.FETCH_NEXT)
656656

657657
resp = self.make_request(self._client.FetchResults, req)
658+
if resp.results.startRowOffset > expected_row_start_offset:
659+
logger.warning("Expected results to start from {} but they instead start at {}".format(
660+
expected_row_start_offset, resp.results.startRowOffset))
658661
arrow_results, n_rows = self._create_arrow_table(resp.results, arrow_schema, description)
659-
arrow_queue = ArrowQueue(arrow_results, n_rows, row_offset - resp.results.startRowOffset)
662+
arrow_queue = ArrowQueue(arrow_results, n_rows)
660663

661664
return arrow_queue, resp.hasMoreRows
662665

cmdexec/clients/python/src/databricks/sql/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
class ArrowQueue:
8-
def __init__(self, arrow_table: pyarrow.Table, n_valid_rows: int, start_row_index: int):
8+
def __init__(self, arrow_table: pyarrow.Table, n_valid_rows: int, start_row_index: int = 0):
99
"""
1010
A queue-like wrapper over an Arrow table
1111
@@ -20,6 +20,8 @@ def __init__(self, arrow_table: pyarrow.Table, n_valid_rows: int, start_row_inde
2020
def next_n_rows(self, num_rows: int) -> pyarrow.Table:
2121
"""Get upto the next n rows of the Arrow dataframe"""
2222
length = min(num_rows, self.n_valid_rows - self.cur_row_index)
23+
# Note that the table.slice API is not the same as Python's slice
24+
# The second argument should be length, not end index
2325
slice = self.arrow_table.slice(self.cur_row_index, length)
2426
self.cur_row_index += slice.num_rows
2527
return slice
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import unittest
2+
3+
import pyarrow as pa
4+
5+
from databricks.sql.utils import ArrowQueue
6+
7+
8+
class ArrowQueueSuite(unittest.TestCase):
9+
@staticmethod
10+
def make_arrow_table(batch):
11+
n_cols = len(batch[0]) if batch else 0
12+
schema = pa.schema({"col%s" % i: pa.uint32() for i in range(n_cols)})
13+
cols = [[batch[row][col] for row in range(len(batch))] for col in range(n_cols)]
14+
return pa.Table.from_pydict(dict(zip(schema.names, cols)), schema=schema)
15+
16+
def test_fetchmany_respects_n_rows(self):
17+
arrow_table = self.make_arrow_table([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]])
18+
aq = ArrowQueue(arrow_table, 3)
19+
self.assertEqual(aq.next_n_rows(2), self.make_arrow_table([[0, 1, 2], [3, 4, 5]]))
20+
self.assertEqual(aq.next_n_rows(2), self.make_arrow_table([[6, 7, 8]]))
21+
22+
def test_fetch_remaining_rows_respects_n_rows(self):
23+
arrow_table = self.make_arrow_table([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]])
24+
aq = ArrowQueue(arrow_table, 3)
25+
self.assertEqual(aq.next_n_rows(1), self.make_arrow_table([[0, 1, 2]]))
26+
self.assertEqual(aq.remaining_rows(), self.make_arrow_table([[3, 4, 5], [6, 7, 8]]))

cmdexec/clients/python/tests/test_fetches.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def make_arrow_table(batch):
2424
@staticmethod
2525
def make_arrow_queue(batch):
2626
_, table = FetchTests.make_arrow_table(batch)
27-
queue = ArrowQueue(table, len(batch), 0)
27+
queue = ArrowQueue(table, len(batch))
2828
return queue
2929

3030
@staticmethod
@@ -51,7 +51,8 @@ def make_dummy_result_set_from_initial_results(initial_results):
5151
def make_dummy_result_set_from_batch_list(batch_list):
5252
batch_index = 0
5353

54-
def fetch_results(op_handle, max_rows, max_bytes, row_offset, arrow_schema, description):
54+
def fetch_results(op_handle, max_rows, max_bytes, expected_row_start_offset, arrow_schema,
55+
description):
5556
nonlocal batch_index
5657
results = FetchTests.make_arrow_queue(batch_list[batch_index])
5758
batch_index += 1

cmdexec/clients/python/tests/test_thrift_backend.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,11 +498,48 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response(
498498
thrift_backend = self._make_fake_thrift_backend()
499499

500500
thrift_backend._handle_execute_response(execute_resp, Mock())
501-
_, has_more_rows_resp = thrift_backend.fetch_results(Mock(), 1, 1, 0, Mock(),
502-
Mock())
501+
_, has_more_rows_resp = thrift_backend.fetch_results(
502+
op_handle=Mock(),
503+
max_rows=1,
504+
max_bytes=1,
505+
expected_row_start_offset=0,
506+
arrow_schema=Mock(),
507+
description=Mock())
503508

504509
self.assertEqual(has_more_rows, has_more_rows_resp)
505510

511+
@patch("databricks.sql.thrift_backend.TCLIService.Client")
512+
def test_arrow_batches_row_count_are_respected(self, tcli_service_class):
513+
# make some semi-real arrow batches and check the number of rows is correct in the queue
514+
tcli_service_instance = tcli_service_class.return_value
515+
t_fetch_results_resp = ttypes.TFetchResultsResp(
516+
status=self.okay_status,
517+
hasMoreRows=False,
518+
results=ttypes.TRowSet(
519+
startRowOffset=0,
520+
rows=[],
521+
arrowBatches=[
522+
ttypes.TSparkArrowBatch(batch=bytearray(), rowCount=15) for _ in range(10)
523+
]))
524+
tcli_service_instance.FetchResults.return_value = t_fetch_results_resp
525+
schema = pyarrow.schema([
526+
pyarrow.field("column1", pyarrow.int32()),
527+
pyarrow.field("column2", pyarrow.string()),
528+
pyarrow.field("column3", pyarrow.float64()),
529+
pyarrow.field("column3", pyarrow.binary())
530+
])
531+
532+
thrift_backend = ThriftBackend("foobar", 443, "path", [])
533+
arrow_queue, has_more_results = thrift_backend.fetch_results(
534+
op_handle=Mock(),
535+
max_rows=1,
536+
max_bytes=1,
537+
expected_row_start_offset=0,
538+
arrow_schema=schema,
539+
description=MagicMock())
540+
541+
self.assertEqual(arrow_queue.n_valid_rows, 15 * 10)
542+
506543
@patch("databricks.sql.thrift_backend.TCLIService.Client")
507544
def test_execute_statement_calls_client_and_handle_execute_response(self, tcli_service_class):
508545
tcli_service_instance = tcli_service_class.return_value

cmdexec/clients/python/tests/tests.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from test_fetches import FetchTests
1414
from test_thrift_backend import ThriftBackendTestSuite
15+
from test_arrow_queue import ArrowQueueSuite
1516

1617

1718
class ClientTestSuite(unittest.TestCase):
@@ -342,7 +343,7 @@ def test_version_is_canonical(self):
342343
if __name__ == '__main__':
343344
suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
344345
loader = unittest.TestLoader()
345-
test_classes = [ClientTestSuite, FetchTests, ThriftBackendTestSuite]
346+
test_classes = [ClientTestSuite, FetchTests, ThriftBackendTestSuite, ArrowQueueSuite]
346347
suites_list = []
347348
for test_class in test_classes:
348349
suite = loader.loadTestsFromTestCase(test_class)

0 commit comments

Comments
 (0)