Skip to content

Commit 5a920cd

Browse files
authored
Arrow: Remove check for supported Arrow transforms (#2340)
All the transforms support Arrow, so we can remove the check; less code is more!
1 parent cf987c6 commit 5a920cd

File tree

3 files changed

+15
-34
lines changed

3 files changed

+15
-34
lines changed

pyiceberg/table/__init__.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,9 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive
431431
name_mapping=self.table_metadata.name_mapping(),
432432
)
433433

434-
def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> UpdateSnapshot:
434+
def update_snapshot(
435+
self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH
436+
) -> UpdateSnapshot:
435437
"""Create a new UpdateSnapshot to produce a new snapshot for the table.
436438
437439
Returns:
@@ -470,12 +472,6 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT,
470472
if not isinstance(df, pa.Table):
471473
raise ValueError(f"Expected PyArrow table, got: {df}")
472474

473-
if unsupported_partitions := [
474-
field for field in self.table_metadata.spec().fields if not field.transform.supports_pyarrow_transform
475-
]:
476-
raise ValueError(
477-
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
478-
)
479475
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
480476
_check_pyarrow_schema_compatible(
481477
self.table_metadata.schema(),
@@ -592,12 +588,6 @@ def overwrite(
592588
if not isinstance(df, pa.Table):
593589
raise ValueError(f"Expected PyArrow table, got: {df}")
594590

595-
if unsupported_partitions := [
596-
field for field in self.table_metadata.spec().fields if not field.transform.supports_pyarrow_transform
597-
]:
598-
raise ValueError(
599-
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
600-
)
601591
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
602592
_check_pyarrow_schema_compatible(
603593
self.table_metadata.schema(),

pyiceberg/transforms.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,6 @@ def __eq__(self, other: Any) -> bool:
212212
return self.root == other.root
213213
return False
214214

215-
@property
216-
def supports_pyarrow_transform(self) -> bool:
217-
return False
218-
219215
@abstractmethod
220216
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ...
221217

@@ -399,10 +395,6 @@ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Arr
399395
pyiceberg_core_transform = _try_import("pyiceberg_core", extras_name="pyiceberg-core").transform
400396
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.bucket, self._num_buckets)
401397

402-
@property
403-
def supports_pyarrow_transform(self) -> bool:
404-
return True
405-
406398

407399
class TimeResolution(IntEnum):
408400
YEAR = 6
@@ -462,10 +454,6 @@ def dedup_name(self) -> str:
462454
def preserves_order(self) -> bool:
463455
return True
464456

465-
@property
466-
def supports_pyarrow_transform(self) -> bool:
467-
return True
468-
469457

470458
class YearTransform(TimeTransform[S]):
471459
"""Transforms a datetime value into a year value.
@@ -781,10 +769,6 @@ def __repr__(self) -> str:
781769
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
782770
return lambda v: v
783771

784-
@property
785-
def supports_pyarrow_transform(self) -> bool:
786-
return True
787-
788772

789773
class TruncateTransform(Transform[S, S]):
790774
"""A transform for truncating a value to a specified width.
@@ -931,10 +915,6 @@ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Arr
931915

932916
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.truncate, self._width)
933917

934-
@property
935-
def supports_pyarrow_transform(self) -> bool:
936-
return True
937-
938918

939919
@singledispatch
940920
def _human_string(value: Any, _type: IcebergType) -> str:
@@ -1049,7 +1029,12 @@ def __repr__(self) -> str:
10491029
return "VoidTransform()"
10501030

10511031
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
1052-
raise NotImplementedError()
1032+
try:
1033+
import pyarrow as pa
1034+
except ModuleNotFoundError as e:
1035+
raise ModuleNotFoundError("For partition transforms, PyArrow needs to be installed") from e
1036+
1037+
return lambda arr: pa.nulls(len(arr), type=arr.type)
10531038

10541039

10551040
def _truncate_number(

tests/test_transforms.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1655,6 +1655,12 @@ def test_bucket_pyarrow_transforms(
16551655
assert expected == transform.pyarrow_transform(source_type)(input_arr)
16561656

16571657

1658+
def test_bucket_pyarrow_void_transform() -> None:
1659+
input_arr = pa.chunked_array([pa.array([1, 2], type=pa.int32()), pa.array([3, 4], type=pa.int32())])
1660+
output_arr = pa.array([None, None, None, None], type=pa.int32())
1661+
assert output_arr == VoidTransform().pyarrow_transform(IntegerType())(input_arr)
1662+
1663+
16581664
@pytest.mark.parametrize(
16591665
"source_type, input_arr, expected, width",
16601666
[

0 commit comments

Comments
 (0)