From 245acdab5c6f450996d54ce3c44c264687a83841 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Fri, 12 Jul 2024 21:02:50 +0000 Subject: [PATCH 01/13] merge --- pyiceberg/io/pyarrow.py | 99 ++++++++++------ pyiceberg/schema.py | 5 + pyiceberg/table/__init__.py | 7 +- tests/integration/test_add_files.py | 13 +-- tests/integration/test_writes/test_writes.py | 24 +++- tests/io/test_pyarrow.py | 116 +++++++++++++++---- 6 files changed, 195 insertions(+), 69 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 199133f794..dcc2f48544 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -120,6 +120,7 @@ Schema, SchemaVisitorPerPrimitiveType, SchemaWithPartnerVisitor, + assign_fresh_schema_ids, pre_order_visit, promote, prune_columns, @@ -1450,14 +1451,17 @@ def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: st except ValueError: return None - if isinstance(partner_struct, pa.StructArray): - return partner_struct.field(name) - elif isinstance(partner_struct, pa.Table): - return partner_struct.column(name).combine_chunks() - elif isinstance(partner_struct, pa.RecordBatch): - return partner_struct.column(name) - else: - raise ValueError(f"Cannot find {name} in expected partner_struct type {type(partner_struct)}") + try: + if isinstance(partner_struct, pa.StructArray): + return partner_struct.field(name) + elif isinstance(partner_struct, pa.Table): + return partner_struct.column(name).combine_chunks() + elif isinstance(partner_struct, pa.RecordBatch): + return partner_struct.column(name) + else: + raise ValueError(f"Cannot find {name} in expected partner_struct type {type(partner_struct)}") + except KeyError: + return None return None @@ -2079,36 +2083,63 @@ def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, down Raises: ValueError: If the schemas are not compatible. """ - name_mapping = table_schema.name_mapping - try: - task_schema = pyarrow_to_schema( - other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us - ) - except ValueError as e: - other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us) - additional_names = set(other_schema.column_names) - set(table_schema.column_names) - raise ValueError( - f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." - ) from e - - if table_schema.as_struct() != task_schema.as_struct(): - from rich.console import Console - from rich.table import Table as RichTable + task_schema = assign_fresh_schema_ids( + _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us) + ) - console = Console(record=True) + extra_fields = task_schema.field_names - table_schema.field_names + missing_fields = table_schema.field_names - task_schema.field_names + fields_in_both = task_schema.field_names.intersection(table_schema.field_names) + + from rich.console import Console + from rich.table import Table as RichTable + + console = Console(record=True) + + rich_table = RichTable(show_header=True, header_style="bold") + rich_table.add_column("Field Name") + rich_table.add_column("Category") + rich_table.add_column("Table field") + rich_table.add_column("Dataframe field") + + def print_nullability(required: bool) -> str: + return "required" if required else "optional" + + for field_name in fields_in_both: + lhs = table_schema.find_field(field_name) + rhs = task_schema.find_field(field_name) + # Check nullability + if lhs.required != rhs.required: + rich_table.add_row( + field_name, + "Nullability", + f"{print_nullability(lhs.required)} {str(lhs.field_type)}", + f"{print_nullability(rhs.required)} {str(rhs.field_type)}", + ) + # Check if type is consistent + if any( + (isinstance(lhs.field_type, container_type) and isinstance(rhs.field_type, container_type)) + for container_type in {StructType, MapType, ListType} + ): + continue + elif lhs.field_type != rhs.field_type: + rich_table.add_row( + field_name, + "Type", + f"{print_nullability(lhs.required)} {str(lhs.field_type)}", + f"{print_nullability(rhs.required)} {str(rhs.field_type)}", + ) - rich_table = RichTable(show_header=True, header_style="bold") - rich_table.add_column("") - rich_table.add_column("Table field") - rich_table.add_column("Dataframe field") + for field_name in extra_fields: + rhs = task_schema.find_field(field_name) + rich_table.add_row(field_name, "Extra Fields", "", f"{print_nullability(rhs.required)} {str(rhs.field_type)}") - for lhs in table_schema.fields: - try: - rhs = task_schema.find_field(lhs.field_id) - rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs)) - except ValueError: - rich_table.add_row("❌", str(lhs), "Missing") + for field_name in missing_fields: + lhs = table_schema.find_field(field_name) + if lhs.required: + rich_table.add_row(field_name, "Missing Fields", f"{print_nullability(lhs.required)} {str(lhs.field_type)}", "") + if rich_table.row_count: console.print(rich_table) raise ValueError(f"Mismatch in fields:\n{console.export_text()}") diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index 77f1addbf5..dc5725295d 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -324,6 +324,11 @@ def field_ids(self) -> Set[int]: """Return the IDs of the current schema.""" return set(self._name_to_id.values()) + @property + def field_names(self) -> Set[str]: + """Return the Names of the current schema.""" + return set(self._name_to_id.keys()) + def _validate_identifier_field(self, field_id: int) -> None: """Validate that the field with the given ID is a valid identifier field. diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index b43dc3206b..ef88224a18 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -73,7 +73,6 @@ manifest_evaluator, ) from pyiceberg.io import FileIO, OutputFile, load_file_io -from pyiceberg.io.pyarrow import _check_schema_compatible, _dataframe_to_data_files, expression_to_pyarrow, project_table from pyiceberg.manifest import ( POSITIONAL_DELETE_SCHEMA, DataFile, @@ -471,6 +470,8 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) except ModuleNotFoundError as e: raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e + from pyiceberg.io.pyarrow import _check_schema_compatible, _dataframe_to_data_files + if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") @@ -528,6 +529,8 @@ def overwrite( except ModuleNotFoundError as e: raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e + from pyiceberg.io.pyarrow import _check_schema_compatible, _dataframe_to_data_files + if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") @@ -566,6 +569,8 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti delete_filter: A boolean expression to delete rows from a table snapshot_properties: Custom properties to be added to the snapshot summary """ + from pyiceberg.io.pyarrow import _dataframe_to_data_files, expression_to_pyarrow, project_table + if ( self.table_metadata.properties.get(TableProperties.DELETE_MODE, TableProperties.DELETE_MODE_DEFAULT) == TableProperties.DELETE_MODE_MERGE_ON_READ diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index b8fd6d0926..f485ac3ebd 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -501,14 +501,11 @@ def test_add_files_fails_on_schema_mismatch(spark: SparkSession, session_catalog ) expected = """Mismatch in fields: -┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ -┃ ┃ Table field ┃ Dataframe field ┃ -┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ -│ ✅ │ 1: foo: optional boolean │ 1: foo: optional boolean │ -| ✅ │ 2: bar: optional string │ 2: bar: optional string │ -│ ❌ │ 3: baz: optional int │ 3: baz: optional string │ -│ ✅ │ 4: qux: optional date │ 4: qux: optional date │ -└────┴──────────────────────────┴──────────────────────────┘ +┏━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ +┃ Field Name ┃ Category ┃ Table field ┃ Dataframe field ┃ +┡━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ +│ baz │ Type │ optional int │ optional string │ +└────────────┴──────────┴──────────────┴─────────────────┘ """ with pytest.raises(ValueError, match=expected): diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 41bc6fb5bf..a6f38d23f3 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -964,11 +964,14 @@ def test_sanitize_character_partitioned(catalog: Catalog) -> None: assert len(tbl.scan().to_arrow()) == 22 +@pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) -def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: - identifier = "default.table_append_subset_of_schema" +def test_table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: + identifier = "default.test_table_write_subset_of_schema" tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table_with_null]) arrow_table_without_some_columns = arrow_table_with_null.combine_chunks().drop(arrow_table_with_null.column_names[0]) + print(arrow_table_without_some_columns.schema) + print(arrow_table_with_null.schema) assert len(arrow_table_without_some_columns.columns) < len(arrow_table_with_null.columns) tbl.overwrite(arrow_table_without_some_columns) tbl.append(arrow_table_without_some_columns) @@ -976,6 +979,23 @@ def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null assert len(tbl.scan().to_arrow()) == len(arrow_table_without_some_columns) * 2 +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_table_write_out_of_order_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: + identifier = "default.test_table_write_out_of_order_schema" + # rotate the schema fields by 1 + fields = list(arrow_table_with_null.schema) + rotated_fields = fields[1:] + fields[:1] + rotated_schema = pa.schema(rotated_fields) + assert arrow_table_with_null.schema != rotated_schema + tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=rotated_schema) + + tbl.overwrite(arrow_table_with_null) + tbl.append(arrow_table_with_null) + # overwrite and then append should produce twice the data + assert len(tbl.scan().to_arrow()) == len(arrow_table_with_null) * 2 + + @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) def test_write_all_timestamp_precision( diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 37198b7edb..1d58ebe316 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -1732,13 +1732,11 @@ def test_schema_mismatch_type(table_schema_simple: Schema) -> None: )) expected = r"""Mismatch in fields: -┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ -┃ ┃ Table field ┃ Dataframe field ┃ -┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ -│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ -│ ❌ │ 2: bar: required int │ 2: bar: required decimal\(18, 6\) │ -│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ -└────┴──────────────────────────┴─────────────────────────────────┘ +┏━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Field Name ┃ Category ┃ Table field ┃ Dataframe field ┃ +┡━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ bar │ Type │ required int │ required decimal\(18, 6\) │ +└────────────┴──────────┴──────────────┴─────────────────────────┘ """ with pytest.raises(ValueError, match=expected): @@ -1753,13 +1751,11 @@ def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: )) expected = """Mismatch in fields: -┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ -┃ ┃ Table field ┃ Dataframe field ┃ -┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ -│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ -│ ❌ │ 2: bar: required int │ 2: bar: optional int │ -│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ -└────┴──────────────────────────┴──────────────────────────┘ +┏━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ +┃ Field Name ┃ Category ┃ Table field ┃ Dataframe field ┃ +┡━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ +│ bar │ Nullability │ required int │ optional int │ +└────────────┴─────────────┴──────────────┴─────────────────┘ """ with pytest.raises(ValueError, match=expected): @@ -1773,33 +1769,105 @@ def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: )) expected = """Mismatch in fields: -┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ -┃ ┃ Table field ┃ Dataframe field ┃ -┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ -│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ -│ ❌ │ 2: bar: required int │ Missing │ -│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ -└────┴──────────────────────────┴──────────────────────────┘ +┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ +┃ Field Name ┃ Category ┃ Table field ┃ Dataframe field ┃ +┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ +│ bar │ Missing Fields │ required int │ │ +└────────────┴────────────────┴──────────────┴─────────────────┘ """ with pytest.raises(ValueError, match=expected): _check_schema_compatible(table_schema_simple, other_schema) +def test_schema_compatible_missing_nullable_field_nested(table_schema_nested: Schema) -> None: + schema = table_schema_nested.as_arrow() + schema = schema.remove(6).insert( + 6, + pa.field( + "person", + pa.struct([ + pa.field("age", pa.int32(), nullable=False), + ]), + nullable=True, + ), + ) + try: + _check_schema_compatible(table_schema_nested, schema) + except Exception: + pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") + + +def test_schema_mismatch_missing_required_field_nested(table_schema_nested: Schema) -> None: + other_schema = table_schema_nested.as_arrow() + other_schema = other_schema.remove(6).insert( + 6, + pa.field( + "person", + pa.struct([ + pa.field("name", pa.string(), nullable=True), + ]), + nullable=True, + ), + ) + expected = """Mismatch in fields: +┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ +┃ Field Name ┃ Category ┃ Table field ┃ Dataframe field ┃ +┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ +│ person.age │ Missing Fields │ required int │ │ +└────────────┴────────────────┴──────────────┴─────────────────┘ +""" + + with pytest.raises(ValueError, match=expected): + _check_schema_compatible(table_schema_nested, other_schema) + + +def test_schema_compatible_nested(table_schema_nested: Schema) -> None: + try: + _check_schema_compatible(table_schema_nested, table_schema_nested.as_arrow()) + except Exception: + pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") + + def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None: other_schema = pa.schema(( pa.field("foo", pa.string(), nullable=True), - pa.field("bar", pa.int32(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), pa.field("baz", pa.bool_(), nullable=True), pa.field("new_field", pa.date32(), nullable=True), )) - expected = r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)." + expected = """Mismatch in fields: +┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ +┃ Field Name ┃ Category ┃ Table field ┃ Dataframe field ┃ +┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ +│ new_field │ Extra Fields │ │ optional date │ +└────────────┴──────────────┴─────────────┴─────────────────┘ +""" with pytest.raises(ValueError, match=expected): _check_schema_compatible(table_schema_simple, other_schema) +def test_schema_compatible(table_schema_simple: Schema) -> None: + try: + _check_schema_compatible(table_schema_simple, table_schema_simple.as_arrow()) + except Exception: + pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") + + +def test_schema_projection(table_schema_simple: Schema) -> None: + # remove optional `baz` field from `table_schema_simple` + other_schema = pa.schema(( + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + )) + try: + _check_schema_compatible(table_schema_simple, other_schema) + except Exception: + pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") + + def test_schema_downcast(table_schema_simple: Schema) -> None: # large_string type is compatible with string type other_schema = pa.schema(( @@ -1811,7 +1879,7 @@ def test_schema_downcast(table_schema_simple: Schema) -> None: try: _check_schema_compatible(table_schema_simple, other_schema) except Exception: - pytest.fail("Unexpected Exception raised when calling `_check_schema`") + pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") def test_partition_for_demo() -> None: From 0118f2a7cff987e895d20648cdf798ba59b5155a Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Fri, 12 Jul 2024 19:42:47 -0400 Subject: [PATCH 02/13] thanks @HonahX :) Co-authored-by: Honah J. --- tests/integration/test_writes/test_writes.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index a6f38d23f3..91590cc63f 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -970,8 +970,6 @@ def test_table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with identifier = "default.test_table_write_subset_of_schema" tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table_with_null]) arrow_table_without_some_columns = arrow_table_with_null.combine_chunks().drop(arrow_table_with_null.column_names[0]) - print(arrow_table_without_some_columns.schema) - print(arrow_table_with_null.schema) assert len(arrow_table_without_some_columns.columns) < len(arrow_table_with_null.columns) tbl.overwrite(arrow_table_without_some_columns) tbl.append(arrow_table_without_some_columns) From e75e0adbfae43754c9f4e31070997febe12bab11 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Sat, 13 Jul 2024 00:12:34 +0000 Subject: [PATCH 03/13] support promote --- pyiceberg/io/pyarrow.py | 15 ++++---- tests/integration/test_writes/test_writes.py | 38 +++++++++++++++++++- tests/io/test_pyarrow.py | 18 ++++++++++ 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index dcc2f48544..6096ec4844 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -2123,12 +2123,15 @@ def print_nullability(required: bool) -> str: ): continue elif lhs.field_type != rhs.field_type: - rich_table.add_row( - field_name, - "Type", - f"{print_nullability(lhs.required)} {str(lhs.field_type)}", - f"{print_nullability(rhs.required)} {str(rhs.field_type)}", - ) + try: + promote(rhs.field_type, lhs.field_type) + except ResolveError: + rich_table.add_row( + field_name, + "Type", + f"{print_nullability(lhs.required)} {str(lhs.field_type)}", + f"{print_nullability(rhs.required)} {str(rhs.field_type)}", + ) for field_name in extra_fields: rhs = task_schema.find_field(field_name) diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 91590cc63f..e55f3861f5 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -43,7 +43,7 @@ from pyiceberg.schema import Schema from pyiceberg.table import TableProperties from pyiceberg.transforms import IdentityTransform -from pyiceberg.types import IntegerType, NestedField +from pyiceberg.types import BooleanType, IntegerType, LongType, NestedField from utils import _create_table @@ -964,6 +964,42 @@ def test_sanitize_character_partitioned(catalog: Catalog) -> None: assert len(tbl.scan().to_arrow()) == 22 +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_table_write_schema_with_valid_upcast( + session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_table_write_with_valid_upcast" + table_schema = Schema( + NestedField(field_id=1, name="boolean", field_type=BooleanType(), required=True), + NestedField(field_id=2, name="integer", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="long", field_type=LongType(), required=True), + ) + other_schema = pa.schema(( + pa.field("boolean", pa.bool_(), nullable=False), + pa.field("integer", pa.int32(), nullable=False), + pa.field("long", pa.int32(), nullable=False), # IntegerType can be cast to LongType + )) + arrow_table = pa.Table.from_pydict( + { + "bool": [False, True], + "integer": [1, 9], + "long": [1, 9], + }, + schema=other_schema, + ) + tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table], schema=table_schema) + tbl.append(arrow_table) + # table's long field should cast to long on read + assert tbl.scan().to_arrow() == arrow_table.cast( + pa.schema(( + pa.field("boolean", pa.bool_(), nullable=False), + pa.field("integer", pa.int32(), nullable=False), + pa.field("long", pa.int64(), nullable=False), # IntegerType can be cast to LongType + )) + ) + + @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) def test_table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 1d58ebe316..e5c6fbcbd8 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -1856,6 +1856,24 @@ def test_schema_compatible(table_schema_simple: Schema) -> None: pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") +def test_schema_compatible_upcast() -> None: + table_schema = Schema( + NestedField(field_id=1, name="boolean", field_type=BooleanType(), required=True), + NestedField(field_id=2, name="integer", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="long", field_type=LongType(), required=True), + ) + other_schema = pa.schema(( + pa.field("boolean", pa.bool_(), nullable=False), + pa.field("integer", pa.int32(), nullable=False), + pa.field("long", pa.int32(), nullable=False), # IntegerType can be cast to LongType + )) + + try: + _check_schema_compatible(table_schema, other_schema) + except Exception: + pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") + + def test_schema_projection(table_schema_simple: Schema) -> None: # remove optional `baz` field from `table_schema_simple` other_schema = pa.schema(( From b6e34103924d6717b95f6c6d97bb1f327773308a Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Sat, 13 Jul 2024 04:20:30 +0000 Subject: [PATCH 04/13] revert promote --- pyiceberg/io/pyarrow.py | 16 ++++----- tests/integration/test_writes/test_writes.py | 38 +------------------- tests/io/test_pyarrow.py | 18 ---------- 3 files changed, 7 insertions(+), 65 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 6096ec4844..b1c99b7fbb 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -2003,7 +2003,6 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT def write_parquet(task: WriteTask) -> DataFile: table_schema = task.schema - # if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly # otherwise use the original schema if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema: @@ -2123,15 +2122,12 @@ def print_nullability(required: bool) -> str: ): continue elif lhs.field_type != rhs.field_type: - try: - promote(rhs.field_type, lhs.field_type) - except ResolveError: - rich_table.add_row( - field_name, - "Type", - f"{print_nullability(lhs.required)} {str(lhs.field_type)}", - f"{print_nullability(rhs.required)} {str(rhs.field_type)}", - ) + rich_table.add_row( + field_name, + "Type", + f"{print_nullability(lhs.required)} {str(lhs.field_type)}", + f"{print_nullability(rhs.required)} {str(rhs.field_type)}", + ) for field_name in extra_fields: rhs = task_schema.find_field(field_name) diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index e55f3861f5..91590cc63f 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -43,7 +43,7 @@ from pyiceberg.schema import Schema from pyiceberg.table import TableProperties from pyiceberg.transforms import IdentityTransform -from pyiceberg.types import BooleanType, IntegerType, LongType, NestedField +from pyiceberg.types import IntegerType, NestedField from utils import _create_table @@ -964,42 +964,6 @@ def test_sanitize_character_partitioned(catalog: Catalog) -> None: assert len(tbl.scan().to_arrow()) == 22 -@pytest.mark.integration -@pytest.mark.parametrize("format_version", [1, 2]) -def test_table_write_schema_with_valid_upcast( - session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int -) -> None: - identifier = "default.test_table_write_with_valid_upcast" - table_schema = Schema( - NestedField(field_id=1, name="boolean", field_type=BooleanType(), required=True), - NestedField(field_id=2, name="integer", field_type=IntegerType(), required=True), - NestedField(field_id=3, name="long", field_type=LongType(), required=True), - ) - other_schema = pa.schema(( - pa.field("boolean", pa.bool_(), nullable=False), - pa.field("integer", pa.int32(), nullable=False), - pa.field("long", pa.int32(), nullable=False), # IntegerType can be cast to LongType - )) - arrow_table = pa.Table.from_pydict( - { - "bool": [False, True], - "integer": [1, 9], - "long": [1, 9], - }, - schema=other_schema, - ) - tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table], schema=table_schema) - tbl.append(arrow_table) - # table's long field should cast to long on read - assert tbl.scan().to_arrow() == arrow_table.cast( - pa.schema(( - pa.field("boolean", pa.bool_(), nullable=False), - pa.field("integer", pa.int32(), nullable=False), - pa.field("long", pa.int64(), nullable=False), # IntegerType can be cast to LongType - )) - ) - - @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) def test_table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index e5c6fbcbd8..1d58ebe316 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -1856,24 +1856,6 @@ def test_schema_compatible(table_schema_simple: Schema) -> None: pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") -def test_schema_compatible_upcast() -> None: - table_schema = Schema( - NestedField(field_id=1, name="boolean", field_type=BooleanType(), required=True), - NestedField(field_id=2, name="integer", field_type=IntegerType(), required=True), - NestedField(field_id=3, name="long", field_type=LongType(), required=True), - ) - other_schema = pa.schema(( - pa.field("boolean", pa.bool_(), nullable=False), - pa.field("integer", pa.int32(), nullable=False), - pa.field("long", pa.int32(), nullable=False), # IntegerType can be cast to LongType - )) - - try: - _check_schema_compatible(table_schema, other_schema) - except Exception: - pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") - - def test_schema_projection(table_schema_simple: Schema) -> None: # remove optional `baz` field from `table_schema_simple` other_schema = pa.schema(( From 6b774c6278e826a76fe8b6bccecd718db4de9baf Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Sun, 14 Jul 2024 18:09:39 +0000 Subject: [PATCH 05/13] use a visitor --- pyiceberg/io/pyarrow.py | 93 ++++-------- pyiceberg/schema.py | 142 +++++++++++++++++++ pyiceberg/table/__init__.py | 12 +- tests/integration/test_writes/test_writes.py | 38 ++++- tests/io/test_pyarrow.py | 128 +++++++++++------ 5 files changed, 294 insertions(+), 119 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index b1c99b7fbb..7e60e67f78 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -120,7 +120,7 @@ Schema, SchemaVisitorPerPrimitiveType, SchemaWithPartnerVisitor, - assign_fresh_schema_ids, + _check_schema_compatible, pre_order_visit, promote, prune_columns, @@ -2002,7 +2002,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT ) def write_parquet(task: WriteTask) -> DataFile: - table_schema = task.schema + table_schema = table_metadata.schema() # if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly # otherwise use the original schema if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema: @@ -2014,7 +2014,7 @@ def write_parquet(task: WriteTask) -> DataFile: batches = [ _to_requested_schema( requested_schema=file_schema, - file_schema=table_schema, + file_schema=task.schema, batch=batch, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, include_field_ids=True, @@ -2073,74 +2073,30 @@ def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[List[ return bin_packed_record_batches -def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False) -> None: +def _check_pyarrow_schema_compatible( + requested_schema: Schema, provided_schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False +) -> None: """ - Check if the `table_schema` is compatible with `other_schema`. + Check if the `requested_schema` is compatible with `provided_schema`. Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type. Raises: ValueError: If the schemas are not compatible. """ - task_schema = assign_fresh_schema_ids( - _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us) - ) - - extra_fields = task_schema.field_names - table_schema.field_names - missing_fields = table_schema.field_names - task_schema.field_names - fields_in_both = task_schema.field_names.intersection(table_schema.field_names) - - from rich.console import Console - from rich.table import Table as RichTable - - console = Console(record=True) - - rich_table = RichTable(show_header=True, header_style="bold") - rich_table.add_column("Field Name") - rich_table.add_column("Category") - rich_table.add_column("Table field") - rich_table.add_column("Dataframe field") - - def print_nullability(required: bool) -> str: - return "required" if required else "optional" - - for field_name in fields_in_both: - lhs = table_schema.find_field(field_name) - rhs = task_schema.find_field(field_name) - # Check nullability - if lhs.required != rhs.required: - rich_table.add_row( - field_name, - "Nullability", - f"{print_nullability(lhs.required)} {str(lhs.field_type)}", - f"{print_nullability(rhs.required)} {str(rhs.field_type)}", - ) - # Check if type is consistent - if any( - (isinstance(lhs.field_type, container_type) and isinstance(rhs.field_type, container_type)) - for container_type in {StructType, MapType, ListType} - ): - continue - elif lhs.field_type != rhs.field_type: - rich_table.add_row( - field_name, - "Type", - f"{print_nullability(lhs.required)} {str(lhs.field_type)}", - f"{print_nullability(rhs.required)} {str(rhs.field_type)}", - ) - - for field_name in extra_fields: - rhs = task_schema.find_field(field_name) - rich_table.add_row(field_name, "Extra Fields", "", f"{print_nullability(rhs.required)} {str(rhs.field_type)}") - - for field_name in missing_fields: - lhs = table_schema.find_field(field_name) - if lhs.required: - rich_table.add_row(field_name, "Missing Fields", f"{print_nullability(lhs.required)} {str(lhs.field_type)}", "") + name_mapping = requested_schema.name_mapping + try: + provided_schema = pyarrow_to_schema( + provided_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us + ) + except ValueError as e: + provided_schema = _pyarrow_to_schema_without_ids(provided_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us) + additional_names = provided_schema.field_names - requested_schema.field_names + raise ValueError( + f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." + ) from e - if rich_table.row_count: - console.print(rich_table) - raise ValueError(f"Mismatch in fields:\n{console.export_text()}") + _check_schema_compatible(requested_schema, provided_schema) def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_paths: Iterator[str]) -> Iterator[DataFile]: @@ -2154,7 +2110,7 @@ def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_ f"Cannot add file {file_path} because it has field IDs. `add_files` only supports addition of files without field_ids" ) schema = table_metadata.schema() - _check_schema_compatible(schema, parquet_metadata.schema.to_arrow_schema()) + _check_pyarrow_schema_compatible(schema, parquet_metadata.schema.to_arrow_schema()) statistics = data_file_statistics_from_parquet_metadata( parquet_metadata=parquet_metadata, @@ -2235,7 +2191,7 @@ def _dataframe_to_data_files( Returns: An iterable that supplies datafiles that represent the table. """ - from pyiceberg.table import PropertyUtil, TableProperties, WriteTask + from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, PropertyUtil, TableProperties, WriteTask counter = counter or itertools.count(0) write_uuid = write_uuid or uuid.uuid4() @@ -2244,13 +2200,16 @@ def _dataframe_to_data_files( property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES, default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT, ) + name_mapping = table_metadata.schema().name_mapping + downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False + task_schema = pyarrow_to_schema(df.schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us) if table_metadata.spec().is_unpartitioned(): yield from write_file( io=io, table_metadata=table_metadata, tasks=iter([ - WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=table_metadata.schema()) + WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=task_schema) for batches in bin_pack_arrow_table(df, target_file_size) ]), ) @@ -2265,7 +2224,7 @@ def _dataframe_to_data_files( task_id=next(counter), record_batches=batches, partition_key=partition.partition_key, - schema=table_metadata.schema(), + schema=task_schema, ) for partition in partitions for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size) diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index dc5725295d..dd0817d682 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -1621,3 +1621,145 @@ def _(file_type: FixedType, read_type: IcebergType) -> IcebergType: return read_type else: raise ResolveError(f"Cannot promote {file_type} to {read_type}") + + +def _check_schema_compatible(requested_schema: Schema, provided_schema: Schema) -> None: + """ + Check if the `provided_schema` is compatible with `requested_schema`. + + Both Schemas must have valid IDs and share the same ID for the same field names. + + Two schemas are considered compatible when: + 1. All `required` fields in `requested_schema` are present and are also `required` in the `provided_schema` + 2. Field Types are consistent for fields that are present in both schemas. I.e. the field type + in the `provided_schema` can be promoted to the field type of the same field ID in `requested_schema` + + Raises: + ValueError: If the schemas are not compatible. + """ + visit(requested_schema, _SchemaCompatibilityVisitor(provided_schema)) + + # from rich.console import Console + # from rich.table import Table as RichTable + + # console = Console(record=True) + + # rich_table = RichTable(show_header=True, header_style="bold") + # rich_table.add_column("") + # rich_table.add_column("Table field") + # rich_table.add_column("Dataframe field") + + # is_compatible = True + + # for field_id in requested_schema.field_ids: + # lhs = requested_schema.find_field(field_id) + # try: + # rhs = provided_schema.find_field(field_id) + # except ValueError: + # if lhs.required: + # rich_table.add_row("❌", str(lhs), "Missing") + # is_compatible = False + # else: + # rich_table.add_row("✅", str(lhs), "Missing") + # continue + + # if lhs.required and not rhs.required: + # rich_table.add_row("❌", str(lhs), "Missing") + # is_compatible = False + + # if lhs.field_type == rhs.field_type: + # rich_table.add_row("✅", str(lhs), str(rhs)) + # continue + # elif any( + # (isinstance(lhs.field_type, container_type) and isinstance(rhs.field_type, container_type)) + # for container_type in {StructType, MapType, ListType} + # ): + # rich_table.add_row("✅", str(lhs), str(rhs)) + # continue + # else: + # try: + # promote(rhs.field_type, lhs.field_type) + # rich_table.add_row("✅", str(lhs), str(rhs)) + # except ResolveError: + # rich_table.add_row("❌", str(lhs), str(rhs)) + # is_compatible = False + + # if not is_compatible: + # console.print(rich_table) + # raise ValueError(f"Mismatch in fields:\n{console.export_text()}") + + +class _SchemaCompatibilityVisitor(SchemaVisitor[bool]): + provided_schema: Schema + + def __init__(self, provided_schema: Schema): + from rich.console import Console + from rich.table import Table as RichTable + + self.provided_schema = provided_schema + self.rich_table = RichTable(show_header=True, header_style="bold") + self.rich_table.add_column("") + self.rich_table.add_column("Table field") + self.rich_table.add_column("Dataframe field") + self.console = Console(record=True) + + def _is_field_compatible(self, lhs: NestedField) -> bool: + # Check required field exists as required field first + try: + rhs = self.provided_schema.find_field(lhs.field_id) + except ValueError: + if lhs.required: + self.rich_table.add_row("❌", str(lhs), "Missing") + return False + else: + self.rich_table.add_row("✅", str(lhs), "Missing") + return True + + if lhs.required and not rhs.required: + self.rich_table.add_row("❌", str(lhs), "Missing") + return False + + # Check type compatibility + if lhs.field_type == rhs.field_type: + self.rich_table.add_row("✅", str(lhs), str(rhs)) + return True + elif any( + (isinstance(lhs.field_type, container_type) and isinstance(rhs.field_type, container_type)) + for container_type in {StructType, MapType, ListType} + ): + self.rich_table.add_row("✅", str(lhs), str(rhs)) + return True + else: + try: + promote(rhs.field_type, lhs.field_type) + self.rich_table.add_row("✅", str(lhs), str(rhs)) + return True + except ResolveError: + self.rich_table.add_row("❌", str(lhs), str(rhs)) + return False + + def schema(self, schema: Schema, struct_result: bool) -> bool: + if not struct_result: + self.console.print(self.rich_table) + raise ValueError(f"Mismatch in fields:\n{self.console.export_text()}") + return struct_result + + def struct(self, struct: StructType, field_results: List[bool]) -> bool: + return all(field_results) + + def field(self, field: NestedField, field_result: bool) -> bool: + return all([self._is_field_compatible(field), field_result]) + + def list(self, list_type: ListType, element_result: bool) -> bool: + return element_result and self._is_field_compatible(list_type.element_field) + + def map(self, map_type: MapType, key_result: bool, value_result: bool) -> bool: + return all([ + self._is_field_compatible(map_type.key_field), + self._is_field_compatible(map_type.value_field), + key_result, + value_result, + ]) + + def primitive(self, primitive: PrimitiveType) -> bool: + return True diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index ef88224a18..0b211e673d 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -470,7 +470,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) except ModuleNotFoundError as e: raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e - from pyiceberg.io.pyarrow import _check_schema_compatible, _dataframe_to_data_files + from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _dataframe_to_data_files if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") @@ -482,8 +482,8 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}." ) downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False - _check_schema_compatible( - self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us + _check_pyarrow_schema_compatible( + self._table.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) manifest_merge_enabled = PropertyUtil.property_as_bool( @@ -529,7 +529,7 @@ def overwrite( except ModuleNotFoundError as e: raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e - from pyiceberg.io.pyarrow import _check_schema_compatible, _dataframe_to_data_files + from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _dataframe_to_data_files if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") @@ -541,8 +541,8 @@ def overwrite( f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}." ) downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False - _check_schema_compatible( - self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us + _check_pyarrow_schema_compatible( + self._table.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties) diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 91590cc63f..3cd1876219 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -43,7 +43,7 @@ from pyiceberg.schema import Schema from pyiceberg.table import TableProperties from pyiceberg.transforms import IdentityTransform -from pyiceberg.types import IntegerType, NestedField +from pyiceberg.types import BooleanType, IntegerType, LongType, NestedField from utils import _create_table @@ -994,6 +994,42 @@ def test_table_write_out_of_order_schema(session_catalog: Catalog, arrow_table_w assert len(tbl.scan().to_arrow()) == len(arrow_table_with_null) * 2 +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_table_write_schema_with_valid_nullability_diff( + session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_table_write_with_valid_nullability_diff" + table_schema = Schema( + NestedField(field_id=1, name="boolean", field_type=BooleanType(), required=False), + NestedField(field_id=2, name="integer", field_type=IntegerType(), required=False), + NestedField(field_id=3, name="long", field_type=LongType(), required=False), + ) + other_schema = pa.schema(( + pa.field("boolean", pa.bool_(), nullable=True), + pa.field("integer", pa.int32(), nullable=True), + pa.field("long", pa.int64(), nullable=False), # can support writing required pyarrow field to optional Iceberg field + )) + arrow_table = pa.Table.from_pydict( + { + "boolean": [False, True], + "integer": [1, 9], + "long": [1, 9], + }, + schema=other_schema, + ) + tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table], schema=table_schema) + tbl.overwrite(arrow_table) + # table's long field should cast to long on read + assert tbl.scan().to_arrow() == arrow_table.cast( + pa.schema(( + pa.field("boolean", pa.bool_(), nullable=True), + pa.field("integer", pa.int32(), nullable=True), + pa.field("long", pa.int64(), nullable=True), + )) + ) + + @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) def test_write_all_timestamp_precision( diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 1d58ebe316..645804b4cd 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -60,7 +60,7 @@ PyArrowFile, PyArrowFileIO, StatsAggregator, - _check_schema_compatible, + _check_pyarrow_schema_compatible, _ConvertToArrowSchema, _determine_partitions, _primitive_to_physical, @@ -1732,15 +1732,17 @@ def test_schema_mismatch_type(table_schema_simple: Schema) -> None: )) expected = r"""Mismatch in fields: -┏━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓ -┃ Field Name ┃ Category ┃ Table field ┃ Dataframe field ┃ -┡━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩ -│ bar │ Type │ required int │ required decimal\(18, 6\) │ -└────────────┴──────────┴──────────────┴─────────────────────────┘ +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ +│ ❌ │ 2: bar: required int │ 2: bar: required decimal\(18, 6\) │ +│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ +└────┴──────────────────────────┴─────────────────────────────────┘ """ with pytest.raises(ValueError, match=expected): - _check_schema_compatible(table_schema_simple, other_schema) + _check_pyarrow_schema_compatible(table_schema_simple, other_schema) def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: @@ -1751,15 +1753,30 @@ def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: )) expected = """Mismatch in fields: -┏━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ -┃ Field Name ┃ Category ┃ Table field ┃ Dataframe field ┃ -┡━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ -│ bar │ Nullability │ required int │ optional int │ -└────────────┴─────────────┴──────────────┴─────────────────┘ +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ +│ ❌ │ 2: bar: required int │ Missing │ +│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ +└────┴──────────────────────────┴──────────────────────────┘ """ with pytest.raises(ValueError, match=expected): - _check_schema_compatible(table_schema_simple, other_schema) + _check_pyarrow_schema_compatible(table_schema_simple, other_schema) + + +def test_schema_compatible_nullability_diff(table_schema_simple: Schema) -> None: + other_schema = pa.schema(( + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + pa.field("baz", pa.bool_(), nullable=False), + )) + + try: + _check_pyarrow_schema_compatible(table_schema_simple, other_schema) + except Exception: + pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: @@ -1769,15 +1786,17 @@ def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: )) expected = """Mismatch in fields: -┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ -┃ Field Name ┃ Category ┃ Table field ┃ Dataframe field ┃ -┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ -│ bar │ Missing Fields │ required int │ │ -└────────────┴────────────────┴──────────────┴─────────────────┘ +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ +│ ❌ │ 2: bar: required int │ Missing │ +│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ +└────┴──────────────────────────┴──────────────────────────┘ """ with pytest.raises(ValueError, match=expected): - _check_schema_compatible(table_schema_simple, other_schema) + _check_pyarrow_schema_compatible(table_schema_simple, other_schema) def test_schema_compatible_missing_nullable_field_nested(table_schema_nested: Schema) -> None: @@ -1793,9 +1812,9 @@ def test_schema_compatible_missing_nullable_field_nested(table_schema_nested: Sc ), ) try: - _check_schema_compatible(table_schema_nested, schema) + _check_pyarrow_schema_compatible(table_schema_nested, schema) except Exception: - pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") + pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") def test_schema_mismatch_missing_required_field_nested(table_schema_nested: Schema) -> None: @@ -1811,22 +1830,47 @@ def test_schema_mismatch_missing_required_field_nested(table_schema_nested: Sche ), ) expected = """Mismatch in fields: -┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ -┃ Field Name ┃ Category ┃ Table field ┃ Dataframe field ┃ -┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ -│ person.age │ Missing Fields │ required int │ │ -└────────────┴────────────────┴──────────────┴─────────────────┘ +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ +│ ✅ │ 2: bar: required int │ 2: bar: required int │ +│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ +│ ✅ │ 5: element: required string │ 5: element: required string │ +│ ✅ │ 4: qux: required list │ 4: qux: required list │ +│ ✅ │ 9: key: required string │ 9: key: required string │ +│ ✅ │ 10: value: required int │ 10: value: required int │ +│ ✅ │ 7: key: required string │ 7: key: required string │ +│ ✅ │ 8: value: required map │ int> │ +│ ✅ │ 6: quux: required map> │ map> │ +│ ✅ │ 13: latitude: optional float │ 13: latitude: optional float │ +│ ✅ │ 14: longitude: optional float │ 14: longitude: optional float │ +│ ✅ │ 12: element: required struct<13: │ 12: element: required struct<13: │ +│ │ latitude: optional float, 14: │ latitude: optional float, 14: │ +│ │ longitude: optional float> │ longitude: optional float> │ +│ ✅ │ 11: location: required │ 11: location: required │ +│ │ list> │ float>> │ +│ ✅ │ 16: name: optional string │ 16: name: optional string │ +│ ❌ │ 17: age: required int │ Missing │ +│ ✅ │ 15: person: optional struct<16: │ 15: person: optional struct<16: │ +│ │ name: optional string, 17: age: │ name: optional string> │ +│ │ required int> │ │ +└────┴────────────────────────────────────┴────────────────────────────────────┘ """ with pytest.raises(ValueError, match=expected): - _check_schema_compatible(table_schema_nested, other_schema) + _check_pyarrow_schema_compatible(table_schema_nested, other_schema) def test_schema_compatible_nested(table_schema_nested: Schema) -> None: try: - _check_schema_compatible(table_schema_nested, table_schema_nested.as_arrow()) + _check_pyarrow_schema_compatible(table_schema_nested, table_schema_nested.as_arrow()) except Exception: - pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") + pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None: @@ -1837,23 +1881,17 @@ def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None: pa.field("new_field", pa.date32(), nullable=True), )) - expected = """Mismatch in fields: -┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ -┃ Field Name ┃ Category ┃ Table field ┃ Dataframe field ┃ -┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ -│ new_field │ Extra Fields │ │ optional date │ -└────────────┴──────────────┴─────────────┴─────────────────┘ -""" - - with pytest.raises(ValueError, match=expected): - _check_schema_compatible(table_schema_simple, other_schema) + with pytest.raises( + ValueError, match=r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)." + ): + _check_pyarrow_schema_compatible(table_schema_simple, other_schema) def test_schema_compatible(table_schema_simple: Schema) -> None: try: - _check_schema_compatible(table_schema_simple, table_schema_simple.as_arrow()) + _check_pyarrow_schema_compatible(table_schema_simple, table_schema_simple.as_arrow()) except Exception: - pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") + pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") def test_schema_projection(table_schema_simple: Schema) -> None: @@ -1863,9 +1901,9 @@ def test_schema_projection(table_schema_simple: Schema) -> None: pa.field("bar", pa.int32(), nullable=False), )) try: - _check_schema_compatible(table_schema_simple, other_schema) + _check_pyarrow_schema_compatible(table_schema_simple, other_schema) except Exception: - pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") + pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") def test_schema_downcast(table_schema_simple: Schema) -> None: @@ -1877,9 +1915,9 @@ def test_schema_downcast(table_schema_simple: Schema) -> None: )) try: - _check_schema_compatible(table_schema_simple, other_schema) + _check_pyarrow_schema_compatible(table_schema_simple, other_schema) except Exception: - pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") + pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") def test_partition_for_demo() -> None: From e26eb23c258e83780471dd681bb509d975136af4 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Mon, 15 Jul 2024 01:33:17 +0000 Subject: [PATCH 06/13] support promotion on write --- pyiceberg/io/pyarrow.py | 6 +- pyiceberg/schema.py | 97 +++++--------------- tests/integration/test_add_files.py | 13 ++- tests/integration/test_writes/test_writes.py | 85 ++++++++++++++--- tests/io/test_pyarrow.py | 26 +++--- 5 files changed, 122 insertions(+), 105 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 7e60e67f78..7664da3fe9 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1408,7 +1408,7 @@ def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array: # This can be removed once this has been fixed: # https://github.com/apache/arrow/issues/38809 list_array = pa.LargeListArray.from_arrays(list_array.offsets, value_array) - + value_array = self._cast_if_needed(list_type.element_field, value_array) arrow_field = pa.large_list(self._construct_field(list_type.element_field, value_array.type)) return list_array.cast(arrow_field) else: @@ -1418,6 +1418,8 @@ def map( self, map_type: MapType, map_array: Optional[pa.Array], key_result: Optional[pa.Array], value_result: Optional[pa.Array] ) -> Optional[pa.Array]: if isinstance(map_array, pa.MapArray) and key_result is not None and value_result is not None: + key_result = self._cast_if_needed(map_type.key_field, key_result) + value_result = self._cast_if_needed(map_type.value_field, value_result) arrow_field = pa.map_( self._construct_field(map_type.key_field, key_result.type), self._construct_field(map_type.value_field, value_result.type), @@ -2091,7 +2093,7 @@ def _check_pyarrow_schema_compatible( ) except ValueError as e: provided_schema = _pyarrow_to_schema_without_ids(provided_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us) - additional_names = provided_schema.field_names - requested_schema.field_names + additional_names = set(provided_schema._name_to_id.keys()) - set(requested_schema._name_to_id.keys()) raise ValueError( f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." ) from e diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index dd0817d682..cfe3fe3a7b 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -324,11 +324,6 @@ def field_ids(self) -> Set[int]: """Return the IDs of the current schema.""" return set(self._name_to_id.values()) - @property - def field_names(self) -> Set[str]: - """Return the Names of the current schema.""" - return set(self._name_to_id.keys()) - def _validate_identifier_field(self, field_id: int) -> None: """Validate that the field with the given ID is a valid identifier field. @@ -1637,59 +1632,10 @@ def _check_schema_compatible(requested_schema: Schema, provided_schema: Schema) Raises: ValueError: If the schemas are not compatible. """ - visit(requested_schema, _SchemaCompatibilityVisitor(provided_schema)) - - # from rich.console import Console - # from rich.table import Table as RichTable - - # console = Console(record=True) - - # rich_table = RichTable(show_header=True, header_style="bold") - # rich_table.add_column("") - # rich_table.add_column("Table field") - # rich_table.add_column("Dataframe field") - - # is_compatible = True - - # for field_id in requested_schema.field_ids: - # lhs = requested_schema.find_field(field_id) - # try: - # rhs = provided_schema.find_field(field_id) - # except ValueError: - # if lhs.required: - # rich_table.add_row("❌", str(lhs), "Missing") - # is_compatible = False - # else: - # rich_table.add_row("✅", str(lhs), "Missing") - # continue - - # if lhs.required and not rhs.required: - # rich_table.add_row("❌", str(lhs), "Missing") - # is_compatible = False - - # if lhs.field_type == rhs.field_type: - # rich_table.add_row("✅", str(lhs), str(rhs)) - # continue - # elif any( - # (isinstance(lhs.field_type, container_type) and isinstance(rhs.field_type, container_type)) - # for container_type in {StructType, MapType, ListType} - # ): - # rich_table.add_row("✅", str(lhs), str(rhs)) - # continue - # else: - # try: - # promote(rhs.field_type, lhs.field_type) - # rich_table.add_row("✅", str(lhs), str(rhs)) - # except ResolveError: - # rich_table.add_row("❌", str(lhs), str(rhs)) - # is_compatible = False - - # if not is_compatible: - # console.print(rich_table) - # raise ValueError(f"Mismatch in fields:\n{console.export_text()}") - - -class _SchemaCompatibilityVisitor(SchemaVisitor[bool]): + pre_order_visit(requested_schema, _SchemaCompatibilityVisitor(provided_schema)) + + +class _SchemaCompatibilityVisitor(PreOrderSchemaVisitor[bool]): provided_schema: Schema def __init__(self, provided_schema: Schema): @@ -1704,7 +1650,9 @@ def __init__(self, provided_schema: Schema): self.console = Console(record=True) def _is_field_compatible(self, lhs: NestedField) -> bool: - # Check required field exists as required field first + # Validate nullability first. + # An optional field can be missing in the provided schema + # But a required field must exist as a required field try: rhs = self.provided_schema.find_field(lhs.field_id) except ValueError: @@ -1716,13 +1664,15 @@ def _is_field_compatible(self, lhs: NestedField) -> bool: return True if lhs.required and not rhs.required: - self.rich_table.add_row("❌", str(lhs), "Missing") + self.rich_table.add_row("❌", str(lhs), str(rhs)) return False # Check type compatibility if lhs.field_type == rhs.field_type: self.rich_table.add_row("✅", str(lhs), str(rhs)) return True + # We only check that the parent node is also of the same type. + # We check the type of the child nodes when we traverse them later. elif any( (isinstance(lhs.field_type, container_type) and isinstance(rhs.field_type, container_type)) for container_type in {StructType, MapType, ListType} @@ -1731,6 +1681,8 @@ def _is_field_compatible(self, lhs: NestedField) -> bool: return True else: try: + # If type can be promoted to the requested schema + # it is considered compatible promote(rhs.field_type, lhs.field_type) self.rich_table.add_row("✅", str(lhs), str(rhs)) return True @@ -1738,27 +1690,28 @@ def _is_field_compatible(self, lhs: NestedField) -> bool: self.rich_table.add_row("❌", str(lhs), str(rhs)) return False - def schema(self, schema: Schema, struct_result: bool) -> bool: - if not struct_result: + def schema(self, schema: Schema, struct_result: Callable[[], bool]) -> bool: + if not (result := struct_result()): self.console.print(self.rich_table) raise ValueError(f"Mismatch in fields:\n{self.console.export_text()}") - return struct_result + return result - def struct(self, struct: StructType, field_results: List[bool]) -> bool: - return all(field_results) + def struct(self, struct: StructType, field_results: List[Callable[[], bool]]) -> bool: + results = [result() for result in field_results] + return all(results) - def field(self, field: NestedField, field_result: bool) -> bool: - return all([self._is_field_compatible(field), field_result]) + def field(self, field: NestedField, field_result: Callable[[], bool]) -> bool: + return self._is_field_compatible(field) and field_result() - def list(self, list_type: ListType, element_result: bool) -> bool: - return element_result and self._is_field_compatible(list_type.element_field) + def list(self, list_type: ListType, element_result: Callable[[], bool]) -> bool: + return self._is_field_compatible(list_type.element_field) and element_result() - def map(self, map_type: MapType, key_result: bool, value_result: bool) -> bool: + def map(self, map_type: MapType, key_result: Callable[[], bool], value_result: Callable[[], bool]) -> bool: return all([ self._is_field_compatible(map_type.key_field), self._is_field_compatible(map_type.value_field), - key_result, - value_result, + key_result(), + value_result(), ]) def primitive(self, primitive: PrimitiveType) -> bool: diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index f485ac3ebd..421cc11fec 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -501,11 +501,14 @@ def test_add_files_fails_on_schema_mismatch(spark: SparkSession, session_catalog ) expected = """Mismatch in fields: -┏━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ -┃ Field Name ┃ Category ┃ Table field ┃ Dataframe field ┃ -┡━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ -│ baz │ Type │ optional int │ optional string │ -└────────────┴──────────┴──────────────┴─────────────────┘ +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ✅ │ 1: foo: optional boolean │ 1: foo: optional boolean │ +│ ✅ │ 2: bar: optional string │ 2: bar: optional string │ +│ ❌ │ 3: baz: optional int │ 3: baz: optional string │ +│ ✅ │ 4: qux: optional date │ 4: qux: optional date │ +└────┴──────────────────────────┴──────────────────────────┘ """ with pytest.raises(ValueError, match=expected): diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 3cd1876219..6ab7bfc24b 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -43,7 +43,7 @@ from pyiceberg.schema import Schema from pyiceberg.table import TableProperties from pyiceberg.transforms import IdentityTransform -from pyiceberg.types import BooleanType, IntegerType, LongType, NestedField +from pyiceberg.types import IntegerType, ListType, LongType, MapType, NestedField, StringType from utils import _create_table @@ -997,37 +997,96 @@ def test_table_write_out_of_order_schema(session_catalog: Catalog, arrow_table_w @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) def test_table_write_schema_with_valid_nullability_diff( - session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int ) -> None: identifier = "default.test_table_write_with_valid_nullability_diff" table_schema = Schema( - NestedField(field_id=1, name="boolean", field_type=BooleanType(), required=False), - NestedField(field_id=2, name="integer", field_type=IntegerType(), required=False), - NestedField(field_id=3, name="long", field_type=LongType(), required=False), + NestedField(field_id=1, name="long", field_type=LongType(), required=False), ) other_schema = pa.schema(( - pa.field("boolean", pa.bool_(), nullable=True), - pa.field("integer", pa.int32(), nullable=True), pa.field("long", pa.int64(), nullable=False), # can support writing required pyarrow field to optional Iceberg field )) arrow_table = pa.Table.from_pydict( { - "boolean": [False, True], - "integer": [1, 9], "long": [1, 9], }, schema=other_schema, ) tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table], schema=table_schema) - tbl.overwrite(arrow_table) + # table's long field should cast to be optional on read + written_arrow_table = tbl.scan().to_arrow() + assert written_arrow_table == arrow_table.cast(pa.schema((pa.field("long", pa.int64(), nullable=True),))) + lhs = spark.table(f"{identifier}").toPandas() + rhs = written_arrow_table.to_pandas() + + for column in written_arrow_table.column_names: + for left, right in zip(lhs[column].to_list(), rhs[column].to_list()): + assert left == right + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_table_write_schema_with_valid_upcast( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_table_write_with_valid_upcast" + table_schema = Schema( + NestedField(field_id=1, name="long", field_type=LongType(), required=False), + NestedField( + field_id=2, + name="list", + field_type=ListType(element_id=4, element_type=LongType(), element_required=False), + required=True, + ), + NestedField( + field_id=3, + name="map", + field_type=MapType( + key_id=5, + key_type=StringType(), + value_id=6, + value_type=LongType(), + value_required=False, + ), + required=True, + ), + ) + other_schema = pa.schema(( + pa.field("long", pa.int32(), nullable=True), # can support upcasting integer to long + pa.field("list", pa.list_(pa.int32()), nullable=False), # can support upcasting integer to long + pa.field("map", pa.map_(pa.string(), pa.int32()), nullable=False), # can support upcasting integer to long + )) + arrow_table = pa.Table.from_pydict( + { + "long": [1, 9], + "list": [[1, 1], [2, 2]], + "map": [{"a": 1}, {"b": 2}], + }, + schema=other_schema, + ) + tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table], schema=table_schema) # table's long field should cast to long on read - assert tbl.scan().to_arrow() == arrow_table.cast( + written_arrow_table = tbl.scan().to_arrow() + assert written_arrow_table == arrow_table.cast( pa.schema(( - pa.field("boolean", pa.bool_(), nullable=True), - pa.field("integer", pa.int32(), nullable=True), pa.field("long", pa.int64(), nullable=True), + pa.field("list", pa.large_list(pa.int64()), nullable=False), + pa.field("map", pa.map_(pa.large_string(), pa.int64()), nullable=False), )) ) + lhs = spark.table(f"{identifier}").toPandas() + rhs = written_arrow_table.to_pandas() + + for column in written_arrow_table.column_names: + for left, right in zip(lhs[column].to_list(), rhs[column].to_list()): + print(f"{left=}, {right=}") + if column == "map": + # Arrow returns a list of tuples, instead of a dict + right = dict(right) + if column == "list": + # Arrow returns an array + right = list(right) + assert left == right @pytest.mark.integration diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 645804b4cd..d61a50bb0d 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -1757,7 +1757,7 @@ def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: ┃ ┃ Table field ┃ Dataframe field ┃ ┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ ✅ │ 1: foo: optional string │ 1: foo: optional string │ -│ ❌ │ 2: bar: required int │ Missing │ +│ ❌ │ 2: bar: required int │ 2: bar: optional int │ │ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ └────┴──────────────────────────┴──────────────────────────┘ """ @@ -1836,29 +1836,29 @@ def test_schema_mismatch_missing_required_field_nested(table_schema_nested: Sche │ ✅ │ 1: foo: optional string │ 1: foo: optional string │ │ ✅ │ 2: bar: required int │ 2: bar: required int │ │ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ -│ ✅ │ 5: element: required string │ 5: element: required string │ │ ✅ │ 4: qux: required list │ 4: qux: required list │ -│ ✅ │ 9: key: required string │ 9: key: required string │ -│ ✅ │ 10: value: required int │ 10: value: required int │ +│ ✅ │ 5: element: required string │ 5: element: required string │ +│ ✅ │ 6: quux: required map> │ map> │ │ ✅ │ 7: key: required string │ 7: key: required string │ │ ✅ │ 8: value: required map │ int> │ -│ ✅ │ 6: quux: required map> │ map> │ -│ ✅ │ 13: latitude: optional float │ 13: latitude: optional float │ -│ ✅ │ 14: longitude: optional float │ 14: longitude: optional float │ -│ ✅ │ 12: element: required struct<13: │ 12: element: required struct<13: │ -│ │ latitude: optional float, 14: │ latitude: optional float, 14: │ -│ │ longitude: optional float> │ longitude: optional float> │ +│ ✅ │ 9: key: required string │ 9: key: required string │ +│ ✅ │ 10: value: required int │ 10: value: required int │ │ ✅ │ 11: location: required │ 11: location: required │ │ │ list> │ float>> │ -│ ✅ │ 16: name: optional string │ 16: name: optional string │ -│ ❌ │ 17: age: required int │ Missing │ +│ ✅ │ 12: element: required struct<13: │ 12: element: required struct<13: │ +│ │ latitude: optional float, 14: │ latitude: optional float, 14: │ +│ │ longitude: optional float> │ longitude: optional float> │ +│ ✅ │ 13: latitude: optional float │ 13: latitude: optional float │ +│ ✅ │ 14: longitude: optional float │ 14: longitude: optional float │ │ ✅ │ 15: person: optional struct<16: │ 15: person: optional struct<16: │ │ │ name: optional string, 17: age: │ name: optional string> │ │ │ required int> │ │ +│ ✅ │ 16: name: optional string │ 16: name: optional string │ +│ ❌ │ 17: age: required int │ Missing │ └────┴────────────────────────────────────┴────────────────────────────────────┘ """ From f0125e9b42f45607d6bea908ecbf68199bb5de07 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Mon, 15 Jul 2024 01:56:40 +0000 Subject: [PATCH 07/13] fix --- pyiceberg/io/pyarrow.py | 19 ++++++++----------- tests/integration/test_writes/test_writes.py | 5 ++--- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 7664da3fe9..6cd1ea3b68 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1453,17 +1453,14 @@ def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: st except ValueError: return None - try: - if isinstance(partner_struct, pa.StructArray): - return partner_struct.field(name) - elif isinstance(partner_struct, pa.Table): - return partner_struct.column(name).combine_chunks() - elif isinstance(partner_struct, pa.RecordBatch): - return partner_struct.column(name) - else: - raise ValueError(f"Cannot find {name} in expected partner_struct type {type(partner_struct)}") - except KeyError: - return None + if isinstance(partner_struct, pa.StructArray): + return partner_struct.field(name) + elif isinstance(partner_struct, pa.Table): + return partner_struct.column(name).combine_chunks() + elif isinstance(partner_struct, pa.RecordBatch): + return partner_struct.column(name) + else: + raise ValueError(f"Cannot find {name} in expected partner_struct type {type(partner_struct)}") return None diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 6ab7bfc24b..6565b763d7 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -1079,13 +1079,12 @@ def test_table_write_schema_with_valid_upcast( for column in written_arrow_table.column_names: for left, right in zip(lhs[column].to_list(), rhs[column].to_list()): - print(f"{left=}, {right=}") if column == "map": # Arrow returns a list of tuples, instead of a dict right = dict(right) if column == "list": - # Arrow returns an array - right = list(right) + # Arrow returns an array, convert to list for equality check + left, right = list(left), list(right) assert left == right From 29573d967c2a51ebdf8d77cf11bbd63d04329021 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Mon, 15 Jul 2024 11:31:59 -0400 Subject: [PATCH 08/13] Thank you @Fokko ! Co-authored-by: Fokko Driesprong --- pyiceberg/schema.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index cfe3fe3a7b..ac6eab1a74 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -1673,10 +1673,7 @@ def _is_field_compatible(self, lhs: NestedField) -> bool: return True # We only check that the parent node is also of the same type. # We check the type of the child nodes when we traverse them later. - elif any( - (isinstance(lhs.field_type, container_type) and isinstance(rhs.field_type, container_type)) - for container_type in {StructType, MapType, ListType} - ): + elif not lhs.is_primtive and not rhs.is_primitive: self.rich_table.add_row("✅", str(lhs), str(rhs)) return True else: From d7ec362013f4648bb86fa62a6745f32ef8061eb9 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Mon, 15 Jul 2024 15:57:00 +0000 Subject: [PATCH 09/13] revert --- pyiceberg/schema.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index ac6eab1a74..cfe3fe3a7b 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -1673,7 +1673,10 @@ def _is_field_compatible(self, lhs: NestedField) -> bool: return True # We only check that the parent node is also of the same type. # We check the type of the child nodes when we traverse them later. - elif not lhs.is_primtive and not rhs.is_primitive: + elif any( + (isinstance(lhs.field_type, container_type) and isinstance(rhs.field_type, container_type)) + for container_type in {StructType, MapType, ListType} + ): self.rich_table.add_row("✅", str(lhs), str(rhs)) return True else: From d4d80e323d1c43fb4694ad0397fb4dece7f145e4 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Mon, 15 Jul 2024 21:13:41 +0000 Subject: [PATCH 10/13] add-files promotiontest --- tests/conftest.py | 53 +++++++++++ tests/integration/test_add_files.py | 94 ++++++++++++++++++++ tests/integration/test_writes/test_writes.py | 52 ++++------- 3 files changed, 162 insertions(+), 37 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 91ab8f2e56..8c14be1f70 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2506,3 +2506,56 @@ def table_schema_with_all_microseconds_timestamp_precision() -> Schema: NestedField(field_id=10, name="timestamptz_ns_z", field_type=TimestamptzType(), required=False), NestedField(field_id=11, name="timestamptz_s_0000", field_type=TimestamptzType(), required=False), ) + + +@pytest.fixture(scope="session") +def table_schema_with_longs() -> Schema: + """Iceberg table Schema with longs in simple and nested types.""" + return Schema( + NestedField(field_id=1, name="long", field_type=LongType(), required=False), + NestedField( + field_id=2, + name="list", + field_type=ListType(element_id=4, element_type=LongType(), element_required=False), + required=True, + ), + NestedField( + field_id=3, + name="map", + field_type=MapType( + key_id=5, + key_type=StringType(), + value_id=6, + value_type=LongType(), + value_required=False, + ), + required=True, + ), + ) + + +@pytest.fixture(scope="session") +def pyarrow_schema_with_longs() -> "pa.Schema": + """Pyarrow Schema with longs in simple and nested types.""" + import pyarrow as pa + + return pa.schema(( + pa.field("long", pa.int32(), nullable=True), # can support upcasting integer to long + pa.field("list", pa.list_(pa.int32()), nullable=False), # can support upcasting integer to long + pa.field("map", pa.map_(pa.string(), pa.int32()), nullable=False), # can support upcasting integer to long + )) + + +@pytest.fixture(scope="session") +def pyarrow_table_with_longs(pyarrow_schema_with_longs: "pa.Schema") -> "pa.Table": + """Pyarrow table with longs in simple and nested types.""" + import pyarrow as pa + + return pa.Table.from_pydict( + { + "long": [1, 9], + "list": [[1, 1], [2, 2]], + "map": [{"a": 1}, {"b": 2}], + }, + schema=pyarrow_schema_with_longs, + ) diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index 421cc11fec..be8eacff6e 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -38,6 +38,7 @@ BooleanType, DateType, IntegerType, + LongType, NestedField, StringType, TimestamptzType, @@ -617,3 +618,96 @@ def test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog, format_v ), ): tbl.add_files(file_paths=[file_path]) + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_table_write_schema_with_valid_nullability_diff( + spark: SparkSession, session_catalog: Catalog, format_version: int +) -> None: + identifier = f"default.test_table_write_with_valid_nullability_diff{format_version}" + table_schema = Schema( + NestedField(field_id=1, name="long", field_type=LongType(), required=False), + ) + other_schema = pa.schema(( + pa.field("long", pa.int64(), nullable=False), # can support writing required pyarrow field to optional Iceberg field + )) + arrow_table = pa.Table.from_pydict( + { + "long": [1, 9], + }, + schema=other_schema, + ) + tbl = session_catalog.create_table( + identifier=identifier, + schema=table_schema, + properties={"format-version": str(format_version)}, + partition_spec=PartitionSpec(), + ) + + file_path = f"s3://warehouse/default/test_valid_nullability_diff/v{format_version}/test.parquet" + # write parquet files + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=other_schema) as writer: + writer.write_table(arrow_table) + + tbl.add_files(file_paths=[file_path]) + # table's long field should cast to be optional on read + written_arrow_table = tbl.scan().to_arrow() + assert written_arrow_table == arrow_table.cast(pa.schema((pa.field("long", pa.int64(), nullable=True),))) + lhs = spark.table(f"{identifier}").toPandas() + rhs = written_arrow_table.to_pandas() + + for column in written_arrow_table.column_names: + for left, right in zip(lhs[column].to_list(), rhs[column].to_list()): + assert left == right + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_table_write_schema_with_valid_upcast( + spark: SparkSession, + session_catalog: Catalog, + format_version: int, + table_schema_with_longs: Schema, + pyarrow_schema_with_longs: pa.Schema, + pyarrow_table_with_longs: pa.Table, +) -> None: + identifier = f"default.test_table_write_with_valid_upcast{format_version}" + tbl = session_catalog.create_table( + identifier=identifier, + schema=table_schema_with_longs, + properties={"format-version": str(format_version)}, + partition_spec=PartitionSpec(), + ) + + file_path = f"s3://warehouse/default/test_valid_nullability_diff/v{format_version}/test.parquet" + # write parquet files + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=pyarrow_schema_with_longs) as writer: + writer.write_table(pyarrow_table_with_longs) + + tbl.add_files(file_paths=[file_path]) + # table's long field should cast to long on read + written_arrow_table = tbl.scan().to_arrow() + assert written_arrow_table == pyarrow_table_with_longs.cast( + pa.schema(( + pa.field("long", pa.int64(), nullable=True), + pa.field("list", pa.large_list(pa.int64()), nullable=False), + pa.field("map", pa.map_(pa.large_string(), pa.int64()), nullable=False), + )) + ) + lhs = spark.table(f"{identifier}").toPandas() + rhs = written_arrow_table.to_pandas() + + for column in written_arrow_table.column_names: + for left, right in zip(lhs[column].to_list(), rhs[column].to_list()): + if column == "map": + # Arrow returns a list of tuples, instead of a dict + right = dict(right) + if column == "list": + # Arrow returns an array, convert to list for equality check + left, right = list(left), list(right) + assert left == right diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 6565b763d7..ce6aaa350e 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -43,7 +43,7 @@ from pyiceberg.schema import Schema from pyiceberg.table import TableProperties from pyiceberg.transforms import IdentityTransform -from pyiceberg.types import IntegerType, ListType, LongType, MapType, NestedField, StringType +from pyiceberg.types import IntegerType, LongType, NestedField from utils import _create_table @@ -1027,47 +1027,25 @@ def test_table_write_schema_with_valid_nullability_diff( @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) def test_table_write_schema_with_valid_upcast( - spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int + spark: SparkSession, + session_catalog: Catalog, + format_version: int, + table_schema_with_longs: Schema, + pyarrow_schema_with_longs: pa.Schema, + pyarrow_table_with_longs: pa.Table, ) -> None: identifier = "default.test_table_write_with_valid_upcast" - table_schema = Schema( - NestedField(field_id=1, name="long", field_type=LongType(), required=False), - NestedField( - field_id=2, - name="list", - field_type=ListType(element_id=4, element_type=LongType(), element_required=False), - required=True, - ), - NestedField( - field_id=3, - name="map", - field_type=MapType( - key_id=5, - key_type=StringType(), - value_id=6, - value_type=LongType(), - value_required=False, - ), - required=True, - ), - ) - other_schema = pa.schema(( - pa.field("long", pa.int32(), nullable=True), # can support upcasting integer to long - pa.field("list", pa.list_(pa.int32()), nullable=False), # can support upcasting integer to long - pa.field("map", pa.map_(pa.string(), pa.int32()), nullable=False), # can support upcasting integer to long - )) - arrow_table = pa.Table.from_pydict( - { - "long": [1, 9], - "list": [[1, 1], [2, 2]], - "map": [{"a": 1}, {"b": 2}], - }, - schema=other_schema, + + tbl = _create_table( + session_catalog, + identifier, + {"format-version": format_version}, + [pyarrow_table_with_longs], + schema=table_schema_with_longs, ) - tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table], schema=table_schema) # table's long field should cast to long on read written_arrow_table = tbl.scan().to_arrow() - assert written_arrow_table == arrow_table.cast( + assert written_arrow_table == pyarrow_table_with_longs.cast( pa.schema(( pa.field("long", pa.int64(), nullable=True), pa.field("list", pa.large_list(pa.int64()), nullable=False), From 865c4467424ffa1e2583da3cd1bbd49a78c5d2dd Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Tue, 16 Jul 2024 01:27:52 +0000 Subject: [PATCH 11/13] support promote for add_files --- pyiceberg/io/pyarrow.py | 13 ++++++--- tests/conftest.py | 3 +++ tests/integration/test_add_files.py | 28 +++----------------- tests/integration/test_writes/test_writes.py | 1 + 4 files changed, 18 insertions(+), 27 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 6cd1ea3b68..f0d124e7f7 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1552,9 +1552,16 @@ def __init__(self, iceberg_type: PrimitiveType, physical_type_string: str, trunc expected_physical_type = _primitive_to_physical(iceberg_type) if expected_physical_type != physical_type_string: - raise ValueError( - f"Unexpected physical type {physical_type_string} for {iceberg_type}, expected {expected_physical_type}" - ) + # Allow promotable physical types + # INT32 -> INT64 and FLOAT -> DOUBLE are safe type casts + if (physical_type_string == "INT32" and expected_physical_type == "INT64") or ( + physical_type_string == "FLOAT" and expected_physical_type == "DOUBLE" + ): + pass + else: + raise ValueError( + f"Unexpected physical type {physical_type_string} for {iceberg_type}, expected {expected_physical_type}" + ) self.primitive_type = iceberg_type diff --git a/tests/conftest.py b/tests/conftest.py index 8c14be1f70..6604573cb9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2531,6 +2531,7 @@ def table_schema_with_longs() -> Schema: ), required=True, ), + NestedField(field_id=7, name="double", field_type=DoubleType(), required=False), ) @@ -2543,6 +2544,7 @@ def pyarrow_schema_with_longs() -> "pa.Schema": pa.field("long", pa.int32(), nullable=True), # can support upcasting integer to long pa.field("list", pa.list_(pa.int32()), nullable=False), # can support upcasting integer to long pa.field("map", pa.map_(pa.string(), pa.int32()), nullable=False), # can support upcasting integer to long + pa.field("double", pa.float32(), nullable=True), # can support upcasting float to double )) @@ -2556,6 +2558,7 @@ def pyarrow_table_with_longs(pyarrow_schema_with_longs: "pa.Schema") -> "pa.Tabl "long": [1, 9], "list": [[1, 1], [2, 2]], "map": [{"a": 1}, {"b": 2}], + "double": [1.1, 9.2], }, schema=pyarrow_schema_with_longs, ) diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index be8eacff6e..46ec6793f9 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -590,18 +590,7 @@ def test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog, format_v mocker.patch.dict(os.environ, values={"PYICEBERG_DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE": "True"}) identifier = f"default.timestamptz_ns_added{format_version}" - - try: - session_catalog.drop_table(identifier=identifier) - except NoSuchTableError: - pass - - tbl = session_catalog.create_table( - identifier=identifier, - schema=nanoseconds_schema_iceberg, - properties={"format-version": str(format_version)}, - partition_spec=PartitionSpec(), - ) + tbl = _create_table(session_catalog, identifier, format_version, schema=nanoseconds_schema_iceberg) file_path = f"s3://warehouse/default/test_timestamp_tz/v{format_version}/test.parquet" # write parquet files @@ -638,12 +627,7 @@ def test_table_write_schema_with_valid_nullability_diff( }, schema=other_schema, ) - tbl = session_catalog.create_table( - identifier=identifier, - schema=table_schema, - properties={"format-version": str(format_version)}, - partition_spec=PartitionSpec(), - ) + tbl = _create_table(session_catalog, identifier, format_version, schema=table_schema) file_path = f"s3://warehouse/default/test_valid_nullability_diff/v{format_version}/test.parquet" # write parquet files @@ -675,12 +659,7 @@ def test_table_write_schema_with_valid_upcast( pyarrow_table_with_longs: pa.Table, ) -> None: identifier = f"default.test_table_write_with_valid_upcast{format_version}" - tbl = session_catalog.create_table( - identifier=identifier, - schema=table_schema_with_longs, - properties={"format-version": str(format_version)}, - partition_spec=PartitionSpec(), - ) + tbl = _create_table(session_catalog, identifier, format_version, schema=table_schema_with_longs) file_path = f"s3://warehouse/default/test_valid_nullability_diff/v{format_version}/test.parquet" # write parquet files @@ -697,6 +676,7 @@ def test_table_write_schema_with_valid_upcast( pa.field("long", pa.int64(), nullable=True), pa.field("list", pa.large_list(pa.int64()), nullable=False), pa.field("map", pa.map_(pa.large_string(), pa.int64()), nullable=False), + pa.field("double", pa.float64(), nullable=True), )) ) lhs = spark.table(f"{identifier}").toPandas() diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index ce6aaa350e..e343c709ba 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -1050,6 +1050,7 @@ def test_table_write_schema_with_valid_upcast( pa.field("long", pa.int64(), nullable=True), pa.field("list", pa.large_list(pa.int64()), nullable=False), pa.field("map", pa.map_(pa.large_string(), pa.int64()), nullable=False), + pa.field("double", pa.float64(), nullable=True), # can support upcasting float to double )) ) lhs = spark.table(f"{identifier}").toPandas() From 734047635ab89c12fdc615d00ebf8dcb32bda122 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Tue, 16 Jul 2024 02:07:23 +0000 Subject: [PATCH 12/13] add tests for uuid --- tests/conftest.py | 17 ++++++++++------- tests/integration/test_add_files.py | 19 ++++++++++++------- tests/integration/test_writes/test_writes.py | 17 +++++++++++------ 3 files changed, 33 insertions(+), 20 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 6604573cb9..7f9a2bcfa8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2509,8 +2509,8 @@ def table_schema_with_all_microseconds_timestamp_precision() -> Schema: @pytest.fixture(scope="session") -def table_schema_with_longs() -> Schema: - """Iceberg table Schema with longs in simple and nested types.""" +def table_schema_with_promoted_types() -> Schema: + """Iceberg table Schema with longs, doubles and uuid in simple and nested types.""" return Schema( NestedField(field_id=1, name="long", field_type=LongType(), required=False), NestedField( @@ -2532,12 +2532,13 @@ def table_schema_with_longs() -> Schema: required=True, ), NestedField(field_id=7, name="double", field_type=DoubleType(), required=False), + NestedField(field_id=8, name="uuid", field_type=UUIDType(), required=False), ) @pytest.fixture(scope="session") -def pyarrow_schema_with_longs() -> "pa.Schema": - """Pyarrow Schema with longs in simple and nested types.""" +def pyarrow_schema_with_promoted_types() -> "pa.Schema": + """Pyarrow Schema with longs, doubles and uuid in simple and nested types.""" import pyarrow as pa return pa.schema(( @@ -2545,12 +2546,13 @@ def pyarrow_schema_with_longs() -> "pa.Schema": pa.field("list", pa.list_(pa.int32()), nullable=False), # can support upcasting integer to long pa.field("map", pa.map_(pa.string(), pa.int32()), nullable=False), # can support upcasting integer to long pa.field("double", pa.float32(), nullable=True), # can support upcasting float to double + pa.field("uuid", pa.binary(length=16), nullable=True), # can support upcasting float to double )) @pytest.fixture(scope="session") -def pyarrow_table_with_longs(pyarrow_schema_with_longs: "pa.Schema") -> "pa.Table": - """Pyarrow table with longs in simple and nested types.""" +def pyarrow_table_with_promoted_types(pyarrow_schema_with_promoted_types: "pa.Schema") -> "pa.Table": + """Pyarrow table with longs, doubles and uuid in simple and nested types.""" import pyarrow as pa return pa.Table.from_pydict( @@ -2559,6 +2561,7 @@ def pyarrow_table_with_longs(pyarrow_schema_with_longs: "pa.Schema") -> "pa.Tabl "list": [[1, 1], [2, 2]], "map": [{"a": 1}, {"b": 2}], "double": [1.1, 9.2], + "uuid": [b"qZx\xefNS@\x89\x9b\xf9:\xd0\xee\x9b\xf5E", b"\x97]\x87T^JDJ\x96\x97\xf4v\xe4\x03\x0c\xde"], }, - schema=pyarrow_schema_with_longs, + schema=pyarrow_schema_with_promoted_types, ) diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index 46ec6793f9..43a1d455d0 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -654,29 +654,30 @@ def test_table_write_schema_with_valid_upcast( spark: SparkSession, session_catalog: Catalog, format_version: int, - table_schema_with_longs: Schema, - pyarrow_schema_with_longs: pa.Schema, - pyarrow_table_with_longs: pa.Table, + table_schema_with_promoted_types: Schema, + pyarrow_schema_with_promoted_types: pa.Schema, + pyarrow_table_with_promoted_types: pa.Table, ) -> None: identifier = f"default.test_table_write_with_valid_upcast{format_version}" - tbl = _create_table(session_catalog, identifier, format_version, schema=table_schema_with_longs) + tbl = _create_table(session_catalog, identifier, format_version, schema=table_schema_with_promoted_types) file_path = f"s3://warehouse/default/test_valid_nullability_diff/v{format_version}/test.parquet" # write parquet files fo = tbl.io.new_output(file_path) with fo.create(overwrite=True) as fos: - with pq.ParquetWriter(fos, schema=pyarrow_schema_with_longs) as writer: - writer.write_table(pyarrow_table_with_longs) + with pq.ParquetWriter(fos, schema=pyarrow_schema_with_promoted_types) as writer: + writer.write_table(pyarrow_table_with_promoted_types) tbl.add_files(file_paths=[file_path]) # table's long field should cast to long on read written_arrow_table = tbl.scan().to_arrow() - assert written_arrow_table == pyarrow_table_with_longs.cast( + assert written_arrow_table == pyarrow_table_with_promoted_types.cast( pa.schema(( pa.field("long", pa.int64(), nullable=True), pa.field("list", pa.large_list(pa.int64()), nullable=False), pa.field("map", pa.map_(pa.large_string(), pa.int64()), nullable=False), pa.field("double", pa.float64(), nullable=True), + pa.field("uuid", pa.binary(length=16), nullable=True), # can UUID is read as fixed length binary of length 16 )) ) lhs = spark.table(f"{identifier}").toPandas() @@ -690,4 +691,8 @@ def test_table_write_schema_with_valid_upcast( if column == "list": # Arrow returns an array, convert to list for equality check left, right = list(left), list(right) + if column == "uuid": + # Spark Iceberg represents UUID as hex string like '715a78ef-4e53-4089-9bf9-3ad0ee9bf545' + # whereas PyIceberg represents UUID as bytes on read + left, right = left.replace("-", ""), right.hex() assert left == right diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index e343c709ba..09fe654d29 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -1030,9 +1030,9 @@ def test_table_write_schema_with_valid_upcast( spark: SparkSession, session_catalog: Catalog, format_version: int, - table_schema_with_longs: Schema, - pyarrow_schema_with_longs: pa.Schema, - pyarrow_table_with_longs: pa.Table, + table_schema_with_promoted_types: Schema, + pyarrow_schema_with_promoted_types: pa.Schema, + pyarrow_table_with_promoted_types: pa.Table, ) -> None: identifier = "default.test_table_write_with_valid_upcast" @@ -1040,17 +1040,18 @@ def test_table_write_schema_with_valid_upcast( session_catalog, identifier, {"format-version": format_version}, - [pyarrow_table_with_longs], - schema=table_schema_with_longs, + [pyarrow_table_with_promoted_types], + schema=table_schema_with_promoted_types, ) # table's long field should cast to long on read written_arrow_table = tbl.scan().to_arrow() - assert written_arrow_table == pyarrow_table_with_longs.cast( + assert written_arrow_table == pyarrow_table_with_promoted_types.cast( pa.schema(( pa.field("long", pa.int64(), nullable=True), pa.field("list", pa.large_list(pa.int64()), nullable=False), pa.field("map", pa.map_(pa.large_string(), pa.int64()), nullable=False), pa.field("double", pa.float64(), nullable=True), # can support upcasting float to double + pa.field("uuid", pa.binary(length=16), nullable=True), # can UUID is read as fixed length binary of length 16 )) ) lhs = spark.table(f"{identifier}").toPandas() @@ -1064,6 +1065,10 @@ def test_table_write_schema_with_valid_upcast( if column == "list": # Arrow returns an array, convert to list for equality check left, right = list(left), list(right) + if column == "uuid": + # Spark Iceberg represents UUID as hex string like '715a78ef-4e53-4089-9bf9-3ad0ee9bf545' + # whereas PyIceberg represents UUID as bytes on read + left, right = left.replace("-", ""), right.hex() assert left == right From 28e20d9efa28ea93f4143c0f1cf2206858ff9e87 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Tue, 16 Jul 2024 15:55:21 +0000 Subject: [PATCH 13/13] add_files subset schema test --- pyiceberg/io/pyarrow.py | 10 ------ tests/integration/test_add_files.py | 52 ++++++++++++++++++++++++----- 2 files changed, 44 insertions(+), 18 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index f0d124e7f7..cd6736fbba 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1906,16 +1906,6 @@ def data_file_statistics_from_parquet_metadata( set the mode for column metrics collection parquet_column_mapping (Dict[str, int]): The mapping of the parquet file name to the field ID """ - if parquet_metadata.num_columns != len(stats_columns): - raise ValueError( - f"Number of columns in statistics configuration ({len(stats_columns)}) is different from the number of columns in pyarrow table ({parquet_metadata.num_columns})" - ) - - if parquet_metadata.num_columns != len(parquet_column_mapping): - raise ValueError( - f"Number of columns in column mapping ({len(parquet_column_mapping)}) is different from the number of columns in pyarrow table ({parquet_metadata.num_columns})" - ) - column_sizes: Dict[int, int] = {} value_counts: Dict[int, int] = {} split_offsets: List[int] = [] diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index 43a1d455d0..3703a9e0b6 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -30,6 +30,7 @@ from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError from pyiceberg.io import FileIO +from pyiceberg.io.pyarrow import _pyarrow_schema_ensure_large_types from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.table import Table @@ -611,10 +612,8 @@ def test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog, format_v @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) -def test_table_write_schema_with_valid_nullability_diff( - spark: SparkSession, session_catalog: Catalog, format_version: int -) -> None: - identifier = f"default.test_table_write_with_valid_nullability_diff{format_version}" +def test_add_file_with_valid_nullability_diff(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: + identifier = f"default.test_table_with_valid_nullability_diff{format_version}" table_schema = Schema( NestedField(field_id=1, name="long", field_type=LongType(), required=False), ) @@ -629,7 +628,7 @@ def test_table_write_schema_with_valid_nullability_diff( ) tbl = _create_table(session_catalog, identifier, format_version, schema=table_schema) - file_path = f"s3://warehouse/default/test_valid_nullability_diff/v{format_version}/test.parquet" + file_path = f"s3://warehouse/default/test_add_file_with_valid_nullability_diff/v{format_version}/test.parquet" # write parquet files fo = tbl.io.new_output(file_path) with fo.create(overwrite=True) as fos: @@ -650,7 +649,7 @@ def test_table_write_schema_with_valid_nullability_diff( @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) -def test_table_write_schema_with_valid_upcast( +def test_add_files_with_valid_upcast( spark: SparkSession, session_catalog: Catalog, format_version: int, @@ -658,10 +657,10 @@ def test_table_write_schema_with_valid_upcast( pyarrow_schema_with_promoted_types: pa.Schema, pyarrow_table_with_promoted_types: pa.Table, ) -> None: - identifier = f"default.test_table_write_with_valid_upcast{format_version}" + identifier = f"default.test_table_with_valid_upcast{format_version}" tbl = _create_table(session_catalog, identifier, format_version, schema=table_schema_with_promoted_types) - file_path = f"s3://warehouse/default/test_valid_nullability_diff/v{format_version}/test.parquet" + file_path = f"s3://warehouse/default/test_add_files_with_valid_upcast/v{format_version}/test.parquet" # write parquet files fo = tbl.io.new_output(file_path) with fo.create(overwrite=True) as fos: @@ -696,3 +695,40 @@ def test_table_write_schema_with_valid_upcast( # whereas PyIceberg represents UUID as bytes on read left, right = left.replace("-", ""), right.hex() assert left == right + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_add_files_subset_of_schema(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: + identifier = f"default.test_table_subset_of_schema{format_version}" + tbl = _create_table(session_catalog, identifier, format_version) + + file_path = f"s3://warehouse/default/test_add_files_subset_of_schema/v{format_version}/test.parquet" + arrow_table_without_some_columns = ARROW_TABLE.combine_chunks().drop(ARROW_TABLE.column_names[0]) + + # write parquet files + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=arrow_table_without_some_columns.schema) as writer: + writer.write_table(arrow_table_without_some_columns) + + tbl.add_files(file_paths=[file_path]) + written_arrow_table = tbl.scan().to_arrow() + assert tbl.scan().to_arrow() == pa.Table.from_pylist( + [ + { + "foo": None, # Missing column is read as None on read + "bar": "bar_string", + "baz": 123, + "qux": date(2024, 3, 7), + } + ], + schema=_pyarrow_schema_ensure_large_types(ARROW_SCHEMA), + ) + + lhs = spark.table(f"{identifier}").toPandas() + rhs = written_arrow_table.to_pandas() + + for column in written_arrow_table.column_names: + for left, right in zip(lhs[column].to_list(), rhs[column].to_list()): + assert left == right