Skip to content

Commit 953f30f

Browse files
authored
Refactor reader and writer (#16)
1 parent 6581f65 commit 953f30f

File tree

8 files changed

+114
-116
lines changed

8 files changed

+114
-116
lines changed

paimon_python_api/table_read.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# limitations under the License.
1717
#################################################################################
1818

19+
import pandas as pd
1920
import pyarrow as pa
2021

2122
from abc import ABC, abstractmethod
@@ -27,5 +28,13 @@ class TableRead(ABC):
2728
"""To read data from data splits."""
2829

2930
@abstractmethod
30-
def create_reader(self, splits: List[Split]) -> pa.RecordBatchReader:
31-
"""Return a reader containing batches of pyarrow format."""
31+
def to_arrow(self, splits: List[Split]) -> pa.Table:
32+
"""Read data from splits and converted to pyarrow.Table format."""
33+
34+
@abstractmethod
35+
def to_arrow_batch_reader(self, splits: List[Split]) -> pa.RecordBatchReader:
36+
"""Read data from splits and converted to pyarrow.RecordBatchReader format."""
37+
38+
@abstractmethod
39+
def to_pandas(self, splits: List[Split]) -> pd.DataFrame:
40+
"""Read data from splits and converted to pandas.DataFrame format."""

paimon_python_api/table_write.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# limitations under the License.
1717
#################################################################################
1818

19+
import pandas as pd
1920
import pyarrow as pa
2021

2122
from abc import ABC, abstractmethod
@@ -27,8 +28,16 @@ class BatchTableWrite(ABC):
2728
"""A table write for batch processing. Recommended for one-time committing."""
2829

2930
@abstractmethod
30-
def write(self, record_batch: pa.RecordBatch):
31-
""" Write a batch to the writer. */"""
31+
def write_arrow(self, table: pa.Table):
32+
""" Write an arrow table to the writer."""
33+
34+
@abstractmethod
35+
def write_arrow_batch(self, record_batch: pa.RecordBatch):
36+
""" Write an arrow record batch to the writer."""
37+
38+
@abstractmethod
39+
def write_pandas(self, dataframe: pd.DataFrame):
40+
""" Write a pandas dataframe to the writer."""
3241

3342
@abstractmethod
3443
def prepare_commit(self) -> List[CommitMessage]:

paimon_python_java/java_gateway.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def import_paimon_view(gateway):
107107
java_import(gateway.jvm, "org.apache.paimon.catalog.*")
108108
java_import(gateway.jvm, "org.apache.paimon.schema.Schema*")
109109
java_import(gateway.jvm, 'org.apache.paimon.types.*')
110-
java_import(gateway.jvm, 'org.apache.paimon.python.InvocationUtil')
110+
java_import(gateway.jvm, 'org.apache.paimon.python.*')
111111
java_import(gateway.jvm, "org.apache.paimon.data.*")
112112

113113

paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
package org.apache.paimon.python;
2020

21-
import org.apache.paimon.arrow.ArrowUtils;
2221
import org.apache.paimon.arrow.reader.ArrowBatchReader;
2322
import org.apache.paimon.data.InternalRow;
2423
import org.apache.paimon.table.sink.TableWrite;
@@ -28,43 +27,26 @@
2827
import org.apache.arrow.memory.RootAllocator;
2928
import org.apache.arrow.vector.VectorSchemaRoot;
3029
import org.apache.arrow.vector.ipc.ArrowStreamReader;
31-
import org.apache.arrow.vector.types.pojo.Field;
3230

3331
import java.io.ByteArrayInputStream;
34-
import java.util.List;
35-
import java.util.Objects;
36-
import java.util.stream.Collectors;
3732

3833
/** Write Arrow bytes to Paimon. */
3934
public class BytesWriter {
4035

4136
private final TableWrite tableWrite;
4237
private final ArrowBatchReader arrowBatchReader;
4338
private final BufferAllocator allocator;
44-
private final List<Field> arrowFields;
4539

4640
public BytesWriter(TableWrite tableWrite, RowType rowType) {
4741
this.tableWrite = tableWrite;
4842
this.arrowBatchReader = new ArrowBatchReader(rowType);
4943
this.allocator = new RootAllocator();
50-
arrowFields =
51-
rowType.getFields().stream()
52-
.map(f -> ArrowUtils.toArrowField(f.name(), f.type()))
53-
.collect(Collectors.toList());
5444
}
5545

5646
public void write(byte[] bytes) throws Exception {
5747
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
5848
ArrowStreamReader arrowStreamReader = new ArrowStreamReader(bais, allocator);
5949
VectorSchemaRoot vsr = arrowStreamReader.getVectorSchemaRoot();
60-
if (!checkSchema(arrowFields, vsr.getSchema().getFields())) {
61-
throw new RuntimeException(
62-
String.format(
63-
"Input schema isn't consistent with table schema.\n"
64-
+ "\tTable schema is: %s\n"
65-
+ "\tInput schema is: %s",
66-
arrowFields, vsr.getSchema().getFields()));
67-
}
6850

6951
while (arrowStreamReader.loadNextBatch()) {
7052
Iterable<InternalRow> rows = arrowBatchReader.readBatch(vsr);
@@ -78,26 +60,4 @@ public void write(byte[] bytes) throws Exception {
7860
public void close() {
7961
allocator.close();
8062
}
81-
82-
private boolean checkSchema(List<Field> expectedFields, List<Field> actualFields) {
83-
if (expectedFields.size() != actualFields.size()) {
84-
return false;
85-
}
86-
87-
for (int i = 0; i < expectedFields.size(); i++) {
88-
Field expectedField = expectedFields.get(i);
89-
Field actualField = actualFields.get(i);
90-
if (!checkField(expectedField, actualField)
91-
|| !checkSchema(expectedField.getChildren(), actualField.getChildren())) {
92-
return false;
93-
}
94-
}
95-
96-
return true;
97-
}
98-
99-
private boolean checkField(Field expected, Field actual) {
100-
return Objects.equals(expected.getName(), actual.getName())
101-
&& Objects.equals(expected.getType(), actual.getType());
102-
}
10363
}

paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/ParallelBytesReader.java

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
package org.apache.paimon.python;
2020

21-
import org.apache.paimon.arrow.ArrowUtils;
2221
import org.apache.paimon.arrow.vector.ArrowFormatWriter;
2322
import org.apache.paimon.data.InternalRow;
2423
import org.apache.paimon.reader.RecordReader;
@@ -30,11 +29,8 @@
3029

3130
import org.apache.paimon.shade.guava30.com.google.common.collect.Iterators;
3231

33-
import org.apache.arrow.vector.VectorSchemaRoot;
34-
3532
import javax.annotation.Nullable;
3633

37-
import java.io.ByteArrayOutputStream;
3834
import java.io.IOException;
3935
import java.util.ArrayDeque;
4036
import java.util.ArrayList;
@@ -77,15 +73,6 @@ public void setSplits(List<Split> splits) {
7773
bytesIterator = randomlyExecute(getExecutor(), makeProcessor(), splits);
7874
}
7975

80-
public byte[] serializeSchema() {
81-
ArrowFormatWriter arrowFormatWriter = newWriter();
82-
VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
83-
ByteArrayOutputStream out = new ByteArrayOutputStream();
84-
ArrowUtils.serializeToIpc(vsr, out);
85-
arrowFormatWriter.close();
86-
return out.toByteArray();
87-
}
88-
8976
@Nullable
9077
public byte[] next() {
9178
if (bytesIterator.hasNext()) {
@@ -110,7 +97,8 @@ private Function<Split, Iterator<byte[]>> makeProcessor() {
11097
RecordReaderIterator<InternalRow> iterator =
11198
new RecordReaderIterator<>(recordReader);
11299
iterators.add(iterator);
113-
ArrowFormatWriter arrowFormatWriter = newWriter();
100+
ArrowFormatWriter arrowFormatWriter =
101+
new ArrowFormatWriter(rowType, DEFAULT_WRITE_BATCH_SIZE, true);
114102
arrowFormatWriters.add(arrowFormatWriter);
115103
return new RecordBytesIterator(iterator, arrowFormatWriter);
116104
} catch (IOException e) {
@@ -164,8 +152,4 @@ private void closeResources() {
164152
}
165153
arrowFormatWriters.clear();
166154
}
167-
168-
private ArrowFormatWriter newWriter() {
169-
return new ArrowFormatWriter(rowType, DEFAULT_WRITE_BATCH_SIZE, true);
170-
}
171155
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.paimon.python;
20+
21+
import org.apache.paimon.arrow.ArrowUtils;
22+
import org.apache.paimon.types.RowType;
23+
24+
import org.apache.arrow.memory.BufferAllocator;
25+
import org.apache.arrow.memory.RootAllocator;
26+
import org.apache.arrow.vector.VectorSchemaRoot;
27+
28+
import java.io.ByteArrayOutputStream;
29+
30+
/** Util to get arrow schema from row type. */
31+
public class SchemaUtil {
32+
public static byte[] getArrowSchema(RowType rowType) {
33+
BufferAllocator allocator = new RootAllocator();
34+
VectorSchemaRoot emptyRoot = ArrowUtils.createVectorSchemaRoot(rowType, allocator, true);
35+
ByteArrayOutputStream out = new ByteArrayOutputStream();
36+
ArrowUtils.serializeToIpc(emptyRoot, out);
37+
emptyRoot.close();
38+
allocator.close();
39+
return out.toByteArray();
40+
}
41+
}

paimon_python_java/pypaimon.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# limitations under the License.
1717
################################################################################
1818

19+
import pandas as pd
1920
import pyarrow as pa
2021

2122
from paimon_python_java.java_gateway import get_gateway
@@ -59,23 +60,30 @@ class Table(table.Table):
5960
def __init__(self, j_table, catalog_options: dict):
6061
self._j_table = j_table
6162
self._catalog_options = catalog_options
63+
# init arrow schema
64+
schema_bytes = get_gateway().jvm.SchemaUtil.getArrowSchema(j_table.rowType())
65+
schema_reader = pa.RecordBatchStreamReader(pa.BufferReader(schema_bytes))
66+
self._arrow_schema = schema_reader.schema
67+
schema_reader.close()
6268

6369
def new_read_builder(self) -> 'ReadBuilder':
6470
j_read_builder = get_gateway().jvm.InvocationUtil.getReadBuilder(self._j_table)
65-
return ReadBuilder(j_read_builder, self._j_table.rowType(), self._catalog_options)
71+
return ReadBuilder(
72+
j_read_builder, self._j_table.rowType(), self._catalog_options, self._arrow_schema)
6673

6774
def new_batch_write_builder(self) -> 'BatchWriteBuilder':
6875
java_utils.check_batch_write(self._j_table)
6976
j_batch_write_builder = get_gateway().jvm.InvocationUtil.getBatchWriteBuilder(self._j_table)
70-
return BatchWriteBuilder(j_batch_write_builder, self._j_table.rowType())
77+
return BatchWriteBuilder(j_batch_write_builder, self._j_table.rowType(), self._arrow_schema)
7178

7279

7380
class ReadBuilder(read_builder.ReadBuilder):
7481

75-
def __init__(self, j_read_builder, j_row_type, catalog_options: dict):
82+
def __init__(self, j_read_builder, j_row_type, catalog_options: dict, arrow_schema: pa.Schema):
7683
self._j_read_builder = j_read_builder
7784
self._j_row_type = j_row_type
7885
self._catalog_options = catalog_options
86+
self._arrow_schema = arrow_schema
7987

8088
def with_projection(self, projection: List[List[int]]) -> 'ReadBuilder':
8189
self._j_read_builder.withProjection(projection)
@@ -91,7 +99,7 @@ def new_scan(self) -> 'TableScan':
9199

92100
def new_read(self) -> 'TableRead':
93101
j_table_read = self._j_read_builder.newRead()
94-
return TableRead(j_table_read, self._j_row_type, self._catalog_options)
102+
return TableRead(j_table_read, self._j_row_type, self._catalog_options, self._arrow_schema)
95103

96104

97105
class TableScan(table_scan.TableScan):
@@ -125,20 +133,27 @@ def to_j_split(self):
125133

126134
class TableRead(table_read.TableRead):
127135

128-
def __init__(self, j_table_read, j_row_type, catalog_options):
136+
def __init__(self, j_table_read, j_row_type, catalog_options, arrow_schema):
129137
self._j_table_read = j_table_read
130138
self._j_row_type = j_row_type
131139
self._catalog_options = catalog_options
132140
self._j_bytes_reader = None
133-
self._arrow_schema = None
141+
self._arrow_schema = arrow_schema
134142

135-
def create_reader(self, splits):
143+
def to_arrow(self, splits):
144+
record_batch_reader = self.to_arrow_batch_reader(splits)
145+
return pa.Table.from_batches(record_batch_reader, schema=self._arrow_schema)
146+
147+
def to_arrow_batch_reader(self, splits):
136148
self._init()
137149
j_splits = list(map(lambda s: s.to_j_split(), splits))
138150
self._j_bytes_reader.setSplits(j_splits)
139151
batch_iterator = self._batch_generator()
140152
return pa.RecordBatchReader.from_batches(self._arrow_schema, batch_iterator)
141153

154+
def to_pandas(self, splits: List[Split]) -> pd.DataFrame:
155+
return self.to_arrow(splits).to_pandas()
156+
142157
def _init(self):
143158
if self._j_bytes_reader is None:
144159
# get thread num
@@ -153,12 +168,6 @@ def _init(self):
153168
self._j_bytes_reader = get_gateway().jvm.InvocationUtil.createParallelBytesReader(
154169
self._j_table_read, self._j_row_type, max_workers)
155170

156-
if self._arrow_schema is None:
157-
schema_bytes = self._j_bytes_reader.serializeSchema()
158-
schema_reader = pa.RecordBatchStreamReader(pa.BufferReader(schema_bytes))
159-
self._arrow_schema = schema_reader.schema
160-
schema_reader.close()
161-
162171
def _batch_generator(self) -> Iterator[pa.RecordBatch]:
163172
while True:
164173
next_bytes = self._j_bytes_reader.next()
@@ -171,17 +180,18 @@ def _batch_generator(self) -> Iterator[pa.RecordBatch]:
171180

172181
class BatchWriteBuilder(write_builder.BatchWriteBuilder):
173182

174-
def __init__(self, j_batch_write_builder, j_row_type):
183+
def __init__(self, j_batch_write_builder, j_row_type, arrow_schema: pa.Schema):
175184
self._j_batch_write_builder = j_batch_write_builder
176185
self._j_row_type = j_row_type
186+
self._arrow_schema = arrow_schema
177187

178188
def with_overwrite(self, static_partition: dict) -> 'BatchWriteBuilder':
179189
self._j_batch_write_builder.withOverwrite(static_partition)
180190
return self
181191

182192
def new_write(self) -> 'BatchTableWrite':
183193
j_batch_table_write = self._j_batch_write_builder.newWrite()
184-
return BatchTableWrite(j_batch_table_write, self._j_row_type)
194+
return BatchTableWrite(j_batch_table_write, self._j_row_type, self._arrow_schema)
185195

186196
def new_commit(self) -> 'BatchTableCommit':
187197
j_batch_table_commit = self._j_batch_write_builder.newCommit()
@@ -190,19 +200,32 @@ def new_commit(self) -> 'BatchTableCommit':
190200

191201
class BatchTableWrite(table_write.BatchTableWrite):
192202

193-
def __init__(self, j_batch_table_write, j_row_type):
203+
def __init__(self, j_batch_table_write, j_row_type, arrow_schema: pa.Schema):
194204
self._j_batch_table_write = j_batch_table_write
195205
self._j_bytes_writer = get_gateway().jvm.InvocationUtil.createBytesWriter(
196206
j_batch_table_write, j_row_type)
197-
198-
def write(self, record_batch: pa.RecordBatch):
207+
self._arrow_schema = arrow_schema
208+
209+
def write_arrow(self, table):
210+
for record_batch in table.to_reader():
211+
# TODO: can we use a reusable stream?
212+
stream = pa.BufferOutputStream()
213+
with pa.RecordBatchStreamWriter(stream, self._arrow_schema) as writer:
214+
writer.write(record_batch)
215+
arrow_bytes = stream.getvalue().to_pybytes()
216+
self._j_bytes_writer.write(arrow_bytes)
217+
218+
def write_arrow_batch(self, record_batch):
199219
stream = pa.BufferOutputStream()
200-
with pa.RecordBatchStreamWriter(stream, record_batch.schema) as writer:
220+
with pa.RecordBatchStreamWriter(stream, self._arrow_schema) as writer:
201221
writer.write(record_batch)
202-
writer.close()
203222
arrow_bytes = stream.getvalue().to_pybytes()
204223
self._j_bytes_writer.write(arrow_bytes)
205224

225+
def write_pandas(self, dataframe: pd.DataFrame):
226+
record_batch = pa.RecordBatch.from_pandas(dataframe, schema=self._arrow_schema)
227+
self.write_arrow_batch(record_batch)
228+
206229
def prepare_commit(self) -> List['CommitMessage']:
207230
j_commit_messages = self._j_batch_table_write.prepareCommit()
208231
return list(map(lambda cm: CommitMessage(cm), j_commit_messages))

0 commit comments

Comments
 (0)