Skip to content

Commit caa72a6

Browse files
author
yuzelin
committed
Fix that field nullability affects write
1 parent 5e7d468 commit caa72a6

File tree

3 files changed

+126
-13
lines changed

3 files changed

+126
-13
lines changed

paimon_python_java/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesWriter.java

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
package org.apache.paimon.python;
2020

21+
import org.apache.paimon.arrow.ArrowUtils;
2122
import org.apache.paimon.arrow.reader.ArrowBatchReader;
2223
import org.apache.paimon.data.InternalRow;
2324
import org.apache.paimon.table.sink.TableWrite;
@@ -27,26 +28,43 @@
2728
import org.apache.arrow.memory.RootAllocator;
2829
import org.apache.arrow.vector.VectorSchemaRoot;
2930
import org.apache.arrow.vector.ipc.ArrowStreamReader;
31+
import org.apache.arrow.vector.types.pojo.Field;
3032

3133
import java.io.ByteArrayInputStream;
34+
import java.util.List;
35+
import java.util.Objects;
36+
import java.util.stream.Collectors;
3237

3338
/** Write Arrow bytes to Paimon. */
3439
public class BytesWriter {
3540

3641
private final TableWrite tableWrite;
3742
private final ArrowBatchReader arrowBatchReader;
3843
private final BufferAllocator allocator;
44+
private final List<Field> arrowFields;
3945

4046
public BytesWriter(TableWrite tableWrite, RowType rowType) {
4147
this.tableWrite = tableWrite;
4248
this.arrowBatchReader = new ArrowBatchReader(rowType);
4349
this.allocator = new RootAllocator();
50+
arrowFields =
51+
rowType.getFields().stream()
52+
.map(f -> ArrowUtils.toArrowField(f.name(), f.type()))
53+
.collect(Collectors.toList());
4454
}
4555

46-
public void write(byte[] bytes) throws Exception {
56+
public void write(byte[] bytes, boolean needCheckSchema) throws Exception {
4757
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
4858
ArrowStreamReader arrowStreamReader = new ArrowStreamReader(bais, allocator);
4959
VectorSchemaRoot vsr = arrowStreamReader.getVectorSchemaRoot();
60+
if (needCheckSchema && !checkSchema(arrowFields, vsr.getSchema().getFields())) {
61+
throw new RuntimeException(
62+
String.format(
63+
"Input schema isn't consistent with table schema.\n"
64+
+ "\tTable schema is: %s\n"
65+
+ "\tInput schema is: %s",
66+
arrowFields, vsr.getSchema().getFields()));
67+
}
5068

5169
while (arrowStreamReader.loadNextBatch()) {
5270
Iterable<InternalRow> rows = arrowBatchReader.readBatch(vsr);
@@ -60,4 +78,26 @@ public void write(byte[] bytes) throws Exception {
6078
public void close() {
6179
allocator.close();
6280
}
81+
82+
private boolean checkSchema(List<Field> expectedFields, List<Field> actualFields) {
83+
if (expectedFields.size() != actualFields.size()) {
84+
return false;
85+
}
86+
87+
for (int i = 0; i < expectedFields.size(); i++) {
88+
Field expectedField = expectedFields.get(i);
89+
Field actualField = actualFields.get(i);
90+
if (!checkField(expectedField, actualField)
91+
|| !checkSchema(expectedField.getChildren(), actualField.getChildren())) {
92+
return false;
93+
}
94+
}
95+
96+
return true;
97+
}
98+
99+
private boolean checkField(Field expected, Field actual) {
100+
return Objects.equals(expected.getName(), actual.getName())
101+
&& Objects.equals(expected.getType(), actual.getType());
102+
}
63103
}

paimon_python_java/pypaimon.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -218,23 +218,22 @@ def __init__(self, j_batch_table_write, j_row_type, arrow_schema: pa.Schema):
218218

219219
def write_arrow(self, table):
220220
for record_batch in table.to_reader():
221-
# TODO: can we use a reusable stream?
222-
stream = pa.BufferOutputStream()
223-
with pa.RecordBatchStreamWriter(stream, self._arrow_schema) as writer:
224-
writer.write(record_batch)
225-
arrow_bytes = stream.getvalue().to_pybytes()
226-
self._j_bytes_writer.write(arrow_bytes)
221+
# TODO: can we use a reusable stream in #_write_arrow_batch ?
222+
self._write_arrow_batch(record_batch, True)
227223

228224
def write_arrow_batch(self, record_batch):
229-
stream = pa.BufferOutputStream()
230-
with pa.RecordBatchStreamWriter(stream, self._arrow_schema) as writer:
231-
writer.write(record_batch)
232-
arrow_bytes = stream.getvalue().to_pybytes()
233-
self._j_bytes_writer.write(arrow_bytes)
225+
self._write_arrow_batch(record_batch, True)
234226

235227
def write_pandas(self, dataframe: pd.DataFrame):
236228
record_batch = pa.RecordBatch.from_pandas(dataframe, schema=self._arrow_schema)
237-
self.write_arrow_batch(record_batch)
229+
self._write_arrow_batch(record_batch, False)
230+
231+
def _write_arrow_batch(self, record_batch, check_schema):
232+
stream = pa.BufferOutputStream()
233+
with pa.RecordBatchStreamWriter(stream, record_batch.schema) as writer:
234+
writer.write(record_batch)
235+
arrow_bytes = stream.getvalue().to_pybytes()
236+
self._j_bytes_writer.write(arrow_bytes, check_schema)
238237

239238
def prepare_commit(self) -> List['CommitMessage']:
240239
j_commit_messages = self._j_batch_table_write.prepareCommit()

paimon_python_java/tests/test_write_and_read.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import unittest
2323
import pandas as pd
2424
import pyarrow as pa
25+
from py4j.protocol import Py4JJavaError
2526

2627
from paimon_python_api import Schema
2728
from paimon_python_java import Catalog
@@ -371,3 +372,76 @@ def test_overwrite(self):
371372
df2['f0'] = df2['f0'].astype('int32')
372373
pd.testing.assert_frame_equal(
373374
actual_df2.reset_index(drop=True), df2.reset_index(drop=True))
375+
376+
def testWriteWrongSchema(self):
377+
schema = Schema(self.simple_pa_schema)
378+
self.catalog.create_table('default.test_wrong_schema', schema, False)
379+
table = self.catalog.get_table('default.test_wrong_schema')
380+
381+
data = {
382+
'f0': [1, 2, 3],
383+
'f1': ['a', 'b', 'c'],
384+
}
385+
df = pd.DataFrame(data)
386+
schema = pa.schema([
387+
('f0', pa.int64()),
388+
('f1', pa.string())
389+
])
390+
record_batch = pa.RecordBatch.from_pandas(df, schema)
391+
392+
write_builder = table.new_batch_write_builder()
393+
table_write = write_builder.new_write()
394+
395+
with self.assertRaises(Py4JJavaError) as e:
396+
table_write.write_arrow_batch(record_batch)
397+
self.assertEqual(
398+
str(e.exception.java_exception),
399+
'''java.lang.RuntimeException: Input schema isn't consistent with table schema.
400+
\tTable schema is: [f0: Int(32, true), f1: Utf8]
401+
\tInput schema is: [f0: Int(64, true), f1: Utf8]''')
402+
403+
def testIgnoreNullable(self):
404+
pa_schema1 = pa.schema([
405+
('f0', pa.int32(), False),
406+
('f1', pa.string())
407+
])
408+
409+
pa_schema2 = pa.schema([
410+
('f0', pa.int32()),
411+
('f1', pa.string())
412+
])
413+
414+
# write nullable to non-null
415+
self._testIgnoreNullableImpl('test_ignore_nullable1', pa_schema1, pa_schema2)
416+
417+
# write non-null to nullable
418+
self._testIgnoreNullableImpl('test_ignore_nullable2', pa_schema2, pa_schema1)
419+
420+
def _testIgnoreNullableImpl(self, table_name, table_schema, data_schema):
421+
schema = Schema(table_schema)
422+
self.catalog.create_table(f'default.{table_name}', schema, False)
423+
table = self.catalog.get_table(f'default.{table_name}')
424+
425+
data = {
426+
'f0': [1, 2, 3],
427+
'f1': ['a', 'b', 'c'],
428+
}
429+
df = pd.DataFrame(data)
430+
record_batch = pa.RecordBatch.from_pandas(pd.DataFrame(data), data_schema)
431+
432+
write_builder = table.new_batch_write_builder()
433+
table_write = write_builder.new_write()
434+
table_commit = write_builder.new_commit()
435+
table_write.write_arrow_batch(record_batch)
436+
table_commit.commit(table_write.prepare_commit())
437+
438+
table_write.close()
439+
table_commit.close()
440+
441+
read_builder = table.new_read_builder()
442+
table_scan = read_builder.new_scan()
443+
table_read = read_builder.new_read()
444+
actual_df = table_read.to_pandas(table_scan.plan().splits())
445+
df['f0'] = df['f0'].astype('int32')
446+
pd.testing.assert_frame_equal(
447+
actual_df.reset_index(drop=True), df.reset_index(drop=True))

0 commit comments

Comments
 (0)