Skip to content
36 changes: 31 additions & 5 deletions pyiceberg/expressions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,12 +743,18 @@ def as_bound(self) -> Type[BoundNotIn[L]]:
return BoundNotIn[L]


class LiteralPredicate(UnboundPredicate[L], ABC):
literal: Literal[L]
class LiteralPredicate(IcebergBaseModel, UnboundPredicate[L], ABC):
type: TypingLiteral["lt", "lt-eq", "gt", "gt-eq", "eq", "not-eq", "starts-with", "not-starts-with"] = Field(alias="type")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could you include this in the PR description so its easily referenced in the future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

term: UnboundTerm[Any]
value: Literal[L] = Field()
model_config = ConfigDict(populate_by_name=True, frozen=True, arbitrary_types_allowed=True)

def __init__(self, term: Union[str, UnboundTerm[Any]], literal: Union[L, Literal[L]]): # pylint: disable=W0621
super().__init__(term)
self.literal = _to_literal(literal) # pylint: disable=W0621
def __init__(self, term: Union[str, UnboundTerm[Any]], literal: Union[L, Literal[L]]):
super().__init__(term=_to_unbound_term(term), value=_to_literal(literal)) # type: ignore[call-arg]

@property
def literal(self) -> Literal[L]:
return self.value

def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundLiteralPredicate[L]:
bound_term = self.term.bind(schema, case_sensitive)
Expand All @@ -773,6 +779,10 @@ def __eq__(self, other: Any) -> bool:
return self.term == other.term and self.literal == other.literal
return False

def __str__(self) -> str:
"""Return the string representation of the LiteralPredicate class."""
return f"{str(self.__class__.__name__)}(term={repr(self.term)}, literal={repr(self.literal)})"

def __repr__(self) -> str:
"""Return the string representation of the LiteralPredicate class."""
return f"{str(self.__class__.__name__)}(term={repr(self.term)}, literal={repr(self.literal)})"
Expand Down Expand Up @@ -886,6 +896,8 @@ def as_unbound(self) -> Type[NotStartsWith[L]]:


class EqualTo(LiteralPredicate[L]):
type: TypingLiteral["eq"] = Field(default="eq", alias="type")

def __invert__(self) -> NotEqualTo[L]:
"""Transform the Expression into its negated version."""
return NotEqualTo[L](self.term, self.literal)
Expand All @@ -896,6 +908,8 @@ def as_bound(self) -> Type[BoundEqualTo[L]]:


class NotEqualTo(LiteralPredicate[L]):
type: TypingLiteral["not-eq"] = Field(default="not-eq", alias="type")

def __invert__(self) -> EqualTo[L]:
"""Transform the Expression into its negated version."""
return EqualTo[L](self.term, self.literal)
Expand All @@ -906,6 +920,8 @@ def as_bound(self) -> Type[BoundNotEqualTo[L]]:


class LessThan(LiteralPredicate[L]):
type: TypingLiteral["lt"] = Field(default="lt", alias="type")

def __invert__(self) -> GreaterThanOrEqual[L]:
"""Transform the Expression into its negated version."""
return GreaterThanOrEqual[L](self.term, self.literal)
Expand All @@ -916,6 +932,8 @@ def as_bound(self) -> Type[BoundLessThan[L]]:


class GreaterThanOrEqual(LiteralPredicate[L]):
type: TypingLiteral["gt-eq"] = Field(default="gt-eq", alias="type")

def __invert__(self) -> LessThan[L]:
"""Transform the Expression into its negated version."""
return LessThan[L](self.term, self.literal)
Expand All @@ -926,6 +944,8 @@ def as_bound(self) -> Type[BoundGreaterThanOrEqual[L]]:


class GreaterThan(LiteralPredicate[L]):
type: TypingLiteral["gt"] = Field(default="gt", alias="type")

def __invert__(self) -> LessThanOrEqual[L]:
"""Transform the Expression into its negated version."""
return LessThanOrEqual[L](self.term, self.literal)
Expand All @@ -936,6 +956,8 @@ def as_bound(self) -> Type[BoundGreaterThan[L]]:


class LessThanOrEqual(LiteralPredicate[L]):
type: TypingLiteral["lt-eq"] = Field(default="lt-eq", alias="type")

def __invert__(self) -> GreaterThan[L]:
"""Transform the Expression into its negated version."""
return GreaterThan[L](self.term, self.literal)
Expand All @@ -946,6 +968,8 @@ def as_bound(self) -> Type[BoundLessThanOrEqual[L]]:


