diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py index dd707cea14..408126d3b3 100644 --- a/pyiceberg/partitioning.py +++ b/pyiceberg/partitioning.py @@ -21,7 +21,7 @@ from dataclasses import dataclass from datetime import date, datetime, time from functools import cached_property, singledispatch -from typing import Annotated, Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union +from typing import Annotated, Any, Dict, Generic, List, Optional, Set, Tuple, TypeVar, Union from urllib.parse import quote_plus from pydantic import ( @@ -249,6 +249,31 @@ def partition_to_path(self, data: Record, schema: Schema) -> str: UNPARTITIONED_PARTITION_SPEC = PartitionSpec(spec_id=0) +def validate_partition_name( + field_name: str, + partition_transform: Transform[Any, Any], + source_id: int, + schema: Schema, + partition_names: Set[str], +) -> None: + """Validate that a partition field name doesn't conflict with schema field names.""" + try: + schema_field = schema.find_field(field_name) + except ValueError: + return # No conflict if field doesn't exist in schema + + if isinstance(partition_transform, (IdentityTransform, VoidTransform)): + # For identity and void transforms, allow conflict only if sourced from the same schema field + if schema_field.field_id != source_id: + raise ValueError(f"Cannot create identity partition sourced from different field in schema: {field_name}") + else: + raise ValueError(f"Cannot create partition with a name that exists in schema: {field_name}") + if not field_name: + raise ValueError("Undefined name") + if field_name in partition_names: + raise ValueError(f"Partition name has to be unique: {field_name}") + + def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fresh_schema: Schema) -> PartitionSpec: partition_fields = [] for pos, field in enumerate(spec.fields): @@ -258,6 +283,9 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre fresh_field = fresh_schema.find_field(original_column_name) if fresh_field is None: raise ValueError(f"Could not find field in fresh schema: {original_column_name}") + + validate_partition_name(field.name, field.transform, fresh_field.field_id, fresh_schema, set()) + partition_fields.append( PartitionField( name=field.name, diff --git a/pyiceberg/table/update/schema.py b/pyiceberg/table/update/schema.py index 6ad01e97f2..44feaa112c 100644 --- a/pyiceberg/table/update/schema.py +++ b/pyiceberg/table/update/schema.py @@ -658,6 +658,13 @@ def _apply(self) -> Schema: # Check the field-ids new_schema = Schema(*struct.fields) + from pyiceberg.partitioning import validate_partition_name + + for spec in self._transaction.table_metadata.partition_specs: + for partition_field in spec.fields: + validate_partition_name( + partition_field.name, partition_field.transform, partition_field.source_id, new_schema, set() + ) field_ids = set() for name in self._identifier_field_names: try: diff --git a/pyiceberg/table/update/spec.py b/pyiceberg/table/update/spec.py index 1f91aa5d17..2a3d54969d 100644 --- a/pyiceberg/table/update/spec.py +++ b/pyiceberg/table/update/spec.py @@ -174,26 +174,18 @@ def _commit(self) -> UpdatesAndRequirements: return updates, requirements def _apply(self) -> PartitionSpec: - def _check_and_add_partition_name(schema: Schema, name: str, source_id: int, partition_names: Set[str]) -> None: - try: - field = schema.find_field(name) - except ValueError: - field = None - - if source_id is not None and field is not None and field.field_id != source_id: - raise ValueError(f"Cannot create identity partition from a different field in the schema {name}") - elif field is not None and source_id != field.field_id: - raise ValueError(f"Cannot create partition from name that exists in schema {name}") - if not name: - raise ValueError("Undefined name") - if name in partition_names: - raise ValueError(f"Partition name has to be unique: {name}") + def _check_and_add_partition_name( + schema: Schema, name: str, source_id: int, transform: Transform[Any, Any], partition_names: Set[str] + ) -> None: + from pyiceberg.partitioning import validate_partition_name + + validate_partition_name(name, transform, source_id, schema, partition_names) partition_names.add(name) def _add_new_field( schema: Schema, source_id: int, field_id: int, name: str, transform: Transform[Any, Any], partition_names: Set[str] ) -> PartitionField: - _check_and_add_partition_name(schema, name, source_id, partition_names) + _check_and_add_partition_name(schema, name, source_id, transform, partition_names) return PartitionField(source_id, field_id, transform, name) partition_fields = [] @@ -244,6 +236,13 @@ def _add_new_field( partition_fields.append(new_field) for added_field in self._adds: + _check_and_add_partition_name( + self._transaction.table_metadata.schema(), + added_field.name, + added_field.source_id, + added_field.transform, + partition_names, + ) new_field = PartitionField( source_id=added_field.source_id, field_id=added_field.field_id, diff --git a/tests/integration/test_partition_evolution.py b/tests/integration/test_partition_evolution.py index d489d6a5d0..2444e18737 100644 --- a/tests/integration/test_partition_evolution.py +++ b/tests/integration/test_partition_evolution.py @@ -20,7 +20,7 @@ from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError -from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.table import Table from pyiceberg.transforms import ( @@ -63,13 +63,18 @@ def _table_v2(catalog: Catalog) -> Table: return _create_table_with_schema(catalog, schema_with_timestamp, "2") -def _create_table_with_schema(catalog: Catalog, schema: Schema, format_version: str) -> Table: +def _create_table_with_schema( + catalog: Catalog, schema: Schema, format_version: str, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC +) -> Table: tbl_name = "default.test_schema_evolution" try: catalog.drop_table(tbl_name) except NoSuchTableError: pass - return catalog.create_table(identifier=tbl_name, schema=schema, properties={"format-version": format_version}) + + return catalog.create_table( + identifier=tbl_name, schema=schema, partition_spec=partition_spec, properties={"format-version": format_version} + ) @pytest.mark.integration @@ -564,3 +569,80 @@ def _validate_new_partition_fields( assert len(spec.fields) == len(expected_partition_fields) for i in range(len(spec.fields)): assert spec.fields[i] == expected_partition_fields[i] + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_partition_schema_field_name_conflict(catalog: Catalog) -> None: + schema = Schema( + NestedField(1, "id", LongType(), required=False), + NestedField(2, "event_ts", TimestampType(), required=False), + NestedField(3, "another_ts", TimestampType(), required=False), + NestedField(4, "str", StringType(), required=False), + ) + table = _create_table_with_schema(catalog, schema, "2") + + with pytest.raises(ValueError, match="Cannot create partition with a name that exists in schema: another_ts"): + table.update_spec().add_field("event_ts", YearTransform(), "another_ts").commit() + with pytest.raises(ValueError, match="Cannot create partition with a name that exists in schema: id"): + table.update_spec().add_field("event_ts", DayTransform(), "id").commit() + + with pytest.raises(ValueError, match="Cannot create identity partition sourced from different field in schema: another_ts"): + table.update_spec().add_field("event_ts", IdentityTransform(), "another_ts").commit() + with pytest.raises(ValueError, match="Cannot create identity partition sourced from different field in schema: str"): + table.update_spec().add_field("id", IdentityTransform(), "str").commit() + + table.update_spec().add_field("id", IdentityTransform(), "id").commit() + table.update_spec().add_field("event_ts", YearTransform(), "event_year").commit() + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_partition_validation_during_table_creation(catalog: Catalog) -> None: + schema = Schema( + NestedField(1, "id", LongType(), required=False), + NestedField(2, "event_ts", TimestampType(), required=False), + NestedField(3, "another_ts", TimestampType(), required=False), + NestedField(4, "str", StringType(), required=False), + ) + + partition_spec = PartitionSpec( + PartitionField(source_id=2, field_id=1000, transform=YearTransform(), name="another_ts"), spec_id=1 + ) + with pytest.raises(ValueError, match="Cannot create partition with a name that exists in schema: another_ts"): + _create_table_with_schema(catalog, schema, "2", partition_spec) + + partition_spec = PartitionSpec( + PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="id"), spec_id=1 + ) + _create_table_with_schema(catalog, schema, "2", partition_spec) + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_schema_evolution_partition_conflict(catalog: Catalog) -> None: + schema = Schema( + NestedField(1, "id", LongType(), required=False), + NestedField(2, "event_ts", TimestampType(), required=False), + ) + partition_spec = PartitionSpec( + PartitionField(source_id=2, field_id=1000, transform=YearTransform(), name="event_year"), + PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="first_name"), + PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="id"), + spec_id=1, + ) + table = _create_table_with_schema(catalog, schema, "2", partition_spec) + + with pytest.raises(ValueError, match="Cannot create partition with a name that exists in schema: event_year"): + table.update_schema().add_column("event_year", StringType()).commit() + with pytest.raises(ValueError, match="Cannot create identity partition sourced from different field in schema: first_name"): + table.update_schema().add_column("first_name", StringType()).commit() + + table.update_schema().add_column("other_field", StringType()).commit() + + with pytest.raises(ValueError, match="Cannot create partition with a name that exists in schema: event_year"): + table.update_schema().rename_column("other_field", "event_year").commit() + with pytest.raises(ValueError, match="Cannot create identity partition sourced from different field in schema: first_name"): + table.update_schema().rename_column("other_field", "first_name").commit() + + table.update_schema().rename_column("other_field", "valid_name").commit() diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index e9698067c1..4b6c6a4d7b 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -980,8 +980,16 @@ def test_append_ymd_transform_partitioned( # Given identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_partition_on_col_{part_col}" nested_field = TABLE_SCHEMA.find_field(part_col) + + if isinstance(transform, YearTransform): + partition_name = f"{part_col}_year" + elif isinstance(transform, MonthTransform): + partition_name = f"{part_col}_month" + elif isinstance(transform, DayTransform): + partition_name = f"{part_col}_day" + partition_spec = PartitionSpec( - PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col) + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=partition_name) ) # When @@ -1037,8 +1045,18 @@ def test_append_transform_partition_verify_partitions_count( part_col = "timestamptz" identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}" nested_field = table_date_timestamps_schema.find_field(part_col) + + if isinstance(transform, YearTransform): + partition_name = f"{part_col}_year" + elif isinstance(transform, MonthTransform): + partition_name = f"{part_col}_month" + elif isinstance(transform, DayTransform): + partition_name = f"{part_col}_day" + elif isinstance(transform, HourTransform): + partition_name = f"{part_col}_hour" + partition_spec = PartitionSpec( - PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col), + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=partition_name), ) # When @@ -1061,7 +1079,7 @@ def test_append_transform_partition_verify_partitions_count( partitions_table = tbl.inspect.partitions() assert partitions_table.num_rows == len(expected_partitions) - assert {part[part_col] for part in partitions_table["partition"].to_pylist()} == expected_partitions + assert {part[partition_name] for part in partitions_table["partition"].to_pylist()} == expected_partitions files_df = spark.sql( f""" SELECT * diff --git a/tests/test_schema.py b/tests/test_schema.py index 3ca74c4027..5e1052f8be 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -32,6 +32,7 @@ prune_columns, sanitize_column_names, ) +from pyiceberg.table import Table, Transaction from pyiceberg.table.update.schema import UpdateSchema from pyiceberg.typedef import EMPTY_DICT, StructProtocol from pyiceberg.types import ( @@ -927,14 +928,14 @@ def primitive_fields() -> List[NestedField]: ] -def test_add_top_level_primitives(primitive_fields: List[NestedField]) -> None: +def test_add_top_level_primitives(primitive_fields: List[NestedField], table_v2: Table) -> None: for primitive_field in primitive_fields: new_schema = Schema(primitive_field) - applied = UpdateSchema(transaction=None, schema=Schema()).union_by_name(new_schema)._apply() # type: ignore + applied = UpdateSchema(transaction=Transaction(table_v2), schema=Schema()).union_by_name(new_schema)._apply() assert applied == new_schema -def test_add_top_level_list_of_primitives(primitive_fields: NestedField) -> None: +def test_add_top_level_list_of_primitives(primitive_fields: NestedField, table_v2: Table) -> None: for primitive_type in TEST_PRIMITIVE_TYPES: new_schema = Schema( NestedField( @@ -944,11 +945,11 @@ def test_add_top_level_list_of_primitives(primitive_fields: NestedField) -> None required=False, ) ) - applied = UpdateSchema(transaction=None, schema=Schema()).union_by_name(new_schema)._apply() # type: ignore + applied = UpdateSchema(transaction=Transaction(table_v2), schema=Schema()).union_by_name(new_schema)._apply() assert applied.as_struct() == new_schema.as_struct() -def test_add_top_level_map_of_primitives(primitive_fields: NestedField) -> None: +def test_add_top_level_map_of_primitives(primitive_fields: NestedField, table_v2: Table) -> None: for primitive_type in TEST_PRIMITIVE_TYPES: new_schema = Schema( NestedField( @@ -960,11 +961,11 @@ def test_add_top_level_map_of_primitives(primitive_fields: NestedField) -> None: required=False, ) ) - applied = UpdateSchema(transaction=None, schema=Schema()).union_by_name(new_schema)._apply() # type: ignore + applied = UpdateSchema(transaction=Transaction(table_v2), schema=Schema()).union_by_name(new_schema)._apply() assert applied.as_struct() == new_schema.as_struct() -def test_add_top_struct_of_primitives(primitive_fields: NestedField) -> None: +def test_add_top_struct_of_primitives(primitive_fields: NestedField, table_v2: Table) -> None: for primitive_type in TEST_PRIMITIVE_TYPES: new_schema = Schema( NestedField( @@ -974,11 +975,11 @@ def test_add_top_struct_of_primitives(primitive_fields: NestedField) -> None: required=False, ) ) - applied = UpdateSchema(transaction=None, schema=Schema()).union_by_name(new_schema)._apply() # type: ignore + applied = UpdateSchema(transaction=Transaction(table_v2), schema=Schema()).union_by_name(new_schema)._apply() assert applied.as_struct() == new_schema.as_struct() -def test_add_nested_primitive(primitive_fields: NestedField) -> None: +def test_add_nested_primitive(primitive_fields: NestedField, table_v2: Table) -> None: for primitive_type in TEST_PRIMITIVE_TYPES: current_schema = Schema(NestedField(field_id=1, name="aStruct", field_type=StructType(), required=False)) new_schema = Schema( @@ -989,7 +990,7 @@ def test_add_nested_primitive(primitive_fields: NestedField) -> None: required=False, ) ) - applied = UpdateSchema(None, None, schema=current_schema).union_by_name(new_schema)._apply() # type: ignore + applied = UpdateSchema(transaction=Transaction(table_v2), schema=current_schema).union_by_name(new_schema)._apply() assert applied.as_struct() == new_schema.as_struct() @@ -1002,18 +1003,18 @@ def _primitive_fields(types: List[PrimitiveType], start_id: int = 0) -> List[Nes return fields -def test_add_nested_primitives(primitive_fields: NestedField) -> None: +def test_add_nested_primitives(primitive_fields: NestedField, table_v2: Table) -> None: current_schema = Schema(NestedField(field_id=1, name="aStruct", field_type=StructType(), required=False)) new_schema = Schema( NestedField( field_id=1, name="aStruct", field_type=StructType(*_primitive_fields(TEST_PRIMITIVE_TYPES, 2)), required=False ) ) - applied = UpdateSchema(transaction=None, schema=current_schema).union_by_name(new_schema)._apply() # type: ignore + applied = UpdateSchema(transaction=Transaction(table_v2), schema=current_schema).union_by_name(new_schema)._apply() assert applied.as_struct() == new_schema.as_struct() -def test_add_nested_lists(primitive_fields: NestedField) -> None: +def test_add_nested_lists(primitive_fields: NestedField, table_v2: Table) -> None: new_schema = Schema( NestedField( field_id=1, @@ -1050,11 +1051,11 @@ def test_add_nested_lists(primitive_fields: NestedField) -> None: required=False, ) ) - applied = UpdateSchema(transaction=None, schema=Schema()).union_by_name(new_schema)._apply() # type: ignore + applied = UpdateSchema(transaction=Transaction(table_v2), schema=Schema()).union_by_name(new_schema)._apply() assert applied.as_struct() == new_schema.as_struct() -def test_add_nested_struct(primitive_fields: NestedField) -> None: +def test_add_nested_struct(primitive_fields: NestedField, table_v2: Table) -> None: new_schema = Schema( NestedField( field_id=1, @@ -1100,11 +1101,11 @@ def test_add_nested_struct(primitive_fields: NestedField) -> None: required=False, ) ) - applied = UpdateSchema(transaction=None, schema=Schema()).union_by_name(new_schema)._apply() # type: ignore + applied = UpdateSchema(transaction=Transaction(table_v2), schema=Schema()).union_by_name(new_schema)._apply() assert applied.as_struct() == new_schema.as_struct() -def test_add_nested_maps(primitive_fields: NestedField) -> None: +def test_add_nested_maps(primitive_fields: NestedField, table_v2: Table) -> None: new_schema = Schema( NestedField( field_id=1, @@ -1143,11 +1144,11 @@ def test_add_nested_maps(primitive_fields: NestedField) -> None: required=False, ) ) - applied = UpdateSchema(transaction=None, schema=Schema()).union_by_name(new_schema)._apply() # type: ignore + applied = UpdateSchema(transaction=Transaction(table_v2), schema=Schema()).union_by_name(new_schema)._apply() assert applied.as_struct() == new_schema.as_struct() -def test_detect_invalid_top_level_list() -> None: +def test_detect_invalid_top_level_list(table_v2: Table) -> None: current_schema = Schema( NestedField( field_id=1, @@ -1166,10 +1167,10 @@ def test_detect_invalid_top_level_list() -> None: ) with pytest.raises(ValidationError, match="Cannot change column type: aList.element: string -> double"): - _ = UpdateSchema(transaction=None, schema=current_schema).union_by_name(new_schema)._apply() # type: ignore + _ = UpdateSchema(transaction=Transaction(table_v2), schema=current_schema).union_by_name(new_schema)._apply() -def test_detect_invalid_top_level_maps() -> None: +def test_detect_invalid_top_level_maps(table_v2: Table) -> None: current_schema = Schema( NestedField( field_id=1, @@ -1188,68 +1189,68 @@ def test_detect_invalid_top_level_maps() -> None: ) with pytest.raises(ValidationError, match="Cannot change column type: aMap.key: string -> uuid"): - _ = UpdateSchema(transaction=None, schema=current_schema).union_by_name(new_schema)._apply() # type: ignore + _ = UpdateSchema(transaction=Transaction(table_v2), schema=current_schema).union_by_name(new_schema)._apply() -def test_allow_double_to_float() -> None: +def test_allow_double_to_float(table_v2: Table) -> None: current_schema = Schema(NestedField(field_id=1, name="aCol", field_type=DoubleType(), required=False)) new_schema = Schema(NestedField(field_id=1, name="aCol", field_type=FloatType(), required=False)) - applied = UpdateSchema(transaction=None, schema=current_schema).union_by_name(new_schema)._apply() # type: ignore + applied = UpdateSchema(transaction=Transaction(table_v2), schema=current_schema).union_by_name(new_schema)._apply() assert applied.as_struct() == current_schema.as_struct() assert len(applied.fields) == 1 assert isinstance(applied.fields[0].field_type, DoubleType) -def test_promote_float_to_double() -> None: +def test_promote_float_to_double(table_v2: Table) -> None: current_schema = Schema(NestedField(field_id=1, name="aCol", field_type=FloatType(), required=False)) new_schema = Schema(NestedField(field_id=1, name="aCol", field_type=DoubleType(), required=False)) - applied = UpdateSchema(transaction=None, schema=current_schema).union_by_name(new_schema)._apply() # type: ignore + applied = UpdateSchema(transaction=Transaction(table_v2), schema=current_schema).union_by_name(new_schema)._apply() assert applied.as_struct() == new_schema.as_struct() assert len(applied.fields) == 1 assert isinstance(applied.fields[0].field_type, DoubleType) -def test_allow_long_to_int() -> None: +def test_allow_long_to_int(table_v2: Table) -> None: current_schema = Schema(NestedField(field_id=1, name="aCol", field_type=LongType(), required=False)) new_schema = Schema(NestedField(field_id=1, name="aCol", field_type=IntegerType(), required=False)) - applied = UpdateSchema(transaction=None, schema=current_schema).union_by_name(new_schema)._apply() # type: ignore + applied = UpdateSchema(transaction=Transaction(table_v2), schema=current_schema).union_by_name(new_schema)._apply() assert applied.as_struct() == current_schema.as_struct() assert len(applied.fields) == 1 assert isinstance(applied.fields[0].field_type, LongType) -def test_promote_int_to_long() -> None: +def test_promote_int_to_long(table_v2: Table) -> None: current_schema = Schema(NestedField(field_id=1, name="aCol", field_type=IntegerType(), required=False)) new_schema = Schema(NestedField(field_id=1, name="aCol", field_type=LongType(), required=False)) - applied = UpdateSchema(transaction=None, schema=current_schema).union_by_name(new_schema)._apply() # type: ignore + applied = UpdateSchema(transaction=Transaction(table_v2), schema=current_schema).union_by_name(new_schema)._apply() assert applied.as_struct() == new_schema.as_struct() assert len(applied.fields) == 1 assert isinstance(applied.fields[0].field_type, LongType) -def test_detect_invalid_promotion_string_to_float() -> None: +def test_detect_invalid_promotion_string_to_float(table_v2: Table) -> None: current_schema = Schema(NestedField(field_id=1, name="aCol", field_type=StringType(), required=False)) new_schema = Schema(NestedField(field_id=1, name="aCol", field_type=FloatType(), required=False)) with pytest.raises(ValidationError, match="Cannot change column type: aCol: string -> float"): - _ = UpdateSchema(transaction=None, schema=current_schema).union_by_name(new_schema)._apply() # type: ignore + _ = UpdateSchema(transaction=Transaction(table_v2), schema=current_schema).union_by_name(new_schema)._apply() # decimal(P,S) Fixed-point decimal; precision P, scale S -> Scale is fixed [1], # precision must be 38 or less -def test_type_promote_decimal_to_fixed_scale_with_wider_precision() -> None: +def test_type_promote_decimal_to_fixed_scale_with_wider_precision(table_v2: Table) -> None: current_schema = Schema(NestedField(field_id=1, name="aCol", field_type=DecimalType(precision=20, scale=1), required=False)) new_schema = Schema(NestedField(field_id=1, name="aCol", field_type=DecimalType(precision=22, scale=1), required=False)) - applied = UpdateSchema(transaction=None, schema=current_schema).union_by_name(new_schema)._apply() # type: ignore + applied = UpdateSchema(transaction=Transaction(table_v2), schema=current_schema).union_by_name(new_schema)._apply() assert applied.as_struct() == new_schema.as_struct() assert len(applied.fields) == 1 @@ -1260,7 +1261,7 @@ def test_type_promote_decimal_to_fixed_scale_with_wider_precision() -> None: assert decimal_type.scale == 1 -def test_add_nested_structs(primitive_fields: NestedField) -> None: +def test_add_nested_structs(primitive_fields: NestedField, table_v2: Table) -> None: schema = Schema( NestedField( field_id=1, @@ -1317,7 +1318,7 @@ def test_add_nested_structs(primitive_fields: NestedField) -> None: required=False, ) ) - applied = UpdateSchema(transaction=None, schema=schema).union_by_name(new_schema)._apply() # type: ignore + applied = UpdateSchema(transaction=Transaction(table_v2), schema=schema).union_by_name(new_schema)._apply() expected = Schema( NestedField( @@ -1352,15 +1353,15 @@ def test_add_nested_structs(primitive_fields: NestedField) -> None: assert applied.as_struct() == expected.as_struct() -def test_replace_list_with_primitive() -> None: +def test_replace_list_with_primitive(table_v2: Table) -> None: current_schema = Schema(NestedField(field_id=1, name="aCol", field_type=ListType(element_id=2, element_type=StringType()))) new_schema = Schema(NestedField(field_id=1, name="aCol", field_type=StringType())) with pytest.raises(ValidationError, match="Cannot change column type: list is not a primitive"): - _ = UpdateSchema(transaction=None, schema=current_schema).union_by_name(new_schema)._apply() # type: ignore + _ = UpdateSchema(transaction=Transaction(table_v2), schema=current_schema).union_by_name(new_schema)._apply() -def test_mirrored_schemas() -> None: +def test_mirrored_schemas(table_v2: Table) -> None: current_schema = Schema( NestedField(9, "struct1", StructType(NestedField(8, "string1", StringType(), required=False)), required=False), NestedField(6, "list1", ListType(element_id=7, element_type=StringType(), element_required=False), required=False), @@ -1380,12 +1381,12 @@ def test_mirrored_schemas() -> None: NestedField(9, "string6", StringType(), required=False), ) - applied = UpdateSchema(transaction=None, schema=current_schema).union_by_name(mirrored_schema)._apply() # type: ignore + applied = UpdateSchema(transaction=Transaction(table_v2), schema=current_schema).union_by_name(mirrored_schema)._apply() assert applied.as_struct() == current_schema.as_struct() -def test_add_new_top_level_struct() -> None: +def test_add_new_top_level_struct(table_v2: Table) -> None: current_schema = Schema( NestedField( 1, @@ -1432,12 +1433,12 @@ def test_add_new_top_level_struct() -> None: ), ) - applied = UpdateSchema(transaction=None, schema=current_schema).union_by_name(observed_schema)._apply() # type: ignore + applied = UpdateSchema(transaction=Transaction(table_v2), schema=current_schema).union_by_name(observed_schema)._apply() assert applied.as_struct() == observed_schema.as_struct() -def test_append_nested_struct() -> None: +def test_append_nested_struct(table_v2: Table) -> None: current_schema = Schema( NestedField( field_id=1, @@ -1511,12 +1512,12 @@ def test_append_nested_struct() -> None: ) ) - applied = UpdateSchema(transaction=None, schema=current_schema).union_by_name(observed_schema)._apply() # type: ignore + applied = UpdateSchema(transaction=Transaction(table_v2), schema=current_schema).union_by_name(observed_schema)._apply() assert applied.as_struct() == observed_schema.as_struct() -def test_append_nested_lists() -> None: +def test_append_nested_lists(table_v2: Table) -> None: current_schema = Schema( NestedField( field_id=1, @@ -1576,7 +1577,7 @@ def test_append_nested_lists() -> None: required=False, ) ) - union = UpdateSchema(transaction=None, schema=current_schema).union_by_name(observed_schema)._apply() # type: ignore + union = UpdateSchema(transaction=Transaction(table_v2), schema=current_schema).union_by_name(observed_schema)._apply() expected = Schema( NestedField( @@ -1617,7 +1618,7 @@ def test_append_nested_lists() -> None: assert union.as_struct() == expected.as_struct() -def test_union_with_pa_schema(primitive_fields: NestedField) -> None: +def test_union_with_pa_schema(primitive_fields: NestedField, table_v2: Table) -> None: base_schema = Schema(NestedField(field_id=1, name="foo", field_type=StringType(), required=True)) pa_schema = pa.schema( @@ -1628,7 +1629,7 @@ def test_union_with_pa_schema(primitive_fields: NestedField) -> None: ] ) - new_schema = UpdateSchema(transaction=None, schema=base_schema).union_by_name(pa_schema)._apply() # type: ignore + new_schema = UpdateSchema(transaction=Transaction(table_v2), schema=base_schema).union_by_name(pa_schema)._apply() expected_schema = Schema( NestedField(field_id=1, name="foo", field_type=StringType(), required=True),