Skip to content

Commit f599ebc

Browse files
committed
WORKING fetchall_arrow
1 parent c7492cc commit f599ebc

File tree

2 files changed

+156
-56
lines changed

2 files changed

+156
-56
lines changed

src/databricks/sql/client.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,7 +1456,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
14561456
results = self.results.next_n_rows(size)
14571457
n_remaining_rows = size - results.num_rows
14581458
self._next_row_index += results.num_rows
1459-
partial_result_chunks = [results]
1459+
# partial_result_chunks = [results]
14601460

14611461
TOTAL_SIZE = results.num_rows
14621462
while (
@@ -1467,12 +1467,13 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
14671467
print(f"TOTAL DATA ROWS {TOTAL_SIZE}")
14681468
self._fill_results_buffer()
14691469
partial_results = self.results.next_n_rows(n_remaining_rows)
1470-
partial_result_chunks.append(partial_results)
1470+
results.append(partial_results)
1471+
# partial_result_chunks.append(partial_results)
14711472
n_remaining_rows -= partial_results.num_rows
14721473
self._next_row_index += partial_results.num_rows
14731474
TOTAL_SIZE += partial_results.num_rows
14741475

1475-
return concat_chunked_tables(partial_result_chunks)
1476+
return results.to_arrow_table()
14761477

14771478

14781479

@@ -1506,29 +1507,29 @@ def fetchall_arrow(self) -> "pyarrow.Table":
15061507
results = self.results.remaining_rows()
15071508
self._next_row_index += results.num_rows
15081509

1509-
partial_result_chunks = [results]
1510+
# partial_result_chunks = [results]
15101511
print("Server side has more rows", self.has_more_rows)
15111512
TOTAL_SIZE = results.num_rows
15121513

15131514
while not self.has_been_closed_server_side and self.has_more_rows:
15141515
print(f"TOTAL DATA ROWS {TOTAL_SIZE}")
15151516
self._fill_results_buffer()
15161517
partial_results = self.results.remaining_rows()
1517-
partial_result_chunks.append(partial_results)
1518+
results.append(partial_results)
15181519
self._next_row_index += partial_results.num_rows
15191520
TOTAL_SIZE += partial_results.num_rows
15201521

1521-
results = concat_chunked_tables(partial_result_chunks)
1522+
# results = concat_chunked_tables(partial_result_chunks)
15221523

15231524
# If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table
15241525
# Valid only for metadata commands result set
1525-
if isinstance(results, ColumnTable) and pyarrow:
1526-
data = {
1527-
name: col
1528-
for name, col in zip(results.column_names, results.column_table)
1529-
}
1530-
return pyarrow.Table.from_pydict(data)
1531-
return results
1526+
# if isinstance(results, ColumnTable) and pyarrow:
1527+
# data = {
1528+
# name: col
1529+
# for name, col in zip(results.column_names, results.column_table)
1530+
# }
1531+
# return pyarrow.Table.from_pydict(data)
1532+
return results.to_arrow_table()
15321533

15331534
def fetchall_columnar(self):
15341535
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""

src/databricks/sql/utils.py

Lines changed: 142 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,23 @@ def num_columns(self):
123123

124124
def get_item(self, col_index, row_index):
125125
return self.column_table[col_index][row_index]
126+
127+
def append(self, other: ColumnTable):
128+
if self.column_names != other.column_names:
129+
raise ValueError("The columns in the results don't match")
130+
131+
merged_result = [
132+
self.column_table[i] + other.column_table[i]
133+
for i in range(self.num_columns)
134+
]
135+
return ColumnTable(merged_result, self.column_names)
136+
137+
def to_arrow_table(self):
138+
data = {
139+
name: col
140+
for name, col in zip(self.column_names, self.column_table)
141+
}
142+
return pyarrow.Table.from_pydict(data)
126143

127144
def slice(self, curr_index, length):
128145
sliced_column_table = [
@@ -138,10 +155,72 @@ def __eq__(self, other):
138155

139156

140157
class ArrowStreamTable:
141-
def __init__(self, arrow_stream, num_rows):
142-
self.arrow_stream = arrow_stream
158+
def __init__(self, record_batches, num_rows, column_description):
159+
self.record_batches = record_batches
143160
self.num_rows = num_rows
144-
161+
self.column_description = column_description
162+
self.curr_batch_index = 0
163+
164+
def append(self, other: ArrowStreamTable):
165+
if self.column_description != other.column_description:
166+
raise ValueError("ArrowStreamTable: Column descriptions do not match for the tables to be appended")
167+
168+
self.record_batches.extend(other.record_batches)
169+
self.num_rows += other.num_rows
170+
171+
def next_n_rows(self, req_num_rows: int):
172+
consumed_batches = []
173+
consumed_num_rows = 0
174+
while req_num_rows > 0 and self.record_batches:
175+
current = self.record_batches[0]
176+
if current.num_rows <= req_num_rows:
177+
consumed_batches.append(current)
178+
req_num_rows -= current.num_rows
179+
consumed_num_rows += current.num_rows
180+
self.num_rows -= current.num_rows
181+
self.record_batches.pop(0)
182+
else:
183+
consumed_batches.append(current.slice(0, req_num_rows))
184+
self.record_batches[0] = current.slice(req_num_rows)
185+
self.num_rows -= req_num_rows
186+
consumed_num_rows += req_num_rows
187+
req_num_rows = 0
188+
189+
return ArrowStreamTable(consumed_batches, consumed_num_rows, self.column_description)
190+
191+
192+
def convert_decimals_in_record_batch(self,batch: "pyarrow.RecordBatch") -> "pyarrow.RecordBatch":
193+
new_columns = []
194+
new_fields = []
195+
196+
for i, col in enumerate(batch.columns):
197+
field = batch.schema.field(i)
198+
199+
if self.column_description[i][1] == "decimal":
200+
precision, scale = self.column_description[i][4], self.column_description[i][5]
201+
assert scale is not None and precision is not None
202+
dtype = pyarrow.decimal128(precision, scale)
203+
204+
new_col = col.cast(dtype)
205+
new_field = field.with_type(dtype)
206+
207+
new_columns.append(new_col)
208+
new_fields.append(new_field)
209+
else:
210+
new_columns.append(col)
211+
new_fields.append(field)
212+
213+
new_schema = pyarrow.schema(new_fields)
214+
return pyarrow.RecordBatch.from_arrays(new_columns, schema=new_schema)
215+
216+
def to_arrow_table(self) -> "pyarrow.Table":
217+
def batch_generator():
218+
for batch in self.record_batches:
219+
yield self.convert_decimals_in_record_batch(batch)
220+
221+
return pyarrow.Table.from_batches(batch_generator())
222+
223+
145224
class ColumnQueue(ResultSetQueue):
146225
def __init__(self, column_table: ColumnTable):
147226
self.column_table = column_table
@@ -250,9 +329,9 @@ def __init__(
250329
)
251330

252331
self.table = self._create_next_table()
253-
self.table_row_index = 0
332+
# self.table_row_index = 0
254333

255-
def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
334+
def next_n_rows(self, num_rows: int):
256335
"""
257336
Get up to the next n rows of the cloud fetch Arrow dataframes.
258337
@@ -262,55 +341,62 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
262341
Returns:
263342
pyarrow.Table
264343
"""
344+
results = self._create_empty_table()
265345
if not self.table:
266346
logger.debug("CloudFetchQueue: no more rows available")
267347
# Return empty pyarrow table to cause retry of fetch
268-
return self._create_empty_table()
348+
return results
269349
logger.debug("CloudFetchQueue: trying to get {} next rows".format(num_rows))
270-
results = self.table.slice(0, 0)
271-
partial_result_chunks = [results]
350+
351+
# results = self.table.slice(0, 0)
352+
# partial_result_chunks = [results]
272353
while num_rows > 0 and self.table:
273354
# Get remaining of num_rows or the rest of the current table, whichever is smaller
274-
length = min(num_rows, self.table.num_rows - self.table_row_index)
275-
table_slice = self.table.slice(self.table_row_index, length)
276-
partial_result_chunks.append(table_slice)
277-
self.table_row_index += table_slice.num_rows
355+
length = min(num_rows, self.table.num_rows)
356+
nxt_result = self.table.next_n_rows(length)
357+
results.append(nxt_result)
358+
num_rows -= nxt_result.num_rows
359+
# table_slice = self.table.slice(self.table_row_index, length)
360+
# partial_result_chunks.append(table_slice)
361+
# self.table_row_index += table_slice.num_rows
278362

279363
# Replace current table with the next table if we are at the end of the current table
280-
if self.table_row_index == self.table.num_rows:
364+
if self.table.num_rows == 0:
281365
self.table = self._create_next_table()
282-
self.table_row_index = 0
283-
num_rows -= table_slice.num_rows
366+
# self.table_row_index = 0
367+
# num_rows -= table_slice.num_rows
284368

285369
logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows))
286-
return concat_chunked_tables(partial_result_chunks)
370+
return results
287371

