@@ -177,23 +177,20 @@ def file_paths(self) -> List[str]:
177177class TableRead (table_read .TableRead ):
178178
179179 def __init__ (self , j_table_read , j_read_type , catalog_options ):
180- self ._j_table_read = j_table_read
181- self ._j_read_type = j_read_type
182- self ._catalog_options = catalog_options
183- self ._j_bytes_reader = None
184180 self ._arrow_schema = java_utils .to_arrow_schema (j_read_type )
181+ self ._j_bytes_reader = get_gateway ().jvm .InvocationUtil .createParallelBytesReader (
182+ j_table_read , j_read_type , TableRead ._get_max_workers (catalog_options ))
185183
186- def to_arrow (self , splits ):
187- record_batch_reader = self .to_arrow_batch_reader (splits )
188- return pa .Table .from_batches (record_batch_reader , schema = self ._arrow_schema )
189-
190- def to_arrow_batch_reader (self , splits ):
191- self ._init ()
184+ def to_arrow_batch_reader (self , splits ) -> pa .RecordBatchReader :
192185 j_splits = list (map (lambda s : s .to_j_split (), splits ))
193186 self ._j_bytes_reader .setSplits (j_splits )
194187 batch_iterator = self ._batch_generator ()
195188 return pa .RecordBatchReader .from_batches (self ._arrow_schema , batch_iterator )
196189
190+ def to_arrow (self , splits ) -> pa .Table :
191+ record_batch_reader = self .to_arrow_batch_reader (splits )
192+ return pa .Table .from_batches (record_batch_reader , schema = self ._arrow_schema )
193+
197194 def to_pandas (self , splits : List [Split ]) -> pd .DataFrame :
198195 return self .to_arrow (splits ).to_pandas ()
199196
@@ -213,19 +210,13 @@ def to_ray(self, splits: List[Split]) -> "ray.data.dataset.Dataset":
213210
214211 return ray .data .from_arrow (self .to_arrow (splits ))
215212
216- def _init (self ):
217- if self ._j_bytes_reader is None :
218- # get thread num
219- max_workers = self ._catalog_options .get (constants .MAX_WORKERS )
220- if max_workers is None :
221- # default is sequential
222- max_workers = 1
223- else :
224- max_workers = int (max_workers )
225- if max_workers <= 0 :
226- raise ValueError ("max_workers must be greater than 0" )
227- self ._j_bytes_reader = get_gateway ().jvm .InvocationUtil .createParallelBytesReader (
228- self ._j_table_read , self ._j_read_type , max_workers )
213+ @staticmethod
214+ def _get_max_workers (catalog_options ):
215+ # default is sequential
216+ max_workers = int (catalog_options .get (constants .MAX_WORKERS , 1 ))
217+ if max_workers <= 0 :
218+ raise ValueError ("max_workers must be greater than 0" )
219+ return max_workers
229220
230221 def _batch_generator (self ) -> Iterator [pa .RecordBatch ]:
231222 while True :
0 commit comments