From 80e379da02de2557026083dd5bc12c5692a8438f Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Mon, 18 Aug 2025 10:15:09 +0200 Subject: [PATCH] Arrow: Remove check for supported Arrow transforms --- pyiceberg/table/__init__.py | 16 +++------------- pyiceberg/transforms.py | 27 ++++++--------------------- tests/test_transforms.py | 6 ++++++ 3 files changed, 15 insertions(+), 34 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 92bbd60358..5c82daa525 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -431,7 +431,9 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive name_mapping=self.table_metadata.name_mapping(), ) - def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> UpdateSnapshot: + def update_snapshot( + self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH + ) -> UpdateSnapshot: """Create a new UpdateSnapshot to produce a new snapshot for the table. Returns: @@ -470,12 +472,6 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") - if unsupported_partitions := [ - field for field in self.table_metadata.spec().fields if not field.transform.supports_pyarrow_transform - ]: - raise ValueError( - f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}." - ) downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False _check_pyarrow_schema_compatible( self.table_metadata.schema(), @@ -592,12 +588,6 @@ def overwrite( if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") - if unsupported_partitions := [ - field for field in self.table_metadata.spec().fields if not field.transform.supports_pyarrow_transform - ]: - raise ValueError( - f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}." - ) downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False _check_pyarrow_schema_compatible( self.table_metadata.schema(), diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index 3f5a8d8998..30b3929329 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -212,10 +212,6 @@ def __eq__(self, other: Any) -> bool: return self.root == other.root return False - @property - def supports_pyarrow_transform(self) -> bool: - return False - @abstractmethod def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ... @@ -399,10 +395,6 @@ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Arr pyiceberg_core_transform = _try_import("pyiceberg_core", extras_name="pyiceberg-core").transform return _pyiceberg_transform_wrapper(pyiceberg_core_transform.bucket, self._num_buckets) - @property - def supports_pyarrow_transform(self) -> bool: - return True - class TimeResolution(IntEnum): YEAR = 6 @@ -462,10 +454,6 @@ def dedup_name(self) -> str: def preserves_order(self) -> bool: return True - @property - def supports_pyarrow_transform(self) -> bool: - return True - class YearTransform(TimeTransform[S]): """Transforms a datetime value into a year value. @@ -781,10 +769,6 @@ def __repr__(self) -> str: def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": return lambda v: v - @property - def supports_pyarrow_transform(self) -> bool: - return True - class TruncateTransform(Transform[S, S]): """A transform for truncating a value to a specified width. @@ -931,10 +915,6 @@ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Arr return _pyiceberg_transform_wrapper(pyiceberg_core_transform.truncate, self._width) - @property - def supports_pyarrow_transform(self) -> bool: - return True - @singledispatch def _human_string(value: Any, _type: IcebergType) -> str: @@ -1049,7 +1029,12 @@ def __repr__(self) -> str: return "VoidTransform()" def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": - raise NotImplementedError() + try: + import pyarrow as pa + except ModuleNotFoundError as e: + raise ModuleNotFoundError("For partition transforms, PyArrow needs to be installed") from e + + return lambda arr: pa.nulls(len(arr), type=arr.type) def _truncate_number( diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 7a7d4a6d8e..deaf5d52b6 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1655,6 +1655,12 @@ def test_bucket_pyarrow_transforms( assert expected == transform.pyarrow_transform(source_type)(input_arr) +def test_bucket_pyarrow_void_transform() -> None: + input_arr = pa.chunked_array([pa.array([1, 2], type=pa.int32()), pa.array([3, 4], type=pa.int32())]) + output_arr = pa.array([None, None, None, None], type=pa.int32()) + assert output_arr == VoidTransform().pyarrow_transform(IntegerType())(input_arr) + + @pytest.mark.parametrize( "source_type, input_arr, expected, width", [