diff --git a/pyiceberg/expressions/__init__.py b/pyiceberg/expressions/__init__.py index c3b5ae74d6..3bf646ef69 100644 --- a/pyiceberg/expressions/__init__.py +++ b/pyiceberg/expressions/__init__.py @@ -31,6 +31,7 @@ TypeVar, Union, ) +from typing import Literal as TypingLiteral from pydantic import Field @@ -41,10 +42,15 @@ literal, ) from pyiceberg.schema import Accessor, Schema -from pyiceberg.typedef import IcebergRootModel, L, StructProtocol +from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel, L, StructProtocol from pyiceberg.types import DoubleType, FloatType, NestedField from pyiceberg.utils.singleton import Singleton +try: + from pydantic import ConfigDict +except ImportError: + ConfigDict = dict + def _to_unbound_term(term: Union[str, UnboundTerm[Any]]) -> UnboundTerm[Any]: return Reference(term) if isinstance(term, str) else term @@ -571,12 +577,14 @@ def as_bound(self) -> Type[BoundNotNaN[L]]: return BoundNotNaN[L] -class SetPredicate(UnboundPredicate[L], ABC): - literals: Set[Literal[L]] +class SetPredicate(IcebergBaseModel, UnboundPredicate[L], ABC): + model_config = ConfigDict(arbitrary_types_allowed=True) + + type: TypingLiteral["in", "not-in"] = Field(default="in") + literals: Set[Literal[L]] = Field(alias="items") def __init__(self, term: Union[str, UnboundTerm[Any]], literals: Union[Iterable[L], Iterable[Literal[L]]]): - super().__init__(term) - self.literals = _to_literal_set(literals) + super().__init__(term=_to_unbound_term(term), items=_to_literal_set(literals)) # type: ignore def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundSetPredicate[L]: bound_term = self.term.bind(schema, case_sensitive) @@ -688,6 +696,8 @@ def as_unbound(self) -> Type[NotIn[L]]: class In(SetPredicate[L]): + type: TypingLiteral["in"] = Field(default="in", alias="type") + def __new__( # type: ignore # pylint: disable=W0221 cls, term: Union[str, UnboundTerm[Any]], literals: Union[Iterable[L], Iterable[Literal[L]]] ) -> BooleanExpression: @@ -710,6 +720,8 @@ def as_bound(self) -> Type[BoundIn[L]]: class NotIn(SetPredicate[L], ABC): + type: TypingLiteral["not-in"] = Field(default="not-in", alias="type") + def __new__( # type: ignore # pylint: disable=W0221 cls, term: Union[str, UnboundTerm[Any]], literals: Union[Iterable[L], Iterable[Literal[L]]] ) -> BooleanExpression: diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index bcbf25a12d..5a0c8c9241 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -873,6 +873,16 @@ def test_not_in() -> None: assert not_in == pickle.loads(pickle.dumps(not_in)) +def test_serialize_in() -> None: + pred = In(term="foo", literals=[1, 2, 3]) + assert pred.model_dump_json() == '{"term":"foo","type":"in","items":[1,2,3]}' + + +def test_serialize_not_in() -> None: + pred = NotIn(term="foo", literals=[1, 2, 3]) + assert pred.model_dump_json() == '{"term":"foo","type":"not-in","items":[1,2,3]}' + + def test_bound_equal_to(term: BoundReference[Any]) -> None: bound_equal_to = BoundEqualTo(term, literal("a")) assert str(bound_equal_to) == f"BoundEqualTo(term={str(term)}, literal=literal('a'))"