diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index b8f0b975e6..4a652a2f4a 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -817,10 +817,11 @@ def strict_project(self, name: str, pred: BoundPredicate[Any]) -> Optional[Unbou if isinstance(pred.term, BoundTransform): return _project_transform_predicate(self, name, pred) + if isinstance(pred, BoundUnaryPredicate): + return pred.as_unbound(Reference(name)) + if isinstance(field_type, (IntegerType, LongType, DecimalType)): - if isinstance(pred, BoundUnaryPredicate): - return pred.as_unbound(Reference(name)) - elif isinstance(pred, BoundLiteralPredicate): + if isinstance(pred, BoundLiteralPredicate): return _truncate_number_strict(name, pred, self.transform(field_type)) elif isinstance(pred, BoundNotIn): return _set_apply_transform(name, pred, self.transform(field_type)) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 51e8e23953..3ad3ff5a84 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1281,6 +1281,9 @@ def test_negative_year_strict_upper_bound(bound_reference_date: BoundReference[i def test_strict_bucket_integer(bound_reference_int: BoundReference[int]) -> None: value = literal(100).to(IntegerType()) transform = BucketTransform(num_buckets=10) + + _assert_projection_strict(BoundIsNull(term=bound_reference_int), transform, AlwaysFalse) + _assert_projection_strict(BoundNotEqualTo(term=bound_reference_int, literal=value), transform, NotEqualTo, "6") for expr in [BoundEqualTo, BoundLessThan, BoundLessThanOrEqual, BoundGreaterThan, BoundGreaterThanOrEqual]: @@ -1294,6 +1297,9 @@ def test_strict_bucket_integer(bound_reference_int: BoundReference[int]) -> None def test_strict_bucket_long(bound_reference_long: BoundReference[int]) -> None: value = literal(100).to(LongType()) transform = BucketTransform(num_buckets=10) + + _assert_projection_strict(BoundIsNull(term=bound_reference_long), transform, AlwaysFalse) + _assert_projection_strict(BoundNotEqualTo(term=bound_reference_long, literal=value), transform, NotEqualTo, "6") for expr in [BoundEqualTo, BoundLessThan, BoundLessThanOrEqual, BoundGreaterThan, BoundGreaterThanOrEqual]: @@ -1308,6 +1314,9 @@ def test_strict_bucket_decimal(bound_reference_decimal: BoundReference[int]) -> dec = DecimalType(9, 2) value = literal("100.00").to(dec) transform = BucketTransform(num_buckets=10) + + _assert_projection_strict(BoundIsNull(term=bound_reference_decimal), transform, AlwaysFalse) + _assert_projection_strict(BoundNotEqualTo(term=bound_reference_decimal, literal=value), transform, NotEqualTo, "2") for expr in [BoundEqualTo, BoundLessThan, BoundLessThanOrEqual, BoundGreaterThan, BoundGreaterThanOrEqual]: @@ -1321,6 +1330,9 @@ def test_strict_bucket_decimal(bound_reference_decimal: BoundReference[int]) -> def test_strict_bucket_string(bound_reference_str: BoundReference[int]) -> None: value = literal("abcdefg").to(StringType()) transform = BucketTransform(num_buckets=10) + + _assert_projection_strict(BoundIsNull(term=bound_reference_str), transform, AlwaysFalse) + _assert_projection_strict(BoundNotEqualTo(term=bound_reference_str, literal=value), transform, NotEqualTo, "4") for expr in [BoundEqualTo, BoundLessThan, BoundLessThanOrEqual, BoundGreaterThan, BoundGreaterThanOrEqual]: @@ -1334,6 +1346,9 @@ def test_strict_bucket_string(bound_reference_str: BoundReference[int]) -> None: def test_strict_bucket_bytes(bound_reference_binary: BoundReference[int]) -> None: value = literal(str.encode("abcdefg")).to(BinaryType()) transform = BucketTransform(num_buckets=10) + + _assert_projection_strict(BoundIsNull(term=bound_reference_binary), transform, AlwaysFalse) + _assert_projection_strict(BoundNotEqualTo(term=bound_reference_binary, literal=value), transform, NotEqualTo, "4") for expr in [BoundEqualTo, BoundLessThan, BoundLessThanOrEqual, BoundGreaterThan, BoundGreaterThanOrEqual]: @@ -1347,6 +1362,9 @@ def test_strict_bucket_bytes(bound_reference_binary: BoundReference[int]) -> Non def test_strict_bucket_uuid(bound_reference_uuid: BoundReference[int]) -> None: value = literal("00000000-0000-007b-0000-0000000001c8").to(UUIDType()) transform = BucketTransform(num_buckets=10) + + _assert_projection_strict(BoundIsNull(term=bound_reference_uuid), transform, AlwaysFalse) + _assert_projection_strict(BoundNotEqualTo(term=bound_reference_uuid, literal=value), transform, NotEqualTo, "4") for expr in [BoundEqualTo, BoundLessThan, BoundLessThanOrEqual, BoundGreaterThan, BoundGreaterThanOrEqual]: