1616# limitations under the License.
1717################################################################################
1818
19+ import pandas as pd
1920import pyarrow as pa
2021
2122from paimon_python_java .java_gateway import get_gateway
@@ -59,23 +60,30 @@ class Table(table.Table):
5960 def __init__ (self , j_table , catalog_options : dict ):
6061 self ._j_table = j_table
6162 self ._catalog_options = catalog_options
63+ # init arrow schema
64+ schema_bytes = get_gateway ().jvm .SchemaUtil .getArrowSchema (j_table .rowType ())
65+ schema_reader = pa .RecordBatchStreamReader (pa .BufferReader (schema_bytes ))
66+ self ._arrow_schema = schema_reader .schema
67+ schema_reader .close ()
6268
6369 def new_read_builder (self ) -> 'ReadBuilder' :
6470 j_read_builder = get_gateway ().jvm .InvocationUtil .getReadBuilder (self ._j_table )
65- return ReadBuilder (j_read_builder , self ._j_table .rowType (), self ._catalog_options )
71+ return ReadBuilder (
72+ j_read_builder , self ._j_table .rowType (), self ._catalog_options , self ._arrow_schema )
6673
6774 def new_batch_write_builder (self ) -> 'BatchWriteBuilder' :
6875 java_utils .check_batch_write (self ._j_table )
6976 j_batch_write_builder = get_gateway ().jvm .InvocationUtil .getBatchWriteBuilder (self ._j_table )
70- return BatchWriteBuilder (j_batch_write_builder , self ._j_table .rowType ())
77+ return BatchWriteBuilder (j_batch_write_builder , self ._j_table .rowType (), self . _arrow_schema )
7178
7279
7380class ReadBuilder (read_builder .ReadBuilder ):
7481
75- def __init__ (self , j_read_builder , j_row_type , catalog_options : dict ):
82+ def __init__ (self , j_read_builder , j_row_type , catalog_options : dict , arrow_schema : pa . Schema ):
7683 self ._j_read_builder = j_read_builder
7784 self ._j_row_type = j_row_type
7885 self ._catalog_options = catalog_options
86+ self ._arrow_schema = arrow_schema
7987
8088 def with_projection (self , projection : List [List [int ]]) -> 'ReadBuilder' :
8189 self ._j_read_builder .withProjection (projection )
@@ -91,7 +99,7 @@ def new_scan(self) -> 'TableScan':
9199
92100 def new_read (self ) -> 'TableRead' :
93101 j_table_read = self ._j_read_builder .newRead ()
94- return TableRead (j_table_read , self ._j_row_type , self ._catalog_options )
102+ return TableRead (j_table_read , self ._j_row_type , self ._catalog_options , self . _arrow_schema )
95103
96104
97105class TableScan (table_scan .TableScan ):
@@ -125,20 +133,27 @@ def to_j_split(self):
125133
126134class TableRead (table_read .TableRead ):
127135
128- def __init__ (self , j_table_read , j_row_type , catalog_options ):
136+ def __init__ (self , j_table_read , j_row_type , catalog_options , arrow_schema ):
129137 self ._j_table_read = j_table_read
130138 self ._j_row_type = j_row_type
131139 self ._catalog_options = catalog_options
132140 self ._j_bytes_reader = None
133- self ._arrow_schema = None
141+ self ._arrow_schema = arrow_schema
134142
135- def create_reader (self , splits ):
143+ def to_arrow (self , splits ):
144+ record_batch_reader = self .to_arrow_batch_reader (splits )
145+ return pa .Table .from_batches (record_batch_reader , schema = self ._arrow_schema )
146+
147+ def to_arrow_batch_reader (self , splits ):
136148 self ._init ()
137149 j_splits = list (map (lambda s : s .to_j_split (), splits ))
138150 self ._j_bytes_reader .setSplits (j_splits )
139151 batch_iterator = self ._batch_generator ()
140152 return pa .RecordBatchReader .from_batches (self ._arrow_schema , batch_iterator )
141153
154+ def to_pandas (self , splits : List [Split ]) -> pd .DataFrame :
155+ return self .to_arrow (splits ).to_pandas ()
156+
142157 def _init (self ):
143158 if self ._j_bytes_reader is None :
144159 # get thread num
@@ -153,12 +168,6 @@ def _init(self):
153168 self ._j_bytes_reader = get_gateway ().jvm .InvocationUtil .createParallelBytesReader (
154169 self ._j_table_read , self ._j_row_type , max_workers )
155170
156- if self ._arrow_schema is None :
157- schema_bytes = self ._j_bytes_reader .serializeSchema ()
158- schema_reader = pa .RecordBatchStreamReader (pa .BufferReader (schema_bytes ))
159- self ._arrow_schema = schema_reader .schema
160- schema_reader .close ()
161-
162171 def _batch_generator (self ) -> Iterator [pa .RecordBatch ]:
163172 while True :
164173 next_bytes = self ._j_bytes_reader .next ()
@@ -171,17 +180,18 @@ def _batch_generator(self) -> Iterator[pa.RecordBatch]:
171180
172181class BatchWriteBuilder (write_builder .BatchWriteBuilder ):
173182
174- def __init__ (self , j_batch_write_builder , j_row_type ):
183+ def __init__ (self , j_batch_write_builder , j_row_type , arrow_schema : pa . Schema ):
175184 self ._j_batch_write_builder = j_batch_write_builder
176185 self ._j_row_type = j_row_type
186+ self ._arrow_schema = arrow_schema
177187
178188 def with_overwrite (self , static_partition : dict ) -> 'BatchWriteBuilder' :
179189 self ._j_batch_write_builder .withOverwrite (static_partition )
180190 return self
181191
182192 def new_write (self ) -> 'BatchTableWrite' :
183193 j_batch_table_write = self ._j_batch_write_builder .newWrite ()
184- return BatchTableWrite (j_batch_table_write , self ._j_row_type )
194+ return BatchTableWrite (j_batch_table_write , self ._j_row_type , self . _arrow_schema )
185195
186196 def new_commit (self ) -> 'BatchTableCommit' :
187197 j_batch_table_commit = self ._j_batch_write_builder .newCommit ()
@@ -190,19 +200,32 @@ def new_commit(self) -> 'BatchTableCommit':
190200
191201class BatchTableWrite (table_write .BatchTableWrite ):
192202
193- def __init__ (self , j_batch_table_write , j_row_type ):
203+ def __init__ (self , j_batch_table_write , j_row_type , arrow_schema : pa . Schema ):
194204 self ._j_batch_table_write = j_batch_table_write
195205 self ._j_bytes_writer = get_gateway ().jvm .InvocationUtil .createBytesWriter (
196206 j_batch_table_write , j_row_type )
197-
198- def write (self , record_batch : pa .RecordBatch ):
207+ self ._arrow_schema = arrow_schema
208+
209+ def write_arrow (self , table ):
210+ for record_batch in table .to_reader ():
211+ # TODO: can we use a reusable stream?
212+ stream = pa .BufferOutputStream ()
213+ with pa .RecordBatchStreamWriter (stream , self ._arrow_schema ) as writer :
214+ writer .write (record_batch )
215+ arrow_bytes = stream .getvalue ().to_pybytes ()
216+ self ._j_bytes_writer .write (arrow_bytes )
217+
218+ def write_arrow_batch (self , record_batch ):
199219 stream = pa .BufferOutputStream ()
200- with pa .RecordBatchStreamWriter (stream , record_batch . schema ) as writer :
220+ with pa .RecordBatchStreamWriter (stream , self . _arrow_schema ) as writer :
201221 writer .write (record_batch )
202- writer .close ()
203222 arrow_bytes = stream .getvalue ().to_pybytes ()
204223 self ._j_bytes_writer .write (arrow_bytes )
205224
225+ def write_pandas (self , dataframe : pd .DataFrame ):
226+ record_batch = pa .RecordBatch .from_pandas (dataframe , schema = self ._arrow_schema )
227+ self .write_arrow_batch (record_batch )
228+
206229 def prepare_commit (self ) -> List ['CommitMessage' ]:
207230 j_commit_messages = self ._j_batch_table_write .prepareCommit ()
208231 return list (map (lambda cm : CommitMessage (cm ), j_commit_messages ))
0 commit comments