288-
def remaining_rows(self) -> "pyarrow.Table":
372+
def remaining_rows(self):
289373
"""
290374
Get all remaining rows of the cloud fetch Arrow dataframes.
291375
292376
Returns:
293377
pyarrow.Table
294378
"""
379+
result = self._create_empty_table()
295380
if not self.table:
296381
# Return empty pyarrow table to cause retry of fetch
297-
return self._create_empty_table()
298-
results = self.table.slice(0, 0)
299-
partial_result_chunks = [results]
382+
return result
383+
# results = self.table.slice(0, 0)
384+
# result = self._create_empty_table()
385+
300386
print("remaining_rows call")
301387
print(f"self.table.num_rows - {self.table.num_rows}")
302388
while self.table:
303-
table_slice = self.table.slice(
304-
self.table_row_index, self.table.num_rows - self.table_row_index
305-
)
306-
partial_result_chunks.append(table_slice)
307-
self.table_row_index += table_slice.num_rows
389+
# table_slice = self.table.slice(
390+
# self.table_row_index, self.table.num_rows - self.table_row_index
391+
# )
392+
result.append(self.table)
393+
# self.table_row_index += table_slice.num_rows
308394
self.table = self._create_next_table()
309-
self.table_row_index = 0
310-
print(f"results.num_rows - {results.num_rows}")
311-
return concat_chunked_tables(partial_result_chunks)
395+
# self.table_row_index = 0
396+
print(f"result.num_rows - {result.num_rows}")
397+
return result
312398