class StartsWith(LiteralPredicate[L]):
type: TypingLiteral["starts-with"] = Field(default="starts-with", alias="type")

def __invert__(self) -> NotStartsWith[L]:
"""Transform the Expression into its negated version."""
return NotStartsWith[L](self.term, self.literal)
Expand All @@ -956,6 +980,8 @@ def as_bound(self) -> Type[BoundStartsWith[L]]:


class NotStartsWith(LiteralPredicate[L]):
type: TypingLiteral["not-starts-with"] = Field(default="not-starts-with", alias="type")

def __invert__(self) -> StartsWith[L]:
"""Transform the Expression into its negated version."""
return StartsWith[L](self.term, self.literal)
Expand Down
2 changes: 1 addition & 1 deletion pyiceberg/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _try_import(module_name: str, extras_name: Optional[str] = None) -> types.Mo
raise NotInstalledError(msg) from None


def _transform_literal(func: Callable[[L], L], lit: Literal[L]) -> Literal[L]:
def _transform_literal(func: Callable[[Any], Any], lit: Literal[L]) -> Literal[L]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is this change relevant? i dont see _transform_literal used anywhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to silence this mypy errors:

- hook id: mypy
- exit code: 1
  pyiceberg/transforms.py:1049: error: Argument 1 to "_transform_literal" has incompatible type "Callable[[str | None], str | None]"; expected "Callable[[str], str]"  [arg-type]
  pyiceberg/transforms.py:1049: error: Argument 1 to "_transform_literal" has incompatible type "Callable[[bool | None], bool | None]"; expected "Callable[[str], str]"  [arg-type]
  pyiceberg/transforms.py:1049: error: Argument 1 to "_transform_literal" has incompatible type "Callable[[int | None], int | None]"; expected "Callable[[str], str]"  [arg-type]
  pyiceberg/transforms.py:1049: error: Argument 1 to "_transform_literal" has incompatible type "Callable[[float | None], float | None]"; expected "Callable[[str], str]"  [arg-type]
  pyiceberg/transforms.py:1049: error: Argument 1 to "_transform_literal" has incompatible type "Callable[[bytes | None], bytes | None]"; expected "Callable[[str], str]"  [arg-type]
  pyiceberg/transforms.py:1049: error: Argument 1 to "_transform_literal" has incompatible type "Callable[[UUID | None], UUID | None]"; expected "Callable[[str], str]"  [arg-type]
  pyiceberg/transforms.py:1049: error: Argument 1 to "_transform_literal" has incompatible type "Callable[[Decimal | None], Decimal | None]"; expected "Callable[[str], str]"  [arg-type]
  pyiceberg/transforms.py:1049: error: Argument 1 to "_transform_literal" has incompatible type "Callable[[datetime | None], datetime | None]"; expected "Callable[[str], str]"  [arg-type]
  pyiceberg/transforms.py:1049: error: Argument 1 to "_transform_literal" has incompatible type "Callable[[date | None], date | None]"; expected "Callable[[str], str]"  [arg-type]
  pyiceberg/transforms.py:1049: error: Argument 1 to "_transform_literal" has incompatible type "Callable[[time | None], time | None]"; expected "Callable[[str], str]"  [arg-type]
  pyiceberg/transforms.py:1051: error: Argument 1 to "_transform_literal" has incompatible type "Callable[[str | None], str | None]"; expected "Callable[[str], str]"  [arg-type]```

"""Small helper to upwrap the value from the literal, and wrap it again."""
return literal(func(lit.value))

Expand Down
46 changes: 25 additions & 21 deletions tests/expressions/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pyiceberg.conversions import to_bytes
from pyiceberg.expressions import (
And,
BooleanExpression,
EqualTo,
GreaterThan,
GreaterThanOrEqual,
Expand All @@ -30,6 +31,7 @@
IsNull,
LessThan,
LessThanOrEqual,
LiteralPredicate,
Not,
NotEqualTo,
NotIn,
Expand Down Expand Up @@ -301,7 +303,7 @@ def test_missing_stats() -> None:
upper_bounds=None,
)

expressions = [
expressions: list[BooleanExpression] = [
LessThan("no_stats", 5),
LessThanOrEqual("no_stats", 30),
EqualTo("no_stats", 70),
Expand All @@ -324,7 +326,7 @@ def test_zero_record_file_stats(schema_data_file: Schema) -> None:
file_path="file_1.parquet", file_format=FileFormat.PARQUET, partition=Record(), record_count=0
)

expressions = [
expressions: list[BooleanExpression] = [
LessThan("no_stats", 5),
LessThanOrEqual("no_stats", 30),
EqualTo("no_stats", 70),
Expand Down Expand Up @@ -683,26 +685,27 @@ def data_file_nan() -> DataFile:


def test_inclusive_metrics_evaluator_less_than_and_less_than_equal(schema_data_file_nan: Schema, data_file_nan: DataFile) -> None:
for operator in [LessThan, LessThanOrEqual]: # type: ignore
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
operators: tuple[type[LiteralPredicate[Any]], ...] = (LessThan, LessThanOrEqual)
for operator in operators:
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan)
assert not should_read, "Should not match: all nan column doesn't contain number"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 1)).eval(data_file_nan)
assert not should_read, "Should not match: 1 is smaller than lower bound"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 10)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 10)).eval(data_file_nan)
assert should_read, "Should match: 10 is larger than lower bound"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("min_max_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("min_max_nan", 1)).eval(data_file_nan)
assert should_read, "Should match: no visibility"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan_null_bounds", 1)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan_null_bounds", 1)).eval(data_file_nan)
assert not should_read, "Should not match: all nan column doesn't contain number"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 1)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 1)).eval(data_file_nan)
assert not should_read, "Should not match: 1 is smaller than lower bound"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 10)).eval( # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 10)).eval(
data_file_nan
)
assert should_read, "Should match: 10 larger than lower bound"
Expand All @@ -711,31 +714,32 @@ def test_inclusive_metrics_evaluator_less_than_and_less_than_equal(schema_data_f
def test_inclusive_metrics_evaluator_greater_than_and_greater_than_equal(
schema_data_file_nan: Schema, data_file_nan: DataFile
) -> None:
for operator in [GreaterThan, GreaterThanOrEqual]: # type: ignore
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
operators: tuple[type[LiteralPredicate[Any]], ...] = (GreaterThan, GreaterThanOrEqual)
for operator in operators:
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan)
assert not should_read, "Should not match: all nan column doesn't contain number"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 1)).eval(data_file_nan)
assert should_read, "Should match: upper bound is larger than 1"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 10)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 10)).eval(data_file_nan)
assert should_read, "Should match: upper bound is larger than 10"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("min_max_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("min_max_nan", 1)).eval(data_file_nan)
assert should_read, "Should match: no visibility"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan_null_bounds", 1)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan_null_bounds", 1)).eval(data_file_nan)
assert not should_read, "Should not match: all nan column doesn't contain number"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 1)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 1)).eval(data_file_nan)
assert should_read, "Should match: 1 is smaller than upper bound"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 10)).eval( # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 10)).eval(
data_file_nan
)
assert should_read, "Should match: 10 is smaller than upper bound"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 30)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 30)).eval(data_file_nan)
assert not should_read, "Should not match: 30 is greater than upper bound"


Expand Down Expand Up @@ -1162,7 +1166,7 @@ def test_strict_missing_stats(strict_data_file_schema: Schema, strict_data_file_
upper_bounds=None,
)

expressions = [
expressions: list[BooleanExpression] = [
LessThan("no_stats", 5),
LessThanOrEqual("no_stats", 30),
EqualTo("no_stats", 70),
Expand All @@ -1185,7 +1189,7 @@ def test_strict_zero_record_file_stats(strict_data_file_schema: Schema) -> None:
file_path="file_1.parquet", file_format=FileFormat.PARQUET, partition=Record(), record_count=0
)

expressions = [
expressions: list[BooleanExpression] = [
LessThan("no_stats", 5),
LessThanOrEqual("no_stats", 30),
EqualTo("no_stats", 70),
Expand Down
31 changes: 29 additions & 2 deletions tests/expressions/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,22 @@
IsNull,
LessThan,
LessThanOrEqual,
LiteralPredicate,
Not,
NotEqualTo,
NotIn,
NotNaN,
NotNull,
NotStartsWith,
Or,
Reference,
StartsWith,
UnboundPredicate,
)
from pyiceberg.expressions.literals import Literal, literal
from pyiceberg.expressions.visitors import _from_byte_buffer
from pyiceberg.schema import Accessor, Schema
from pyiceberg.typedef import Record
from pyiceberg.typedef import L, Record
from pyiceberg.types import (
DecimalType,
DoubleType,
Expand Down Expand Up @@ -915,6 +918,7 @@ def test_bound_less_than_or_equal(term: BoundReference[Any]) -> None:

def test_equal_to() -> None:
equal_to = EqualTo(Reference("a"), literal("a"))
assert equal_to.model_dump_json() == '{"term":"a","type":"eq","value":"a"}'
assert str(equal_to) == "EqualTo(term=Reference(name='a'), literal=literal('a'))"
assert repr(equal_to) == "EqualTo(term=Reference(name='a'), literal=literal('a'))"
assert equal_to == eval(repr(equal_to))
Expand All @@ -923,6 +927,7 @@ def test_equal_to() -> None:

def test_not_equal_to() -> None:
not_equal_to = NotEqualTo(Reference("a"), literal("a"))
assert not_equal_to.model_dump_json() == '{"term":"a","type":"not-eq","value":"a"}'
assert str(not_equal_to) == "NotEqualTo(term=Reference(name='a'), literal=literal('a'))"
assert repr(not_equal_to) == "NotEqualTo(term=Reference(name='a'), literal=literal('a'))"
assert not_equal_to == eval(repr(not_equal_to))
Expand All @@ -931,6 +936,7 @@ def test_not_equal_to() -> None:

def test_greater_than_or_equal_to() -> None:
greater_than_or_equal_to = GreaterThanOrEqual(Reference("a"), literal("a"))
assert greater_than_or_equal_to.model_dump_json() == '{"term":"a","type":"gt-eq","value":"a"}'
assert str(greater_than_or_equal_to) == "GreaterThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
assert repr(greater_than_or_equal_to) == "GreaterThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
assert greater_than_or_equal_to == eval(repr(greater_than_or_equal_to))
Expand All @@ -939,6 +945,7 @@ def test_greater_than_or_equal_to() -> None:

def test_greater_than() -> None:
greater_than = GreaterThan(Reference("a"), literal("a"))
assert greater_than.model_dump_json() == '{"term":"a","type":"gt","value":"a"}'
assert str(greater_than) == "GreaterThan(term=Reference(name='a'), literal=literal('a'))"
assert repr(greater_than) == "GreaterThan(term=Reference(name='a'), literal=literal('a'))"
assert greater_than == eval(repr(greater_than))
Expand All @@ -947,6 +954,7 @@ def test_greater_than() -> None:

def test_less_than() -> None:
less_than = LessThan(Reference("a"), literal("a"))
assert less_than.model_dump_json() == '{"term":"a","type":"lt","value":"a"}'
assert str(less_than) == "LessThan(term=Reference(name='a'), literal=literal('a'))"
assert repr(less_than) == "LessThan(term=Reference(name='a'), literal=literal('a'))"
assert less_than == eval(repr(less_than))
Expand All @@ -955,12 +963,23 @@ def test_less_than() -> None:

def test_less_than_or_equal() -> None:
less_than_or_equal = LessThanOrEqual(Reference("a"), literal("a"))
assert less_than_or_equal.model_dump_json() == '{"term":"a","type":"lt-eq","value":"a"}'
assert str(less_than_or_equal) == "LessThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
assert repr(less_than_or_equal) == "LessThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
assert less_than_or_equal == eval(repr(less_than_or_equal))
assert less_than_or_equal == pickle.loads(pickle.dumps(less_than_or_equal))


def test_starts_with() -> None:
starts_with = StartsWith(Reference("a"), literal("a"))
assert starts_with.model_dump_json() == '{"term":"a","type":"starts-with","value":"a"}'


def test_not_starts_with() -> None:
not_starts_with = NotStartsWith(Reference("a"), literal("a"))
assert not_starts_with.model_dump_json() == '{"term":"a","type":"not-starts-with","value":"a"}'


def test_bound_reference_eval(table_schema_simple: Schema) -> None:
"""Test creating a BoundReference and evaluating it on a StructProtocol"""
struct = Record("foovalue", 123, True)
Expand Down Expand Up @@ -1199,7 +1218,15 @@ def test_bind_ambiguous_name() -> None:
# |_| |_|\_, |_| \_, |
# |__/ |__/

assert_type(EqualTo("a", "b"), EqualTo[str])

def _assert_literal_predicate_type(expr: LiteralPredicate[L]) -> None:
assert_type(expr, LiteralPredicate[L])


_assert_literal_predicate_type(EqualTo("a", "b"))
_assert_literal_predicate_type(In("a", ("a", "b", "c")))
_assert_literal_predicate_type(In("a", (1, 2, 3)))
_assert_literal_predicate_type(NotIn("a", ("a", "b", "c")))
assert_type(In("a", ("a", "b", "c")), In[str])
assert_type(In("a", (1, 2, 3)), In[int])
assert_type(NotIn("a", ("a", "b", "c")), NotIn[str])
Comment on lines 1230 to 1232
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should we use _assert_literal_predicate_type for these too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing