2323
2424from pypaimon .py4j .java_gateway import get_gateway
2525from pypaimon .py4j .util import java_utils , constants
26+ from pypaimon .py4j .util .java_utils import serialize_java_object , deserialize_java_object
2627from pypaimon .api import \
2728 (catalog , table , read_builder , table_scan , split , row_type ,
2829 table_read , write_builder , table_write , commit_message ,
@@ -109,8 +110,9 @@ def new_scan(self) -> 'TableScan':
109110 return TableScan (j_table_scan )
110111
111112 def new_read (self ) -> 'TableRead' :
112- j_table_read = self ._j_read_builder .newRead ().executeFilter ()
113- return TableRead (j_table_read , self ._j_read_builder .readType (), self ._catalog_options )
113+ j_table_read_bytes = serialize_java_object (self ._j_read_builder .newRead ().executeFilter ())
114+ j_read_type_bytes = serialize_java_object (self ._j_read_builder .readType ())
115+ return TableRead (j_table_read_bytes , j_read_type_bytes , self ._catalog_options )
114116
115117 def new_predicate_builder (self ) -> 'PredicateBuilder' :
116118 return PredicateBuilder (self ._j_row_type )
@@ -145,55 +147,66 @@ def __init__(self, j_splits):
145147 self ._j_splits = j_splits
146148
147149 def splits (self ) -> List ['Split' ]:
148- return list (map (lambda s : Split (s ), self ._j_splits ))
150+ return list (map (lambda s : self ._build_single_split (s ), self ._j_splits ))
151+
152+ def _build_single_split (self , j_split ) -> 'Split' :
153+ j_split_bytes = serialize_java_object (j_split )
154+ row_count = j_split .rowCount ()
155+ files_optional = j_split .convertToRawFiles ()
156+ if not files_optional .isPresent ():
157+ file_size = 0
158+ file_paths = []
159+ else :
160+ files = files_optional .get ()
161+ file_size = sum (file .length () for file in files )
162+ file_paths = [file .path () for file in files ]
163+ return Split (j_split_bytes , row_count , file_size , file_paths )
149164
150165
151166class Split (split .Split ):
152167
153- def __init__ (self , j_split ):
154- self ._j_split = j_split
168+ def __init__ (self , j_split_bytes , row_count : int , file_size : int , file_paths : List [str ]):
169+ self ._j_split_bytes = j_split_bytes
170+ self ._row_count = row_count
171+ self ._file_size = file_size
172+ self ._file_paths = file_paths
155173
156174 def to_j_split (self ):
157- return self ._j_split
175+ return deserialize_java_object ( self ._j_split_bytes )
158176
159177 def row_count (self ) -> int :
160- return self ._j_split . rowCount ()
178+ return self ._row_count
161179
162180 def file_size (self ) -> int :
163- files_optional = self ._j_split .convertToRawFiles ()
164- if not files_optional .isPresent ():
165- return 0
166- files = files_optional .get ()
167- return sum (file .length () for file in files )
181+ return self ._file_size
168182
169183 def file_paths (self ) -> List [str ]:
170- files_optional = self ._j_split .convertToRawFiles ()
171- if not files_optional .isPresent ():
172- return []
173- files = files_optional .get ()
174- return [file .path () for file in files ]
184+ return self ._file_paths
175185
176186
177187class TableRead (table_read .TableRead ):
178188
179- def __init__ (self , j_table_read , j_read_type , catalog_options ):
180- self ._j_table_read = j_table_read
181- self ._j_read_type = j_read_type
189+ def __init__ (self , j_table_read_bytes , j_read_type_bytes , catalog_options ):
190+ self ._j_table_read_bytes = j_table_read_bytes
191+ self ._j_read_type_bytes = j_read_type_bytes
182192 self ._catalog_options = catalog_options
183- self ._j_bytes_reader = None
184- self ._arrow_schema = java_utils .to_arrow_schema (j_read_type )
185193
186- def to_arrow (self , splits ):
187- record_batch_reader = self .to_arrow_batch_reader (splits )
188- return pa .Table .from_batches (record_batch_reader , schema = self ._arrow_schema )
194+ self ._j_table_read = None
195+ self ._j_read_type = None
196+ self ._arrow_schema = None
197+ self ._j_bytes_reader = None
189198
190- def to_arrow_batch_reader (self , splits ):
199+ def to_arrow_batch_reader (self , splits ) -> pa . RecordBatchReader :
191200 self ._init ()
192201 j_splits = list (map (lambda s : s .to_j_split (), splits ))
193202 self ._j_bytes_reader .setSplits (j_splits )
194203 batch_iterator = self ._batch_generator ()
195204 return pa .RecordBatchReader .from_batches (self ._arrow_schema , batch_iterator )
196205
206+ def to_arrow (self , splits ) -> pa .Table :
207+ record_batch_reader = self .to_arrow_batch_reader (splits )
208+ return pa .Table .from_batches (record_batch_reader , schema = self ._arrow_schema )
209+
197210 def to_pandas (self , splits : List [Split ]) -> pd .DataFrame :
198211 return self .to_arrow (splits ).to_pandas ()
199212
@@ -214,6 +227,12 @@ def to_ray(self, splits: List[Split]) -> "ray.data.dataset.Dataset":
214227 return ray .data .from_arrow (self .to_arrow (splits ))
215228
216229 def _init (self ):
230+ if self ._j_table_read is None :
231+ self ._j_table_read = deserialize_java_object (self ._j_table_read_bytes )
232+ if self ._j_read_type is None :
233+ self ._j_read_type = deserialize_java_object (self ._j_read_type_bytes )
234+ if self ._arrow_schema is None :
235+ self ._arrow_schema = java_utils .to_arrow_schema (self ._j_read_type )
217236 if self ._j_bytes_reader is None :
218237 # get thread num
219238 max_workers = self ._catalog_options .get (constants .MAX_WORKERS )
0 commit comments