@@ -123,6 +123,23 @@ def num_columns(self):
123123
124124 def get_item (self , col_index , row_index ):
125125 return self .column_table [col_index ][row_index ]
126+
127+ def append (self , other : ColumnTable ):
128+ if self .column_names != other .column_names :
129+ raise ValueError ("The columns in the results don't match" )
130+
131+ merged_result = [
132+ self .column_table [i ] + other .column_table [i ]
133+ for i in range (self .num_columns )
134+ ]
135+ return ColumnTable (merged_result , self .column_names )
136+
137+ def to_arrow_table (self ):
138+ data = {
139+ name : col
140+ for name , col in zip (self .column_names , self .column_table )
141+ }
142+ return pyarrow .Table .from_pydict (data )
126143
127144 def slice (self , curr_index , length ):
128145 sliced_column_table = [
@@ -138,10 +155,72 @@ def __eq__(self, other):
138155
139156
140157class ArrowStreamTable :
141- def __init__ (self , arrow_stream , num_rows ):
142- self .arrow_stream = arrow_stream
158+ def __init__ (self , record_batches , num_rows , column_description ):
159+ self .record_batches = record_batches
143160 self .num_rows = num_rows
144-
161+ self .column_description = column_description
162+ self .curr_batch_index = 0
163+
164+ def append (self , other : ArrowStreamTable ):
165+ if self .column_description != other .column_description :
166+ raise ValueError ("ArrowStreamTable: Column descriptions do not match for the tables to be appended" )
167+
168+ self .record_batches .extend (other .record_batches )
169+ self .num_rows += other .num_rows
170+
171+ def next_n_rows (self , req_num_rows : int ):
172+ consumed_batches = []
173+ consumed_num_rows = 0
174+ while req_num_rows > 0 and self .record_batches :
175+ current = self .record_batches [0 ]
176+ if current .num_rows <= req_num_rows :
177+ consumed_batches .append (current )
178+ req_num_rows -= current .num_rows
179+ consumed_num_rows += current .num_rows
180+ self .num_rows -= current .num_rows
181+ self .record_batches .pop (0 )
182+ else :
183+ consumed_batches .append (current .slice (0 , req_num_rows ))
184+ self .record_batches [0 ] = current .slice (req_num_rows )
185+ self .num_rows -= req_num_rows
186+ consumed_num_rows += req_num_rows
187+ req_num_rows = 0
188+
189+ return ArrowStreamTable (consumed_batches , consumed_num_rows , self .column_description )
190+
191+
192+ def convert_decimals_in_record_batch (self ,batch : "pyarrow.RecordBatch" ) -> "pyarrow.RecordBatch" :
193+ new_columns = []
194+ new_fields = []
195+
196+ for i , col in enumerate (batch .columns ):
197+ field = batch .schema .field (i )
198+
199+ if self .column_description [i ][1 ] == "decimal" :
200+ precision , scale = self .column_description [i ][4 ], self .column_description [i ][5 ]
201+ assert scale is not None and precision is not None
202+ dtype = pyarrow .decimal128 (precision , scale )
203+
204+ new_col = col .cast (dtype )
205+ new_field = field .with_type (dtype )
206+
207+ new_columns .append (new_col )
208+ new_fields .append (new_field )
209+ else :
210+ new_columns .append (col )
211+ new_fields .append (field )
212+
213+ new_schema = pyarrow .schema (new_fields )
214+ return pyarrow .RecordBatch .from_arrays (new_columns , schema = new_schema )
215+
216+ def to_arrow_table (self ) -> "pyarrow.Table" :
217+ def batch_generator ():
218+ for batch in self .record_batches :
219+ yield self .convert_decimals_in_record_batch (batch )
220+
221+ return pyarrow .Table .from_batches (batch_generator ())
222+
223+
145224class ColumnQueue (ResultSetQueue ):
146225 def __init__ (self , column_table : ColumnTable ):
147226 self .column_table = column_table
@@ -250,9 +329,9 @@ def __init__(
250329 )
251330
252331 self .table = self ._create_next_table ()
253- self .table_row_index = 0
332+ # self.table_row_index = 0
254333
255- def next_n_rows (self , num_rows : int ) -> "pyarrow.Table" :
334+ def next_n_rows (self , num_rows : int ):
256335 """
257336 Get up to the next n rows of the cloud fetch Arrow dataframes.
258337
@@ -262,55 +341,62 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
262341 Returns:
263342 pyarrow.Table
264343 """
344+ results = self ._create_empty_table ()
265345 if not self .table :
266346 logger .debug ("CloudFetchQueue: no more rows available" )
267347 # Return empty pyarrow table to cause retry of fetch
268- return self . _create_empty_table ()
348+ return results
269349 logger .debug ("CloudFetchQueue: trying to get {} next rows" .format (num_rows ))
270- results = self .table .slice (0 , 0 )
271- partial_result_chunks = [results ]
350+
351+ # results = self.table.slice(0, 0)
352+ # partial_result_chunks = [results]
272353 while num_rows > 0 and self .table :
273354 # Get remaining of num_rows or the rest of the current table, whichever is smaller
274- length = min (num_rows , self .table .num_rows - self .table_row_index )
275- table_slice = self .table .slice (self .table_row_index , length )
276- partial_result_chunks .append (table_slice )
277- self .table_row_index += table_slice .num_rows
355+ length = min (num_rows , self .table .num_rows )
356+ nxt_result = self .table .next_n_rows (length )
357+ results .append (nxt_result )
358+ num_rows -= nxt_result .num_rows
359+ # table_slice = self.table.slice(self.table_row_index, length)
360+ # partial_result_chunks.append(table_slice)
361+ # self.table_row_index += table_slice.num_rows
278362
279363 # Replace current table with the next table if we are at the end of the current table
280- if self .table_row_index == self . table . num_rows :
364+ if self .table . num_rows == 0 :
281365 self .table = self ._create_next_table ()
282- self .table_row_index = 0
283- num_rows -= table_slice .num_rows
366+ # self.table_row_index = 0
367+ # num_rows -= table_slice.num_rows
284368
285369 logger .debug ("CloudFetchQueue: collected {} next rows" .format (results .num_rows ))
286- return concat_chunked_tables ( partial_result_chunks )
370+ return results
287371
288- def remaining_rows (self ) -> "pyarrow.Table" :
372+ def remaining_rows (self ):
289373 """
290374 Get all remaining rows of the cloud fetch Arrow dataframes.
291375
292376 Returns:
293377 pyarrow.Table
294378 """
379+ result = self ._create_empty_table ()
295380 if not self .table :
296381 # Return empty pyarrow table to cause retry of fetch
297- return self ._create_empty_table ()
298- results = self .table .slice (0 , 0 )
299- partial_result_chunks = [results ]
382+ return result
383+ # results = self.table.slice(0, 0)
384+ # result = self._create_empty_table()
385+
300386 print ("remaining_rows call" )
301387 print (f"self.table.num_rows - { self .table .num_rows } " )
302388 while self .table :
303- table_slice = self .table .slice (
304- self .table_row_index , self .table .num_rows - self .table_row_index
305- )
306- partial_result_chunks .append (table_slice )
307- self .table_row_index += table_slice .num_rows
389+ # table_slice = self.table.slice(
390+ # self.table_row_index, self.table.num_rows - self.table_row_index
391+ # )
392+ result .append (self . table )
393+ # self.table_row_index += table_slice.num_rows
308394 self .table = self ._create_next_table ()
309- self .table_row_index = 0
310- print (f"results .num_rows - { results .num_rows } " )
311- return concat_chunked_tables ( partial_result_chunks )
395+ # self.table_row_index = 0
396+ print (f"result .num_rows - { result .num_rows } " )
397+ return result
312398
313- def _create_next_table (self ) -> Union [ "pyarrow.Table" , None ] :
399+ def _create_next_table (self ) -> ArrowStreamTable :
314400 logger .debug (
315401 "CloudFetchQueue: Trying to get downloaded file for row {}" .format (
316402 self .start_row_index
@@ -328,32 +414,41 @@ def _create_next_table(self) -> Union["pyarrow.Table", None]:
328414 )
329415 # None signals no more Arrow tables can be built from the remaining handlers if any remain
330416 return None
331- arrow_table = create_arrow_table_from_arrow_file (
332- downloaded_file .file_bytes , self .description
333- )
417+
418+ arrow_stream_table = ArrowStreamTable (
419+ list (pyarrow .ipc .open_stream (downloaded_file .file_bytes )),
420+ downloaded_file .row_count ,
421+ self .description )
422+ # arrow_table = create_arrow_table_from_arrow_file(
423+ # downloaded_file.file_bytes, self.description
424+ # )
334425
335426 # The server rarely prepares the exact number of rows requested by the client in cloud fetch.
336427 # Subsequently, we drop the extraneous rows in the last file if more rows are retrieved than requested
337- if arrow_table .num_rows > downloaded_file .row_count :
338- arrow_table = arrow_table .slice (0 , downloaded_file .row_count )
428+ # if arrow_table.num_rows > downloaded_file.row_count:
429+ # arrow_table = arrow_table.slice(0, downloaded_file.row_count)
339430
340431 # At this point, whether the file has extraneous rows or not, the arrow table should have the correct num rows
341- assert downloaded_file .row_count == arrow_table .num_rows
342- self .start_row_index += arrow_table .num_rows
432+ # assert downloaded_file.row_count == arrow_table.num_rows
433+ # self.start_row_index += arrow_table.num_rows
434+ self .start_row_index += arrow_stream_table .num_rows
343435
344436 logger .debug (
345437 "CloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}" .format (
346- arrow_table .num_rows , self .start_row_index
438+ arrow_stream_table .num_rows , self .start_row_index
347439 )
348440 )
349441
350442 print ("_create_next_table" )
351- print (f"arrow_table .num_rows - { arrow_table .num_rows } " )
352- return arrow_table
443+ print (f"arrow_stream_table .num_rows - { arrow_stream_table .num_rows } " )
444+ return arrow_stream_table
353445
354- def _create_empty_table (self ) -> "pyarrow.Table" :
446+ def _create_empty_table (self ) -> ArrowStreamTable :
355447 # Create a 0-row table with just the schema bytes
356- return create_arrow_table_from_arrow_file (self .schema_bytes , self .description )
448+ return ArrowStreamTable (
449+ list (pyarrow .ipc .open_stream (self .schema_bytes )),
450+ 0 ,
451+ self .description )
357452
358453
359454ExecuteResponse = namedtuple (
@@ -612,7 +707,6 @@ def create_arrow_table_from_arrow_file(
612707 arrow_table = convert_arrow_based_file_to_arrow_table (file_bytes )
613708 return convert_decimals_in_arrow_table (arrow_table , description )
614709
615-
616710def convert_arrow_based_file_to_arrow_table (file_bytes : bytes ):
617711 try :
618712 return pyarrow .ipc .open_stream (file_bytes ).read_all ()
@@ -779,12 +873,17 @@ def _create_python_tuple(t_col_value_wrapper):
779873 return tuple (result )
780874
781875
782- def concat_chunked_tables (tables : List [Union ["pyarrow.Table" , ColumnTable ]]) -> Union ["pyarrow.Table" , ColumnTable ]:
876+ def concat_chunked_tables (tables : List [Union ["pyarrow.Table" , ColumnTable , ArrowStreamTable ]]) -> Union ["pyarrow.Table" , ColumnTable , ArrowStreamTable ]:
783877 if isinstance (tables [0 ], ColumnTable ):
784878 base_table = tables [0 ]
785879 for table in tables [1 :]:
786880 base_table = merge_columnar (base_table , table )
787881 return base_table
882+ elif isinstance (tables [0 ], ArrowStreamTable ):
883+ base_table = tables [0 ]
884+ for table in tables [1 :]:
885+ base_table = base_table .append (table )
886+ return base_table
788887 else :
789888 return pyarrow .concat_tables (tables )
790889
0 commit comments