From 521596fcfccf64fdc935937ed8f4aa88bb294b65 Mon Sep 17 00:00:00 2001 From: guptaakashdeep Date: Sat, 19 Apr 2025 18:12:21 +0530 Subject: [PATCH] Closes #1882 Changes to support string transform in add_field. --- pyiceberg/table/update/spec.py | 16 ++----- pyiceberg/transforms.py | 46 +++++++++---------- tests/integration/test_partition_evolution.py | 8 ++++ 3 files changed, 36 insertions(+), 34 deletions(-) diff --git a/pyiceberg/table/update/spec.py b/pyiceberg/table/update/spec.py index b732b2116e..1f91aa5d17 100644 --- a/pyiceberg/table/update/spec.py +++ b/pyiceberg/table/update/spec.py @@ -16,15 +16,7 @@ # under the License. from __future__ import annotations -from typing import ( - TYPE_CHECKING, - Any, - Dict, - List, - Optional, - Set, - Tuple, -) +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union from pyiceberg.expressions import ( Reference, @@ -47,7 +39,7 @@ UpdatesAndRequirements, UpdateTableMetadata, ) -from pyiceberg.transforms import IdentityTransform, TimeTransform, Transform, VoidTransform +from pyiceberg.transforms import IdentityTransform, TimeTransform, Transform, VoidTransform, parse_transform if TYPE_CHECKING: from pyiceberg.table import Transaction @@ -85,11 +77,13 @@ def __init__(self, transaction: Transaction, case_sensitive: bool = True) -> Non def add_field( self, source_column_name: str, - transform: Transform[Any, Any], + transform: Union[str, Transform[Any, Any]], partition_field_name: Optional[str] = None, ) -> UpdateSpec: ref = Reference(source_column_name) bound_ref = ref.bind(self._transaction.table_metadata.schema(), self._case_sensitive) + if isinstance(transform, str): + transform = parse_transform(transform) # verify transform can actually bind it output_type = bound_ref.field.field_type if not transform.can_transform(output_type): diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index 7833215d09..19889a98e8 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -111,29 +111,6 @@ def _transform_literal(func: Callable[[L], L], lit: Literal[L]) -> Literal[L]: return literal(func(lit.value)) -def parse_transform(v: Any) -> Any: - if isinstance(v, str): - if v == IDENTITY: - return IdentityTransform() - elif v == VOID: - return VoidTransform() - elif v.startswith(BUCKET): - return BucketTransform(num_buckets=BUCKET_PARSER.match(v)) - elif v.startswith(TRUNCATE): - return TruncateTransform(width=TRUNCATE_PARSER.match(v)) - elif v == YEAR: - return YearTransform() - elif v == MONTH: - return MonthTransform() - elif v == DAY: - return DayTransform() - elif v == HOUR: - return HourTransform() - else: - return UnknownTransform(transform=v) - return v - - class Transform(IcebergRootModel[str], ABC, Generic[S, T]): """Transform base class for concrete transforms. @@ -220,6 +197,29 @@ def _transform(array: "ArrayLike") -> "ArrayLike": return _transform +def parse_transform(v: Any) -> Transform[Any, Any]: + if isinstance(v, str): + if v == IDENTITY: + return IdentityTransform() + elif v == VOID: + return VoidTransform() + elif v.startswith(BUCKET): + return BucketTransform(num_buckets=BUCKET_PARSER.match(v)) + elif v.startswith(TRUNCATE): + return TruncateTransform(width=TRUNCATE_PARSER.match(v)) + elif v == YEAR: + return YearTransform() + elif v == MONTH: + return MonthTransform() + elif v == DAY: + return DayTransform() + elif v == HOUR: + return HourTransform() + else: + return UnknownTransform(transform=v) + return v + + class BucketTransform(Transform[S, int]): """Base Transform class to transform a value into a bucket partition value. diff --git a/tests/integration/test_partition_evolution.py b/tests/integration/test_partition_evolution.py index 0e607a46f0..d489d6a5d0 100644 --- a/tests/integration/test_partition_evolution.py +++ b/tests/integration/test_partition_evolution.py @@ -140,6 +140,14 @@ def test_add_hour(catalog: Catalog) -> None: _validate_new_partition_fields(table, 1000, 1, 1000, PartitionField(2, 1000, HourTransform(), "hour_transform")) +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_add_hour_string_transform(catalog: Catalog) -> None: + table = _table(catalog) + table.update_spec().add_field("event_ts", "hour", "str_hour_transform").commit() + _validate_new_partition_fields(table, 1000, 1, 1000, PartitionField(2, 1000, HourTransform(), "str_hour_transform")) + + @pytest.mark.integration @pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) def test_add_hour_generates_default_name(catalog: Catalog) -> None: