From 01491e755129a9f5c4a67b4b1db5e70442766874 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Mon, 17 Nov 2025 16:56:58 +0100 Subject: [PATCH 1/3] Remove `Generic` from expressions From the beginning we had these Generic in the expressions system, but it really never worked as we hoped. It came from Java where Generics are much stronger, but the static typing of Python/mypy doesn't really follow the types. --- pyiceberg/expressions/__init__.py | 364 +++++++++++++------------- pyiceberg/expressions/visitors.py | 276 +++++++++---------- pyiceberg/io/pyarrow.py | 78 +++--- pyiceberg/transforms.py | 60 ++--- tests/conftest.py | 6 +- tests/expressions/test_expressions.py | 39 ++- tests/expressions/test_visitors.py | 54 ++-- tests/io/test_pyarrow.py | 42 +-- tests/io/test_pyarrow_visitor.py | 39 ++- tests/test_transforms.py | 164 ++++++------ 10 files changed, 554 insertions(+), 568 deletions(-) diff --git a/pyiceberg/expressions/__init__.py b/pyiceberg/expressions/__init__.py index 20df6e548c..d866b80f22 100644 --- a/pyiceberg/expressions/__init__.py +++ b/pyiceberg/expressions/__init__.py @@ -20,24 +20,12 @@ import builtins from abc import ABC, abstractmethod from functools import cached_property -from typing import ( - Any, - Callable, - Generic, - Iterable, - Sequence, - TypeVar, -) +from typing import Any, Callable, Generic, Iterable, Sequence, cast from typing import Literal as TypingLiteral from pydantic import ConfigDict, Field -from pyiceberg.expressions.literals import ( - AboveMax, - BelowMin, - Literal, - literal, -) +from pyiceberg.expressions.literals import AboveMax, BelowMin, Literal, literal from pyiceberg.schema import Accessor, Schema from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel, L, StructProtocol from pyiceberg.types import DoubleType, FloatType, NestedField @@ -48,8 +36,10 @@ except ImportError: ConfigDict = dict +LiteralValue = Literal[Any] -def _to_unbound_term(term: str | UnboundTerm[Any]) -> UnboundTerm[Any]: + +def _to_unbound_term(term: str | UnboundTerm) -> UnboundTerm: return Reference(term) if isinstance(term, str) else term @@ -125,7 +115,7 @@ def _build_balanced_tree( return operator_(left, right) -class Term(Generic[L], ABC): +class Term: """A simple expression that evaluates to a value.""" @@ -133,33 +123,30 @@ class Bound: """Represents a bound value expression.""" -B = TypeVar("B") - - -class Unbound(Generic[B], ABC): +class Unbound(ABC): """Represents an unbound value expression.""" @abstractmethod - def bind(self, schema: Schema, case_sensitive: bool = True) -> B: ... + def bind(self, schema: Schema, case_sensitive: bool = True) -> Bound | BooleanExpression: ... @property @abstractmethod def as_bound(self) -> type[Bound]: ... -class BoundTerm(Term[L], Bound, ABC): +class BoundTerm(Term, Bound, ABC): """Represents a bound term.""" @abstractmethod - def ref(self) -> BoundReference[L]: + def ref(self) -> BoundReference: """Return the bound reference.""" @abstractmethod - def eval(self, struct: StructProtocol) -> L: # pylint: disable=W0613 + def eval(self, struct: StructProtocol) -> Any: # pylint: disable=W0613 """Return the value at the referenced field's position in an object that abides by the StructProtocol.""" -class BoundReference(BoundTerm[L]): +class BoundReference(BoundTerm): """A reference bound to a field in a schema. Args: @@ -174,7 +161,7 @@ def __init__(self, field: NestedField, accessor: Accessor): self.field = field self.accessor = accessor - def eval(self, struct: StructProtocol) -> L: + def eval(self, struct: StructProtocol) -> Any: """Return the value at the referenced field's position in an object that abides by the StructProtocol. Args: @@ -192,7 +179,7 @@ def __repr__(self) -> str: """Return the string representation of the BoundReference class.""" return f"BoundReference(field={repr(self.field)}, accessor={repr(self.accessor)})" - def ref(self) -> BoundReference[L]: + def ref(self) -> BoundReference: return self def __hash__(self) -> int: @@ -200,14 +187,14 @@ def __hash__(self) -> int: return hash(str(self)) -class UnboundTerm(Term[Any], Unbound[BoundTerm[L]], ABC): +class UnboundTerm(Term, Unbound, ABC): """Represents an unbound term.""" @abstractmethod - def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundTerm[L]: ... + def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundTerm: ... -class Reference(UnboundTerm[Any], IcebergRootModel[str]): +class Reference(UnboundTerm, IcebergRootModel[str]): """A reference not yet bound to a field in a schema. Args: @@ -230,7 +217,7 @@ def __str__(self) -> str: """Return the string representation of the Reference class.""" return f"Reference(name={repr(self.root)})" - def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundReference[L]: + def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundReference: """Bind the reference to an Iceberg schema. Args: @@ -245,15 +232,15 @@ def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundReference[L] """ field = schema.find_field(name_or_id=self.name, case_sensitive=case_sensitive) accessor = schema.accessor_for_field(field.field_id) - return self.as_bound(field=field, accessor=accessor) # type: ignore + return self.as_bound(field=field, accessor=accessor) @property def name(self) -> str: return self.root @property - def as_bound(self) -> type[BoundReference[L]]: - return BoundReference[L] + def as_bound(self) -> type[BoundReference]: + return BoundReference class And(BooleanExpression): @@ -425,10 +412,10 @@ def __repr__(self) -> str: return "AlwaysFalse()" -class BoundPredicate(Generic[L], Bound, BooleanExpression, ABC): - term: BoundTerm[L] +class BoundPredicate(Bound, BooleanExpression, ABC): + term: BoundTerm - def __init__(self, term: BoundTerm[L]): + def __init__(self, term: BoundTerm): self.term = term def __eq__(self, other: Any) -> bool: @@ -439,13 +426,13 @@ def __eq__(self, other: Any) -> bool: @property @abstractmethod - def as_unbound(self) -> type[UnboundPredicate[Any]]: ... + def as_unbound(self) -> type[UnboundPredicate]: ... -class UnboundPredicate(Generic[L], Unbound[BooleanExpression], BooleanExpression, ABC): - term: UnboundTerm[Any] +class UnboundPredicate(Unbound, BooleanExpression, ABC): + term: UnboundTerm - def __init__(self, term: str | UnboundTerm[Any]): + def __init__(self, term: str | UnboundTerm): self.term = _to_unbound_term(term) def __eq__(self, other: Any) -> bool: @@ -457,15 +444,15 @@ def bind(self, schema: Schema, case_sensitive: bool = True) -> BooleanExpression @property @abstractmethod - def as_bound(self) -> type[BoundPredicate[L]]: ... + def as_bound(self) -> type[BoundPredicate]: ... -class UnaryPredicate(IcebergBaseModel, UnboundPredicate[Any], ABC): +class UnaryPredicate(IcebergBaseModel, UnboundPredicate, ABC): type: str model_config = {"arbitrary_types_allowed": True} - def __init__(self, term: str | UnboundTerm[Any]): + def __init__(self, term: str | UnboundTerm): unbound = _to_unbound_term(term) super().__init__(term=unbound) @@ -474,7 +461,7 @@ def __str__(self) -> str: # Sort to make it deterministic return f"{str(self.__class__.__name__)}(term={str(self.term)})" - def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundUnaryPredicate[Any]: + def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundUnaryPredicate: bound_term = self.term.bind(schema, case_sensitive) return self.as_bound(bound_term) # type: ignore @@ -484,10 +471,10 @@ def __repr__(self) -> str: @property @abstractmethod - def as_bound(self) -> type[BoundUnaryPredicate[Any]]: ... # type: ignore + def as_bound(self) -> type[BoundUnaryPredicate]: ... # type: ignore -class BoundUnaryPredicate(BoundPredicate[L], ABC): +class BoundUnaryPredicate(BoundPredicate, ABC): def __repr__(self) -> str: """Return the string representation of the BoundUnaryPredicate class.""" return f"{str(self.__class__.__name__)}(term={repr(self.term)})" @@ -496,18 +483,18 @@ def __repr__(self) -> str: @abstractmethod def as_unbound(self) -> type[UnaryPredicate]: ... - def __getnewargs__(self) -> tuple[BoundTerm[L]]: + def __getnewargs__(self) -> tuple[BoundTerm]: """Pickle the BoundUnaryPredicate class.""" return (self.term,) -class BoundIsNull(BoundUnaryPredicate[L]): - def __new__(cls, term: BoundTerm[L]) -> BooleanExpression: # type: ignore # pylint: disable=W0221 +class BoundIsNull(BoundUnaryPredicate): + def __new__(cls, term: BoundTerm) -> BooleanExpression: # type: ignore # pylint: disable=W0221 if term.ref().field.required: return AlwaysFalse() return super().__new__(cls) - def __invert__(self) -> BoundNotNull[L]: + def __invert__(self) -> BoundNotNull: """Transform the Expression into its negated version.""" return BoundNotNull(self.term) @@ -516,13 +503,13 @@ def as_unbound(self) -> type[IsNull]: return IsNull -class BoundNotNull(BoundUnaryPredicate[L]): - def __new__(cls, term: BoundTerm[L]): # type: ignore # pylint: disable=W0221 +class BoundNotNull(BoundUnaryPredicate): + def __new__(cls, term: BoundTerm): # type: ignore # pylint: disable=W0221 if term.ref().field.required: return AlwaysTrue() return super().__new__(cls) - def __invert__(self) -> BoundIsNull[L]: + def __invert__(self) -> BoundIsNull: """Transform the Expression into its negated version.""" return BoundIsNull(self.term) @@ -539,8 +526,8 @@ def __invert__(self) -> NotNull: return NotNull(self.term) @property - def as_bound(self) -> builtins.type[BoundIsNull[L]]: - return BoundIsNull[L] + def as_bound(self) -> builtins.type[BoundIsNull]: + return BoundIsNull class NotNull(UnaryPredicate): @@ -551,18 +538,18 @@ def __invert__(self) -> IsNull: return IsNull(self.term) @property - def as_bound(self) -> builtins.type[BoundNotNull[L]]: - return BoundNotNull[L] + def as_bound(self) -> builtins.type[BoundNotNull]: + return BoundNotNull -class BoundIsNaN(BoundUnaryPredicate[L]): - def __new__(cls, term: BoundTerm[L]) -> BooleanExpression: # type: ignore # pylint: disable=W0221 +class BoundIsNaN(BoundUnaryPredicate): + def __new__(cls, term: BoundTerm) -> BooleanExpression: # type: ignore # pylint: disable=W0221 bound_type = term.ref().field.field_type if isinstance(bound_type, (FloatType, DoubleType)): return super().__new__(cls) return AlwaysFalse() - def __invert__(self) -> BoundNotNaN[L]: + def __invert__(self) -> BoundNotNaN: """Transform the Expression into its negated version.""" return BoundNotNaN(self.term) @@ -571,14 +558,14 @@ def as_unbound(self) -> type[IsNaN]: return IsNaN -class BoundNotNaN(BoundUnaryPredicate[L]): - def __new__(cls, term: BoundTerm[L]) -> BooleanExpression: # type: ignore # pylint: disable=W0221 +class BoundNotNaN(BoundUnaryPredicate): + def __new__(cls, term: BoundTerm) -> BooleanExpression: # type: ignore # pylint: disable=W0221 bound_type = term.ref().field.field_type if isinstance(bound_type, (FloatType, DoubleType)): return super().__new__(cls) return AlwaysTrue() - def __invert__(self) -> BoundIsNaN[L]: + def __invert__(self) -> BoundIsNaN: """Transform the Expression into its negated version.""" return BoundIsNaN(self.term) @@ -595,8 +582,8 @@ def __invert__(self) -> NotNaN: return NotNaN(self.term) @property - def as_bound(self) -> builtins.type[BoundIsNaN[L]]: - return BoundIsNaN[L] + def as_bound(self) -> builtins.type[BoundIsNaN]: + return BoundIsNaN class NotNaN(UnaryPredicate): @@ -607,22 +594,25 @@ def __invert__(self) -> IsNaN: return IsNaN(self.term) @property - def as_bound(self) -> builtins.type[BoundNotNaN[L]]: - return BoundNotNaN[L] + def as_bound(self) -> builtins.type[BoundNotNaN]: + return BoundNotNaN -class SetPredicate(IcebergBaseModel, UnboundPredicate[L], ABC): +class SetPredicate(IcebergBaseModel, UnboundPredicate, ABC): model_config = ConfigDict(arbitrary_types_allowed=True) type: TypingLiteral["in", "not-in"] = Field(default="in") - literals: set[Literal[L]] = Field(alias="items") + literals: set[Any] = Field(alias="items") - def __init__(self, term: str | UnboundTerm[Any], literals: Iterable[L] | Iterable[Literal[L]]): - super().__init__(term=_to_unbound_term(term), items=_to_literal_set(literals)) # type: ignore + def __init__(self, term: str | UnboundTerm, literals: Iterable[Any] | Iterable[LiteralValue]): + literal_set = _to_literal_set(literals) + super().__init__(term=_to_unbound_term(term), items=literal_set) # type: ignore + object.__setattr__(self, "literals", literal_set) - def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundSetPredicate[L]: + def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundSetPredicate: bound_term = self.term.bind(schema, case_sensitive) - return self.as_bound(bound_term, {lit.to(bound_term.ref().field.field_type) for lit in self.literals}) + literal_set = cast(set[LiteralValue], self.literals) + return self.as_bound(bound_term, {lit.to(bound_term.ref().field.field_type) for lit in literal_set}) def __str__(self) -> str: """Return the string representation of the SetPredicate class.""" @@ -638,26 +628,25 @@ def __eq__(self, other: Any) -> bool: """Return the equality of two instances of the SetPredicate class.""" return self.term == other.term and self.literals == other.literals if isinstance(other, self.__class__) else False - def __getnewargs__(self) -> tuple[UnboundTerm[L], set[Literal[L]]]: + def __getnewargs__(self) -> tuple[UnboundTerm, set[Any]]: """Pickle the SetPredicate class.""" return (self.term, self.literals) @property @abstractmethod - def as_bound(self) -> builtins.type[BoundSetPredicate[L]]: - return BoundSetPredicate[L] + def as_bound(self) -> builtins.type[BoundSetPredicate]: + return BoundSetPredicate -class BoundSetPredicate(BoundPredicate[L], ABC): - literals: set[Literal[L]] +class BoundSetPredicate(BoundPredicate, ABC): + literals: set[LiteralValue] - def __init__(self, term: BoundTerm[L], literals: set[Literal[L]]): - # Since we don't know the type of BoundPredicate[L], we have to ignore this one - super().__init__(term) # type: ignore + def __init__(self, term: BoundTerm, literals: set[LiteralValue]): + super().__init__(term) self.literals = _to_literal_set(literals) # pylint: disable=W0621 @cached_property - def value_set(self) -> set[L]: + def value_set(self) -> set[Any]: return {lit.value for lit in self.literals} def __str__(self) -> str: @@ -674,17 +663,17 @@ def __eq__(self, other: Any) -> bool: """Return the equality of two instances of the BoundSetPredicate class.""" return self.term == other.term and self.literals == other.literals if isinstance(other, self.__class__) else False - def __getnewargs__(self) -> tuple[BoundTerm[L], set[Literal[L]]]: + def __getnewargs__(self) -> tuple[BoundTerm, set[LiteralValue]]: """Pickle the BoundSetPredicate class.""" return (self.term, self.literals) @property @abstractmethod - def as_unbound(self) -> type[SetPredicate[L]]: ... + def as_unbound(self) -> type[SetPredicate]: ... -class BoundIn(BoundSetPredicate[L]): - def __new__(cls, term: BoundTerm[L], literals: set[Literal[L]]) -> BooleanExpression: # type: ignore # pylint: disable=W0221 +class BoundIn(BoundSetPredicate): + def __new__(cls, term: BoundTerm, literals: set[LiteralValue]) -> BooleanExpression: # type: ignore # pylint: disable=W0221 count = len(literals) if count == 0: return AlwaysFalse() @@ -693,7 +682,7 @@ def __new__(cls, term: BoundTerm[L], literals: set[Literal[L]]) -> BooleanExpres else: return super().__new__(cls) - def __invert__(self) -> BoundNotIn[L]: + def __invert__(self) -> BoundNotIn: """Transform the Expression into its negated version.""" return BoundNotIn(self.term, self.literals) @@ -702,15 +691,15 @@ def __eq__(self, other: Any) -> bool: return self.term == other.term and self.literals == other.literals if isinstance(other, self.__class__) else False @property - def as_unbound(self) -> type[In[L]]: + def as_unbound(self) -> type[In]: return In -class BoundNotIn(BoundSetPredicate[L]): +class BoundNotIn(BoundSetPredicate): def __new__( # type: ignore # pylint: disable=W0221 cls, - term: BoundTerm[L], - literals: set[Literal[L]], + term: BoundTerm, + literals: set[LiteralValue], ) -> BooleanExpression: count = len(literals) if count == 0: @@ -720,22 +709,22 @@ def __new__( # type: ignore # pylint: disable=W0221 else: return super().__new__(cls) - def __invert__(self) -> BoundIn[L]: + def __invert__(self) -> BoundIn: """Transform the Expression into its negated version.""" return BoundIn(self.term, self.literals) @property - def as_unbound(self) -> type[NotIn[L]]: + def as_unbound(self) -> type[NotIn]: return NotIn -class In(SetPredicate[L]): +class In(SetPredicate): type: TypingLiteral["in"] = Field(default="in", alias="type") def __new__( # type: ignore # pylint: disable=W0221 - cls, term: str | UnboundTerm[Any], literals: Iterable[L] | Iterable[Literal[L]] + cls, term: str | UnboundTerm, literals: Iterable[Any] | Iterable[LiteralValue] ) -> BooleanExpression: - literals_set: set[Literal[L]] = _to_literal_set(literals) + literals_set: set[LiteralValue] = _to_literal_set(literals) count = len(literals_set) if count == 0: return AlwaysFalse() @@ -744,22 +733,22 @@ def __new__( # type: ignore # pylint: disable=W0221 else: return super().__new__(cls) - def __invert__(self) -> NotIn[L]: + def __invert__(self) -> NotIn: """Transform the Expression into its negated version.""" - return NotIn[L](self.term, self.literals) + return NotIn(self.term, self.literals) @property - def as_bound(self) -> builtins.type[BoundIn[L]]: - return BoundIn[L] + def as_bound(self) -> builtins.type[BoundIn]: + return BoundIn -class NotIn(SetPredicate[L], ABC): +class NotIn(SetPredicate, ABC): type: TypingLiteral["not-in"] = Field(default="not-in", alias="type") def __new__( # type: ignore # pylint: disable=W0221 - cls, term: str | UnboundTerm[Any], literals: Iterable[L] | Iterable[Literal[L]] + cls, term: str | UnboundTerm, literals: Iterable[Any] | Iterable[LiteralValue] ) -> BooleanExpression: - literals_set: set[Literal[L]] = _to_literal_set(literals) + literals_set: set[LiteralValue] = _to_literal_set(literals) count = len(literals_set) if count == 0: return AlwaysTrue() @@ -768,29 +757,29 @@ def __new__( # type: ignore # pylint: disable=W0221 else: return super().__new__(cls) - def __invert__(self) -> In[L]: + def __invert__(self) -> In: """Transform the Expression into its negated version.""" - return In[L](self.term, self.literals) + return In(self.term, self.literals) @property - def as_bound(self) -> builtins.type[BoundNotIn[L]]: - return BoundNotIn[L] + def as_bound(self) -> builtins.type[BoundNotIn]: + return BoundNotIn -class LiteralPredicate(IcebergBaseModel, UnboundPredicate[L], ABC): +class LiteralPredicate(IcebergBaseModel, UnboundPredicate, Generic[L], ABC): type: TypingLiteral["lt", "lt-eq", "gt", "gt-eq", "eq", "not-eq", "starts-with", "not-starts-with"] = Field(alias="type") - term: UnboundTerm[Any] + term: UnboundTerm value: Literal[L] = Field() model_config = ConfigDict(populate_by_name=True, frozen=True, arbitrary_types_allowed=True) - def __init__(self, term: str | UnboundTerm[Any], literal: L | Literal[L]): + def __init__(self, term: str | UnboundTerm, literal: L | Literal[L]): super().__init__(term=_to_unbound_term(term), value=_to_literal(literal)) # type: ignore[call-arg] @property def literal(self) -> Literal[L]: return self.value - def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundLiteralPredicate[L]: + def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundLiteralPredicate: bound_term = self.term.bind(schema, case_sensitive) lit = self.literal.to(bound_term.ref().field.field_type) @@ -823,15 +812,14 @@ def __repr__(self) -> str: @property @abstractmethod - def as_bound(self) -> builtins.type[BoundLiteralPredicate[L]]: ... + def as_bound(self) -> builtins.type[BoundLiteralPredicate]: ... -class BoundLiteralPredicate(BoundPredicate[L], ABC): - literal: Literal[L] +class BoundLiteralPredicate(BoundPredicate, ABC): + literal: Literal[Any] - def __init__(self, term: BoundTerm[L], literal: Literal[L]): # pylint: disable=W0621 - # Since we don't know the type of BoundPredicate[L], we have to ignore this one - super().__init__(term) # type: ignore + def __init__(self, term: BoundTerm, literal: Literal[Any]): # pylint: disable=W0621 + super().__init__(term) self.literal = literal # pylint: disable=W0621 def __eq__(self, other: Any) -> bool: @@ -846,180 +834,180 @@ def __repr__(self) -> str: @property @abstractmethod - def as_unbound(self) -> type[LiteralPredicate[L]]: ... + def as_unbound(self) -> type[LiteralPredicate[Any]]: ... -class BoundEqualTo(BoundLiteralPredicate[L]): - def __invert__(self) -> BoundNotEqualTo[L]: +class BoundEqualTo(BoundLiteralPredicate): + def __invert__(self) -> BoundNotEqualTo: """Transform the Expression into its negated version.""" - return BoundNotEqualTo[L](self.term, self.literal) + return BoundNotEqualTo(self.term, self.literal) @property - def as_unbound(self) -> type[EqualTo[L]]: + def as_unbound(self) -> type[EqualTo[Any]]: return EqualTo -class BoundNotEqualTo(BoundLiteralPredicate[L]): - def __invert__(self) -> BoundEqualTo[L]: +class BoundNotEqualTo(BoundLiteralPredicate): + def __invert__(self) -> BoundEqualTo: """Transform the Expression into its negated version.""" - return BoundEqualTo[L](self.term, self.literal) + return BoundEqualTo(self.term, self.literal) @property - def as_unbound(self) -> type[NotEqualTo[L]]: + def as_unbound(self) -> type[NotEqualTo[Any]]: return NotEqualTo -class BoundGreaterThanOrEqual(BoundLiteralPredicate[L]): - def __invert__(self) -> BoundLessThan[L]: +class BoundGreaterThanOrEqual(BoundLiteralPredicate): + def __invert__(self) -> BoundLessThan: """Transform the Expression into its negated version.""" - return BoundLessThan[L](self.term, self.literal) + return BoundLessThan(self.term, self.literal) @property - def as_unbound(self) -> type[GreaterThanOrEqual[L]]: - return GreaterThanOrEqual[L] + def as_unbound(self) -> type[GreaterThanOrEqual[Any]]: + return GreaterThanOrEqual -class BoundGreaterThan(BoundLiteralPredicate[L]): - def __invert__(self) -> BoundLessThanOrEqual[L]: +class BoundGreaterThan(BoundLiteralPredicate): + def __invert__(self) -> BoundLessThanOrEqual: """Transform the Expression into its negated version.""" return BoundLessThanOrEqual(self.term, self.literal) @property - def as_unbound(self) -> type[GreaterThan[L]]: - return GreaterThan[L] + def as_unbound(self) -> type[GreaterThan[Any]]: + return GreaterThan -class BoundLessThan(BoundLiteralPredicate[L]): - def __invert__(self) -> BoundGreaterThanOrEqual[L]: +class BoundLessThan(BoundLiteralPredicate): + def __invert__(self) -> BoundGreaterThanOrEqual: """Transform the Expression into its negated version.""" - return BoundGreaterThanOrEqual[L](self.term, self.literal) + return BoundGreaterThanOrEqual(self.term, self.literal) @property - def as_unbound(self) -> type[LessThan[L]]: - return LessThan[L] + def as_unbound(self) -> type[LessThan[Any]]: + return LessThan -class BoundLessThanOrEqual(BoundLiteralPredicate[L]): - def __invert__(self) -> BoundGreaterThan[L]: +class BoundLessThanOrEqual(BoundLiteralPredicate): + def __invert__(self) -> BoundGreaterThan: """Transform the Expression into its negated version.""" - return BoundGreaterThan[L](self.term, self.literal) + return BoundGreaterThan(self.term, self.literal) @property - def as_unbound(self) -> type[LessThanOrEqual[L]]: - return LessThanOrEqual[L] + def as_unbound(self) -> type[LessThanOrEqual[Any]]: + return LessThanOrEqual -class BoundStartsWith(BoundLiteralPredicate[L]): - def __invert__(self) -> BoundNotStartsWith[L]: +class BoundStartsWith(BoundLiteralPredicate): + def __invert__(self) -> BoundNotStartsWith: """Transform the Expression into its negated version.""" - return BoundNotStartsWith[L](self.term, self.literal) + return BoundNotStartsWith(self.term, self.literal) @property - def as_unbound(self) -> type[StartsWith[L]]: - return StartsWith[L] + def as_unbound(self) -> type[StartsWith[Any]]: + return StartsWith -class BoundNotStartsWith(BoundLiteralPredicate[L]): - def __invert__(self) -> BoundStartsWith[L]: +class BoundNotStartsWith(BoundLiteralPredicate): + def __invert__(self) -> BoundStartsWith: """Transform the Expression into its negated version.""" - return BoundStartsWith[L](self.term, self.literal) + return BoundStartsWith(self.term, self.literal) @property - def as_unbound(self) -> type[NotStartsWith[L]]: - return NotStartsWith[L] + def as_unbound(self) -> type[NotStartsWith[Any]]: + return NotStartsWith class EqualTo(LiteralPredicate[L]): type: TypingLiteral["eq"] = Field(default="eq", alias="type") - def __invert__(self) -> NotEqualTo[L]: + def __invert__(self) -> NotEqualTo[Any]: """Transform the Expression into its negated version.""" - return NotEqualTo[L](self.term, self.literal) + return NotEqualTo(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundEqualTo[L]]: - return BoundEqualTo[L] + def as_bound(self) -> builtins.type[BoundEqualTo]: + return BoundEqualTo class NotEqualTo(LiteralPredicate[L]): type: TypingLiteral["not-eq"] = Field(default="not-eq", alias="type") - def __invert__(self) -> EqualTo[L]: + def __invert__(self) -> EqualTo[Any]: """Transform the Expression into its negated version.""" - return EqualTo[L](self.term, self.literal) + return EqualTo(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundNotEqualTo[L]]: - return BoundNotEqualTo[L] + def as_bound(self) -> builtins.type[BoundNotEqualTo]: + return BoundNotEqualTo class LessThan(LiteralPredicate[L]): type: TypingLiteral["lt"] = Field(default="lt", alias="type") - def __invert__(self) -> GreaterThanOrEqual[L]: + def __invert__(self) -> GreaterThanOrEqual[Any]: """Transform the Expression into its negated version.""" - return GreaterThanOrEqual[L](self.term, self.literal) + return GreaterThanOrEqual(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundLessThan[L]]: - return BoundLessThan[L] + def as_bound(self) -> builtins.type[BoundLessThan]: + return BoundLessThan class GreaterThanOrEqual(LiteralPredicate[L]): type: TypingLiteral["gt-eq"] = Field(default="gt-eq", alias="type") - def __invert__(self) -> LessThan[L]: + def __invert__(self) -> LessThan[Any]: """Transform the Expression into its negated version.""" - return LessThan[L](self.term, self.literal) + return LessThan(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundGreaterThanOrEqual[L]]: - return BoundGreaterThanOrEqual[L] + def as_bound(self) -> builtins.type[BoundGreaterThanOrEqual]: + return BoundGreaterThanOrEqual class GreaterThan(LiteralPredicate[L]): type: TypingLiteral["gt"] = Field(default="gt", alias="type") - def __invert__(self) -> LessThanOrEqual[L]: + def __invert__(self) -> LessThanOrEqual[Any]: """Transform the Expression into its negated version.""" - return LessThanOrEqual[L](self.term, self.literal) + return LessThanOrEqual(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundGreaterThan[L]]: - return BoundGreaterThan[L] + def as_bound(self) -> builtins.type[BoundGreaterThan]: + return BoundGreaterThan class LessThanOrEqual(LiteralPredicate[L]): type: TypingLiteral["lt-eq"] = Field(default="lt-eq", alias="type") - def __invert__(self) -> GreaterThan[L]: + def __invert__(self) -> GreaterThan[Any]: """Transform the Expression into its negated version.""" - return GreaterThan[L](self.term, self.literal) + return GreaterThan(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundLessThanOrEqual[L]]: - return BoundLessThanOrEqual[L] + def as_bound(self) -> builtins.type[BoundLessThanOrEqual]: + return BoundLessThanOrEqual class StartsWith(LiteralPredicate[L]): type: TypingLiteral["starts-with"] = Field(default="starts-with", alias="type") - def __invert__(self) -> NotStartsWith[L]: + def __invert__(self) -> NotStartsWith[Any]: """Transform the Expression into its negated version.""" - return NotStartsWith[L](self.term, self.literal) + return NotStartsWith(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundStartsWith[L]]: - return BoundStartsWith[L] + def as_bound(self) -> builtins.type[BoundStartsWith]: + return BoundStartsWith class NotStartsWith(LiteralPredicate[L]): type: TypingLiteral["not-starts-with"] = Field(default="not-starts-with", alias="type") - def __invert__(self) -> StartsWith[L]: + def __invert__(self) -> StartsWith[Any]: """Transform the Expression into its negated version.""" - return StartsWith[L](self.term, self.literal) + return StartsWith(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundNotStartsWith[L]]: - return BoundNotStartsWith[L] + def as_bound(self) -> builtins.type[BoundNotStartsWith]: + return BoundNotStartsWith diff --git a/pyiceberg/expressions/visitors.py b/pyiceberg/expressions/visitors.py index 4c096f1215..1362945a00 100644 --- a/pyiceberg/expressions/visitors.py +++ b/pyiceberg/expressions/visitors.py @@ -116,19 +116,19 @@ def visit_or(self, left_result: T, right_result: T) -> T: """ @abstractmethod - def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> T: + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> T: """Visit method for an unbound predicate in an expression tree. Args: - predicate (UnboundPredicate[L): An instance of an UnboundPredicate. + predicate (UnboundPredicate): An instance of an UnboundPredicate. """ @abstractmethod - def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> T: + def visit_bound_predicate(self, predicate: BoundPredicate) -> T: """Visit method for a bound predicate in an expression tree. Args: - predicate (BoundPredicate[L]): An instance of a BoundPredicate. + predicate (BoundPredicate): An instance of a BoundPredicate. """ @@ -176,13 +176,13 @@ def _(obj: And, visitor: BooleanExpressionVisitor[T]) -> T: @visit.register(UnboundPredicate) -def _(obj: UnboundPredicate[L], visitor: BooleanExpressionVisitor[T]) -> T: +def _(obj: UnboundPredicate, visitor: BooleanExpressionVisitor[T]) -> T: """Visit an unbound boolean expression with a concrete BooleanExpressionVisitor.""" return visitor.visit_unbound_predicate(predicate=obj) @visit.register(BoundPredicate) -def _(obj: BoundPredicate[L], visitor: BooleanExpressionVisitor[T]) -> T: +def _(obj: BoundPredicate, visitor: BooleanExpressionVisitor[T]) -> T: """Visit a bound boolean expression with a concrete BooleanExpressionVisitor.""" return visitor.visit_bound_predicate(predicate=obj) @@ -242,60 +242,60 @@ def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpress def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: return Or(left=left_result, right=right_result) - def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpression: + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> BooleanExpression: return predicate.bind(self.schema, case_sensitive=self.case_sensitive) - def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpression: + def visit_bound_predicate(self, predicate: BoundPredicate) -> BooleanExpression: raise TypeError(f"Found already bound predicate: {predicate}") class BoundBooleanExpressionVisitor(BooleanExpressionVisitor[T], ABC): @abstractmethod - def visit_in(self, term: BoundTerm[L], literals: set[L]) -> T: + def visit_in(self, term: BoundTerm, literals: set[L]) -> T: """Visit a bound In predicate.""" @abstractmethod - def visit_not_in(self, term: BoundTerm[L], literals: set[L]) -> T: + def visit_not_in(self, term: BoundTerm, literals: set[L]) -> T: """Visit a bound NotIn predicate.""" @abstractmethod - def visit_is_nan(self, term: BoundTerm[L]) -> T: + def visit_is_nan(self, term: BoundTerm) -> T: """Visit a bound IsNan predicate.""" @abstractmethod - def visit_not_nan(self, term: BoundTerm[L]) -> T: + def visit_not_nan(self, term: BoundTerm) -> T: """Visit a bound NotNan predicate.""" @abstractmethod - def visit_is_null(self, term: BoundTerm[L]) -> T: + def visit_is_null(self, term: BoundTerm) -> T: """Visit a bound IsNull predicate.""" @abstractmethod - def visit_not_null(self, term: BoundTerm[L]) -> T: + def visit_not_null(self, term: BoundTerm) -> T: """Visit a bound NotNull predicate.""" @abstractmethod - def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> T: + def visit_equal(self, term: BoundTerm, literal: Literal[L]) -> T: """Visit a bound Equal predicate.""" @abstractmethod - def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> T: + def visit_not_equal(self, term: BoundTerm, literal: Literal[L]) -> T: """Visit a bound NotEqual predicate.""" @abstractmethod - def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> T: + def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> T: """Visit a bound GreaterThanOrEqual predicate.""" @abstractmethod - def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> T: + def visit_greater_than(self, term: BoundTerm, literal: Literal[L]) -> T: """Visit a bound GreaterThan predicate.""" @abstractmethod - def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> T: + def visit_less_than(self, term: BoundTerm, literal: Literal[L]) -> T: """Visit a bound LessThan predicate.""" @abstractmethod - def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> T: + def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> T: """Visit a bound LessThanOrEqual predicate.""" @abstractmethod @@ -319,105 +319,105 @@ def visit_or(self, left_result: T, right_result: T) -> T: """Visit a bound Or predicate.""" @abstractmethod - def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> T: + def visit_starts_with(self, term: BoundTerm, literal: Literal[L]) -> T: """Visit bound StartsWith predicate.""" @abstractmethod - def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> T: + def visit_not_starts_with(self, term: BoundTerm, literal: Literal[L]) -> T: """Visit bound NotStartsWith predicate.""" - def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> T: + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> T: """Visit an unbound predicate. Args: - predicate (UnboundPredicate[L]): An unbound predicate. + predicate (UnboundPredicate): An unbound predicate. Raises: TypeError: This always raises since an unbound predicate is not expected in a bound boolean expression. """ raise TypeError(f"Not a bound predicate: {predicate}") - def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> T: + def visit_bound_predicate(self, predicate: BoundPredicate) -> T: """Visit a bound predicate. Args: - predicate (BoundPredicate[L]): A bound predicate. + predicate (BoundPredicate): A bound predicate. """ return visit_bound_predicate(predicate, self) @singledispatch -def visit_bound_predicate(expr: BoundPredicate[L], _: BooleanExpressionVisitor[T]) -> T: +def visit_bound_predicate(expr: BoundPredicate, _: BooleanExpressionVisitor[T]) -> T: raise TypeError(f"Unknown predicate: {expr}") @visit_bound_predicate.register(BoundIn) -def _(expr: BoundIn[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundIn, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_in(term=expr.term, literals=expr.value_set) @visit_bound_predicate.register(BoundNotIn) -def _(expr: BoundNotIn[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundNotIn, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_not_in(term=expr.term, literals=expr.value_set) @visit_bound_predicate.register(BoundIsNaN) -def _(expr: BoundIsNaN[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundIsNaN, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_is_nan(term=expr.term) @visit_bound_predicate.register(BoundNotNaN) -def _(expr: BoundNotNaN[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundNotNaN, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_not_nan(term=expr.term) @visit_bound_predicate.register(BoundIsNull) -def _(expr: BoundIsNull[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundIsNull, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_is_null(term=expr.term) @visit_bound_predicate.register(BoundNotNull) -def _(expr: BoundNotNull[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundNotNull, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_not_null(term=expr.term) @visit_bound_predicate.register(BoundEqualTo) -def _(expr: BoundEqualTo[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundEqualTo, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_equal(term=expr.term, literal=expr.literal) @visit_bound_predicate.register(BoundNotEqualTo) -def _(expr: BoundNotEqualTo[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundNotEqualTo, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_not_equal(term=expr.term, literal=expr.literal) @visit_bound_predicate.register(BoundGreaterThanOrEqual) -def _(expr: BoundGreaterThanOrEqual[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundGreaterThanOrEqual, visitor: BoundBooleanExpressionVisitor[T]) -> T: """Visit a bound GreaterThanOrEqual predicate.""" return visitor.visit_greater_than_or_equal(term=expr.term, literal=expr.literal) @visit_bound_predicate.register(BoundGreaterThan) -def _(expr: BoundGreaterThan[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundGreaterThan, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_greater_than(term=expr.term, literal=expr.literal) @visit_bound_predicate.register(BoundLessThan) -def _(expr: BoundLessThan[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundLessThan, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_less_than(term=expr.term, literal=expr.literal) @visit_bound_predicate.register(BoundLessThanOrEqual) -def _(expr: BoundLessThanOrEqual[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundLessThanOrEqual, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_less_than_or_equal(term=expr.term, literal=expr.literal) @visit_bound_predicate.register(BoundStartsWith) -def _(expr: BoundStartsWith[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundStartsWith, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_starts_with(term=expr.term, literal=expr.literal) @visit_bound_predicate.register(BoundNotStartsWith) -def _(expr: BoundNotStartsWith[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundNotStartsWith, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_not_starts_with(term=expr.term, literal=expr.literal) @@ -443,10 +443,10 @@ def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpress def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: return Or(left=left_result, right=right_result) - def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpression: + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> BooleanExpression: return predicate - def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpression: + def visit_bound_predicate(self, predicate: BoundPredicate) -> BooleanExpression: return predicate @@ -465,53 +465,53 @@ def eval(self, struct: StructProtocol) -> bool: self.struct = struct return visit(self.bound, self) - def visit_in(self, term: BoundTerm[L], literals: set[L]) -> bool: + def visit_in(self, term: BoundTerm, literals: set[L]) -> bool: return term.eval(self.struct) in literals - def visit_not_in(self, term: BoundTerm[L], literals: set[L]) -> bool: + def visit_not_in(self, term: BoundTerm, literals: set[L]) -> bool: return term.eval(self.struct) not in literals - def visit_is_nan(self, term: BoundTerm[L]) -> bool: + def visit_is_nan(self, term: BoundTerm) -> bool: val = term.eval(self.struct) return val != val - def visit_not_nan(self, term: BoundTerm[L]) -> bool: + def visit_not_nan(self, term: BoundTerm) -> bool: val = term.eval(self.struct) return val == val - def visit_is_null(self, term: BoundTerm[L]) -> bool: + def visit_is_null(self, term: BoundTerm) -> bool: return term.eval(self.struct) is None - def visit_not_null(self, term: BoundTerm[L]) -> bool: + def visit_not_null(self, term: BoundTerm) -> bool: return term.eval(self.struct) is not None - def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: return term.eval(self.struct) == literal.value - def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_not_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: return term.eval(self.struct) != literal.value - def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: value = term.eval(self.struct) return value is not None and value >= literal.value - def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_greater_than(self, term: BoundTerm, literal: Literal[L]) -> bool: value = term.eval(self.struct) return value is not None and value > literal.value - def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_less_than(self, term: BoundTerm, literal: Literal[L]) -> bool: value = term.eval(self.struct) return value is not None and value < literal.value - def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: value = term.eval(self.struct) return value is not None and value <= literal.value - def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_starts_with(self, term: BoundTerm, literal: Literal[L]) -> bool: eval_res = term.eval(self.struct) return eval_res is not None and str(eval_res).startswith(str(literal.value)) - def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_not_starts_with(self, term: BoundTerm, literal: Literal[L]) -> bool: return not self.visit_starts_with(term, literal) def visit_true(self) -> bool: @@ -558,7 +558,7 @@ def eval(self, manifest: ManifestFile) -> bool: # No partition information return ROWS_MIGHT_MATCH - def visit_in(self, term: BoundTerm[L], literals: set[L]) -> bool: + def visit_in(self, term: BoundTerm, literals: set[L]) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] @@ -580,12 +580,12 @@ def visit_in(self, term: BoundTerm[L], literals: set[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_not_in(self, term: BoundTerm[L], literals: set[L]) -> bool: + def visit_not_in(self, term: BoundTerm, literals: set[L]) -> bool: # because the bounds are not necessarily a min or max value, this cannot be answered using # them. notIn(col, {X, ...}) with (X, Y) doesn't guarantee that X is a value in col. return ROWS_MIGHT_MATCH - def visit_is_nan(self, term: BoundTerm[L]) -> bool: + def visit_is_nan(self, term: BoundTerm) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] @@ -594,7 +594,7 @@ def visit_is_nan(self, term: BoundTerm[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_not_nan(self, term: BoundTerm[L]) -> bool: + def visit_not_nan(self, term: BoundTerm) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] @@ -603,7 +603,7 @@ def visit_not_nan(self, term: BoundTerm[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_is_null(self, term: BoundTerm[L]) -> bool: + def visit_is_null(self, term: BoundTerm) -> bool: pos = term.ref().accessor.position if self.partition_fields[pos].contains_null is False: @@ -611,7 +611,7 @@ def visit_is_null(self, term: BoundTerm[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_not_null(self, term: BoundTerm[L]) -> bool: + def visit_not_null(self, term: BoundTerm) -> bool: pos = term.ref().accessor.position # contains_null encodes whether at least one partition value is null, @@ -628,7 +628,7 @@ def visit_not_null(self, term: BoundTerm[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] @@ -648,12 +648,12 @@ def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_not_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: # because the bounds are not necessarily a min or max value, this cannot be answered using # them. notEq(col, X) with (X, Y) doesn't guarantee that X is a value in col. return ROWS_MIGHT_MATCH - def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] @@ -667,7 +667,7 @@ def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) - return ROWS_MIGHT_MATCH - def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_greater_than(self, term: BoundTerm, literal: Literal[L]) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] @@ -681,7 +681,7 @@ def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_less_than(self, term: BoundTerm, literal: Literal[L]) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] @@ -695,7 +695,7 @@ def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] @@ -709,7 +709,7 @@ def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> b return ROWS_MIGHT_MATCH - def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_starts_with(self, term: BoundTerm, literal: Literal[L]) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] prefix = str(literal.value) @@ -733,7 +733,7 @@ def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_not_starts_with(self, term: BoundTerm, literal: Literal[L]) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] prefix = str(literal.value) @@ -820,12 +820,12 @@ def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpress def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: return Or(left_result, right_result) - def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpression: + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> BooleanExpression: raise ValueError(f"Cannot project unbound predicate: {predicate}") class InclusiveProjection(ProjectionEvaluator): - def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> BooleanExpression: + def visit_bound_predicate(self, predicate: BoundPredicate) -> BooleanExpression: parts = self.spec.fields_by_source_id(predicate.term.ref().field.field_id) result: BooleanExpression = AlwaysTrue() @@ -887,10 +887,10 @@ def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpress def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: return Or(left=left_result, right=right_result) - def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpression: + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> BooleanExpression: raise TypeError(f"Expected Bound Predicate, got: {predicate.term}") - def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpression: + def visit_bound_predicate(self, predicate: BoundPredicate) -> BooleanExpression: field = predicate.term.ref().field field_id = field.field_id file_column_name = self.file_schema.find_column_name(field_id) @@ -954,10 +954,10 @@ def visit_and(self, left_result: set[int], right_result: set[int]) -> set[int]: def visit_or(self, left_result: set[int], right_result: set[int]) -> set[int]: return left_result.union(right_result) - def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> set[int]: + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> set[int]: raise ValueError("Only works on bound records") - def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> set[int]: + def visit_bound_predicate(self, predicate: BoundPredicate) -> set[int]: return {predicate.term.ref().field.field_id} @@ -989,10 +989,10 @@ def visit_or( ) -> tuple[BooleanExpression, ...]: return left_result + right_result - def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> tuple[BooleanExpression, ...]: + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> tuple[BooleanExpression, ...]: return (predicate,) - def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> tuple[BooleanExpression, ...]: + def visit_bound_predicate(self, predicate: BoundPredicate) -> tuple[BooleanExpression, ...]: return (predicate,) @@ -1021,48 +1021,48 @@ def _cast_if_necessary(self, iceberg_type: IcebergType, literal: L | set[L]) -> return conversion_function(literal) # type: ignore return literal - def visit_in(self, term: BoundTerm[L], literals: set[L]) -> list[tuple[str, str, Any]]: + def visit_in(self, term: BoundTerm, literals: set[L]) -> list[tuple[str, str, Any]]: field = term.ref().field return [(term.ref().field.name, "in", self._cast_if_necessary(field.field_type, literals))] - def visit_not_in(self, term: BoundTerm[L], literals: set[L]) -> list[tuple[str, str, Any]]: + def visit_not_in(self, term: BoundTerm, literals: set[L]) -> list[tuple[str, str, Any]]: field = term.ref().field return [(field.name, "not in", self._cast_if_necessary(field.field_type, literals))] - def visit_is_nan(self, term: BoundTerm[L]) -> list[tuple[str, str, Any]]: + def visit_is_nan(self, term: BoundTerm) -> list[tuple[str, str, Any]]: return [(term.ref().field.name, "==", float("nan"))] - def visit_not_nan(self, term: BoundTerm[L]) -> list[tuple[str, str, Any]]: + def visit_not_nan(self, term: BoundTerm) -> list[tuple[str, str, Any]]: return [(term.ref().field.name, "!=", float("nan"))] - def visit_is_null(self, term: BoundTerm[L]) -> list[tuple[str, str, Any]]: + def visit_is_null(self, term: BoundTerm) -> list[tuple[str, str, Any]]: return [(term.ref().field.name, "==", None)] - def visit_not_null(self, term: BoundTerm[L]) -> list[tuple[str, str, Any]]: + def visit_not_null(self, term: BoundTerm) -> list[tuple[str, str, Any]]: return [(term.ref().field.name, "!=", None)] - def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> list[tuple[str, str, Any]]: + def visit_equal(self, term: BoundTerm, literal: Literal[L]) -> list[tuple[str, str, Any]]: return [(term.ref().field.name, "==", self._cast_if_necessary(term.ref().field.field_type, literal.value))] - def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> list[tuple[str, str, Any]]: + def visit_not_equal(self, term: BoundTerm, literal: Literal[L]) -> list[tuple[str, str, Any]]: return [(term.ref().field.name, "!=", self._cast_if_necessary(term.ref().field.field_type, literal.value))] - def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> list[tuple[str, str, Any]]: + def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> list[tuple[str, str, Any]]: return [(term.ref().field.name, ">=", self._cast_if_necessary(term.ref().field.field_type, literal.value))] - def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> list[tuple[str, str, Any]]: + def visit_greater_than(self, term: BoundTerm, literal: Literal[L]) -> list[tuple[str, str, Any]]: return [(term.ref().field.name, ">", self._cast_if_necessary(term.ref().field.field_type, literal.value))] - def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> list[tuple[str, str, Any]]: + def visit_less_than(self, term: BoundTerm, literal: Literal[L]) -> list[tuple[str, str, Any]]: return [(term.ref().field.name, "<", self._cast_if_necessary(term.ref().field.field_type, literal.value))] - def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> list[tuple[str, str, Any]]: + def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> list[tuple[str, str, Any]]: return [(term.ref().field.name, "<=", self._cast_if_necessary(term.ref().field.field_type, literal.value))] - def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> list[tuple[str, str, Any]]: + def visit_starts_with(self, term: BoundTerm, literal: Literal[L]) -> list[tuple[str, str, Any]]: return [] - def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> list[tuple[str, str, Any]]: + def visit_not_starts_with(self, term: BoundTerm, literal: Literal[L]) -> list[tuple[str, str, Any]]: return [] def visit_true(self) -> list[tuple[str, str, Any]]: @@ -1192,7 +1192,7 @@ def _contains_nans_only(self, field_id: int) -> bool: return nan_count == value_count return False - def visit_is_null(self, term: BoundTerm[L]) -> bool: + def visit_is_null(self, term: BoundTerm) -> bool: field_id = term.ref().field.field_id if self.null_counts.get(field_id) == 0: @@ -1200,7 +1200,7 @@ def visit_is_null(self, term: BoundTerm[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_not_null(self, term: BoundTerm[L]) -> bool: + def visit_not_null(self, term: BoundTerm) -> bool: # no need to check whether the field is required because binding evaluates that case # if the column has no non-null values, the expression cannot match field_id = term.ref().field.field_id @@ -1210,7 +1210,7 @@ def visit_not_null(self, term: BoundTerm[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_is_nan(self, term: BoundTerm[L]) -> bool: + def visit_is_nan(self, term: BoundTerm) -> bool: field_id = term.ref().field.field_id if self.nan_counts.get(field_id) == 0: @@ -1223,7 +1223,7 @@ def visit_is_nan(self, term: BoundTerm[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_not_nan(self, term: BoundTerm[L]) -> bool: + def visit_not_nan(self, term: BoundTerm) -> bool: field_id = term.ref().field.field_id if self._contains_nans_only(field_id): @@ -1231,7 +1231,7 @@ def visit_not_nan(self, term: BoundTerm[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_less_than(self, term: BoundTerm, literal: Literal[L]) -> bool: field = term.ref().field field_id = field.field_id @@ -1253,7 +1253,7 @@ def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: field = term.ref().field field_id = field.field_id @@ -1274,7 +1274,7 @@ def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> b return ROWS_MIGHT_MATCH - def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_greater_than(self, term: BoundTerm, literal: Literal[L]) -> bool: field = term.ref().field field_id = field.field_id @@ -1295,7 +1295,7 @@ def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: field = term.ref().field field_id = field.field_id @@ -1316,7 +1316,7 @@ def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) - return ROWS_MIGHT_MATCH - def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: field = term.ref().field field_id = field.field_id @@ -1346,10 +1346,10 @@ def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_not_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_in(self, term: BoundTerm[L], literals: set[L]) -> bool: + def visit_in(self, term: BoundTerm, literals: set[L]) -> bool: field = term.ref().field field_id = field.field_id @@ -1385,12 +1385,12 @@ def visit_in(self, term: BoundTerm[L], literals: set[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_not_in(self, term: BoundTerm[L], literals: set[L]) -> bool: + def visit_not_in(self, term: BoundTerm, literals: set[L]) -> bool: # because the bounds are not necessarily a min or max value, this cannot be answered using # them. notIn(col, {X, ...}) with (X, Y) doesn't guarantee that X is a value in col. return ROWS_MIGHT_MATCH - def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_starts_with(self, term: BoundTerm, literal: Literal[L]) -> bool: field = term.ref().field field_id: int = field.field_id @@ -1419,7 +1419,7 @@ def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_not_starts_with(self, term: BoundTerm, literal: Literal[L]) -> bool: field = term.ref().field field_id: int = field.field_id @@ -1460,7 +1460,7 @@ def strict_projection( class StrictProjection(ProjectionEvaluator): - def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> BooleanExpression: + def visit_bound_predicate(self, predicate: BoundPredicate) -> BooleanExpression: parts = self.spec.fields_by_source_id(predicate.term.ref().field.field_id) result: BooleanExpression = AlwaysFalse() @@ -1511,7 +1511,7 @@ def eval(self, file: DataFile) -> bool: return visit(self.expr, self) - def visit_is_null(self, term: BoundTerm[L]) -> bool: + def visit_is_null(self, term: BoundTerm) -> bool: # no need to check whether the field is required because binding evaluates that case # if the column has any non-null values, the expression does not match field_id = term.ref().field.field_id @@ -1521,7 +1521,7 @@ def visit_is_null(self, term: BoundTerm[L]) -> bool: else: return ROWS_MIGHT_NOT_MATCH - def visit_not_null(self, term: BoundTerm[L]) -> bool: + def visit_not_null(self, term: BoundTerm) -> bool: # no need to check whether the field is required because binding evaluates that case # if the column has any non-null values, the expression does not match field_id = term.ref().field.field_id @@ -1531,7 +1531,7 @@ def visit_not_null(self, term: BoundTerm[L]) -> bool: else: return ROWS_MIGHT_NOT_MATCH - def visit_is_nan(self, term: BoundTerm[L]) -> bool: + def visit_is_nan(self, term: BoundTerm) -> bool: field_id = term.ref().field.field_id if self._contains_nans_only(field_id): @@ -1539,7 +1539,7 @@ def visit_is_nan(self, term: BoundTerm[L]) -> bool: else: return ROWS_MIGHT_NOT_MATCH - def visit_not_nan(self, term: BoundTerm[L]) -> bool: + def visit_not_nan(self, term: BoundTerm) -> bool: field_id = term.ref().field.field_id if (nan_count := self.nan_counts.get(field_id)) is not None and nan_count == 0: @@ -1550,7 +1550,7 @@ def visit_not_nan(self, term: BoundTerm[L]) -> bool: return ROWS_MIGHT_NOT_MATCH - def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_less_than(self, term: BoundTerm, literal: Literal[L]) -> bool: # Rows must match when: <----------Min----Max---X-------> field_id = term.ref().field.field_id @@ -1567,7 +1567,7 @@ def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_NOT_MATCH - def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: # Rows must match when: <----------Min----Max---X-------> field_id = term.ref().field.field_id @@ -1584,7 +1584,7 @@ def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> b return ROWS_MIGHT_NOT_MATCH - def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_greater_than(self, term: BoundTerm, literal: Literal[L]) -> bool: # Rows must match when: <-------X---Min----Max----------> field_id = term.ref().field.field_id @@ -1606,7 +1606,7 @@ def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_NOT_MATCH - def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: # Rows must match when: <-------X---Min----Max----------> field_id = term.ref().field.field_id @@ -1627,7 +1627,7 @@ def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) - return ROWS_MIGHT_NOT_MATCH - def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: # Rows must match when Min == X == Max field_id = term.ref().field.field_id @@ -1646,7 +1646,7 @@ def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_NOT_MATCH - def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_not_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: # Rows must match when X < Min or Max < X because it is not in the range field_id = term.ref().field.field_id @@ -1674,7 +1674,7 @@ def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_NOT_MATCH - def visit_in(self, term: BoundTerm[L], literals: set[L]) -> bool: + def visit_in(self, term: BoundTerm, literals: set[L]) -> bool: field_id = term.ref().field.field_id if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): @@ -1703,7 +1703,7 @@ def visit_in(self, term: BoundTerm[L], literals: set[L]) -> bool: return ROWS_MIGHT_NOT_MATCH - def visit_not_in(self, term: BoundTerm[L], literals: set[L]) -> bool: + def visit_not_in(self, term: BoundTerm, literals: set[L]) -> bool: field_id = term.ref().field.field_id if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): @@ -1733,10 +1733,10 @@ def visit_not_in(self, term: BoundTerm[L], literals: set[L]) -> bool: return ROWS_MIGHT_NOT_MATCH - def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_starts_with(self, term: BoundTerm, literal: Literal[L]) -> bool: return ROWS_MIGHT_NOT_MATCH - def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_not_starts_with(self, term: BoundTerm, literal: Literal[L]) -> bool: return ROWS_MIGHT_NOT_MATCH def _get_field(self, field_id: int) -> NestedField: @@ -1799,94 +1799,94 @@ def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpress def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: return Or(left_result, right_result) - def visit_is_null(self, term: BoundTerm[L]) -> BooleanExpression: + def visit_is_null(self, term: BoundTerm) -> BooleanExpression: if term.eval(self.struct) is None: return AlwaysTrue() else: return AlwaysFalse() - def visit_not_null(self, term: BoundTerm[L]) -> BooleanExpression: + def visit_not_null(self, term: BoundTerm) -> BooleanExpression: if term.eval(self.struct) is not None: return AlwaysTrue() else: return AlwaysFalse() - def visit_is_nan(self, term: BoundTerm[L]) -> BooleanExpression: + def visit_is_nan(self, term: BoundTerm) -> BooleanExpression: val = term.eval(self.struct) if isinstance(val, SupportsFloat) and math.isnan(val): return self.visit_true() else: return self.visit_false() - def visit_not_nan(self, term: BoundTerm[L]) -> BooleanExpression: + def visit_not_nan(self, term: BoundTerm) -> BooleanExpression: val = term.eval(self.struct) if isinstance(val, SupportsFloat) and not math.isnan(val): return self.visit_true() else: return self.visit_false() - def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + def visit_less_than(self, term: BoundTerm, literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) < literal.value: return self.visit_true() else: return self.visit_false() - def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) <= literal.value: return self.visit_true() else: return self.visit_false() - def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + def visit_greater_than(self, term: BoundTerm, literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) > literal.value: return self.visit_true() else: return self.visit_false() - def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) >= literal.value: return self.visit_true() else: return self.visit_false() - def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + def visit_equal(self, term: BoundTerm, literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) == literal.value: return self.visit_true() else: return self.visit_false() - def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + def visit_not_equal(self, term: BoundTerm, literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) != literal.value: return self.visit_true() else: return self.visit_false() - def visit_in(self, term: BoundTerm[L], literals: set[L]) -> BooleanExpression: + def visit_in(self, term: BoundTerm, literals: set[L]) -> BooleanExpression: if term.eval(self.struct) in literals: return self.visit_true() else: return self.visit_false() - def visit_not_in(self, term: BoundTerm[L], literals: set[L]) -> BooleanExpression: + def visit_not_in(self, term: BoundTerm, literals: set[L]) -> BooleanExpression: if term.eval(self.struct) not in literals: return self.visit_true() else: return self.visit_false() - def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + def visit_starts_with(self, term: BoundTerm, literal: Literal[L]) -> BooleanExpression: eval_res = term.eval(self.struct) if eval_res is not None and str(eval_res).startswith(str(literal.value)): return AlwaysTrue() else: return AlwaysFalse() - def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + def visit_not_starts_with(self, term: BoundTerm, literal: Literal[L]) -> BooleanExpression: if not self.visit_starts_with(term, literal): return AlwaysTrue() else: return AlwaysFalse() - def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> BooleanExpression: + def visit_bound_predicate(self, predicate: BoundPredicate) -> BooleanExpression: """ If there is no strict projection or if it evaluates to false, then return the predicate. @@ -1940,7 +1940,7 @@ def struct_to_schema(struct: StructType) -> Schema: return predicate - def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpression: + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> BooleanExpression: bound = predicate.bind(self.schema, case_sensitive=self.case_sensitive) if isinstance(bound, BoundPredicate): diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index cd19d43906..0210839df5 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -817,7 +817,7 @@ class _ConvertToArrowExpression(BoundBooleanExpressionVisitor[pc.Expression]): def __init__(self, schema: Schema | None = None): self._schema = schema - def _get_field_name(self, term: BoundTerm[Any]) -> str | tuple[str, ...]: + def _get_field_name(self, term: BoundTerm) -> str | tuple[str, ...]: """Get the field name or nested field path for a bound term. For nested struct fields, returns a tuple of field names (e.g., ("mazeMetadata", "run_id")). @@ -837,50 +837,50 @@ def _get_field_name(self, term: BoundTerm[Any]) -> str | tuple[str, ...]: # Fallback to just the field name if schema is not available return term.ref().field.name - def visit_in(self, term: BoundTerm[Any], literals: set[Any]) -> pc.Expression: + def visit_in(self, term: BoundTerm, literals: set[Any]) -> pc.Expression: pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type)) return pc.field(self._get_field_name(term)).isin(pyarrow_literals) - def visit_not_in(self, term: BoundTerm[Any], literals: set[Any]) -> pc.Expression: + def visit_not_in(self, term: BoundTerm, literals: set[Any]) -> pc.Expression: pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type)) return ~pc.field(self._get_field_name(term)).isin(pyarrow_literals) - def visit_is_nan(self, term: BoundTerm[Any]) -> pc.Expression: + def visit_is_nan(self, term: BoundTerm) -> pc.Expression: ref = pc.field(self._get_field_name(term)) return pc.is_nan(ref) - def visit_not_nan(self, term: BoundTerm[Any]) -> pc.Expression: + def visit_not_nan(self, term: BoundTerm) -> pc.Expression: ref = pc.field(self._get_field_name(term)) return ~pc.is_nan(ref) - def visit_is_null(self, term: BoundTerm[Any]) -> pc.Expression: + def visit_is_null(self, term: BoundTerm) -> pc.Expression: return pc.field(self._get_field_name(term)).is_null(nan_is_null=False) - def visit_not_null(self, term: BoundTerm[Any]) -> pc.Expression: + def visit_not_null(self, term: BoundTerm) -> pc.Expression: return pc.field(self._get_field_name(term)).is_valid() - def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: + def visit_equal(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression: return pc.field(self._get_field_name(term)) == _convert_scalar(literal.value, term.ref().field.field_type) - def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: + def visit_not_equal(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression: return pc.field(self._get_field_name(term)) != _convert_scalar(literal.value, term.ref().field.field_type) - def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: + def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression: return pc.field(self._get_field_name(term)) >= _convert_scalar(literal.value, term.ref().field.field_type) - def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: + def visit_greater_than(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression: return pc.field(self._get_field_name(term)) > _convert_scalar(literal.value, term.ref().field.field_type) - def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: + def visit_less_than(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression: return pc.field(self._get_field_name(term)) < _convert_scalar(literal.value, term.ref().field.field_type) - def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: + def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression: return pc.field(self._get_field_name(term)) <= _convert_scalar(literal.value, term.ref().field.field_type) - def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: + def visit_starts_with(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression: return pc.starts_with(pc.field(self._get_field_name(term)), literal.value) - def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: + def visit_not_starts_with(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression: return ~pc.starts_with(pc.field(self._get_field_name(term)), literal.value) def visit_true(self) -> pc.Expression: @@ -901,13 +901,13 @@ def visit_or(self, left_result: pc.Expression, right_result: pc.Expression) -> p class _NullNaNUnmentionedTermsCollector(BoundBooleanExpressionVisitor[None]): # BoundTerms which have either is_null or is_not_null appearing at least once in the boolean expr. - is_null_or_not_bound_terms: set[BoundTerm[Any]] + is_null_or_not_bound_terms: set[BoundTerm] # The remaining BoundTerms appearing in the boolean expr. - null_unmentioned_bound_terms: set[BoundTerm[Any]] + null_unmentioned_bound_terms: set[BoundTerm] # BoundTerms which have either is_nan or is_not_nan appearing at least once in the boolean expr. - is_nan_or_not_bound_terms: set[BoundTerm[Any]] + is_nan_or_not_bound_terms: set[BoundTerm] # The remaining BoundTerms appearing in the boolean expr. - nan_unmentioned_bound_terms: set[BoundTerm[Any]] + nan_unmentioned_bound_terms: set[BoundTerm] def __init__(self) -> None: super().__init__() @@ -916,81 +916,81 @@ def __init__(self) -> None: self.is_nan_or_not_bound_terms = set() self.nan_unmentioned_bound_terms = set() - def _handle_explicit_is_null_or_not(self, term: BoundTerm[Any]) -> None: + def _handle_explicit_is_null_or_not(self, term: BoundTerm) -> None: """Handle the predicate case where either is_null or is_not_null is included.""" if term in self.null_unmentioned_bound_terms: self.null_unmentioned_bound_terms.remove(term) self.is_null_or_not_bound_terms.add(term) - def _handle_null_unmentioned(self, term: BoundTerm[Any]) -> None: + def _handle_null_unmentioned(self, term: BoundTerm) -> None: """Handle the predicate case where neither is_null or is_not_null is included.""" if term not in self.is_null_or_not_bound_terms: self.null_unmentioned_bound_terms.add(term) - def _handle_explicit_is_nan_or_not(self, term: BoundTerm[Any]) -> None: + def _handle_explicit_is_nan_or_not(self, term: BoundTerm) -> None: """Handle the predicate case where either is_nan or is_not_nan is included.""" if term in self.nan_unmentioned_bound_terms: self.nan_unmentioned_bound_terms.remove(term) self.is_nan_or_not_bound_terms.add(term) - def _handle_nan_unmentioned(self, term: BoundTerm[Any]) -> None: + def _handle_nan_unmentioned(self, term: BoundTerm) -> None: """Handle the predicate case where neither is_nan or is_not_nan is included.""" if term not in self.is_nan_or_not_bound_terms: self.nan_unmentioned_bound_terms.add(term) - def visit_in(self, term: BoundTerm[Any], literals: set[Any]) -> None: + def visit_in(self, term: BoundTerm, literals: set[Any]) -> None: self._handle_null_unmentioned(term) self._handle_nan_unmentioned(term) - def visit_not_in(self, term: BoundTerm[Any], literals: set[Any]) -> None: + def visit_not_in(self, term: BoundTerm, literals: set[Any]) -> None: self._handle_null_unmentioned(term) self._handle_nan_unmentioned(term) - def visit_is_nan(self, term: BoundTerm[Any]) -> None: + def visit_is_nan(self, term: BoundTerm) -> None: self._handle_null_unmentioned(term) self._handle_explicit_is_nan_or_not(term) - def visit_not_nan(self, term: BoundTerm[Any]) -> None: + def visit_not_nan(self, term: BoundTerm) -> None: self._handle_null_unmentioned(term) self._handle_explicit_is_nan_or_not(term) - def visit_is_null(self, term: BoundTerm[Any]) -> None: + def visit_is_null(self, term: BoundTerm) -> None: self._handle_explicit_is_null_or_not(term) self._handle_nan_unmentioned(term) - def visit_not_null(self, term: BoundTerm[Any]) -> None: + def visit_not_null(self, term: BoundTerm) -> None: self._handle_explicit_is_null_or_not(term) self._handle_nan_unmentioned(term) - def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + def visit_equal(self, term: BoundTerm, literal: Literal[Any]) -> None: self._handle_null_unmentioned(term) self._handle_nan_unmentioned(term) - def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + def visit_not_equal(self, term: BoundTerm, literal: Literal[Any]) -> None: self._handle_null_unmentioned(term) self._handle_nan_unmentioned(term) - def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[Any]) -> None: self._handle_null_unmentioned(term) self._handle_nan_unmentioned(term) - def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + def visit_greater_than(self, term: BoundTerm, literal: Literal[Any]) -> None: self._handle_null_unmentioned(term) self._handle_nan_unmentioned(term) - def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + def visit_less_than(self, term: BoundTerm, literal: Literal[Any]) -> None: self._handle_null_unmentioned(term) self._handle_nan_unmentioned(term) - def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[Any]) -> None: self._handle_null_unmentioned(term) self._handle_nan_unmentioned(term) - def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + def visit_starts_with(self, term: BoundTerm, literal: Literal[Any]) -> None: self._handle_null_unmentioned(term) self._handle_nan_unmentioned(term) - def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + def visit_not_starts_with(self, term: BoundTerm, literal: Literal[Any]) -> None: self._handle_null_unmentioned(term) self._handle_nan_unmentioned(term) @@ -1040,10 +1040,10 @@ def _expression_to_complementary_pyarrow(expr: BooleanExpression, schema: Schema collector.collect(expr) # Convert the set of terms to a sorted list so that layout of the expression to build is deterministic. - null_unmentioned_bound_terms: list[BoundTerm[Any]] = sorted( + null_unmentioned_bound_terms: list[BoundTerm] = sorted( collector.null_unmentioned_bound_terms, key=lambda term: term.ref().field.name ) - nan_unmentioned_bound_terms: list[BoundTerm[Any]] = sorted( + nan_unmentioned_bound_terms: list[BoundTerm] = sorted( collector.nan_unmentioned_bound_terms, key=lambda term: term.ref().field.name ) diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index 98cfac1146..896a04527d 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -183,10 +183,10 @@ def result_type(self, source: IcebergType) -> IcebergType: ... @abstractmethod - def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: ... + def project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: ... @abstractmethod - def strict_project(self, name: str, pred: BoundPredicate[Any]) -> UnboundPredicate[Any] | None: ... + def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: ... @property def preserves_order(self) -> bool: @@ -269,7 +269,7 @@ def apply(self, value: S | None) -> int | None: def result_type(self, source: IcebergType) -> IcebergType: return IntegerType() - def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: + def project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: transformer = self.transform(pred.term.ref().field.field_type) if isinstance(pred.term, BoundTransform): @@ -286,7 +286,7 @@ def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | # For example, (x > 0) and (x < 3) can be turned into in({1, 2}) and projected. return None - def strict_project(self, name: str, pred: BoundPredicate[Any]) -> UnboundPredicate[Any] | None: + def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: transformer = self.transform(pred.term.ref().field.field_type) if isinstance(pred.term, BoundTransform): @@ -420,7 +420,7 @@ def result_type(self, source: IcebergType) -> IntegerType: @abstractmethod def transform(self, source: IcebergType) -> Callable[[Any | None], int | None]: ... - def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: + def project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: transformer = self.transform(pred.term.ref().field.field_type) if isinstance(pred.term, BoundTransform): return _project_transform_predicate(self, name, pred) @@ -433,7 +433,7 @@ def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | else: return None - def strict_project(self, name: str, pred: BoundPredicate[Any]) -> UnboundPredicate[Any] | None: + def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: transformer = self.transform(pred.term.ref().field.field_type) if isinstance(pred.term, BoundTransform): return _project_transform_predicate(self, name, pred) @@ -725,7 +725,7 @@ def can_transform(self, source: IcebergType) -> bool: def result_type(self, source: IcebergType) -> IcebergType: return source - def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: + def project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: if isinstance(pred.term, BoundTransform): return _project_transform_predicate(self, name, pred) elif isinstance(pred, BoundUnaryPredicate): @@ -737,7 +737,7 @@ def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | else: return None - def strict_project(self, name: str, pred: BoundPredicate[Any]) -> UnboundPredicate[Any] | None: + def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: if isinstance(pred, BoundUnaryPredicate): return pred.as_unbound(Reference(name)) elif isinstance(pred, BoundLiteralPredicate): @@ -801,7 +801,7 @@ def preserves_order(self) -> bool: def source_type(self) -> IcebergType: return self._source_type - def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: + def project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: field_type = pred.term.ref().field.field_type if isinstance(pred.term, BoundTransform): @@ -819,7 +819,7 @@ def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | return _truncate_array(name, pred, self.transform(field_type)) return None - def strict_project(self, name: str, pred: BoundPredicate[Any]) -> UnboundPredicate[Any] | None: + def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: field_type = pred.term.ref().field.field_type if isinstance(pred.term, BoundTransform): @@ -987,10 +987,10 @@ def can_transform(self, source: IcebergType) -> bool: def result_type(self, source: IcebergType) -> StringType: return StringType() - def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: + def project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: return None - def strict_project(self, name: str, pred: BoundPredicate[Any]) -> UnboundPredicate[Any] | None: + def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: return None def __repr__(self) -> str: @@ -1015,10 +1015,10 @@ def can_transform(self, _: IcebergType) -> bool: def result_type(self, source: IcebergType) -> IcebergType: return source - def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: + def project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: return None - def strict_project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: + def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: return None def to_human_string(self, _: IcebergType, value: S | None) -> str: @@ -1038,8 +1038,8 @@ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Arr def _truncate_number( - name: str, pred: BoundLiteralPredicate[L], transform: Callable[[L | None], L | None] -) -> UnboundPredicate[Any] | None: + name: str, pred: BoundLiteralPredicate, transform: Callable[[Any | None], Any | None] +) -> UnboundPredicate | None: boundary = pred.literal if not isinstance(boundary, (LongLiteral, DecimalLiteral, DateLiteral, TimestampLiteral)): @@ -1060,8 +1060,8 @@ def _truncate_number( def _truncate_number_strict( - name: str, pred: BoundLiteralPredicate[L], transform: Callable[[L | None], L | None] -) -> UnboundPredicate[Any] | None: + name: str, pred: BoundLiteralPredicate, transform: Callable[[Any | None], Any | None] +) -> UnboundPredicate | None: boundary = pred.literal if not isinstance(boundary, (LongLiteral, DecimalLiteral, DateLiteral, TimestampLiteral)): @@ -1086,8 +1086,8 @@ def _truncate_number_strict( def _truncate_array_strict( - name: str, pred: BoundLiteralPredicate[L], transform: Callable[[L | None], L | None] -) -> UnboundPredicate[Any] | None: + name: str, pred: BoundLiteralPredicate, transform: Callable[[Any | None], Any | None] +) -> UnboundPredicate | None: boundary = pred.literal if isinstance(pred, (BoundLessThan, BoundLessThanOrEqual)): @@ -1101,8 +1101,8 @@ def _truncate_array_strict( def _truncate_array( - name: str, pred: BoundLiteralPredicate[L], transform: Callable[[L | None], L | None] -) -> UnboundPredicate[Any] | None: + name: str, pred: BoundLiteralPredicate, transform: Callable[[Any | None], Any | None] +) -> UnboundPredicate | None: boundary = pred.literal if isinstance(pred, (BoundLessThan, BoundLessThanOrEqual)): @@ -1120,15 +1120,15 @@ def _truncate_array( def _project_transform_predicate( - transform: Transform[Any, Any], partition_name: str, pred: BoundPredicate[L] -) -> UnboundPredicate[Any] | None: + transform: Transform[Any, Any], partition_name: str, pred: BoundPredicate +) -> UnboundPredicate | None: term = pred.term if isinstance(term, BoundTransform) and transform == term.transform: return _remove_transform(partition_name, pred) return None -def _remove_transform(partition_name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any]: +def _remove_transform(partition_name: str, pred: BoundPredicate) -> UnboundPredicate: if isinstance(pred, BoundUnaryPredicate): return pred.as_unbound(Reference(partition_name)) elif isinstance(pred, BoundLiteralPredicate): @@ -1139,7 +1139,7 @@ def _remove_transform(partition_name: str, pred: BoundPredicate[L]) -> UnboundPr raise ValueError(f"Cannot replace transform in unknown predicate: {pred}") -def _set_apply_transform(name: str, pred: BoundSetPredicate[L], transform: Callable[[L], L]) -> UnboundPredicate[Any]: +def _set_apply_transform(name: str, pred: BoundSetPredicate, transform: Callable[[Any], Any]) -> UnboundPredicate: literals = pred.literals if isinstance(pred, BoundSetPredicate): transformed_literals = {_transform_literal(transform, literal) for literal in literals} @@ -1148,11 +1148,11 @@ def _set_apply_transform(name: str, pred: BoundSetPredicate[L], transform: Calla raise ValueError(f"Unknown BoundSetPredicate: {pred}") -class BoundTransform(BoundTerm[L]): +class BoundTransform(BoundTerm): """A transform expression.""" - transform: Transform[L, Any] + transform: Transform[Any, Any] - def __init__(self, term: BoundTerm[L], transform: Transform[L, Any]): - self.term: BoundTerm[L] = term + def __init__(self, term: BoundTerm, transform: Transform[Any, Any]): + self.term: BoundTerm = term self.transform = transform diff --git a/tests/conftest.py b/tests/conftest.py index 947fc00a83..fcd188f6a4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2558,17 +2558,17 @@ def table_v2_with_statistics(table_metadata_v2_with_statistics: dict[str, Any]) @pytest.fixture -def bound_reference_str() -> BoundReference[str]: +def bound_reference_str() -> BoundReference: return BoundReference(field=NestedField(1, "field", StringType(), required=False), accessor=Accessor(position=0, inner=None)) @pytest.fixture -def bound_reference_binary() -> BoundReference[str]: +def bound_reference_binary() -> BoundReference: return BoundReference(field=NestedField(1, "field", BinaryType(), required=False), accessor=Accessor(position=0, inner=None)) @pytest.fixture -def bound_reference_uuid() -> BoundReference[str]: +def bound_reference_uuid() -> BoundReference: return BoundReference(field=NestedField(1, "field", UUIDType(), required=False), accessor=Accessor(position=0, inner=None)) diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index dbee2ca045..8a62369775 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -19,7 +19,6 @@ import pickle import uuid from decimal import Decimal -from typing import Any import pytest from typing_extensions import assert_type @@ -502,7 +501,7 @@ def test_less_than_or_equal_invert() -> None: LessThanOrEqual(Reference("foo"), "hello"), ], ) -def test_bind(pred: UnboundPredicate[Any], table_schema_simple: Schema) -> None: +def test_bind(pred: UnboundPredicate, table_schema_simple: Schema) -> None: assert pred.bind(table_schema_simple, case_sensitive=True).term.field == table_schema_simple.find_field( # type: ignore pred.term.name, # type: ignore case_sensitive=True, @@ -522,7 +521,7 @@ def test_bind(pred: UnboundPredicate[Any], table_schema_simple: Schema) -> None: LessThanOrEqual(Reference("Bar"), 5), ], ) -def test_bind_case_insensitive(pred: UnboundPredicate[Any], table_schema_simple: Schema) -> None: +def test_bind_case_insensitive(pred: UnboundPredicate, table_schema_simple: Schema) -> None: assert pred.bind(table_schema_simple, case_sensitive=False).term.field == table_schema_simple.find_field( # type: ignore pred.term.name, # type: ignore case_sensitive=False, @@ -683,7 +682,7 @@ def accessor() -> Accessor: @pytest.fixture -def term(field: NestedField, accessor: Accessor) -> BoundReference[Any]: +def term(field: NestedField, accessor: Accessor) -> BoundReference: return BoundReference( field=field, accessor=accessor, @@ -794,14 +793,14 @@ def test_bound_reference_field_property() -> None: assert bound_ref.field == NestedField(field_id=1, name="foo", field_type=StringType(), required=False) -def test_bound_is_null(term: BoundReference[Any]) -> None: +def test_bound_is_null(term: BoundReference) -> None: bound_is_null = BoundIsNull(term) assert str(bound_is_null) == f"BoundIsNull(term={str(term)})" assert repr(bound_is_null) == f"BoundIsNull(term={repr(term)})" assert bound_is_null == eval(repr(bound_is_null)) -def test_bound_is_not_null(term: BoundReference[Any]) -> None: +def test_bound_is_not_null(term: BoundReference) -> None: bound_not_null = BoundNotNull(term) assert str(bound_not_null) == f"BoundNotNull(term={str(term)})" assert repr(bound_not_null) == f"BoundNotNull(term={repr(term)})" @@ -838,7 +837,7 @@ def test_serialize_not_null() -> None: def test_bound_is_nan(accessor: Accessor) -> None: # We need a FloatType here - term = BoundReference[float]( + term = BoundReference( field=NestedField(field_id=1, name="foo", field_type=FloatType(), required=False), accessor=accessor, ) @@ -851,7 +850,7 @@ def test_bound_is_nan(accessor: Accessor) -> None: def test_bound_is_not_nan(accessor: Accessor) -> None: # We need a FloatType here - term = BoundReference[float]( + term = BoundReference( field=NestedField(field_id=1, name="foo", field_type=FloatType(), required=False), accessor=accessor, ) @@ -880,7 +879,7 @@ def test_not_nan() -> None: assert not_nan == pickle.loads(pickle.dumps(not_nan)) -def test_bound_in(term: BoundReference[Any]) -> None: +def test_bound_in(term: BoundReference) -> None: bound_in = BoundIn(term, {literal("a"), literal("b"), literal("c")}) assert str(bound_in) == f"BoundIn({str(term)}, {{a, b, c}})" assert repr(bound_in) == f"BoundIn({repr(term)}, {{literal('a'), literal('b'), literal('c')}})" @@ -888,7 +887,7 @@ def test_bound_in(term: BoundReference[Any]) -> None: assert bound_in == pickle.loads(pickle.dumps(bound_in)) -def test_bound_not_in(term: BoundReference[Any]) -> None: +def test_bound_not_in(term: BoundReference) -> None: bound_not_in = BoundNotIn(term, {literal("a"), literal("b"), literal("c")}) assert str(bound_not_in) == f"BoundNotIn({str(term)}, {{a, b, c}})" assert repr(bound_not_in) == f"BoundNotIn({repr(term)}, {{literal('a'), literal('b'), literal('c')}})" @@ -924,7 +923,7 @@ def test_serialize_not_in() -> None: assert pred.model_dump_json() == '{"term":"foo","type":"not-in","items":[1,2,3]}' -def test_bound_equal_to(term: BoundReference[Any]) -> None: +def test_bound_equal_to(term: BoundReference) -> None: bound_equal_to = BoundEqualTo(term, literal("a")) assert str(bound_equal_to) == f"BoundEqualTo(term={str(term)}, literal=literal('a'))" assert repr(bound_equal_to) == f"BoundEqualTo(term={repr(term)}, literal=literal('a'))" @@ -932,7 +931,7 @@ def test_bound_equal_to(term: BoundReference[Any]) -> None: assert bound_equal_to == pickle.loads(pickle.dumps(bound_equal_to)) -def test_bound_not_equal_to(term: BoundReference[Any]) -> None: +def test_bound_not_equal_to(term: BoundReference) -> None: bound_not_equal_to = BoundNotEqualTo(term, literal("a")) assert str(bound_not_equal_to) == f"BoundNotEqualTo(term={str(term)}, literal=literal('a'))" assert repr(bound_not_equal_to) == f"BoundNotEqualTo(term={repr(term)}, literal=literal('a'))" @@ -940,7 +939,7 @@ def test_bound_not_equal_to(term: BoundReference[Any]) -> None: assert bound_not_equal_to == pickle.loads(pickle.dumps(bound_not_equal_to)) -def test_bound_greater_than_or_equal_to(term: BoundReference[Any]) -> None: +def test_bound_greater_than_or_equal_to(term: BoundReference) -> None: bound_greater_than_or_equal_to = BoundGreaterThanOrEqual(term, literal("a")) assert str(bound_greater_than_or_equal_to) == f"BoundGreaterThanOrEqual(term={str(term)}, literal=literal('a'))" assert repr(bound_greater_than_or_equal_to) == f"BoundGreaterThanOrEqual(term={repr(term)}, literal=literal('a'))" @@ -948,7 +947,7 @@ def test_bound_greater_than_or_equal_to(term: BoundReference[Any]) -> None: assert bound_greater_than_or_equal_to == pickle.loads(pickle.dumps(bound_greater_than_or_equal_to)) -def test_bound_greater_than(term: BoundReference[Any]) -> None: +def test_bound_greater_than(term: BoundReference) -> None: bound_greater_than = BoundGreaterThan(term, literal("a")) assert str(bound_greater_than) == f"BoundGreaterThan(term={str(term)}, literal=literal('a'))" assert repr(bound_greater_than) == f"BoundGreaterThan(term={repr(term)}, literal=literal('a'))" @@ -956,7 +955,7 @@ def test_bound_greater_than(term: BoundReference[Any]) -> None: assert bound_greater_than == pickle.loads(pickle.dumps(bound_greater_than)) -def test_bound_less_than(term: BoundReference[Any]) -> None: +def test_bound_less_than(term: BoundReference) -> None: bound_less_than = BoundLessThan(term, literal("a")) assert str(bound_less_than) == f"BoundLessThan(term={str(term)}, literal=literal('a'))" assert repr(bound_less_than) == f"BoundLessThan(term={repr(term)}, literal=literal('a'))" @@ -964,7 +963,7 @@ def test_bound_less_than(term: BoundReference[Any]) -> None: assert bound_less_than == pickle.loads(pickle.dumps(bound_less_than)) -def test_bound_less_than_or_equal(term: BoundReference[Any]) -> None: +def test_bound_less_than_or_equal(term: BoundReference) -> None: bound_less_than_or_equal = BoundLessThanOrEqual(term, literal("a")) assert str(bound_less_than_or_equal) == f"BoundLessThanOrEqual(term={str(term)}, literal=literal('a'))" assert repr(bound_less_than_or_equal) == f"BoundLessThanOrEqual(term={repr(term)}, literal=literal('a'))" @@ -1231,7 +1230,7 @@ def test_above_long_bounds_greater_than_or_equal( assert GreaterThanOrEqual[int]("a", below_long_min).bind(long_schema) is AlwaysTrue() -def test_eq_bound_expression(bound_reference_str: BoundReference[str]) -> None: +def test_eq_bound_expression(bound_reference_str: BoundReference) -> None: assert BoundEqualTo(term=bound_reference_str, literal=literal("a")) != BoundGreaterThanOrEqual( term=bound_reference_str, literal=literal("a") ) @@ -1283,6 +1282,6 @@ def _assert_literal_predicate_type(expr: LiteralPredicate[L]) -> None: _assert_literal_predicate_type(In("a", ("a", "b", "c"))) _assert_literal_predicate_type(In("a", (1, 2, 3))) _assert_literal_predicate_type(NotIn("a", ("a", "b", "c"))) -assert_type(In("a", ("a", "b", "c")), In[str]) -assert_type(In("a", (1, 2, 3)), In[int]) -assert_type(NotIn("a", ("a", "b", "c")), NotIn[str]) +assert_type(In("a", ("a", "b", "c")), In) +assert_type(In("a", (1, 2, 3)), In) +assert_type(NotIn("a", ("a", "b", "c")), NotIn) diff --git a/tests/expressions/test_visitors.py b/tests/expressions/test_visitors.py index 2847859db5..798e9f641e 100644 --- a/tests/expressions/test_visitors.py +++ b/tests/expressions/test_visitors.py @@ -121,11 +121,11 @@ def visit_or(self, left_result: list[str], right_result: list[str]) -> list[str] self.visit_history.append("OR") return self.visit_history - def visit_unbound_predicate(self, predicate: UnboundPredicate[Any]) -> list[str]: + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> list[str]: self.visit_history.append(str(predicate.__class__.__name__).upper()) return self.visit_history - def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> list[str]: + def visit_bound_predicate(self, predicate: BoundPredicate) -> list[str]: self.visit_history.append(str(predicate.__class__.__name__).upper()) return self.visit_history @@ -139,51 +139,51 @@ class FooBoundBooleanExpressionVisitor(BoundBooleanExpressionVisitor[list[str]]) def __init__(self) -> None: self.visit_history: list[str] = [] - def visit_in(self, term: BoundTerm[Any], literals: set[Any]) -> list[str]: + def visit_in(self, term: BoundTerm, literals: set[Any]) -> list[str]: self.visit_history.append("IN") return self.visit_history - def visit_not_in(self, term: BoundTerm[Any], literals: set[Any]) -> list[str]: + def visit_not_in(self, term: BoundTerm, literals: set[Any]) -> list[str]: self.visit_history.append("NOT_IN") return self.visit_history - def visit_is_nan(self, term: BoundTerm[Any]) -> list[str]: + def visit_is_nan(self, term: BoundTerm) -> list[str]: self.visit_history.append("IS_NAN") return self.visit_history - def visit_not_nan(self, term: BoundTerm[Any]) -> list[str]: + def visit_not_nan(self, term: BoundTerm) -> list[str]: self.visit_history.append("NOT_NAN") return self.visit_history - def visit_is_null(self, term: BoundTerm[Any]) -> list[str]: + def visit_is_null(self, term: BoundTerm) -> list[str]: self.visit_history.append("IS_NULL") return self.visit_history - def visit_not_null(self, term: BoundTerm[Any]) -> list[str]: + def visit_not_null(self, term: BoundTerm) -> list[str]: self.visit_history.append("NOT_NULL") return self.visit_history - def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name + def visit_equal(self, term: BoundTerm, literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name self.visit_history.append("EQUAL") return self.visit_history - def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name + def visit_not_equal(self, term: BoundTerm, literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name self.visit_history.append("NOT_EQUAL") return self.visit_history - def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name + def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name self.visit_history.append("GREATER_THAN_OR_EQUAL") return self.visit_history - def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name + def visit_greater_than(self, term: BoundTerm, literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name self.visit_history.append("GREATER_THAN") return self.visit_history - def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name + def visit_less_than(self, term: BoundTerm, literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name self.visit_history.append("LESS_THAN") return self.visit_history - def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name + def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name self.visit_history.append("LESS_THAN_OR_EQUAL") return self.visit_history @@ -207,11 +207,11 @@ def visit_or(self, left_result: list[str], right_result: list[str]) -> list[str] self.visit_history.append("OR") return self.visit_history - def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> list[str]: + def visit_starts_with(self, term: BoundTerm, literal: Literal[Any]) -> list[str]: self.visit_history.append("STARTS_WITH") return self.visit_history - def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> list[str]: + def visit_not_starts_with(self, term: BoundTerm, literal: Literal[Any]) -> list[str]: self.visit_history.append("NOT_STARTS_WITH") return self.visit_history @@ -253,7 +253,7 @@ def test_boolean_expression_visit_raise_not_implemented_error() -> None: def test_bind_visitor_already_bound(table_schema_simple: Schema) -> None: - bound = BoundEqualTo[str]( + bound = BoundEqualTo( term=BoundReference(table_schema_simple.find_field(1), table_schema_simple.accessor_for_field(1)), literal=literal("hello"), ) @@ -315,7 +315,7 @@ def test_always_false_or_always_true_expression_binding(table_schema_simple: Sch ), {literal("foo"), literal("bar")}, ), - BoundIn[int]( + BoundIn( BoundReference( field=NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), accessor=Accessor(position=1, inner=None), @@ -345,7 +345,7 @@ def test_always_false_or_always_true_expression_binding(table_schema_simple: Sch {literal("bar"), literal("baz")}, ), And( - BoundEqualTo[int]( + BoundEqualTo( BoundReference( field=NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), accessor=Accessor(position=1, inner=None), @@ -365,7 +365,7 @@ def test_always_false_or_always_true_expression_binding(table_schema_simple: Sch ], ) def test_and_expression_binding( - unbound_and_expression: UnboundPredicate[Any], expected_bound_expression: BoundPredicate[Any], table_schema_simple: Schema + unbound_and_expression: UnboundPredicate, expected_bound_expression: BoundPredicate, table_schema_simple: Schema ) -> None: """Test that visiting an unbound AND expression with a bind-visitor returns the expected bound expression""" bound_expression = visit(unbound_and_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True)) @@ -388,7 +388,7 @@ def test_and_expression_binding( ), {literal("foo"), literal("bar")}, ), - BoundIn[int]( + BoundIn( BoundReference( field=NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), accessor=Accessor(position=1, inner=None), @@ -459,7 +459,7 @@ def test_and_expression_binding( ], ) def test_or_expression_binding( - unbound_or_expression: UnboundPredicate[Any], expected_bound_expression: BoundPredicate[Any], table_schema_simple: Schema + unbound_or_expression: UnboundPredicate, expected_bound_expression: BoundPredicate, table_schema_simple: Schema ) -> None: """Test that visiting an unbound OR expression with a bind-visitor returns the expected bound expression""" bound_expression = visit(unbound_or_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True)) @@ -505,7 +505,7 @@ def test_or_expression_binding( ], ) def test_in_expression_binding( - unbound_in_expression: UnboundPredicate[Any], expected_bound_expression: BoundPredicate[Any], table_schema_simple: Schema + unbound_in_expression: UnboundPredicate, expected_bound_expression: BoundPredicate, table_schema_simple: Schema ) -> None: """Test that visiting an unbound IN expression with a bind-visitor returns the expected bound expression""" bound_expression = visit(unbound_in_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True)) @@ -556,7 +556,7 @@ def test_in_expression_binding( ], ) def test_not_expression_binding( - unbound_not_expression: UnboundPredicate[Any], expected_bound_expression: BoundPredicate[Any], table_schema_simple: Schema + unbound_not_expression: UnboundPredicate, expected_bound_expression: BoundPredicate, table_schema_simple: Schema ) -> None: """Test that visiting an unbound NOT expression with a bind-visitor returns the expected bound expression""" bound_expression = visit(unbound_not_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True)) @@ -1590,16 +1590,16 @@ def test_to_dnf_not_and() -> None: def test_dnf_to_dask(table_schema_simple: Schema) -> None: expr = ( - BoundGreaterThan[str]( + BoundGreaterThan( term=BoundReference(table_schema_simple.find_field(1), table_schema_simple.accessor_for_field(1)), literal=literal("hello"), ), And( - BoundIn[int]( + BoundIn( term=BoundReference(table_schema_simple.find_field(2), table_schema_simple.accessor_for_field(2)), literals={literal(1), literal(2), literal(3)}, ), - BoundEqualTo[bool]( + BoundEqualTo( term=BoundReference(table_schema_simple.find_field(3), table_schema_simple.accessor_for_field(3)), literal=literal(True), ), diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 5758dbe4e5..3977dc2143 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -633,12 +633,12 @@ def test_list_type_to_pyarrow() -> None: @pytest.fixture -def bound_reference(table_schema_simple: Schema) -> BoundReference[str]: +def bound_reference(table_schema_simple: Schema) -> BoundReference: return BoundReference(table_schema_simple.find_field(1), table_schema_simple.accessor_for_field(1)) @pytest.fixture -def bound_double_reference() -> BoundReference[float]: +def bound_double_reference() -> BoundReference: schema = Schema( NestedField(field_id=1, name="foo", field_type=DoubleType(), required=False), schema_id=1, @@ -647,68 +647,68 @@ def bound_double_reference() -> BoundReference[float]: return BoundReference(schema.find_field(1), schema.accessor_for_field(1)) -def test_expr_is_null_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_is_null_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(BoundIsNull(bound_reference))) == "" ) -def test_expr_not_null_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_not_null_to_pyarrow(bound_reference: BoundReference) -> None: assert repr(expression_to_pyarrow(BoundNotNull(bound_reference))) == "" -def test_expr_is_nan_to_pyarrow(bound_double_reference: BoundReference[str]) -> None: +def test_expr_is_nan_to_pyarrow(bound_double_reference: BoundReference) -> None: assert repr(expression_to_pyarrow(BoundIsNaN(bound_double_reference))) == "" -def test_expr_not_nan_to_pyarrow(bound_double_reference: BoundReference[str]) -> None: +def test_expr_not_nan_to_pyarrow(bound_double_reference: BoundReference) -> None: assert repr(expression_to_pyarrow(BoundNotNaN(bound_double_reference))) == "" -def test_expr_equal_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_equal_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(BoundEqualTo(bound_reference, literal("hello")))) == '' ) -def test_expr_not_equal_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_not_equal_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(BoundNotEqualTo(bound_reference, literal("hello")))) == '' ) -def test_expr_greater_than_or_equal_equal_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_greater_than_or_equal_equal_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(BoundGreaterThanOrEqual(bound_reference, literal("hello")))) == '= "hello")>' ) -def test_expr_greater_than_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_greater_than_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(BoundGreaterThan(bound_reference, literal("hello")))) == ' "hello")>' ) -def test_expr_less_than_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_less_than_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(BoundLessThan(bound_reference, literal("hello")))) == '' ) -def test_expr_less_than_or_equal_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_less_than_or_equal_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(BoundLessThanOrEqual(bound_reference, literal("hello")))) == '' ) -def test_expr_in_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_in_to_pyarrow(bound_reference: BoundReference) -> None: assert repr(expression_to_pyarrow(BoundIn(bound_reference, {literal("hello"), literal("world")}))) in ( """ None: ) -def test_expr_not_in_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_not_in_to_pyarrow(bound_reference: BoundReference) -> None: assert repr(expression_to_pyarrow(BoundNotIn(bound_reference, {literal("hello"), literal("world")}))) in ( """ None: ) -def test_expr_starts_with_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_starts_with_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(BoundStartsWith(bound_reference, literal("he")))) == '' ) -def test_expr_not_starts_with_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_not_starts_with_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(BoundNotStartsWith(bound_reference, literal("he")))) == '' ) -def test_and_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_and_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(And(BoundEqualTo(bound_reference, literal("hello")), BoundIsNull(bound_reference)))) == '' ) -def test_or_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_or_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(Or(BoundEqualTo(bound_reference, literal("hello")), BoundIsNull(bound_reference)))) == '' ) -def test_not_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_not_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(Not(BoundEqualTo(bound_reference, literal("hello"))))) == '' ) -def test_always_true_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_always_true_to_pyarrow(bound_reference: BoundReference) -> None: assert repr(expression_to_pyarrow(AlwaysTrue())) == "" -def test_always_false_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_always_false_to_pyarrow(bound_reference: BoundReference) -> None: assert repr(expression_to_pyarrow(AlwaysFalse())) == "" diff --git a/tests/io/test_pyarrow_visitor.py b/tests/io/test_pyarrow_visitor.py index 9d5772d01c..59a4857699 100644 --- a/tests/io/test_pyarrow_visitor.py +++ b/tests/io/test_pyarrow_visitor.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=protected-access,unused-argument,redefined-outer-name import re -from typing import Any import pyarrow as pa import pytest @@ -717,21 +716,21 @@ def test_pyarrow_schema_round_trip_ensure_large_types_and_then_small_types(pyarr @pytest.fixture -def bound_reference_str() -> BoundReference[Any]: +def bound_reference_str() -> BoundReference: return BoundReference( field=NestedField(1, "string_field", StringType(), required=False), accessor=Accessor(position=0, inner=None) ) @pytest.fixture -def bound_reference_float() -> BoundReference[Any]: +def bound_reference_float() -> BoundReference: return BoundReference( field=NestedField(2, "float_field", FloatType(), required=False), accessor=Accessor(position=1, inner=None) ) @pytest.fixture -def bound_reference_double() -> BoundReference[Any]: +def bound_reference_double() -> BoundReference: return BoundReference( field=NestedField(3, "double_field", DoubleType(), required=False), accessor=Accessor(position=2, inner=None), @@ -739,32 +738,32 @@ def bound_reference_double() -> BoundReference[Any]: @pytest.fixture -def bound_eq_str_field(bound_reference_str: BoundReference[Any]) -> BoundEqualTo[Any]: +def bound_eq_str_field(bound_reference_str: BoundReference) -> BoundEqualTo: return BoundEqualTo(term=bound_reference_str, literal=literal("hello")) @pytest.fixture -def bound_greater_than_float_field(bound_reference_float: BoundReference[Any]) -> BoundGreaterThan[Any]: +def bound_greater_than_float_field(bound_reference_float: BoundReference) -> BoundGreaterThan: return BoundGreaterThan(term=bound_reference_float, literal=literal(100)) @pytest.fixture -def bound_is_nan_float_field(bound_reference_float: BoundReference[Any]) -> BoundIsNaN[Any]: +def bound_is_nan_float_field(bound_reference_float: BoundReference) -> BoundIsNaN: return BoundIsNaN(bound_reference_float) @pytest.fixture -def bound_eq_double_field(bound_reference_double: BoundReference[Any]) -> BoundEqualTo[Any]: +def bound_eq_double_field(bound_reference_double: BoundReference) -> BoundEqualTo: return BoundEqualTo(term=bound_reference_double, literal=literal(False)) @pytest.fixture -def bound_is_null_double_field(bound_reference_double: BoundReference[Any]) -> BoundIsNull[Any]: +def bound_is_null_double_field(bound_reference_double: BoundReference) -> BoundIsNull: return BoundIsNull(bound_reference_double) def test_collect_null_nan_unmentioned_terms( - bound_eq_str_field: BoundEqualTo[Any], bound_is_nan_float_field: BoundIsNaN[Any], bound_is_null_double_field: BoundIsNull[Any] + bound_eq_str_field: BoundEqualTo, bound_is_nan_float_field: BoundIsNaN, bound_is_null_double_field: BoundIsNull ) -> None: bound_expr = And( Or(And(bound_eq_str_field, bound_is_nan_float_field), bound_is_null_double_field), Not(bound_is_nan_float_field) @@ -786,11 +785,11 @@ def test_collect_null_nan_unmentioned_terms( def test_collect_null_nan_unmentioned_terms_with_multiple_predicates_on_the_same_term( - bound_eq_str_field: BoundEqualTo[Any], - bound_greater_than_float_field: BoundGreaterThan[Any], - bound_is_nan_float_field: BoundIsNaN[Any], - bound_eq_double_field: BoundEqualTo[Any], - bound_is_null_double_field: BoundIsNull[Any], + bound_eq_str_field: BoundEqualTo, + bound_greater_than_float_field: BoundGreaterThan, + bound_is_nan_float_field: BoundIsNaN, + bound_eq_double_field: BoundEqualTo, + bound_is_null_double_field: BoundIsNull, ) -> None: """Test a single term appears multiple places in the expression tree""" bound_expr = And( @@ -818,11 +817,11 @@ def test_collect_null_nan_unmentioned_terms_with_multiple_predicates_on_the_same def test_expression_to_complementary_pyarrow( - bound_eq_str_field: BoundEqualTo[Any], - bound_greater_than_float_field: BoundGreaterThan[Any], - bound_is_nan_float_field: BoundIsNaN[Any], - bound_eq_double_field: BoundEqualTo[Any], - bound_is_null_double_field: BoundIsNull[Any], + bound_eq_str_field: BoundEqualTo, + bound_greater_than_float_field: BoundGreaterThan, + bound_is_nan_float_field: BoundIsNaN, + bound_eq_double_field: BoundEqualTo, + bound_is_null_double_field: BoundIsNull, ) -> None: bound_expr = And( Or( diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 3d9bfcb555..88f53c51b0 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -91,7 +91,7 @@ YearTransform, parse_transform, ) -from pyiceberg.typedef import UTF8, L +from pyiceberg.typedef import UTF8 from pyiceberg.types import ( BinaryType, BooleanType, @@ -650,146 +650,146 @@ def test_datetime_transform_repr(transform: TimeTransform[Any], transform_repr: @pytest.fixture -def bound_reference_date() -> BoundReference[int]: +def bound_reference_date() -> BoundReference: return BoundReference(field=NestedField(1, "field", DateType(), required=False), accessor=Accessor(position=0, inner=None)) @pytest.fixture -def bound_reference_timestamp() -> BoundReference[int]: +def bound_reference_timestamp() -> BoundReference: return BoundReference( field=NestedField(1, "field", TimestampType(), required=False), accessor=Accessor(position=0, inner=None) ) @pytest.fixture -def bound_reference_decimal() -> BoundReference[Decimal]: +def bound_reference_decimal() -> BoundReference: return BoundReference( field=NestedField(1, "field", DecimalType(8, 2), required=False), accessor=Accessor(position=0, inner=None) ) @pytest.fixture -def bound_reference_int() -> BoundReference[int]: +def bound_reference_int() -> BoundReference: return BoundReference(field=NestedField(1, "field", IntegerType(), required=False), accessor=Accessor(position=0, inner=None)) @pytest.fixture -def bound_reference_long() -> BoundReference[int]: +def bound_reference_long() -> BoundReference: return BoundReference(field=NestedField(1, "field", LongType(), required=False), accessor=Accessor(position=0, inner=None)) -def test_projection_bucket_unary(bound_reference_str: BoundReference[str]) -> None: +def test_projection_bucket_unary(bound_reference_str: BoundReference) -> None: assert BucketTransform(2).project("name", BoundNotNull(term=bound_reference_str)) == NotNull(term=Reference(name="name")) -def test_projection_bucket_literal(bound_reference_str: BoundReference[str]) -> None: +def test_projection_bucket_literal(bound_reference_str: BoundReference) -> None: assert BucketTransform(2).project("name", BoundEqualTo(term=bound_reference_str, literal=literal("data"))) == EqualTo( term="name", literal=1 ) -def test_projection_bucket_set_same_bucket(bound_reference_str: BoundReference[str]) -> None: +def test_projection_bucket_set_same_bucket(bound_reference_str: BoundReference) -> None: assert BucketTransform(2).project( "name", BoundIn(term=bound_reference_str, literals={literal("hello"), literal("world")}) ) == EqualTo(term="name", literal=1) -def test_projection_bucket_set_in(bound_reference_str: BoundReference[str]) -> None: +def test_projection_bucket_set_in(bound_reference_str: BoundReference) -> None: assert BucketTransform(3).project( "name", BoundIn(term=bound_reference_str, literals={literal("hello"), literal("world")}) ) == In(term="name", literals={1, 2}) -def test_projection_bucket_set_not_in(bound_reference_str: BoundReference[str]) -> None: +def test_projection_bucket_set_not_in(bound_reference_str: BoundReference) -> None: assert ( BucketTransform(3).project("name", BoundNotIn(term=bound_reference_str, literals={literal("hello"), literal("world")})) is None ) -def test_projection_year_unary(bound_reference_date: BoundReference[int]) -> None: +def test_projection_year_unary(bound_reference_date: BoundReference) -> None: assert YearTransform().project("name", BoundNotNull(term=bound_reference_date)) == NotNull(term="name") -def test_projection_year_literal(bound_reference_date: BoundReference[int]) -> None: +def test_projection_year_literal(bound_reference_date: BoundReference) -> None: assert YearTransform().project("name", BoundEqualTo(term=bound_reference_date, literal=DateLiteral(1925))) == EqualTo( term="name", literal=5 ) -def test_projection_year_set_same_year(bound_reference_date: BoundReference[int]) -> None: +def test_projection_year_set_same_year(bound_reference_date: BoundReference) -> None: assert YearTransform().project( "name", BoundIn(term=bound_reference_date, literals={DateLiteral(1925), DateLiteral(1926)}) ) == EqualTo(term="name", literal=5) -def test_projection_year_set_in(bound_reference_date: BoundReference[int]) -> None: +def test_projection_year_set_in(bound_reference_date: BoundReference) -> None: assert YearTransform().project( "name", BoundIn(term=bound_reference_date, literals={DateLiteral(1925), DateLiteral(2925)}) ) == In(term="name", literals={8, 5}) -def test_projection_year_set_not_in(bound_reference_date: BoundReference[int]) -> None: +def test_projection_year_set_not_in(bound_reference_date: BoundReference) -> None: assert ( YearTransform().project("name", BoundNotIn(term=bound_reference_date, literals={DateLiteral(1925), DateLiteral(2925)})) is None ) -def test_projection_month_unary(bound_reference_date: BoundReference[int]) -> None: +def test_projection_month_unary(bound_reference_date: BoundReference) -> None: assert MonthTransform().project("name", BoundNotNull(term=bound_reference_date)) == NotNull(term="name") -def test_projection_month_literal(bound_reference_date: BoundReference[int]) -> None: +def test_projection_month_literal(bound_reference_date: BoundReference) -> None: assert MonthTransform().project("name", BoundEqualTo(term=bound_reference_date, literal=DateLiteral(1925))) == EqualTo( term="name", literal=63 ) -def test_projection_month_set_same_month(bound_reference_date: BoundReference[int]) -> None: +def test_projection_month_set_same_month(bound_reference_date: BoundReference) -> None: assert MonthTransform().project( "name", BoundIn(term=bound_reference_date, literals={DateLiteral(1925), DateLiteral(1926)}) ) == EqualTo(term="name", literal=63) -def test_projection_month_set_in(bound_reference_date: BoundReference[int]) -> None: +def test_projection_month_set_in(bound_reference_date: BoundReference) -> None: assert MonthTransform().project( "name", BoundIn(term=bound_reference_date, literals={DateLiteral(1925), DateLiteral(2925)}) ) == In(term="name", literals={96, 63}) -def test_projection_day_month_not_in(bound_reference_date: BoundReference[int]) -> None: +def test_projection_day_month_not_in(bound_reference_date: BoundReference) -> None: assert ( MonthTransform().project("name", BoundNotIn(term=bound_reference_date, literals={DateLiteral(1925), DateLiteral(2925)})) is None ) -def test_projection_day_unary(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_day_unary(bound_reference_timestamp: BoundReference) -> None: assert DayTransform().project("name", BoundNotNull(term=bound_reference_timestamp)) == NotNull(term="name") -def test_projection_day_literal(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_day_literal(bound_reference_timestamp: BoundReference) -> None: assert DayTransform().project( "name", BoundEqualTo(term=bound_reference_timestamp, literal=TimestampLiteral(1667696874000)) ) == EqualTo(term="name", literal=19) -def test_projection_day_set_same_day(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_day_set_same_day(bound_reference_timestamp: BoundReference) -> None: assert DayTransform().project( "name", BoundIn(term=bound_reference_timestamp, literals={TimestampLiteral(1667696874001), TimestampLiteral(1667696874000)}), ) == EqualTo(term="name", literal=19) -def test_projection_day_set_in(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_day_set_in(bound_reference_timestamp: BoundReference) -> None: assert DayTransform().project( "name", BoundIn(term=bound_reference_timestamp, literals={TimestampLiteral(1667696874001), TimestampLiteral(1567696874000)}), ) == In(term="name", literals={18, 19}) -def test_projection_day_set_not_in(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_day_set_not_in(bound_reference_timestamp: BoundReference) -> None: assert ( DayTransform().project( "name", @@ -799,7 +799,7 @@ def test_projection_day_set_not_in(bound_reference_timestamp: BoundReference[int ) -def test_projection_day_human(bound_reference_date: BoundReference[int]) -> None: +def test_projection_day_human(bound_reference_date: BoundReference) -> None: date_literal = DateLiteral(17532) assert DayTransform().project("dt", BoundEqualTo(term=bound_reference_date, literal=date_literal)) == EqualTo( term="dt", literal=17532 @@ -822,7 +822,7 @@ def test_projection_day_human(bound_reference_date: BoundReference[int]) -> None ) # >= 2018, 1, 2 -def test_projection_hour_unary(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_hour_unary(bound_reference_timestamp: BoundReference) -> None: assert HourTransform().project("name", BoundNotNull(term=bound_reference_timestamp)) == NotNull(term="name") @@ -830,13 +830,13 @@ def test_projection_hour_unary(bound_reference_timestamp: BoundReference[int]) - HOUR_IN_MICROSECONDS = 60 * 60 * 1000 * 1000 -def test_projection_hour_literal(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_hour_literal(bound_reference_timestamp: BoundReference) -> None: assert HourTransform().project( "name", BoundEqualTo(term=bound_reference_timestamp, literal=TimestampLiteral(TIMESTAMP_EXAMPLE)) ) == EqualTo(term="name", literal=463249) -def test_projection_hour_set_same_hour(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_hour_set_same_hour(bound_reference_timestamp: BoundReference) -> None: assert HourTransform().project( "name", BoundIn( @@ -846,7 +846,7 @@ def test_projection_hour_set_same_hour(bound_reference_timestamp: BoundReference ) == EqualTo(term="name", literal=463249) -def test_projection_hour_set_in(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_hour_set_in(bound_reference_timestamp: BoundReference) -> None: assert HourTransform().project( "name", BoundIn( @@ -856,7 +856,7 @@ def test_projection_hour_set_in(bound_reference_timestamp: BoundReference[int]) ) == In(term="name", literals={463249, 463250}) -def test_projection_hour_set_not_in(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_hour_set_not_in(bound_reference_timestamp: BoundReference) -> None: assert ( HourTransform().project( "name", @@ -869,17 +869,17 @@ def test_projection_hour_set_not_in(bound_reference_timestamp: BoundReference[in ) -def test_projection_identity_unary(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_identity_unary(bound_reference_timestamp: BoundReference) -> None: assert IdentityTransform().project("name", BoundNotNull(term=bound_reference_timestamp)) == NotNull(term="name") -def test_projection_identity_literal(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_identity_literal(bound_reference_timestamp: BoundReference) -> None: assert IdentityTransform().project( "name", BoundEqualTo(term=bound_reference_timestamp, literal=TimestampLiteral(TIMESTAMP_EXAMPLE)) ) == EqualTo(term="name", literal=TimestampLiteral(TIMESTAMP_EXAMPLE)) -def test_projection_identity_set_in(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_identity_set_in(bound_reference_timestamp: BoundReference) -> None: assert IdentityTransform().project( "name", BoundIn( @@ -892,7 +892,7 @@ def test_projection_identity_set_in(bound_reference_timestamp: BoundReference[in ) -def test_projection_identity_set_not_in(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_identity_set_not_in(bound_reference_timestamp: BoundReference) -> None: assert IdentityTransform().project( "name", BoundNotIn( @@ -905,78 +905,78 @@ def test_projection_identity_set_not_in(bound_reference_timestamp: BoundReferenc ) -def test_projection_truncate_string_unary(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_unary(bound_reference_str: BoundReference) -> None: assert TruncateTransform(2).project("name", BoundNotNull(term=bound_reference_str)) == NotNull(term="name") -def test_projection_truncate_string_literal_eq(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_literal_eq(bound_reference_str: BoundReference) -> None: assert TruncateTransform(2).project("name", BoundEqualTo(term=bound_reference_str, literal=literal("data"))) == EqualTo( term="name", literal=literal("da") ) -def test_projection_truncate_string_literal_gt(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_literal_gt(bound_reference_str: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundGreaterThan(term=bound_reference_str, literal=literal("data")) ) == GreaterThanOrEqual(term="name", literal=literal("da")) -def test_projection_truncate_string_literal_gte(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_literal_gte(bound_reference_str: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundGreaterThanOrEqual(term=bound_reference_str, literal=literal("data")) ) == GreaterThanOrEqual(term="name", literal=literal("da")) -def test_projection_truncate_string_literal_lt(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_literal_lt(bound_reference_str: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundLessThan(term=bound_reference_str, literal=literal("data")) ) == LessThanOrEqual(term="name", literal=literal("da")) -def test_projection_truncate_string_literal_lte(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_literal_lte(bound_reference_str: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundLessThanOrEqual(term=bound_reference_str, literal=literal("data")) ) == LessThanOrEqual(term="name", literal=literal("da")) -def test_projection_truncate_string_set_same_result(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_set_same_result(bound_reference_str: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundIn(term=bound_reference_str, literals={literal("hello"), literal("helloworld")}) ) == EqualTo(term="name", literal=literal("he")) -def test_projection_truncate_string_set_in(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_set_in(bound_reference_str: BoundReference) -> None: assert TruncateTransform(3).project( "name", BoundIn(term=bound_reference_str, literals={literal("hello"), literal("world")}) ) == In(term="name", literals={literal("hel"), literal("wor")}) # codespell:ignore hel -def test_projection_truncate_string_set_not_in(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_set_not_in(bound_reference_str: BoundReference) -> None: assert ( TruncateTransform(3).project("name", BoundNotIn(term=bound_reference_str, literals={literal("hello"), literal("world")})) is None ) -def test_projection_truncate_decimal_literal_eq(bound_reference_decimal: BoundReference[Decimal]) -> None: +def test_projection_truncate_decimal_literal_eq(bound_reference_decimal: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundEqualTo(term=bound_reference_decimal, literal=DecimalLiteral(Decimal(19.25))) ) == EqualTo(term="name", literal=Decimal("19.24")) -def test_projection_truncate_decimal_literal_gt(bound_reference_decimal: BoundReference[Decimal]) -> None: +def test_projection_truncate_decimal_literal_gt(bound_reference_decimal: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundGreaterThan(term=bound_reference_decimal, literal=DecimalLiteral(Decimal(19.25))) ) == GreaterThanOrEqual(term="name", literal=Decimal("19.26")) -def test_projection_truncate_decimal_literal_gte(bound_reference_decimal: BoundReference[Decimal]) -> None: +def test_projection_truncate_decimal_literal_gte(bound_reference_decimal: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundGreaterThanOrEqual(term=bound_reference_decimal, literal=DecimalLiteral(Decimal(19.25))) ) == GreaterThanOrEqual(term="name", literal=Decimal("19.24")) -def test_projection_truncate_decimal_in(bound_reference_decimal: BoundReference[Decimal]) -> None: +def test_projection_truncate_decimal_in(bound_reference_decimal: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundIn(term=bound_reference_decimal, literals={literal(Decimal(19.25)), literal(Decimal(18.15))}) ) == In( @@ -988,25 +988,25 @@ def test_projection_truncate_decimal_in(bound_reference_decimal: BoundReference[ ) -def test_projection_truncate_long_literal_eq(bound_reference_decimal: BoundReference[Decimal]) -> None: +def test_projection_truncate_long_literal_eq(bound_reference_decimal: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundEqualTo(term=bound_reference_decimal, literal=DecimalLiteral(Decimal(19.25))) ) == EqualTo(term="name", literal=Decimal("19.24")) -def test_projection_truncate_long_literal_gt(bound_reference_decimal: BoundReference[Decimal]) -> None: +def test_projection_truncate_long_literal_gt(bound_reference_decimal: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundGreaterThan(term=bound_reference_decimal, literal=DecimalLiteral(Decimal(19.25))) ) == GreaterThanOrEqual(term="name", literal=Decimal("19.26")) -def test_projection_truncate_long_literal_gte(bound_reference_decimal: BoundReference[Decimal]) -> None: +def test_projection_truncate_long_literal_gte(bound_reference_decimal: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundGreaterThanOrEqual(term=bound_reference_decimal, literal=DecimalLiteral(Decimal(19.25))) ) == GreaterThanOrEqual(term="name", literal=Decimal("19.24")) -def test_projection_truncate_long_in(bound_reference_decimal: BoundReference[Decimal]) -> None: +def test_projection_truncate_long_in(bound_reference_decimal: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundIn(term=bound_reference_decimal, literals={DecimalLiteral(Decimal(19.25)), DecimalLiteral(Decimal(18.15))}) ) == In( @@ -1018,19 +1018,19 @@ def test_projection_truncate_long_in(bound_reference_decimal: BoundReference[Dec ) -def test_projection_truncate_string_starts_with(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_starts_with(bound_reference_str: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundStartsWith(term=bound_reference_str, literal=literal("hello")) ) == StartsWith(term="name", literal=literal("he")) -def test_projection_truncate_string_not_starts_with(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_not_starts_with(bound_reference_str: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundNotStartsWith(term=bound_reference_str, literal=literal("hello")) ) == NotStartsWith(term="name", literal=literal("he")) -def _test_projection(lhs: UnboundPredicate[L] | None, rhs: UnboundPredicate[L] | None) -> None: +def _test_projection(lhs: UnboundPredicate | None, rhs: UnboundPredicate | None) -> None: assert type(lhs) is type(lhs), f"Different classes: {type(lhs)} != {type(rhs)}" if lhs is None and rhs is None: # Both null @@ -1068,7 +1068,7 @@ def _assert_projection_strict( assert actual_human_str == expected_human_str -def test_month_projection_strict_epoch(bound_reference_date: BoundReference[int]) -> None: +def test_month_projection_strict_epoch(bound_reference_date: BoundReference) -> None: date = literal("1970-01-01").to(DateType()) transform = MonthTransform() _assert_projection_strict(BoundLessThan(term=bound_reference_date, literal=date), transform, LessThan, "1970-01") @@ -1091,7 +1091,7 @@ def test_month_projection_strict_epoch(bound_reference_date: BoundReference[int] ) -def test_month_projection_strict_lower_bound(bound_reference_date: BoundReference[int]) -> None: +def test_month_projection_strict_lower_bound(bound_reference_date: BoundReference) -> None: date = literal("2017-01-01").to(DateType()) # == 564 months since epoch transform = MonthTransform() @@ -1115,7 +1115,7 @@ def test_month_projection_strict_lower_bound(bound_reference_date: BoundReferenc ) -def test_negative_month_projection_strict_lower_bound(bound_reference_date: BoundReference[int]) -> None: +def test_negative_month_projection_strict_lower_bound(bound_reference_date: BoundReference) -> None: date = literal("1969-01-01").to(DateType()) # == 564 months since epoch transform = MonthTransform() @@ -1140,7 +1140,7 @@ def test_negative_month_projection_strict_lower_bound(bound_reference_date: Boun ) -def test_month_projection_strict_upper_bound(bound_reference_date: BoundReference[int]) -> None: +def test_month_projection_strict_upper_bound(bound_reference_date: BoundReference) -> None: date = literal("2017-12-31").to(DateType()) # == 564 months since epoch transform = MonthTransform() @@ -1164,7 +1164,7 @@ def test_month_projection_strict_upper_bound(bound_reference_date: BoundReferenc ) -def test_negative_month_projection_strict_upper_bound(bound_reference_date: BoundReference[int]) -> None: +def test_negative_month_projection_strict_upper_bound(bound_reference_date: BoundReference) -> None: date = literal("1969-12-31").to(DateType()) # == 564 months since epoch transform = MonthTransform() @@ -1188,7 +1188,7 @@ def test_negative_month_projection_strict_upper_bound(bound_reference_date: Boun ) -def test_day_strict(bound_reference_date: BoundReference[int]) -> None: +def test_day_strict(bound_reference_date: BoundReference) -> None: date = literal("2017-01-01").to(DateType()) transform = DayTransform() @@ -1216,7 +1216,7 @@ def test_day_strict(bound_reference_date: BoundReference[int]) -> None: ) -def test_day_negative_strict(bound_reference_date: BoundReference[int]) -> None: +def test_day_negative_strict(bound_reference_date: BoundReference) -> None: date = literal("1969-12-30").to(DateType()) transform = DayTransform() @@ -1244,7 +1244,7 @@ def test_day_negative_strict(bound_reference_date: BoundReference[int]) -> None: ) -def test_year_strict_lower_bound(bound_reference_date: BoundReference[int]) -> None: +def test_year_strict_lower_bound(bound_reference_date: BoundReference) -> None: date = literal("2017-01-01").to(DateType()) transform = YearTransform() @@ -1265,7 +1265,7 @@ def test_year_strict_lower_bound(bound_reference_date: BoundReference[int]) -> N ) -def test_negative_year_strict_lower_bound(bound_reference_date: BoundReference[int]) -> None: +def test_negative_year_strict_lower_bound(bound_reference_date: BoundReference) -> None: date = literal("1970-01-01").to(DateType()) transform = YearTransform() @@ -1289,7 +1289,7 @@ def test_negative_year_strict_lower_bound(bound_reference_date: BoundReference[i ) -def test_year_strict_upper_bound(bound_reference_date: BoundReference[int]) -> None: +def test_year_strict_upper_bound(bound_reference_date: BoundReference) -> None: date = literal("2017-12-31").to(DateType()) transform = YearTransform() @@ -1313,7 +1313,7 @@ def test_year_strict_upper_bound(bound_reference_date: BoundReference[int]) -> N ) -def test_negative_year_strict_upper_bound(bound_reference_date: BoundReference[int]) -> None: +def test_negative_year_strict_upper_bound(bound_reference_date: BoundReference) -> None: date = literal("2017-12-31").to(DateType()) transform = YearTransform() @@ -1330,7 +1330,7 @@ def test_negative_year_strict_upper_bound(bound_reference_date: BoundReference[i _assert_projection_strict(BoundIn(term=bound_reference_date, literals={date, another_date}), transform, NotIn) -def test_strict_bucket_integer(bound_reference_int: BoundReference[int]) -> None: +def test_strict_bucket_integer(bound_reference_int: BoundReference) -> None: value = literal(100).to(IntegerType()) transform = BucketTransform(num_buckets=10) @@ -1346,7 +1346,7 @@ def test_strict_bucket_integer(bound_reference_int: BoundReference[int]) -> None _assert_projection_strict(BoundIn(term=bound_reference_int, literals=literals), transform, AlwaysFalse) -def test_strict_bucket_long(bound_reference_long: BoundReference[int]) -> None: +def test_strict_bucket_long(bound_reference_long: BoundReference) -> None: value = literal(100).to(LongType()) transform = BucketTransform(num_buckets=10) @@ -1362,7 +1362,7 @@ def test_strict_bucket_long(bound_reference_long: BoundReference[int]) -> None: _assert_projection_strict(BoundIn(term=bound_reference_long, literals=literals), transform, AlwaysFalse) -def test_strict_bucket_decimal(bound_reference_decimal: BoundReference[int]) -> None: +def test_strict_bucket_decimal(bound_reference_decimal: BoundReference) -> None: dec = DecimalType(9, 2) value = literal("100.00").to(dec) transform = BucketTransform(num_buckets=10) @@ -1379,7 +1379,7 @@ def test_strict_bucket_decimal(bound_reference_decimal: BoundReference[int]) -> _assert_projection_strict(BoundIn(term=bound_reference_decimal, literals=literals), transform, AlwaysFalse) -def test_strict_bucket_string(bound_reference_str: BoundReference[int]) -> None: +def test_strict_bucket_string(bound_reference_str: BoundReference) -> None: value = literal("abcdefg").to(StringType()) transform = BucketTransform(num_buckets=10) @@ -1395,7 +1395,7 @@ def test_strict_bucket_string(bound_reference_str: BoundReference[int]) -> None: _assert_projection_strict(BoundIn(term=bound_reference_str, literals={value, other_value}), transform, AlwaysFalse) -def test_strict_bucket_bytes(bound_reference_binary: BoundReference[int]) -> None: +def test_strict_bucket_bytes(bound_reference_binary: BoundReference) -> None: value = literal(str.encode("abcdefg")).to(BinaryType()) transform = BucketTransform(num_buckets=10) @@ -1411,7 +1411,7 @@ def test_strict_bucket_bytes(bound_reference_binary: BoundReference[int]) -> Non _assert_projection_strict(BoundIn(term=bound_reference_binary, literals={value, other_value}), transform, AlwaysFalse) -def test_strict_bucket_uuid(bound_reference_uuid: BoundReference[int]) -> None: +def test_strict_bucket_uuid(bound_reference_uuid: BoundReference) -> None: value = literal("00000000-0000-007b-0000-0000000001c8").to(UUIDType()) transform = BucketTransform(num_buckets=10) @@ -1427,7 +1427,7 @@ def test_strict_bucket_uuid(bound_reference_uuid: BoundReference[int]) -> None: _assert_projection_strict(BoundIn(term=bound_reference_uuid, literals={value, other_value}), transform, AlwaysFalse) -def test_strict_identity_projection(bound_reference_long: BoundReference[int]) -> None: +def test_strict_identity_projection(bound_reference_long: BoundReference) -> None: transform: Transform[Any, Any] = IdentityTransform() predicates = [ BoundNotNull(term=bound_reference_long), @@ -1458,7 +1458,7 @@ def test_strict_identity_projection(bound_reference_long: BoundReference[int]) - ) -def test_truncate_strict_integer_lower_bound(bound_reference_int: BoundReference[int]) -> None: +def test_truncate_strict_integer_lower_bound(bound_reference_int: BoundReference) -> None: value = literal(100).to(IntegerType()) transform = TruncateTransform(10) @@ -1476,7 +1476,7 @@ def test_truncate_strict_integer_lower_bound(bound_reference_int: BoundReference _assert_projection_strict(BoundIn(term=bound_reference_int, literals={value_dec, value, value_inc}), transform, NotIn) -def test_truncate_strict_integer_upper_bound(bound_reference_int: BoundReference[int]) -> None: +def test_truncate_strict_integer_upper_bound(bound_reference_int: BoundReference) -> None: value = literal(99).to(IntegerType()) transform = TruncateTransform(10) @@ -1492,7 +1492,7 @@ def test_truncate_strict_integer_upper_bound(bound_reference_int: BoundReference _assert_projection_strict(BoundIn(term=bound_reference_int, literals=literals), transform, NotIn) -def test_truncate_strict_long_lower_bound(bound_reference_long: BoundReference[int]) -> None: +def test_truncate_strict_long_lower_bound(bound_reference_long: BoundReference) -> None: value = literal(100).to(IntegerType()) transform = TruncateTransform(10) @@ -1510,7 +1510,7 @@ def test_truncate_strict_long_lower_bound(bound_reference_long: BoundReference[i _assert_projection_strict(BoundIn(term=bound_reference_long, literals={value_dec, value, value_inc}), transform, NotIn) -def test_truncate_strict_long_upper_bound(bound_reference_long: BoundReference[int]) -> None: +def test_truncate_strict_long_upper_bound(bound_reference_long: BoundReference) -> None: value = literal(99).to(IntegerType()) transform = TruncateTransform(10) @@ -1528,7 +1528,7 @@ def test_truncate_strict_long_upper_bound(bound_reference_long: BoundReference[i _assert_projection_strict(BoundIn(term=bound_reference_long, literals={value_dec, value, value_inc}), transform, NotIn) -def test_truncate_strict_decimal_lower_bound(bound_reference_decimal: BoundReference[Decimal]) -> None: +def test_truncate_strict_decimal_lower_bound(bound_reference_decimal: BoundReference) -> None: dec = DecimalType(9, 2) value = literal("100.00").to(dec) transform = TruncateTransform(10) @@ -1549,7 +1549,7 @@ def test_truncate_strict_decimal_lower_bound(bound_reference_decimal: BoundRefer _assert_projection_strict(BoundIn(term=bound_reference_decimal, literals=literals), transform, NotIn) -def test_truncate_strict_decimal_upper_bound(bound_reference_decimal: BoundReference[Decimal]) -> None: +def test_truncate_strict_decimal_upper_bound(bound_reference_decimal: BoundReference) -> None: dec = DecimalType(9, 2) value = literal("99.99").to(dec) transform = TruncateTransform(10) @@ -1570,7 +1570,7 @@ def test_truncate_strict_decimal_upper_bound(bound_reference_decimal: BoundRefer _assert_projection_strict(BoundIn(term=bound_reference_decimal, literals=literals), transform, NotIn) -def test_string_strict(bound_reference_str: BoundReference[str]) -> None: +def test_string_strict(bound_reference_str: BoundReference) -> None: value = literal("abcdefg").to(StringType()) transform: Transform[Any, Any] = TruncateTransform(width=5) @@ -1585,7 +1585,7 @@ def test_string_strict(bound_reference_str: BoundReference[str]) -> None: _assert_projection_strict(BoundIn(term=bound_reference_str, literals={value, other_value}), transform, NotIn) -def test_strict_binary(bound_reference_binary: BoundReference[str]) -> None: +def test_strict_binary(bound_reference_binary: BoundReference) -> None: value = literal(b"abcdefg").to(BinaryType()) transform: Transform[Any, Any] = TruncateTransform(width=5) From 9c50eeed69e2c1b9232b877fae63de7c7431b9c8 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Mon, 17 Nov 2025 18:12:45 +0100 Subject: [PATCH 2/3] A few more --- pyiceberg/expressions/__init__.py | 80 +++++++++++++-------------- tests/expressions/test_evaluator.py | 4 +- tests/expressions/test_expressions.py | 78 +++++++++++++------------- 3 files changed, 81 insertions(+), 81 deletions(-) diff --git a/pyiceberg/expressions/__init__.py b/pyiceberg/expressions/__init__.py index d866b80f22..190a094a84 100644 --- a/pyiceberg/expressions/__init__.py +++ b/pyiceberg/expressions/__init__.py @@ -20,7 +20,7 @@ import builtins from abc import ABC, abstractmethod from functools import cached_property -from typing import Any, Callable, Generic, Iterable, Sequence, cast +from typing import Any, Callable, Iterable, Sequence, cast from typing import Literal as TypingLiteral from pydantic import ConfigDict, Field @@ -489,7 +489,7 @@ def __getnewargs__(self) -> tuple[BoundTerm]: class BoundIsNull(BoundUnaryPredicate): - def __new__(cls, term: BoundTerm) -> BooleanExpression: # type: ignore # pylint: disable=W0221 + def __new__(cls, term: BoundTerm) -> BooleanExpression: # type: ignore[misc] # pylint: disable=W0221 if term.ref().field.required: return AlwaysFalse() return super().__new__(cls) @@ -504,7 +504,7 @@ def as_unbound(self) -> type[IsNull]: class BoundNotNull(BoundUnaryPredicate): - def __new__(cls, term: BoundTerm): # type: ignore # pylint: disable=W0221 + def __new__(cls, term: BoundTerm) -> BooleanExpression: # type: ignore[misc] # pylint: disable=W0221 if term.ref().field.required: return AlwaysTrue() return super().__new__(cls) @@ -543,7 +543,7 @@ def as_bound(self) -> builtins.type[BoundNotNull]: class BoundIsNaN(BoundUnaryPredicate): - def __new__(cls, term: BoundTerm) -> BooleanExpression: # type: ignore # pylint: disable=W0221 + def __new__(cls, term: BoundTerm) -> BooleanExpression: # type: ignore[misc] # pylint: disable=W0221 bound_type = term.ref().field.field_type if isinstance(bound_type, (FloatType, DoubleType)): return super().__new__(cls) @@ -559,7 +559,7 @@ def as_unbound(self) -> type[IsNaN]: class BoundNotNaN(BoundUnaryPredicate): - def __new__(cls, term: BoundTerm) -> BooleanExpression: # type: ignore # pylint: disable=W0221 + def __new__(cls, term: BoundTerm) -> BooleanExpression: # type: ignore[misc] # pylint: disable=W0221 bound_type = term.ref().field.field_type if isinstance(bound_type, (FloatType, DoubleType)): return super().__new__(cls) @@ -673,7 +673,7 @@ def as_unbound(self) -> type[SetPredicate]: ... class BoundIn(BoundSetPredicate): - def __new__(cls, term: BoundTerm, literals: set[LiteralValue]) -> BooleanExpression: # type: ignore # pylint: disable=W0221 + def __new__(cls, term: BoundTerm, literals: set[LiteralValue]) -> BooleanExpression: # type: ignore[misc] # pylint: disable=W0221 count = len(literals) if count == 0: return AlwaysFalse() @@ -696,7 +696,7 @@ def as_unbound(self) -> type[In]: class BoundNotIn(BoundSetPredicate): - def __new__( # type: ignore # pylint: disable=W0221 + def __new__( # type: ignore[misc] # pylint: disable=W0221 cls, term: BoundTerm, literals: set[LiteralValue], @@ -721,7 +721,7 @@ def as_unbound(self) -> type[NotIn]: class In(SetPredicate): type: TypingLiteral["in"] = Field(default="in", alias="type") - def __new__( # type: ignore # pylint: disable=W0221 + def __new__( # type: ignore[misc] # pylint: disable=W0221 cls, term: str | UnboundTerm, literals: Iterable[Any] | Iterable[LiteralValue] ) -> BooleanExpression: literals_set: set[LiteralValue] = _to_literal_set(literals) @@ -745,7 +745,7 @@ def as_bound(self) -> builtins.type[BoundIn]: class NotIn(SetPredicate, ABC): type: TypingLiteral["not-in"] = Field(default="not-in", alias="type") - def __new__( # type: ignore # pylint: disable=W0221 + def __new__( # type: ignore[misc] # pylint: disable=W0221 cls, term: str | UnboundTerm, literals: Iterable[Any] | Iterable[LiteralValue] ) -> BooleanExpression: literals_set: set[LiteralValue] = _to_literal_set(literals) @@ -766,17 +766,17 @@ def as_bound(self) -> builtins.type[BoundNotIn]: return BoundNotIn -class LiteralPredicate(IcebergBaseModel, UnboundPredicate, Generic[L], ABC): +class LiteralPredicate(IcebergBaseModel, UnboundPredicate, ABC): type: TypingLiteral["lt", "lt-eq", "gt", "gt-eq", "eq", "not-eq", "starts-with", "not-starts-with"] = Field(alias="type") term: UnboundTerm - value: Literal[L] = Field() + value: LiteralValue = Field() model_config = ConfigDict(populate_by_name=True, frozen=True, arbitrary_types_allowed=True) - def __init__(self, term: str | UnboundTerm, literal: L | Literal[L]): + def __init__(self, term: str | UnboundTerm, literal: Any | LiteralValue): super().__init__(term=_to_unbound_term(term), value=_to_literal(literal)) # type: ignore[call-arg] @property - def literal(self) -> Literal[L]: + def literal(self) -> LiteralValue: return self.value def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundLiteralPredicate: @@ -816,9 +816,9 @@ def as_bound(self) -> builtins.type[BoundLiteralPredicate]: ... class BoundLiteralPredicate(BoundPredicate, ABC): - literal: Literal[Any] + literal: LiteralValue - def __init__(self, term: BoundTerm, literal: Literal[Any]): # pylint: disable=W0621 + def __init__(self, term: BoundTerm, literal: LiteralValue): # pylint: disable=W0621 super().__init__(term) self.literal = literal # pylint: disable=W0621 @@ -834,7 +834,7 @@ def __repr__(self) -> str: @property @abstractmethod - def as_unbound(self) -> type[LiteralPredicate[Any]]: ... + def as_unbound(self) -> type[LiteralPredicate]: ... class BoundEqualTo(BoundLiteralPredicate): @@ -843,7 +843,7 @@ def __invert__(self) -> BoundNotEqualTo: return BoundNotEqualTo(self.term, self.literal) @property - def as_unbound(self) -> type[EqualTo[Any]]: + def as_unbound(self) -> type[EqualTo]: return EqualTo @@ -853,7 +853,7 @@ def __invert__(self) -> BoundEqualTo: return BoundEqualTo(self.term, self.literal) @property - def as_unbound(self) -> type[NotEqualTo[Any]]: + def as_unbound(self) -> type[NotEqualTo]: return NotEqualTo @@ -863,7 +863,7 @@ def __invert__(self) -> BoundLessThan: return BoundLessThan(self.term, self.literal) @property - def as_unbound(self) -> type[GreaterThanOrEqual[Any]]: + def as_unbound(self) -> type[GreaterThanOrEqual]: return GreaterThanOrEqual @@ -873,7 +873,7 @@ def __invert__(self) -> BoundLessThanOrEqual: return BoundLessThanOrEqual(self.term, self.literal) @property - def as_unbound(self) -> type[GreaterThan[Any]]: + def as_unbound(self) -> type[GreaterThan]: return GreaterThan @@ -883,7 +883,7 @@ def __invert__(self) -> BoundGreaterThanOrEqual: return BoundGreaterThanOrEqual(self.term, self.literal) @property - def as_unbound(self) -> type[LessThan[Any]]: + def as_unbound(self) -> type[LessThan]: return LessThan @@ -893,7 +893,7 @@ def __invert__(self) -> BoundGreaterThan: return BoundGreaterThan(self.term, self.literal) @property - def as_unbound(self) -> type[LessThanOrEqual[Any]]: + def as_unbound(self) -> type[LessThanOrEqual]: return LessThanOrEqual @@ -903,7 +903,7 @@ def __invert__(self) -> BoundNotStartsWith: return BoundNotStartsWith(self.term, self.literal) @property - def as_unbound(self) -> type[StartsWith[Any]]: + def as_unbound(self) -> type[StartsWith]: return StartsWith @@ -913,14 +913,14 @@ def __invert__(self) -> BoundStartsWith: return BoundStartsWith(self.term, self.literal) @property - def as_unbound(self) -> type[NotStartsWith[Any]]: + def as_unbound(self) -> type[NotStartsWith]: return NotStartsWith -class EqualTo(LiteralPredicate[L]): +class EqualTo(LiteralPredicate): type: TypingLiteral["eq"] = Field(default="eq", alias="type") - def __invert__(self) -> NotEqualTo[Any]: + def __invert__(self) -> NotEqualTo: """Transform the Expression into its negated version.""" return NotEqualTo(self.term, self.literal) @@ -929,10 +929,10 @@ def as_bound(self) -> builtins.type[BoundEqualTo]: return BoundEqualTo -class NotEqualTo(LiteralPredicate[L]): +class NotEqualTo(LiteralPredicate): type: TypingLiteral["not-eq"] = Field(default="not-eq", alias="type") - def __invert__(self) -> EqualTo[Any]: + def __invert__(self) -> EqualTo: """Transform the Expression into its negated version.""" return EqualTo(self.term, self.literal) @@ -941,10 +941,10 @@ def as_bound(self) -> builtins.type[BoundNotEqualTo]: return BoundNotEqualTo -class LessThan(LiteralPredicate[L]): +class LessThan(LiteralPredicate): type: TypingLiteral["lt"] = Field(default="lt", alias="type") - def __invert__(self) -> GreaterThanOrEqual[Any]: + def __invert__(self) -> GreaterThanOrEqual: """Transform the Expression into its negated version.""" return GreaterThanOrEqual(self.term, self.literal) @@ -953,10 +953,10 @@ def as_bound(self) -> builtins.type[BoundLessThan]: return BoundLessThan -class GreaterThanOrEqual(LiteralPredicate[L]): +class GreaterThanOrEqual(LiteralPredicate): type: TypingLiteral["gt-eq"] = Field(default="gt-eq", alias="type") - def __invert__(self) -> LessThan[Any]: + def __invert__(self) -> LessThan: """Transform the Expression into its negated version.""" return LessThan(self.term, self.literal) @@ -965,10 +965,10 @@ def as_bound(self) -> builtins.type[BoundGreaterThanOrEqual]: return BoundGreaterThanOrEqual -class GreaterThan(LiteralPredicate[L]): +class GreaterThan(LiteralPredicate): type: TypingLiteral["gt"] = Field(default="gt", alias="type") - def __invert__(self) -> LessThanOrEqual[Any]: + def __invert__(self) -> LessThanOrEqual: """Transform the Expression into its negated version.""" return LessThanOrEqual(self.term, self.literal) @@ -977,10 +977,10 @@ def as_bound(self) -> builtins.type[BoundGreaterThan]: return BoundGreaterThan -class LessThanOrEqual(LiteralPredicate[L]): +class LessThanOrEqual(LiteralPredicate): type: TypingLiteral["lt-eq"] = Field(default="lt-eq", alias="type") - def __invert__(self) -> GreaterThan[Any]: + def __invert__(self) -> GreaterThan: """Transform the Expression into its negated version.""" return GreaterThan(self.term, self.literal) @@ -989,10 +989,10 @@ def as_bound(self) -> builtins.type[BoundLessThanOrEqual]: return BoundLessThanOrEqual -class StartsWith(LiteralPredicate[L]): +class StartsWith(LiteralPredicate): type: TypingLiteral["starts-with"] = Field(default="starts-with", alias="type") - def __invert__(self) -> NotStartsWith[Any]: + def __invert__(self) -> NotStartsWith: """Transform the Expression into its negated version.""" return NotStartsWith(self.term, self.literal) @@ -1001,10 +1001,10 @@ def as_bound(self) -> builtins.type[BoundStartsWith]: return BoundStartsWith -class NotStartsWith(LiteralPredicate[L]): +class NotStartsWith(LiteralPredicate): type: TypingLiteral["not-starts-with"] = Field(default="not-starts-with", alias="type") - def __invert__(self) -> StartsWith[Any]: + def __invert__(self) -> StartsWith: """Transform the Expression into its negated version.""" return StartsWith(self.term, self.literal) diff --git a/tests/expressions/test_evaluator.py b/tests/expressions/test_evaluator.py index 07888dd41e..5be3e92be8 100644 --- a/tests/expressions/test_evaluator.py +++ b/tests/expressions/test_evaluator.py @@ -685,7 +685,7 @@ def data_file_nan() -> DataFile: def test_inclusive_metrics_evaluator_less_than_and_less_than_equal(schema_data_file_nan: Schema, data_file_nan: DataFile) -> None: - operators: tuple[type[LiteralPredicate[Any]], ...] = (LessThan, LessThanOrEqual) + operators: tuple[type[LiteralPredicate], ...] = (LessThan, LessThanOrEqual) for operator in operators: should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) assert not should_read, "Should not match: all nan column doesn't contain number" @@ -714,7 +714,7 @@ def test_inclusive_metrics_evaluator_less_than_and_less_than_equal(schema_data_f def test_inclusive_metrics_evaluator_greater_than_and_greater_than_equal( schema_data_file_nan: Schema, data_file_nan: DataFile ) -> None: - operators: tuple[type[LiteralPredicate[Any]], ...] = (GreaterThan, GreaterThanOrEqual) + operators: tuple[type[LiteralPredicate], ...] = (GreaterThan, GreaterThanOrEqual) for operator in operators: should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) assert not should_read, "Should not match: all nan column doesn't contain number" diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index 8a62369775..1fbe8d7a6c 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -64,7 +64,7 @@ from pyiceberg.expressions.literals import Literal, literal from pyiceberg.expressions.visitors import _from_byte_buffer from pyiceberg.schema import Accessor, Schema -from pyiceberg.typedef import L, Record +from pyiceberg.typedef import Record from pyiceberg.types import ( DecimalType, DoubleType, @@ -1091,37 +1091,37 @@ def below_int_min() -> Literal[int]: def test_above_int_bounds_equal_to(int_schema: Schema, above_int_max: Literal[int], below_int_min: Literal[int]) -> None: - assert EqualTo[int]("a", above_int_max).bind(int_schema) is AlwaysFalse() - assert EqualTo[int]("a", below_int_min).bind(int_schema) is AlwaysFalse() + assert EqualTo("a", above_int_max).bind(int_schema) is AlwaysFalse() + assert EqualTo("a", below_int_min).bind(int_schema) is AlwaysFalse() def test_above_int_bounds_not_equal_to(int_schema: Schema, above_int_max: Literal[int], below_int_min: Literal[int]) -> None: - assert NotEqualTo[int]("a", above_int_max).bind(int_schema) is AlwaysTrue() - assert NotEqualTo[int]("a", below_int_min).bind(int_schema) is AlwaysTrue() + assert NotEqualTo("a", above_int_max).bind(int_schema) is AlwaysTrue() + assert NotEqualTo("a", below_int_min).bind(int_schema) is AlwaysTrue() def test_above_int_bounds_less_than(int_schema: Schema, above_int_max: Literal[int], below_int_min: Literal[int]) -> None: - assert LessThan[int]("a", above_int_max).bind(int_schema) is AlwaysTrue() - assert LessThan[int]("a", below_int_min).bind(int_schema) is AlwaysFalse() + assert LessThan("a", above_int_max).bind(int_schema) is AlwaysTrue() + assert LessThan("a", below_int_min).bind(int_schema) is AlwaysFalse() def test_above_int_bounds_less_than_or_equal( int_schema: Schema, above_int_max: Literal[int], below_int_min: Literal[int] ) -> None: - assert LessThanOrEqual[int]("a", above_int_max).bind(int_schema) is AlwaysTrue() - assert LessThanOrEqual[int]("a", below_int_min).bind(int_schema) is AlwaysFalse() + assert LessThanOrEqual("a", above_int_max).bind(int_schema) is AlwaysTrue() + assert LessThanOrEqual("a", below_int_min).bind(int_schema) is AlwaysFalse() def test_above_int_bounds_greater_than(int_schema: Schema, above_int_max: Literal[int], below_int_min: Literal[int]) -> None: - assert GreaterThan[int]("a", above_int_max).bind(int_schema) is AlwaysFalse() - assert GreaterThan[int]("a", below_int_min).bind(int_schema) is AlwaysTrue() + assert GreaterThan("a", above_int_max).bind(int_schema) is AlwaysFalse() + assert GreaterThan("a", below_int_min).bind(int_schema) is AlwaysTrue() def test_above_int_bounds_greater_than_or_equal( int_schema: Schema, above_int_max: Literal[int], below_int_min: Literal[int] ) -> None: - assert GreaterThanOrEqual[int]("a", above_int_max).bind(int_schema) is AlwaysFalse() - assert GreaterThanOrEqual[int]("a", below_int_min).bind(int_schema) is AlwaysTrue() + assert GreaterThanOrEqual("a", above_int_max).bind(int_schema) is AlwaysFalse() + assert GreaterThanOrEqual("a", below_int_min).bind(int_schema) is AlwaysTrue() @pytest.fixture @@ -1142,43 +1142,43 @@ def below_float_min() -> Literal[float]: def test_above_float_bounds_equal_to( float_schema: Schema, above_float_max: Literal[float], below_float_min: Literal[float] ) -> None: - assert EqualTo[float]("a", above_float_max).bind(float_schema) is AlwaysFalse() - assert EqualTo[float]("a", below_float_min).bind(float_schema) is AlwaysFalse() + assert EqualTo("a", above_float_max).bind(float_schema) is AlwaysFalse() + assert EqualTo("a", below_float_min).bind(float_schema) is AlwaysFalse() def test_above_float_bounds_not_equal_to( float_schema: Schema, above_float_max: Literal[float], below_float_min: Literal[float] ) -> None: - assert NotEqualTo[float]("a", above_float_max).bind(float_schema) is AlwaysTrue() - assert NotEqualTo[float]("a", below_float_min).bind(float_schema) is AlwaysTrue() + assert NotEqualTo("a", above_float_max).bind(float_schema) is AlwaysTrue() + assert NotEqualTo("a", below_float_min).bind(float_schema) is AlwaysTrue() def test_above_float_bounds_less_than( float_schema: Schema, above_float_max: Literal[float], below_float_min: Literal[float] ) -> None: - assert LessThan[float]("a", above_float_max).bind(float_schema) is AlwaysTrue() - assert LessThan[float]("a", below_float_min).bind(float_schema) is AlwaysFalse() + assert LessThan("a", above_float_max).bind(float_schema) is AlwaysTrue() + assert LessThan("a", below_float_min).bind(float_schema) is AlwaysFalse() def test_above_float_bounds_less_than_or_equal( float_schema: Schema, above_float_max: Literal[float], below_float_min: Literal[float] ) -> None: - assert LessThanOrEqual[float]("a", above_float_max).bind(float_schema) is AlwaysTrue() - assert LessThanOrEqual[float]("a", below_float_min).bind(float_schema) is AlwaysFalse() + assert LessThanOrEqual("a", above_float_max).bind(float_schema) is AlwaysTrue() + assert LessThanOrEqual("a", below_float_min).bind(float_schema) is AlwaysFalse() def test_above_float_bounds_greater_than( float_schema: Schema, above_float_max: Literal[float], below_float_min: Literal[float] ) -> None: - assert GreaterThan[float]("a", above_float_max).bind(float_schema) is AlwaysFalse() - assert GreaterThan[float]("a", below_float_min).bind(float_schema) is AlwaysTrue() + assert GreaterThan("a", above_float_max).bind(float_schema) is AlwaysFalse() + assert GreaterThan("a", below_float_min).bind(float_schema) is AlwaysTrue() def test_above_float_bounds_greater_than_or_equal( float_schema: Schema, above_float_max: Literal[float], below_float_min: Literal[float] ) -> None: - assert GreaterThanOrEqual[float]("a", above_float_max).bind(float_schema) is AlwaysFalse() - assert GreaterThanOrEqual[float]("a", below_float_min).bind(float_schema) is AlwaysTrue() + assert GreaterThanOrEqual("a", above_float_max).bind(float_schema) is AlwaysFalse() + assert GreaterThanOrEqual("a", below_float_min).bind(float_schema) is AlwaysTrue() @pytest.fixture @@ -1197,37 +1197,37 @@ def below_long_min() -> Literal[float]: def test_above_long_bounds_equal_to(long_schema: Schema, above_long_max: Literal[int], below_long_min: Literal[int]) -> None: - assert EqualTo[int]("a", above_long_max).bind(long_schema) is AlwaysFalse() - assert EqualTo[int]("a", below_long_min).bind(long_schema) is AlwaysFalse() + assert EqualTo("a", above_long_max).bind(long_schema) is AlwaysFalse() + assert EqualTo("a", below_long_min).bind(long_schema) is AlwaysFalse() def test_above_long_bounds_not_equal_to(long_schema: Schema, above_long_max: Literal[int], below_long_min: Literal[int]) -> None: - assert NotEqualTo[int]("a", above_long_max).bind(long_schema) is AlwaysTrue() - assert NotEqualTo[int]("a", below_long_min).bind(long_schema) is AlwaysTrue() + assert NotEqualTo("a", above_long_max).bind(long_schema) is AlwaysTrue() + assert NotEqualTo("a", below_long_min).bind(long_schema) is AlwaysTrue() def test_above_long_bounds_less_than(long_schema: Schema, above_long_max: Literal[int], below_long_min: Literal[int]) -> None: - assert LessThan[int]("a", above_long_max).bind(long_schema) is AlwaysTrue() - assert LessThan[int]("a", below_long_min).bind(long_schema) is AlwaysFalse() + assert LessThan("a", above_long_max).bind(long_schema) is AlwaysTrue() + assert LessThan("a", below_long_min).bind(long_schema) is AlwaysFalse() def test_above_long_bounds_less_than_or_equal( long_schema: Schema, above_long_max: Literal[int], below_long_min: Literal[int] ) -> None: - assert LessThanOrEqual[int]("a", above_long_max).bind(long_schema) is AlwaysTrue() - assert LessThanOrEqual[int]("a", below_long_min).bind(long_schema) is AlwaysFalse() + assert LessThanOrEqual("a", above_long_max).bind(long_schema) is AlwaysTrue() + assert LessThanOrEqual("a", below_long_min).bind(long_schema) is AlwaysFalse() def test_above_long_bounds_greater_than(long_schema: Schema, above_long_max: Literal[int], below_long_min: Literal[int]) -> None: - assert GreaterThan[int]("a", above_long_max).bind(long_schema) is AlwaysFalse() - assert GreaterThan[int]("a", below_long_min).bind(long_schema) is AlwaysTrue() + assert GreaterThan("a", above_long_max).bind(long_schema) is AlwaysFalse() + assert GreaterThan("a", below_long_min).bind(long_schema) is AlwaysTrue() def test_above_long_bounds_greater_than_or_equal( long_schema: Schema, above_long_max: Literal[int], below_long_min: Literal[int] ) -> None: - assert GreaterThanOrEqual[int]("a", above_long_max).bind(long_schema) is AlwaysFalse() - assert GreaterThanOrEqual[int]("a", below_long_min).bind(long_schema) is AlwaysTrue() + assert GreaterThanOrEqual("a", above_long_max).bind(long_schema) is AlwaysFalse() + assert GreaterThanOrEqual("a", below_long_min).bind(long_schema) is AlwaysTrue() def test_eq_bound_expression(bound_reference_str: BoundReference) -> None: @@ -1274,8 +1274,8 @@ def test_bind_ambiguous_name() -> None: # |__/ |__/ -def _assert_literal_predicate_type(expr: LiteralPredicate[L]) -> None: - assert_type(expr, LiteralPredicate[L]) +def _assert_literal_predicate_type(expr: LiteralPredicate) -> None: + assert_type(expr, LiteralPredicate) _assert_literal_predicate_type(EqualTo("a", "b")) From 474aa0df94202e2d6afb9936bca61056adf415ab Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Mon, 17 Nov 2025 20:34:20 +0100 Subject: [PATCH 3/3] One more! --- pyiceberg/expressions/__init__.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pyiceberg/expressions/__init__.py b/pyiceberg/expressions/__init__.py index 190a094a84..ff1b02bf48 100644 --- a/pyiceberg/expressions/__init__.py +++ b/pyiceberg/expressions/__init__.py @@ -20,7 +20,7 @@ import builtins from abc import ABC, abstractmethod from functools import cached_property -from typing import Any, Callable, Iterable, Sequence, cast +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence, cast from typing import Literal as TypingLiteral from pydantic import ConfigDict, Field @@ -36,7 +36,10 @@ except ImportError: ConfigDict = dict -LiteralValue = Literal[Any] +if TYPE_CHECKING: + LiteralValue = Literal[Any] +else: + LiteralValue = Literal def _to_unbound_term(term: str | UnboundTerm) -> UnboundTerm: @@ -772,7 +775,7 @@ class LiteralPredicate(IcebergBaseModel, UnboundPredicate, ABC): value: LiteralValue = Field() model_config = ConfigDict(populate_by_name=True, frozen=True, arbitrary_types_allowed=True) - def __init__(self, term: str | UnboundTerm, literal: Any | LiteralValue): + def __init__(self, term: str | UnboundTerm, literal: Any): super().__init__(term=_to_unbound_term(term), value=_to_literal(literal)) # type: ignore[call-arg] @property