Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 3 additions & 13 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,9 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive
name_mapping=self.table_metadata.name_mapping(),
)

def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> UpdateSnapshot:
def update_snapshot(
self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

) -> UpdateSnapshot:
"""Create a new UpdateSnapshot to produce a new snapshot for the table.
Returns:
Expand Down Expand Up @@ -470,12 +472,6 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT,
if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

if unsupported_partitions := [
field for field in self.table_metadata.spec().fields if not field.transform.supports_pyarrow_transform
]:
raise ValueError(
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
)
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_pyarrow_schema_compatible(
self.table_metadata.schema(),
Expand Down Expand Up @@ -592,12 +588,6 @@ def overwrite(
if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

if unsupported_partitions := [
field for field in self.table_metadata.spec().fields if not field.transform.supports_pyarrow_transform
]:
raise ValueError(
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
)
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_pyarrow_schema_compatible(
self.table_metadata.schema(),
Expand Down
27 changes: 6 additions & 21 deletions pyiceberg/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,6 @@ def __eq__(self, other: Any) -> bool:
return self.root == other.root
return False

@property
def supports_pyarrow_transform(self) -> bool:
return False

@abstractmethod
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ...

Expand Down Expand Up @@ -399,10 +395,6 @@ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Arr
pyiceberg_core_transform = _try_import("pyiceberg_core", extras_name="pyiceberg-core").transform
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.bucket, self._num_buckets)

@property
def supports_pyarrow_transform(self) -> bool:
return True


class TimeResolution(IntEnum):
YEAR = 6
Expand Down Expand Up @@ -462,10 +454,6 @@ def dedup_name(self) -> str:
def preserves_order(self) -> bool:
return True

@property
def supports_pyarrow_transform(self) -> bool:
return True


class YearTransform(TimeTransform[S]):
"""Transforms a datetime value into a year value.
Expand Down Expand Up @@ -781,10 +769,6 @@ def __repr__(self) -> str:
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
return lambda v: v

@property
def supports_pyarrow_transform(self) -> bool:
return True


class TruncateTransform(Transform[S, S]):
"""A transform for truncating a value to a specified width.
Expand Down Expand Up @@ -931,10 +915,6 @@ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Arr

return _pyiceberg_transform_wrapper(pyiceberg_core_transform.truncate, self._width)

@property
def supports_pyarrow_transform(self) -> bool:
return True


@singledispatch
def _human_string(value: Any, _type: IcebergType) -> str:
Expand Down Expand Up @@ -1049,7 +1029,12 @@ def __repr__(self) -> str:
return "VoidTransform()"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
raise NotImplementedError()
try:
import pyarrow as pa
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For partition transforms, PyArrow needs to be installed") from e

return lambda arr: pa.nulls(len(arr), type=arr.type)


def _truncate_number(
Expand Down
6 changes: 6 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,6 +1655,12 @@ def test_bucket_pyarrow_transforms(
assert expected == transform.pyarrow_transform(source_type)(input_arr)


def test_bucket_pyarrow_void_transform() -> None:
input_arr = pa.chunked_array([pa.array([1, 2], type=pa.int32()), pa.array([3, 4], type=pa.int32())])
output_arr = pa.array([None, None, None, None], type=pa.int32())
assert output_arr == VoidTransform().pyarrow_transform(IntegerType())(input_arr)


@pytest.mark.parametrize(
"source_type, input_arr, expected, width",
[
Expand Down