313-
def _create_next_table(self) -> Union["pyarrow.Table", None]:
399+
def _create_next_table(self) -> ArrowStreamTable:
314400
logger.debug(
315401
"CloudFetchQueue: Trying to get downloaded file for row {}".format(
316402
self.start_row_index
@@ -328,32 +414,41 @@ def _create_next_table(self) -> Union["pyarrow.Table", None]:
328414
)
329415
# None signals no more Arrow tables can be built from the remaining handlers if any remain
330416
return None
331-
arrow_table = create_arrow_table_from_arrow_file(
332-
downloaded_file.file_bytes, self.description
333-
)
417+
418+
arrow_stream_table = ArrowStreamTable(
419+
list(pyarrow.ipc.open_stream(downloaded_file.file_bytes)),
420+
downloaded_file.row_count,
421+
self.description)
422+
# arrow_table = create_arrow_table_from_arrow_file(
423+
# downloaded_file.file_bytes, self.description
424+
# )
334425

335426
# The server rarely prepares the exact number of rows requested by the client in cloud fetch.
336427
# Subsequently, we drop the extraneous rows in the last file if more rows are retrieved than requested
337-
if arrow_table.num_rows > downloaded_file.row_count:
338-
arrow_table = arrow_table.slice(0, downloaded_file.row_count)
428+
# if arrow_table.num_rows > downloaded_file.row_count:
429+
# arrow_table = arrow_table.slice(0, downloaded_file.row_count)
339430

340431
# At this point, whether the file has extraneous rows or not, the arrow table should have the correct num rows
341-
assert downloaded_file.row_count == arrow_table.num_rows
342-
self.start_row_index += arrow_table.num_rows
432+
# assert downloaded_file.row_count == arrow_table.num_rows
433+
# self.start_row_index += arrow_table.num_rows
434+
self.start_row_index += arrow_stream_table.num_rows
343435

344436
logger.debug(
345437
"CloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}".format(
346-
arrow_table.num_rows, self.start_row_index
438+
arrow_stream_table.num_rows, self.start_row_index
347439
)
348440
)
349441

350442
print("_create_next_table")
351-
print(f"arrow_table.num_rows - {arrow_table.num_rows}")
352-
return arrow_table
443+
print(f"arrow_stream_table.num_rows - {arrow_stream_table.num_rows}")
444+
return arrow_stream_table
353445

354-
def _create_empty_table(self) -> "pyarrow.Table":
446+
def _create_empty_table(self) -> ArrowStreamTable:
355447
# Create a 0-row table with just the schema bytes
356-
return create_arrow_table_from_arrow_file(self.schema_bytes, self.description)
448+
return ArrowStreamTable(
449+
list(pyarrow.ipc.open_stream(self.schema_bytes)),
450+
0,
451+
self.description)
357452

358453

359454
ExecuteResponse = namedtuple(
@@ -612,7 +707,6 @@ def create_arrow_table_from_arrow_file(
612707
arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes)
613708
return convert_decimals_in_arrow_table(arrow_table, description)
614709

615-
616710
def convert_arrow_based_file_to_arrow_table(file_bytes: bytes):
617711
try:
618712
return pyarrow.ipc.open_stream(file_bytes).read_all()
@@ -779,12 +873,17 @@ def _create_python_tuple(t_col_value_wrapper):
779873
return tuple(result)
780874

781875

782-
def concat_chunked_tables(tables: List[Union["pyarrow.Table", ColumnTable]]) -> Union["pyarrow.Table", ColumnTable]:
876+
def concat_chunked_tables(tables: List[Union["pyarrow.Table", ColumnTable, ArrowStreamTable]]) -> Union["pyarrow.Table", ColumnTable, ArrowStreamTable]:
783877
if isinstance(tables[0], ColumnTable):
784878
base_table = tables[0]
785879
for table in tables[1:]:
786880
base_table = merge_columnar(base_table, table)
787881
return base_table
882+
elif isinstance(tables[0], ArrowStreamTable):
883+
base_table = tables[0]
884+
for table in tables[1:]:
885+
base_table = base_table.append(table)
886+
return base_table
788887
else:
789888
return pyarrow.concat_tables(tables)
790889

0 commit comments

Comments
 (0)