Skip to content

Commit 83c8c3e

Browse files
HonahXkevinjqliu
andauthored
[0.6.x] Backport PR #523 to cast data to iceberg table's pyarrow schema (#559)
* Cast data to Iceberg Table's pyarrow schema (#523) Backport to 0.6.1 * use schema_to_pyarrow directly for backporting * remove print in test --------- Co-authored-by: Kevin Liu <kevinjqliu@users.noreply.github.com>
1 parent b9362ee commit 83c8c3e

File tree

3 files changed

+70
-8
lines changed

3 files changed

+70
-8
lines changed

pyiceberg/table/__init__.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,15 @@
132132
_JAVA_LONG_MAX = 9223372036854775807
133133

134134

135-
def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None:
135+
def _check_schema_compatible(table_schema: Schema, other_schema: "pa.Schema") -> None:
136+
"""
137+
Check if the `table_schema` is compatible with `other_schema`.
138+
139+
Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type.
140+
141+
Raises:
142+
ValueError: If the schemas are not compatible.
143+
"""
136144
from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema
137145

138146
name_mapping = table_schema.name_mapping
@@ -1044,7 +1052,12 @@ def append(self, df: pa.Table) -> None:
10441052
if len(self.spec().fields) > 0:
10451053
raise ValueError("Cannot write to partitioned tables")
10461054

1047-
_check_schema(self.schema(), other_schema=df.schema)
1055+
from pyiceberg.io.pyarrow import schema_to_pyarrow
1056+
1057+
_check_schema_compatible(self.schema(), other_schema=df.schema)
1058+
# cast if the two schemas are compatible but not equal
1059+
if schema_to_pyarrow(self.schema()) != df.schema:
1060+
df = df.cast(schema_to_pyarrow(self.schema()))
10481061

10491062
merge = _MergingSnapshotProducer(operation=Operation.APPEND, table=self)
10501063

@@ -1079,7 +1092,12 @@ def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_T
10791092
if len(self.spec().fields) > 0:
10801093
raise ValueError("Cannot write to partitioned tables")
10811094

1082-
_check_schema(self.schema(), other_schema=df.schema)
1095+
from pyiceberg.io.pyarrow import schema_to_pyarrow
1096+
1097+
_check_schema_compatible(self.schema(), other_schema=df.schema)
1098+
# cast if the two schemas are compatible but not equal
1099+
if schema_to_pyarrow(self.schema()) != df.schema:
1100+
df = df.cast(schema_to_pyarrow(self.schema()))
10831101

10841102
merge = _MergingSnapshotProducer(
10851103
operation=Operation.OVERWRITE if self.current_snapshot() is not None else Operation.APPEND,

tests/catalog/test_sql.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,36 @@ def test_create_table_with_pyarrow_schema(
191191
catalog.drop_table(random_identifier)
192192

193193

194+
@pytest.mark.parametrize(
195+
'catalog',
196+
[
197+
lazy_fixture('catalog_memory'),
198+
lazy_fixture('catalog_sqlite'),
199+
],
200+
)
201+
def test_write_pyarrow_schema(catalog: SqlCatalog, random_identifier: Identifier) -> None:
202+
import pyarrow as pa
203+
204+
pyarrow_table = pa.Table.from_arrays(
205+
[
206+
pa.array([None, "A", "B", "C"]), # 'foo' column
207+
pa.array([1, 2, 3, 4]), # 'bar' column
208+
pa.array([True, None, False, True]), # 'baz' column
209+
pa.array([None, "A", "B", "C"]), # 'large' column
210+
],
211+
schema=pa.schema([
212+
pa.field('foo', pa.string(), nullable=True),
213+
pa.field('bar', pa.int32(), nullable=False),
214+
pa.field('baz', pa.bool_(), nullable=True),
215+
pa.field('large', pa.large_string(), nullable=True),
216+
]),
217+
)
218+
database_name, _table_name = random_identifier
219+
catalog.create_namespace(database_name)
220+
table = catalog.create_table(random_identifier, pyarrow_table.schema)
221+
table.overwrite(pyarrow_table)
222+
223+
194224
@pytest.mark.parametrize(
195225
'catalog',
196226
[

tests/table/test_init.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
Table,
5959
UpdateSchema,
6060
_apply_table_update,
61-
_check_schema,
61+
_check_schema_compatible,
6262
_generate_snapshot_id,
6363
_match_deletes_to_data_file,
6464
_TableMetadataUpdateContext,
@@ -1004,7 +1004,7 @@ def test_schema_mismatch_type(table_schema_simple: Schema) -> None:
10041004
"""
10051005

10061006
with pytest.raises(ValueError, match=expected):
1007-
_check_schema(table_schema_simple, other_schema)
1007+
_check_schema_compatible(table_schema_simple, other_schema)
10081008

10091009

10101010
def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
@@ -1025,7 +1025,7 @@ def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
10251025
"""
10261026

10271027
with pytest.raises(ValueError, match=expected):
1028-
_check_schema(table_schema_simple, other_schema)
1028+
_check_schema_compatible(table_schema_simple, other_schema)
10291029

10301030

10311031
def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
@@ -1045,7 +1045,7 @@ def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
10451045
"""
10461046

10471047
with pytest.raises(ValueError, match=expected):
1048-
_check_schema(table_schema_simple, other_schema)
1048+
_check_schema_compatible(table_schema_simple, other_schema)
10491049

10501050

10511051
def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
@@ -1059,4 +1059,18 @@ def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
10591059
expected = r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)."
10601060

10611061
with pytest.raises(ValueError, match=expected):
1062-
_check_schema(table_schema_simple, other_schema)
1062+
_check_schema_compatible(table_schema_simple, other_schema)
1063+
1064+
1065+
def test_schema_downcast(table_schema_simple: Schema) -> None:
1066+
# large_string type is compatible with string type
1067+
other_schema = pa.schema((
1068+
pa.field("foo", pa.large_string(), nullable=True),
1069+
pa.field("bar", pa.int32(), nullable=False),
1070+
pa.field("baz", pa.bool_(), nullable=True),
1071+
))
1072+
1073+
try:
1074+
_check_schema_compatible(table_schema_simple, other_schema)
1075+
except Exception:
1076+
pytest.fail("Unexpected Exception raised when calling `_check_schema`")

0 commit comments

Comments
 (0)