diff --git a/pyiceberg/expressions/__init__.py b/pyiceberg/expressions/__init__.py index 20df6e548c..ff1b02bf48 100644 --- a/pyiceberg/expressions/__init__.py +++ b/pyiceberg/expressions/__init__.py @@ -20,24 +20,12 @@ import builtins from abc import ABC, abstractmethod from functools import cached_property -from typing import ( - Any, - Callable, - Generic, - Iterable, - Sequence, - TypeVar, -) +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence, cast from typing import Literal as TypingLiteral from pydantic import ConfigDict, Field -from pyiceberg.expressions.literals import ( - AboveMax, - BelowMin, - Literal, - literal, -) +from pyiceberg.expressions.literals import AboveMax, BelowMin, Literal, literal from pyiceberg.schema import Accessor, Schema from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel, L, StructProtocol from pyiceberg.types import DoubleType, FloatType, NestedField @@ -48,8 +36,13 @@ except ImportError: ConfigDict = dict +if TYPE_CHECKING: + LiteralValue = Literal[Any] +else: + LiteralValue = Literal -def _to_unbound_term(term: str | UnboundTerm[Any]) -> UnboundTerm[Any]: + +def _to_unbound_term(term: str | UnboundTerm) -> UnboundTerm: return Reference(term) if isinstance(term, str) else term @@ -125,7 +118,7 @@ def _build_balanced_tree( return operator_(left, right) -class Term(Generic[L], ABC): +class Term: """A simple expression that evaluates to a value.""" @@ -133,33 +126,30 @@ class Bound: """Represents a bound value expression.""" -B = TypeVar("B") - - -class Unbound(Generic[B], ABC): +class Unbound(ABC): """Represents an unbound value expression.""" @abstractmethod - def bind(self, schema: Schema, case_sensitive: bool = True) -> B: ... + def bind(self, schema: Schema, case_sensitive: bool = True) -> Bound | BooleanExpression: ... @property @abstractmethod def as_bound(self) -> type[Bound]: ... -class BoundTerm(Term[L], Bound, ABC): +class BoundTerm(Term, Bound, ABC): """Represents a bound term.""" @abstractmethod - def ref(self) -> BoundReference[L]: + def ref(self) -> BoundReference: """Return the bound reference.""" @abstractmethod - def eval(self, struct: StructProtocol) -> L: # pylint: disable=W0613 + def eval(self, struct: StructProtocol) -> Any: # pylint: disable=W0613 """Return the value at the referenced field's position in an object that abides by the StructProtocol.""" -class BoundReference(BoundTerm[L]): +class BoundReference(BoundTerm): """A reference bound to a field in a schema. Args: @@ -174,7 +164,7 @@ def __init__(self, field: NestedField, accessor: Accessor): self.field = field self.accessor = accessor - def eval(self, struct: StructProtocol) -> L: + def eval(self, struct: StructProtocol) -> Any: """Return the value at the referenced field's position in an object that abides by the StructProtocol. Args: @@ -192,7 +182,7 @@ def __repr__(self) -> str: """Return the string representation of the BoundReference class.""" return f"BoundReference(field={repr(self.field)}, accessor={repr(self.accessor)})" - def ref(self) -> BoundReference[L]: + def ref(self) -> BoundReference: return self def __hash__(self) -> int: @@ -200,14 +190,14 @@ def __hash__(self) -> int: return hash(str(self)) -class UnboundTerm(Term[Any], Unbound[BoundTerm[L]], ABC): +class UnboundTerm(Term, Unbound, ABC): """Represents an unbound term.""" @abstractmethod - def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundTerm[L]: ... + def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundTerm: ... -class Reference(UnboundTerm[Any], IcebergRootModel[str]): +class Reference(UnboundTerm, IcebergRootModel[str]): """A reference not yet bound to a field in a schema. Args: @@ -230,7 +220,7 @@ def __str__(self) -> str: """Return the string representation of the Reference class.""" return f"Reference(name={repr(self.root)})" - def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundReference[L]: + def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundReference: """Bind the reference to an Iceberg schema. Args: @@ -245,15 +235,15 @@ def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundReference[L] """ field = schema.find_field(name_or_id=self.name, case_sensitive=case_sensitive) accessor = schema.accessor_for_field(field.field_id) - return self.as_bound(field=field, accessor=accessor) # type: ignore + return self.as_bound(field=field, accessor=accessor) @property def name(self) -> str: return self.root @property - def as_bound(self) -> type[BoundReference[L]]: - return BoundReference[L] + def as_bound(self) -> type[BoundReference]: + return BoundReference class And(BooleanExpression): @@ -425,10 +415,10 @@ def __repr__(self) -> str: return "AlwaysFalse()" -class BoundPredicate(Generic[L], Bound, BooleanExpression, ABC): - term: BoundTerm[L] +class BoundPredicate(Bound, BooleanExpression, ABC): + term: BoundTerm - def __init__(self, term: BoundTerm[L]): + def __init__(self, term: BoundTerm): self.term = term def __eq__(self, other: Any) -> bool: @@ -439,13 +429,13 @@ def __eq__(self, other: Any) -> bool: @property @abstractmethod - def as_unbound(self) -> type[UnboundPredicate[Any]]: ... + def as_unbound(self) -> type[UnboundPredicate]: ... -class UnboundPredicate(Generic[L], Unbound[BooleanExpression], BooleanExpression, ABC): - term: UnboundTerm[Any] +class UnboundPredicate(Unbound, BooleanExpression, ABC): + term: UnboundTerm - def __init__(self, term: str | UnboundTerm[Any]): + def __init__(self, term: str | UnboundTerm): self.term = _to_unbound_term(term) def __eq__(self, other: Any) -> bool: @@ -457,15 +447,15 @@ def bind(self, schema: Schema, case_sensitive: bool = True) -> BooleanExpression @property @abstractmethod - def as_bound(self) -> type[BoundPredicate[L]]: ... + def as_bound(self) -> type[BoundPredicate]: ... -class UnaryPredicate(IcebergBaseModel, UnboundPredicate[Any], ABC): +class UnaryPredicate(IcebergBaseModel, UnboundPredicate, ABC): type: str model_config = {"arbitrary_types_allowed": True} - def __init__(self, term: str | UnboundTerm[Any]): + def __init__(self, term: str | UnboundTerm): unbound = _to_unbound_term(term) super().__init__(term=unbound) @@ -474,7 +464,7 @@ def __str__(self) -> str: # Sort to make it deterministic return f"{str(self.__class__.__name__)}(term={str(self.term)})" - def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundUnaryPredicate[Any]: + def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundUnaryPredicate: bound_term = self.term.bind(schema, case_sensitive) return self.as_bound(bound_term) # type: ignore @@ -484,10 +474,10 @@ def __repr__(self) -> str: @property @abstractmethod - def as_bound(self) -> type[BoundUnaryPredicate[Any]]: ... # type: ignore + def as_bound(self) -> type[BoundUnaryPredicate]: ... # type: ignore -class BoundUnaryPredicate(BoundPredicate[L], ABC): +class BoundUnaryPredicate(BoundPredicate, ABC): def __repr__(self) -> str: """Return the string representation of the BoundUnaryPredicate class.""" return f"{str(self.__class__.__name__)}(term={repr(self.term)})" @@ -496,18 +486,18 @@ def __repr__(self) -> str: @abstractmethod def as_unbound(self) -> type[UnaryPredicate]: ... - def __getnewargs__(self) -> tuple[BoundTerm[L]]: + def __getnewargs__(self) -> tuple[BoundTerm]: """Pickle the BoundUnaryPredicate class.""" return (self.term,) -class BoundIsNull(BoundUnaryPredicate[L]): - def __new__(cls, term: BoundTerm[L]) -> BooleanExpression: # type: ignore # pylint: disable=W0221 +class BoundIsNull(BoundUnaryPredicate): + def __new__(cls, term: BoundTerm) -> BooleanExpression: # type: ignore[misc] # pylint: disable=W0221 if term.ref().field.required: return AlwaysFalse() return super().__new__(cls) - def __invert__(self) -> BoundNotNull[L]: + def __invert__(self) -> BoundNotNull: """Transform the Expression into its negated version.""" return BoundNotNull(self.term) @@ -516,13 +506,13 @@ def as_unbound(self) -> type[IsNull]: return IsNull -class BoundNotNull(BoundUnaryPredicate[L]): - def __new__(cls, term: BoundTerm[L]): # type: ignore # pylint: disable=W0221 +class BoundNotNull(BoundUnaryPredicate): + def __new__(cls, term: BoundTerm) -> BooleanExpression: # type: ignore[misc] # pylint: disable=W0221 if term.ref().field.required: return AlwaysTrue() return super().__new__(cls) - def __invert__(self) -> BoundIsNull[L]: + def __invert__(self) -> BoundIsNull: """Transform the Expression into its negated version.""" return BoundIsNull(self.term) @@ -539,8 +529,8 @@ def __invert__(self) -> NotNull: return NotNull(self.term) @property - def as_bound(self) -> builtins.type[BoundIsNull[L]]: - return BoundIsNull[L] + def as_bound(self) -> builtins.type[BoundIsNull]: + return BoundIsNull class NotNull(UnaryPredicate): @@ -551,18 +541,18 @@ def __invert__(self) -> IsNull: return IsNull(self.term) @property - def as_bound(self) -> builtins.type[BoundNotNull[L]]: - return BoundNotNull[L] + def as_bound(self) -> builtins.type[BoundNotNull]: + return BoundNotNull -class BoundIsNaN(BoundUnaryPredicate[L]): - def __new__(cls, term: BoundTerm[L]) -> BooleanExpression: # type: ignore # pylint: disable=W0221 +class BoundIsNaN(BoundUnaryPredicate): + def __new__(cls, term: BoundTerm) -> BooleanExpression: # type: ignore[misc] # pylint: disable=W0221 bound_type = term.ref().field.field_type if isinstance(bound_type, (FloatType, DoubleType)): return super().__new__(cls) return AlwaysFalse() - def __invert__(self) -> BoundNotNaN[L]: + def __invert__(self) -> BoundNotNaN: """Transform the Expression into its negated version.""" return BoundNotNaN(self.term) @@ -571,14 +561,14 @@ def as_unbound(self) -> type[IsNaN]: return IsNaN -class BoundNotNaN(BoundUnaryPredicate[L]): - def __new__(cls, term: BoundTerm[L]) -> BooleanExpression: # type: ignore # pylint: disable=W0221 +class BoundNotNaN(BoundUnaryPredicate): + def __new__(cls, term: BoundTerm) -> BooleanExpression: # type: ignore[misc] # pylint: disable=W0221 bound_type = term.ref().field.field_type if isinstance(bound_type, (FloatType, DoubleType)): return super().__new__(cls) return AlwaysTrue() - def __invert__(self) -> BoundIsNaN[L]: + def __invert__(self) -> BoundIsNaN: """Transform the Expression into its negated version.""" return BoundIsNaN(self.term) @@ -595,8 +585,8 @@ def __invert__(self) -> NotNaN: return NotNaN(self.term) @property - def as_bound(self) -> builtins.type[BoundIsNaN[L]]: - return BoundIsNaN[L] + def as_bound(self) -> builtins.type[BoundIsNaN]: + return BoundIsNaN class NotNaN(UnaryPredicate): @@ -607,22 +597,25 @@ def __invert__(self) -> IsNaN: return IsNaN(self.term) @property - def as_bound(self) -> builtins.type[BoundNotNaN[L]]: - return BoundNotNaN[L] + def as_bound(self) -> builtins.type[BoundNotNaN]: + return BoundNotNaN -class SetPredicate(IcebergBaseModel, UnboundPredicate[L], ABC): +class SetPredicate(IcebergBaseModel, UnboundPredicate, ABC): model_config = ConfigDict(arbitrary_types_allowed=True) type: TypingLiteral["in", "not-in"] = Field(default="in") - literals: set[Literal[L]] = Field(alias="items") + literals: set[Any] = Field(alias="items") - def __init__(self, term: str | UnboundTerm[Any], literals: Iterable[L] | Iterable[Literal[L]]): - super().__init__(term=_to_unbound_term(term), items=_to_literal_set(literals)) # type: ignore + def __init__(self, term: str | UnboundTerm, literals: Iterable[Any] | Iterable[LiteralValue]): + literal_set = _to_literal_set(literals) + super().__init__(term=_to_unbound_term(term), items=literal_set) # type: ignore + object.__setattr__(self, "literals", literal_set) - def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundSetPredicate[L]: + def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundSetPredicate: bound_term = self.term.bind(schema, case_sensitive) - return self.as_bound(bound_term, {lit.to(bound_term.ref().field.field_type) for lit in self.literals}) + literal_set = cast(set[LiteralValue], self.literals) + return self.as_bound(bound_term, {lit.to(bound_term.ref().field.field_type) for lit in literal_set}) def __str__(self) -> str: """Return the string representation of the SetPredicate class.""" @@ -638,26 +631,25 @@ def __eq__(self, other: Any) -> bool: """Return the equality of two instances of the SetPredicate class.""" return self.term == other.term and self.literals == other.literals if isinstance(other, self.__class__) else False - def __getnewargs__(self) -> tuple[UnboundTerm[L], set[Literal[L]]]: + def __getnewargs__(self) -> tuple[UnboundTerm, set[Any]]: """Pickle the SetPredicate class.""" return (self.term, self.literals) @property @abstractmethod - def as_bound(self) -> builtins.type[BoundSetPredicate[L]]: - return BoundSetPredicate[L] + def as_bound(self) -> builtins.type[BoundSetPredicate]: + return BoundSetPredicate -class BoundSetPredicate(BoundPredicate[L], ABC): - literals: set[Literal[L]] +class BoundSetPredicate(BoundPredicate, ABC): + literals: set[LiteralValue] - def __init__(self, term: BoundTerm[L], literals: set[Literal[L]]): - # Since we don't know the type of BoundPredicate[L], we have to ignore this one - super().__init__(term) # type: ignore + def __init__(self, term: BoundTerm, literals: set[LiteralValue]): + super().__init__(term) self.literals = _to_literal_set(literals) # pylint: disable=W0621 @cached_property - def value_set(self) -> set[L]: + def value_set(self) -> set[Any]: return {lit.value for lit in self.literals} def __str__(self) -> str: @@ -674,17 +666,17 @@ def __eq__(self, other: Any) -> bool: """Return the equality of two instances of the BoundSetPredicate class.""" return self.term == other.term and self.literals == other.literals if isinstance(other, self.__class__) else False - def __getnewargs__(self) -> tuple[BoundTerm[L], set[Literal[L]]]: + def __getnewargs__(self) -> tuple[BoundTerm, set[LiteralValue]]: """Pickle the BoundSetPredicate class.""" return (self.term, self.literals) @property @abstractmethod - def as_unbound(self) -> type[SetPredicate[L]]: ... + def as_unbound(self) -> type[SetPredicate]: ... -class BoundIn(BoundSetPredicate[L]): - def __new__(cls, term: BoundTerm[L], literals: set[Literal[L]]) -> BooleanExpression: # type: ignore # pylint: disable=W0221 +class BoundIn(BoundSetPredicate): + def __new__(cls, term: BoundTerm, literals: set[LiteralValue]) -> BooleanExpression: # type: ignore[misc] # pylint: disable=W0221 count = len(literals) if count == 0: return AlwaysFalse() @@ -693,7 +685,7 @@ def __new__(cls, term: BoundTerm[L], literals: set[Literal[L]]) -> BooleanExpres else: return super().__new__(cls) - def __invert__(self) -> BoundNotIn[L]: + def __invert__(self) -> BoundNotIn: """Transform the Expression into its negated version.""" return BoundNotIn(self.term, self.literals) @@ -702,15 +694,15 @@ def __eq__(self, other: Any) -> bool: return self.term == other.term and self.literals == other.literals if isinstance(other, self.__class__) else False @property - def as_unbound(self) -> type[In[L]]: + def as_unbound(self) -> type[In]: return In -class BoundNotIn(BoundSetPredicate[L]): - def __new__( # type: ignore # pylint: disable=W0221 +class BoundNotIn(BoundSetPredicate): + def __new__( # type: ignore[misc] # pylint: disable=W0221 cls, - term: BoundTerm[L], - literals: set[Literal[L]], + term: BoundTerm, + literals: set[LiteralValue], ) -> BooleanExpression: count = len(literals) if count == 0: @@ -720,22 +712,22 @@ def __new__( # type: ignore # pylint: disable=W0221 else: return super().__new__(cls) - def __invert__(self) -> BoundIn[L]: + def __invert__(self) -> BoundIn: """Transform the Expression into its negated version.""" return BoundIn(self.term, self.literals) @property - def as_unbound(self) -> type[NotIn[L]]: + def as_unbound(self) -> type[NotIn]: return NotIn -class In(SetPredicate[L]): +class In(SetPredicate): type: TypingLiteral["in"] = Field(default="in", alias="type") - def __new__( # type: ignore # pylint: disable=W0221 - cls, term: str | UnboundTerm[Any], literals: Iterable[L] | Iterable[Literal[L]] + def __new__( # type: ignore[misc] # pylint: disable=W0221 + cls, term: str | UnboundTerm, literals: Iterable[Any] | Iterable[LiteralValue] ) -> BooleanExpression: - literals_set: set[Literal[L]] = _to_literal_set(literals) + literals_set: set[LiteralValue] = _to_literal_set(literals) count = len(literals_set) if count == 0: return AlwaysFalse() @@ -744,22 +736,22 @@ def __new__( # type: ignore # pylint: disable=W0221 else: return super().__new__(cls) - def __invert__(self) -> NotIn[L]: + def __invert__(self) -> NotIn: """Transform the Expression into its negated version.""" - return NotIn[L](self.term, self.literals) + return NotIn(self.term, self.literals) @property - def as_bound(self) -> builtins.type[BoundIn[L]]: - return BoundIn[L] + def as_bound(self) -> builtins.type[BoundIn]: + return BoundIn -class NotIn(SetPredicate[L], ABC): +class NotIn(SetPredicate, ABC): type: TypingLiteral["not-in"] = Field(default="not-in", alias="type") - def __new__( # type: ignore # pylint: disable=W0221 - cls, term: str | UnboundTerm[Any], literals: Iterable[L] | Iterable[Literal[L]] + def __new__( # type: ignore[misc] # pylint: disable=W0221 + cls, term: str | UnboundTerm, literals: Iterable[Any] | Iterable[LiteralValue] ) -> BooleanExpression: - literals_set: set[Literal[L]] = _to_literal_set(literals) + literals_set: set[LiteralValue] = _to_literal_set(literals) count = len(literals_set) if count == 0: return AlwaysTrue() @@ -768,29 +760,29 @@ def __new__( # type: ignore # pylint: disable=W0221 else: return super().__new__(cls) - def __invert__(self) -> In[L]: + def __invert__(self) -> In: """Transform the Expression into its negated version.""" - return In[L](self.term, self.literals) + return In(self.term, self.literals) @property - def as_bound(self) -> builtins.type[BoundNotIn[L]]: - return BoundNotIn[L] + def as_bound(self) -> builtins.type[BoundNotIn]: + return BoundNotIn -class LiteralPredicate(IcebergBaseModel, UnboundPredicate[L], ABC): +class LiteralPredicate(IcebergBaseModel, UnboundPredicate, ABC): type: TypingLiteral["lt", "lt-eq", "gt", "gt-eq", "eq", "not-eq", "starts-with", "not-starts-with"] = Field(alias="type") - term: UnboundTerm[Any] - value: Literal[L] = Field() + term: UnboundTerm + value: LiteralValue = Field() model_config = ConfigDict(populate_by_name=True, frozen=True, arbitrary_types_allowed=True) - def __init__(self, term: str | UnboundTerm[Any], literal: L | Literal[L]): + def __init__(self, term: str | UnboundTerm, literal: Any): super().__init__(term=_to_unbound_term(term), value=_to_literal(literal)) # type: ignore[call-arg] @property - def literal(self) -> Literal[L]: + def literal(self) -> LiteralValue: return self.value - def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundLiteralPredicate[L]: + def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundLiteralPredicate: bound_term = self.term.bind(schema, case_sensitive) lit = self.literal.to(bound_term.ref().field.field_type) @@ -823,15 +815,14 @@ def __repr__(self) -> str: @property @abstractmethod - def as_bound(self) -> builtins.type[BoundLiteralPredicate[L]]: ... + def as_bound(self) -> builtins.type[BoundLiteralPredicate]: ... -class BoundLiteralPredicate(BoundPredicate[L], ABC): - literal: Literal[L] +class BoundLiteralPredicate(BoundPredicate, ABC): + literal: LiteralValue - def __init__(self, term: BoundTerm[L], literal: Literal[L]): # pylint: disable=W0621 - # Since we don't know the type of BoundPredicate[L], we have to ignore this one - super().__init__(term) # type: ignore + def __init__(self, term: BoundTerm, literal: LiteralValue): # pylint: disable=W0621 + super().__init__(term) self.literal = literal # pylint: disable=W0621 def __eq__(self, other: Any) -> bool: @@ -846,180 +837,180 @@ def __repr__(self) -> str: @property @abstractmethod - def as_unbound(self) -> type[LiteralPredicate[L]]: ... + def as_unbound(self) -> type[LiteralPredicate]: ... -class BoundEqualTo(BoundLiteralPredicate[L]): - def __invert__(self) -> BoundNotEqualTo[L]: +class BoundEqualTo(BoundLiteralPredicate): + def __invert__(self) -> BoundNotEqualTo: """Transform the Expression into its negated version.""" - return BoundNotEqualTo[L](self.term, self.literal) + return BoundNotEqualTo(self.term, self.literal) @property - def as_unbound(self) -> type[EqualTo[L]]: + def as_unbound(self) -> type[EqualTo]: return EqualTo -class BoundNotEqualTo(BoundLiteralPredicate[L]): - def __invert__(self) -> BoundEqualTo[L]: +class BoundNotEqualTo(BoundLiteralPredicate): + def __invert__(self) -> BoundEqualTo: """Transform the Expression into its negated version.""" - return BoundEqualTo[L](self.term, self.literal) + return BoundEqualTo(self.term, self.literal) @property - def as_unbound(self) -> type[NotEqualTo[L]]: + def as_unbound(self) -> type[NotEqualTo]: return NotEqualTo -class BoundGreaterThanOrEqual(BoundLiteralPredicate[L]): - def __invert__(self) -> BoundLessThan[L]: +class BoundGreaterThanOrEqual(BoundLiteralPredicate): + def __invert__(self) -> BoundLessThan: """Transform the Expression into its negated version.""" - return BoundLessThan[L](self.term, self.literal) + return BoundLessThan(self.term, self.literal) @property - def as_unbound(self) -> type[GreaterThanOrEqual[L]]: - return GreaterThanOrEqual[L] + def as_unbound(self) -> type[GreaterThanOrEqual]: + return GreaterThanOrEqual -class BoundGreaterThan(BoundLiteralPredicate[L]): - def __invert__(self) -> BoundLessThanOrEqual[L]: +class BoundGreaterThan(BoundLiteralPredicate): + def __invert__(self) -> BoundLessThanOrEqual: """Transform the Expression into its negated version.""" return BoundLessThanOrEqual(self.term, self.literal) @property - def as_unbound(self) -> type[GreaterThan[L]]: - return GreaterThan[L] + def as_unbound(self) -> type[GreaterThan]: + return GreaterThan -class BoundLessThan(BoundLiteralPredicate[L]): - def __invert__(self) -> BoundGreaterThanOrEqual[L]: +class BoundLessThan(BoundLiteralPredicate): + def __invert__(self) -> BoundGreaterThanOrEqual: """Transform the Expression into its negated version.""" - return BoundGreaterThanOrEqual[L](self.term, self.literal) + return BoundGreaterThanOrEqual(self.term, self.literal) @property - def as_unbound(self) -> type[LessThan[L]]: - return LessThan[L] + def as_unbound(self) -> type[LessThan]: + return LessThan -class BoundLessThanOrEqual(BoundLiteralPredicate[L]): - def __invert__(self) -> BoundGreaterThan[L]: +class BoundLessThanOrEqual(BoundLiteralPredicate): + def __invert__(self) -> BoundGreaterThan: """Transform the Expression into its negated version.""" - return BoundGreaterThan[L](self.term, self.literal) + return BoundGreaterThan(self.term, self.literal) @property - def as_unbound(self) -> type[LessThanOrEqual[L]]: - return LessThanOrEqual[L] + def as_unbound(self) -> type[LessThanOrEqual]: + return LessThanOrEqual -class BoundStartsWith(BoundLiteralPredicate[L]): - def __invert__(self) -> BoundNotStartsWith[L]: +class BoundStartsWith(BoundLiteralPredicate): + def __invert__(self) -> BoundNotStartsWith: """Transform the Expression into its negated version.""" - return BoundNotStartsWith[L](self.term, self.literal) + return BoundNotStartsWith(self.term, self.literal) @property - def as_unbound(self) -> type[StartsWith[L]]: - return StartsWith[L] + def as_unbound(self) -> type[StartsWith]: + return StartsWith -class BoundNotStartsWith(BoundLiteralPredicate[L]): - def __invert__(self) -> BoundStartsWith[L]: +class BoundNotStartsWith(BoundLiteralPredicate): + def __invert__(self) -> BoundStartsWith: """Transform the Expression into its negated version.""" - return BoundStartsWith[L](self.term, self.literal) + return BoundStartsWith(self.term, self.literal) @property - def as_unbound(self) -> type[NotStartsWith[L]]: - return NotStartsWith[L] + def as_unbound(self) -> type[NotStartsWith]: + return NotStartsWith -class EqualTo(LiteralPredicate[L]): +class EqualTo(LiteralPredicate): type: TypingLiteral["eq"] = Field(default="eq", alias="type") - def __invert__(self) -> NotEqualTo[L]: + def __invert__(self) -> NotEqualTo: """Transform the Expression into its negated version.""" - return NotEqualTo[L](self.term, self.literal) + return NotEqualTo(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundEqualTo[L]]: - return BoundEqualTo[L] + def as_bound(self) -> builtins.type[BoundEqualTo]: + return BoundEqualTo -class NotEqualTo(LiteralPredicate[L]): +class NotEqualTo(LiteralPredicate): type: TypingLiteral["not-eq"] = Field(default="not-eq", alias="type") - def __invert__(self) -> EqualTo[L]: + def __invert__(self) -> EqualTo: """Transform the Expression into its negated version.""" - return EqualTo[L](self.term, self.literal) + return EqualTo(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundNotEqualTo[L]]: - return BoundNotEqualTo[L] + def as_bound(self) -> builtins.type[BoundNotEqualTo]: + return BoundNotEqualTo -class LessThan(LiteralPredicate[L]): +class LessThan(LiteralPredicate): type: TypingLiteral["lt"] = Field(default="lt", alias="type") - def __invert__(self) -> GreaterThanOrEqual[L]: + def __invert__(self) -> GreaterThanOrEqual: """Transform the Expression into its negated version.""" - return GreaterThanOrEqual[L](self.term, self.literal) + return GreaterThanOrEqual(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundLessThan[L]]: - return BoundLessThan[L] + def as_bound(self) -> builtins.type[BoundLessThan]: + return BoundLessThan -class GreaterThanOrEqual(LiteralPredicate[L]): +class GreaterThanOrEqual(LiteralPredicate): type: TypingLiteral["gt-eq"] = Field(default="gt-eq", alias="type") - def __invert__(self) -> LessThan[L]: + def __invert__(self) -> LessThan: """Transform the Expression into its negated version.""" - return LessThan[L](self.term, self.literal) + return LessThan(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundGreaterThanOrEqual[L]]: - return BoundGreaterThanOrEqual[L] + def as_bound(self) -> builtins.type[BoundGreaterThanOrEqual]: + return BoundGreaterThanOrEqual -class GreaterThan(LiteralPredicate[L]): +class GreaterThan(LiteralPredicate): type: TypingLiteral["gt"] = Field(default="gt", alias="type") - def __invert__(self) -> LessThanOrEqual[L]: + def __invert__(self) -> LessThanOrEqual: """Transform the Expression into its negated version.""" - return LessThanOrEqual[L](self.term, self.literal) + return LessThanOrEqual(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundGreaterThan[L]]: - return BoundGreaterThan[L] + def as_bound(self) -> builtins.type[BoundGreaterThan]: + return BoundGreaterThan -class LessThanOrEqual(LiteralPredicate[L]): +class LessThanOrEqual(LiteralPredicate): type: TypingLiteral["lt-eq"] = Field(default="lt-eq", alias="type") - def __invert__(self) -> GreaterThan[L]: + def __invert__(self) -> GreaterThan: """Transform the Expression into its negated version.""" - return GreaterThan[L](self.term, self.literal) + return GreaterThan(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundLessThanOrEqual[L]]: - return BoundLessThanOrEqual[L] + def as_bound(self) -> builtins.type[BoundLessThanOrEqual]: + return BoundLessThanOrEqual -class StartsWith(LiteralPredicate[L]): +class StartsWith(LiteralPredicate): type: TypingLiteral["starts-with"] = Field(default="starts-with", alias="type") - def __invert__(self) -> NotStartsWith[L]: + def __invert__(self) -> NotStartsWith: """Transform the Expression into its negated version.""" - return NotStartsWith[L](self.term, self.literal) + return NotStartsWith(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundStartsWith[L]]: - return BoundStartsWith[L] + def as_bound(self) -> builtins.type[BoundStartsWith]: + return BoundStartsWith -class NotStartsWith(LiteralPredicate[L]): +class NotStartsWith(LiteralPredicate): type: TypingLiteral["not-starts-with"] = Field(default="not-starts-with", alias="type") - def __invert__(self) -> StartsWith[L]: + def __invert__(self) -> StartsWith: """Transform the Expression into its negated version.""" - return StartsWith[L](self.term, self.literal) + return StartsWith(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundNotStartsWith[L]]: - return BoundNotStartsWith[L] + def as_bound(self) -> builtins.type[BoundNotStartsWith]: + return BoundNotStartsWith diff --git a/pyiceberg/expressions/visitors.py b/pyiceberg/expressions/visitors.py index 4c096f1215..1362945a00 100644 --- a/pyiceberg/expressions/visitors.py +++ b/pyiceberg/expressions/visitors.py @@ -116,19 +116,19 @@ def visit_or(self, left_result: T, right_result: T) -> T: """ @abstractmethod - def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> T: + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> T: """Visit method for an unbound predicate in an expression tree. Args: - predicate (UnboundPredicate[L): An instance of an UnboundPredicate. + predicate (UnboundPredicate): An instance of an UnboundPredicate. """ @abstractmethod - def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> T: + def visit_bound_predicate(self, predicate: BoundPredicate) -> T: """Visit method for a bound predicate in an expression tree. Args: - predicate (BoundPredicate[L]): An instance of a BoundPredicate. + predicate (BoundPredicate): An instance of a BoundPredicate. """ @@ -176,13 +176,13 @@ def _(obj: And, visitor: BooleanExpressionVisitor[T]) -> T: @visit.register(UnboundPredicate) -def _(obj: UnboundPredicate[L], visitor: BooleanExpressionVisitor[T]) -> T: +def _(obj: UnboundPredicate, visitor: BooleanExpressionVisitor[T]) -> T: """Visit an unbound boolean expression with a concrete BooleanExpressionVisitor.""" return visitor.visit_unbound_predicate(predicate=obj) @visit.register(BoundPredicate) -def _(obj: BoundPredicate[L], visitor: BooleanExpressionVisitor[T]) -> T: +def _(obj: BoundPredicate, visitor: BooleanExpressionVisitor[T]) -> T: """Visit a bound boolean expression with a concrete BooleanExpressionVisitor.""" return visitor.visit_bound_predicate(predicate=obj) @@ -242,60 +242,60 @@ def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpress def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: return Or(left=left_result, right=right_result) - def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpression: + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> BooleanExpression: return predicate.bind(self.schema, case_sensitive=self.case_sensitive) - def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpression: + def visit_bound_predicate(self, predicate: BoundPredicate) -> BooleanExpression: raise TypeError(f"Found already bound predicate: {predicate}") class BoundBooleanExpressionVisitor(BooleanExpressionVisitor[T], ABC): @abstractmethod - def visit_in(self, term: BoundTerm[L], literals: set[L]) -> T: + def visit_in(self, term: BoundTerm, literals: set[L]) -> T: """Visit a bound In predicate.""" @abstractmethod - def visit_not_in(self, term: BoundTerm[L], literals: set[L]) -> T: + def visit_not_in(self, term: BoundTerm, literals: set[L]) -> T: """Visit a bound NotIn predicate.""" @abstractmethod - def visit_is_nan(self, term: BoundTerm[L]) -> T: + def visit_is_nan(self, term: BoundTerm) -> T: """Visit a bound IsNan predicate.""" @abstractmethod - def visit_not_nan(self, term: BoundTerm[L]) -> T: + def visit_not_nan(self, term: BoundTerm) -> T: """Visit a bound NotNan predicate.""" @abstractmethod - def visit_is_null(self, term: BoundTerm[L]) -> T: + def visit_is_null(self, term: BoundTerm) -> T: """Visit a bound IsNull predicate.""" @abstractmethod - def visit_not_null(self, term: BoundTerm[L]) -> T: + def visit_not_null(self, term: BoundTerm) -> T: """Visit a bound NotNull predicate.""" @abstractmethod - def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> T: + def visit_equal(self, term: BoundTerm, literal: Literal[L]) -> T: """Visit a bound Equal predicate.""" @abstractmethod - def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> T: + def visit_not_equal(self, term: BoundTerm, literal: Literal[L]) -> T: """Visit a bound NotEqual predicate.""" @abstractmethod - def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> T: + def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> T: """Visit a bound GreaterThanOrEqual predicate.""" @abstractmethod - def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> T: + def visit_greater_than(self, term: BoundTerm, literal: Literal[L]) -> T: """Visit a bound GreaterThan predicate.""" @abstractmethod - def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> T: + def visit_less_than(self, term: BoundTerm, literal: Literal[L]) -> T: """Visit a bound LessThan predicate.""" @abstractmethod - def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> T: + def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> T: """Visit a bound LessThanOrEqual predicate.""" @abstractmethod @@ -319,105 +319,105 @@ def visit_or(self, left_result: T, right_result: T) -> T: """Visit a bound Or predicate.""" @abstractmethod - def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> T: + def visit_starts_with(self, term: BoundTerm, literal: Literal[L]) -> T: """Visit bound StartsWith predicate.""" @abstractmethod - def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> T: + def visit_not_starts_with(self, term: BoundTerm, literal: Literal[L]) -> T: """Visit bound NotStartsWith predicate.""" - def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> T: + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> T: """Visit an unbound predicate. Args: - predicate (UnboundPredicate[L]): An unbound predicate. + predicate (UnboundPredicate): An unbound predicate. Raises: TypeError: This always raises since an unbound predicate is not expected in a bound boolean expression. """ raise TypeError(f"Not a bound predicate: {predicate}") - def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> T: + def visit_bound_predicate(self, predicate: BoundPredicate) -> T: """Visit a bound predicate. Args: - predicate (BoundPredicate[L]): A bound predicate. + predicate (BoundPredicate): A bound predicate. """ return visit_bound_predicate(predicate, self) @singledispatch -def visit_bound_predicate(expr: BoundPredicate[L], _: BooleanExpressionVisitor[T]) -> T: +def visit_bound_predicate(expr: BoundPredicate, _: BooleanExpressionVisitor[T]) -> T: raise TypeError(f"Unknown predicate: {expr}") @visit_bound_predicate.register(BoundIn) -def _(expr: BoundIn[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundIn, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_in(term=expr.term, literals=expr.value_set) @visit_bound_predicate.register(BoundNotIn) -def _(expr: BoundNotIn[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundNotIn, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_not_in(term=expr.term, literals=expr.value_set) @visit_bound_predicate.register(BoundIsNaN) -def _(expr: BoundIsNaN[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundIsNaN, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_is_nan(term=expr.term) @visit_bound_predicate.register(BoundNotNaN) -def _(expr: BoundNotNaN[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundNotNaN, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_not_nan(term=expr.term) @visit_bound_predicate.register(BoundIsNull) -def _(expr: BoundIsNull[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundIsNull, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_is_null(term=expr.term) @visit_bound_predicate.register(BoundNotNull) -def _(expr: BoundNotNull[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundNotNull, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_not_null(term=expr.term) @visit_bound_predicate.register(BoundEqualTo) -def _(expr: BoundEqualTo[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundEqualTo, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_equal(term=expr.term, literal=expr.literal) @visit_bound_predicate.register(BoundNotEqualTo) -def _(expr: BoundNotEqualTo[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundNotEqualTo, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_not_equal(term=expr.term, literal=expr.literal) @visit_bound_predicate.register(BoundGreaterThanOrEqual) -def _(expr: BoundGreaterThanOrEqual[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundGreaterThanOrEqual, visitor: BoundBooleanExpressionVisitor[T]) -> T: """Visit a bound GreaterThanOrEqual predicate.""" return visitor.visit_greater_than_or_equal(term=expr.term, literal=expr.literal) @visit_bound_predicate.register(BoundGreaterThan) -def _(expr: BoundGreaterThan[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundGreaterThan, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_greater_than(term=expr.term, literal=expr.literal) @visit_bound_predicate.register(BoundLessThan) -def _(expr: BoundLessThan[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundLessThan, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_less_than(term=expr.term, literal=expr.literal) @visit_bound_predicate.register(BoundLessThanOrEqual) -def _(expr: BoundLessThanOrEqual[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundLessThanOrEqual, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_less_than_or_equal(term=expr.term, literal=expr.literal) @visit_bound_predicate.register(BoundStartsWith) -def _(expr: BoundStartsWith[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundStartsWith, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_starts_with(term=expr.term, literal=expr.literal) @visit_bound_predicate.register(BoundNotStartsWith) -def _(expr: BoundNotStartsWith[L], visitor: BoundBooleanExpressionVisitor[T]) -> T: +def _(expr: BoundNotStartsWith, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_not_starts_with(term=expr.term, literal=expr.literal) @@ -443,10 +443,10 @@ def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpress def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: return Or(left=left_result, right=right_result) - def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpression: + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> BooleanExpression: return predicate - def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpression: + def visit_bound_predicate(self, predicate: BoundPredicate) -> BooleanExpression: return predicate @@ -465,53 +465,53 @@ def eval(self, struct: StructProtocol) -> bool: self.struct = struct return visit(self.bound, self) - def visit_in(self, term: BoundTerm[L], literals: set[L]) -> bool: + def visit_in(self, term: BoundTerm, literals: set[L]) -> bool: return term.eval(self.struct) in literals - def visit_not_in(self, term: BoundTerm[L], literals: set[L]) -> bool: + def visit_not_in(self, term: BoundTerm, literals: set[L]) -> bool: return term.eval(self.struct) not in literals - def visit_is_nan(self, term: BoundTerm[L]) -> bool: + def visit_is_nan(self, term: BoundTerm) -> bool: val = term.eval(self.struct) return val != val - def visit_not_nan(self, term: BoundTerm[L]) -> bool: + def visit_not_nan(self, term: BoundTerm) -> bool: val = term.eval(self.struct) return val == val - def visit_is_null(self, term: BoundTerm[L]) -> bool: + def visit_is_null(self, term: BoundTerm) -> bool: return term.eval(self.struct) is None - def visit_not_null(self, term: BoundTerm[L]) -> bool: + def visit_not_null(self, term: BoundTerm) -> bool: return term.eval(self.struct) is not None - def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: return term.eval(self.struct) == literal.value - def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_not_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: return term.eval(self.struct) != literal.value - def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: value = term.eval(self.struct) return value is not None and value >= literal.value - def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_greater_than(self, term: BoundTerm, literal: Literal[L]) -> bool: value = term.eval(self.struct) return value is not None and value > literal.value - def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_less_than(self, term: BoundTerm, literal: Literal[L]) -> bool: value = term.eval(self.struct) return value is not None and value < literal.value - def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: value = term.eval(self.struct) return value is not None and value <= literal.value - def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_starts_with(self, term: BoundTerm, literal: Literal[L]) -> bool: eval_res = term.eval(self.struct) return eval_res is not None and str(eval_res).startswith(str(literal.value)) - def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_not_starts_with(self, term: BoundTerm, literal: Literal[L]) -> bool: return not self.visit_starts_with(term, literal) def visit_true(self) -> bool: @@ -558,7 +558,7 @@ def eval(self, manifest: ManifestFile) -> bool: # No partition information return ROWS_MIGHT_MATCH - def visit_in(self, term: BoundTerm[L], literals: set[L]) -> bool: + def visit_in(self, term: BoundTerm, literals: set[L]) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] @@ -580,12 +580,12 @@ def visit_in(self, term: BoundTerm[L], literals: set[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_not_in(self, term: BoundTerm[L], literals: set[L]) -> bool: + def visit_not_in(self, term: BoundTerm, literals: set[L]) -> bool: # because the bounds are not necessarily a min or max value, this cannot be answered using # them. notIn(col, {X, ...}) with (X, Y) doesn't guarantee that X is a value in col. return ROWS_MIGHT_MATCH - def visit_is_nan(self, term: BoundTerm[L]) -> bool: + def visit_is_nan(self, term: BoundTerm) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] @@ -594,7 +594,7 @@ def visit_is_nan(self, term: BoundTerm[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_not_nan(self, term: BoundTerm[L]) -> bool: + def visit_not_nan(self, term: BoundTerm) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] @@ -603,7 +603,7 @@ def visit_not_nan(self, term: BoundTerm[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_is_null(self, term: BoundTerm[L]) -> bool: + def visit_is_null(self, term: BoundTerm) -> bool: pos = term.ref().accessor.position if self.partition_fields[pos].contains_null is False: @@ -611,7 +611,7 @@ def visit_is_null(self, term: BoundTerm[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_not_null(self, term: BoundTerm[L]) -> bool: + def visit_not_null(self, term: BoundTerm) -> bool: pos = term.ref().accessor.position # contains_null encodes whether at least one partition value is null, @@ -628,7 +628,7 @@ def visit_not_null(self, term: BoundTerm[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] @@ -648,12 +648,12 @@ def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_not_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: # because the bounds are not necessarily a min or max value, this cannot be answered using # them. notEq(col, X) with (X, Y) doesn't guarantee that X is a value in col. return ROWS_MIGHT_MATCH - def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] @@ -667,7 +667,7 @@ def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) - return ROWS_MIGHT_MATCH - def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_greater_than(self, term: BoundTerm, literal: Literal[L]) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] @@ -681,7 +681,7 @@ def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_less_than(self, term: BoundTerm, literal: Literal[L]) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] @@ -695,7 +695,7 @@ def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] @@ -709,7 +709,7 @@ def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> b return ROWS_MIGHT_MATCH - def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_starts_with(self, term: BoundTerm, literal: Literal[L]) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] prefix = str(literal.value) @@ -733,7 +733,7 @@ def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_not_starts_with(self, term: BoundTerm, literal: Literal[L]) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] prefix = str(literal.value) @@ -820,12 +820,12 @@ def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpress def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: return Or(left_result, right_result) - def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpression: + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> BooleanExpression: raise ValueError(f"Cannot project unbound predicate: {predicate}") class InclusiveProjection(ProjectionEvaluator): - def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> BooleanExpression: + def visit_bound_predicate(self, predicate: BoundPredicate) -> BooleanExpression: parts = self.spec.fields_by_source_id(predicate.term.ref().field.field_id) result: BooleanExpression = AlwaysTrue() @@ -887,10 +887,10 @@ def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpress def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: return Or(left=left_result, right=right_result) - def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpression: + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> BooleanExpression: raise TypeError(f"Expected Bound Predicate, got: {predicate.term}") - def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpression: + def visit_bound_predicate(self, predicate: BoundPredicate) -> BooleanExpression: field = predicate.term.ref().field field_id = field.field_id file_column_name = self.file_schema.find_column_name(field_id) @@ -954,10 +954,10 @@ def visit_and(self, left_result: set[int], right_result: set[int]) -> set[int]: def visit_or(self, left_result: set[int], right_result: set[int]) -> set[int]: return left_result.union(right_result) - def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> set[int]: + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> set[int]: raise ValueError("Only works on bound records") - def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> set[int]: + def visit_bound_predicate(self, predicate: BoundPredicate) -> set[int]: return {predicate.term.ref().field.field_id} @@ -989,10 +989,10 @@ def visit_or( ) -> tuple[BooleanExpression, ...]: return left_result + right_result - def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> tuple[BooleanExpression, ...]: + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> tuple[BooleanExpression, ...]: return (predicate,) - def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> tuple[BooleanExpression, ...]: + def visit_bound_predicate(self, predicate: BoundPredicate) -> tuple[BooleanExpression, ...]: return (predicate,) @@ -1021,48 +1021,48 @@ def _cast_if_necessary(self, iceberg_type: IcebergType, literal: L | set[L]) -> return conversion_function(literal) # type: ignore return literal - def visit_in(self, term: BoundTerm[L], literals: set[L]) -> list[tuple[str, str, Any]]: + def visit_in(self, term: BoundTerm, literals: set[L]) -> list[tuple[str, str, Any]]: field = term.ref().field return [(term.ref().field.name, "in", self._cast_if_necessary(field.field_type, literals))] - def visit_not_in(self, term: BoundTerm[L], literals: set[L]) -> list[tuple[str, str, Any]]: + def visit_not_in(self, term: BoundTerm, literals: set[L]) -> list[tuple[str, str, Any]]: field = term.ref().field return [(field.name, "not in", self._cast_if_necessary(field.field_type, literals))] - def visit_is_nan(self, term: BoundTerm[L]) -> list[tuple[str, str, Any]]: + def visit_is_nan(self, term: BoundTerm) -> list[tuple[str, str, Any]]: return [(term.ref().field.name, "==", float("nan"))] - def visit_not_nan(self, term: BoundTerm[L]) -> list[tuple[str, str, Any]]: + def visit_not_nan(self, term: BoundTerm) -> list[tuple[str, str, Any]]: return [(term.ref().field.name, "!=", float("nan"))] - def visit_is_null(self, term: BoundTerm[L]) -> list[tuple[str, str, Any]]: + def visit_is_null(self, term: BoundTerm) -> list[tuple[str, str, Any]]: return [(term.ref().field.name, "==", None)] - def visit_not_null(self, term: BoundTerm[L]) -> list[tuple[str, str, Any]]: + def visit_not_null(self, term: BoundTerm) -> list[tuple[str, str, Any]]: return [(term.ref().field.name, "!=", None)] - def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> list[tuple[str, str, Any]]: + def visit_equal(self, term: BoundTerm, literal: Literal[L]) -> list[tuple[str, str, Any]]: return [(term.ref().field.name, "==", self._cast_if_necessary(term.ref().field.field_type, literal.value))] - def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> list[tuple[str, str, Any]]: + def visit_not_equal(self, term: BoundTerm, literal: Literal[L]) -> list[tuple[str, str, Any]]: return [(term.ref().field.name, "!=", self._cast_if_necessary(term.ref().field.field_type, literal.value))] - def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> list[tuple[str, str, Any]]: + def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> list[tuple[str, str, Any]]: return [(term.ref().field.name, ">=", self._cast_if_necessary(term.ref().field.field_type, literal.value))] - def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> list[tuple[str, str, Any]]: + def visit_greater_than(self, term: BoundTerm, literal: Literal[L]) -> list[tuple[str, str, Any]]: return [(term.ref().field.name, ">", self._cast_if_necessary(term.ref().field.field_type, literal.value))] - def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> list[tuple[str, str, Any]]: + def visit_less_than(self, term: BoundTerm, literal: Literal[L]) -> list[tuple[str, str, Any]]: return [(term.ref().field.name, "<", self._cast_if_necessary(term.ref().field.field_type, literal.value))] - def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> list[tuple[str, str, Any]]: + def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> list[tuple[str, str, Any]]: return [(term.ref().field.name, "<=", self._cast_if_necessary(term.ref().field.field_type, literal.value))] - def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> list[tuple[str, str, Any]]: + def visit_starts_with(self, term: BoundTerm, literal: Literal[L]) -> list[tuple[str, str, Any]]: return [] - def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> list[tuple[str, str, Any]]: + def visit_not_starts_with(self, term: BoundTerm, literal: Literal[L]) -> list[tuple[str, str, Any]]: return [] def visit_true(self) -> list[tuple[str, str, Any]]: @@ -1192,7 +1192,7 @@ def _contains_nans_only(self, field_id: int) -> bool: return nan_count == value_count return False - def visit_is_null(self, term: BoundTerm[L]) -> bool: + def visit_is_null(self, term: BoundTerm) -> bool: field_id = term.ref().field.field_id if self.null_counts.get(field_id) == 0: @@ -1200,7 +1200,7 @@ def visit_is_null(self, term: BoundTerm[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_not_null(self, term: BoundTerm[L]) -> bool: + def visit_not_null(self, term: BoundTerm) -> bool: # no need to check whether the field is required because binding evaluates that case # if the column has no non-null values, the expression cannot match field_id = term.ref().field.field_id @@ -1210,7 +1210,7 @@ def visit_not_null(self, term: BoundTerm[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_is_nan(self, term: BoundTerm[L]) -> bool: + def visit_is_nan(self, term: BoundTerm) -> bool: field_id = term.ref().field.field_id if self.nan_counts.get(field_id) == 0: @@ -1223,7 +1223,7 @@ def visit_is_nan(self, term: BoundTerm[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_not_nan(self, term: BoundTerm[L]) -> bool: + def visit_not_nan(self, term: BoundTerm) -> bool: field_id = term.ref().field.field_id if self._contains_nans_only(field_id): @@ -1231,7 +1231,7 @@ def visit_not_nan(self, term: BoundTerm[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_less_than(self, term: BoundTerm, literal: Literal[L]) -> bool: field = term.ref().field field_id = field.field_id @@ -1253,7 +1253,7 @@ def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: field = term.ref().field field_id = field.field_id @@ -1274,7 +1274,7 @@ def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> b return ROWS_MIGHT_MATCH - def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_greater_than(self, term: BoundTerm, literal: Literal[L]) -> bool: field = term.ref().field field_id = field.field_id @@ -1295,7 +1295,7 @@ def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: field = term.ref().field field_id = field.field_id @@ -1316,7 +1316,7 @@ def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) - return ROWS_MIGHT_MATCH - def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: field = term.ref().field field_id = field.field_id @@ -1346,10 +1346,10 @@ def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_not_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_in(self, term: BoundTerm[L], literals: set[L]) -> bool: + def visit_in(self, term: BoundTerm, literals: set[L]) -> bool: field = term.ref().field field_id = field.field_id @@ -1385,12 +1385,12 @@ def visit_in(self, term: BoundTerm[L], literals: set[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_not_in(self, term: BoundTerm[L], literals: set[L]) -> bool: + def visit_not_in(self, term: BoundTerm, literals: set[L]) -> bool: # because the bounds are not necessarily a min or max value, this cannot be answered using # them. notIn(col, {X, ...}) with (X, Y) doesn't guarantee that X is a value in col. return ROWS_MIGHT_MATCH - def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_starts_with(self, term: BoundTerm, literal: Literal[L]) -> bool: field = term.ref().field field_id: int = field.field_id @@ -1419,7 +1419,7 @@ def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_MATCH - def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_not_starts_with(self, term: BoundTerm, literal: Literal[L]) -> bool: field = term.ref().field field_id: int = field.field_id @@ -1460,7 +1460,7 @@ def strict_projection( class StrictProjection(ProjectionEvaluator): - def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> BooleanExpression: + def visit_bound_predicate(self, predicate: BoundPredicate) -> BooleanExpression: parts = self.spec.fields_by_source_id(predicate.term.ref().field.field_id) result: BooleanExpression = AlwaysFalse() @@ -1511,7 +1511,7 @@ def eval(self, file: DataFile) -> bool: return visit(self.expr, self) - def visit_is_null(self, term: BoundTerm[L]) -> bool: + def visit_is_null(self, term: BoundTerm) -> bool: # no need to check whether the field is required because binding evaluates that case # if the column has any non-null values, the expression does not match field_id = term.ref().field.field_id @@ -1521,7 +1521,7 @@ def visit_is_null(self, term: BoundTerm[L]) -> bool: else: return ROWS_MIGHT_NOT_MATCH - def visit_not_null(self, term: BoundTerm[L]) -> bool: + def visit_not_null(self, term: BoundTerm) -> bool: # no need to check whether the field is required because binding evaluates that case # if the column has any non-null values, the expression does not match field_id = term.ref().field.field_id @@ -1531,7 +1531,7 @@ def visit_not_null(self, term: BoundTerm[L]) -> bool: else: return ROWS_MIGHT_NOT_MATCH - def visit_is_nan(self, term: BoundTerm[L]) -> bool: + def visit_is_nan(self, term: BoundTerm) -> bool: field_id = term.ref().field.field_id if self._contains_nans_only(field_id): @@ -1539,7 +1539,7 @@ def visit_is_nan(self, term: BoundTerm[L]) -> bool: else: return ROWS_MIGHT_NOT_MATCH - def visit_not_nan(self, term: BoundTerm[L]) -> bool: + def visit_not_nan(self, term: BoundTerm) -> bool: field_id = term.ref().field.field_id if (nan_count := self.nan_counts.get(field_id)) is not None and nan_count == 0: @@ -1550,7 +1550,7 @@ def visit_not_nan(self, term: BoundTerm[L]) -> bool: return ROWS_MIGHT_NOT_MATCH - def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_less_than(self, term: BoundTerm, literal: Literal[L]) -> bool: # Rows must match when: <----------Min----Max---X-------> field_id = term.ref().field.field_id @@ -1567,7 +1567,7 @@ def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_NOT_MATCH - def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: # Rows must match when: <----------Min----Max---X-------> field_id = term.ref().field.field_id @@ -1584,7 +1584,7 @@ def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> b return ROWS_MIGHT_NOT_MATCH - def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_greater_than(self, term: BoundTerm, literal: Literal[L]) -> bool: # Rows must match when: <-------X---Min----Max----------> field_id = term.ref().field.field_id @@ -1606,7 +1606,7 @@ def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_NOT_MATCH - def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: # Rows must match when: <-------X---Min----Max----------> field_id = term.ref().field.field_id @@ -1627,7 +1627,7 @@ def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) - return ROWS_MIGHT_NOT_MATCH - def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: # Rows must match when Min == X == Max field_id = term.ref().field.field_id @@ -1646,7 +1646,7 @@ def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_NOT_MATCH - def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_not_equal(self, term: BoundTerm, literal: Literal[L]) -> bool: # Rows must match when X < Min or Max < X because it is not in the range field_id = term.ref().field.field_id @@ -1674,7 +1674,7 @@ def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_NOT_MATCH - def visit_in(self, term: BoundTerm[L], literals: set[L]) -> bool: + def visit_in(self, term: BoundTerm, literals: set[L]) -> bool: field_id = term.ref().field.field_id if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): @@ -1703,7 +1703,7 @@ def visit_in(self, term: BoundTerm[L], literals: set[L]) -> bool: return ROWS_MIGHT_NOT_MATCH - def visit_not_in(self, term: BoundTerm[L], literals: set[L]) -> bool: + def visit_not_in(self, term: BoundTerm, literals: set[L]) -> bool: field_id = term.ref().field.field_id if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): @@ -1733,10 +1733,10 @@ def visit_not_in(self, term: BoundTerm[L], literals: set[L]) -> bool: return ROWS_MIGHT_NOT_MATCH - def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_starts_with(self, term: BoundTerm, literal: Literal[L]) -> bool: return ROWS_MIGHT_NOT_MATCH - def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_not_starts_with(self, term: BoundTerm, literal: Literal[L]) -> bool: return ROWS_MIGHT_NOT_MATCH def _get_field(self, field_id: int) -> NestedField: @@ -1799,94 +1799,94 @@ def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpress def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: return Or(left_result, right_result) - def visit_is_null(self, term: BoundTerm[L]) -> BooleanExpression: + def visit_is_null(self, term: BoundTerm) -> BooleanExpression: if term.eval(self.struct) is None: return AlwaysTrue() else: return AlwaysFalse() - def visit_not_null(self, term: BoundTerm[L]) -> BooleanExpression: + def visit_not_null(self, term: BoundTerm) -> BooleanExpression: if term.eval(self.struct) is not None: return AlwaysTrue() else: return AlwaysFalse() - def visit_is_nan(self, term: BoundTerm[L]) -> BooleanExpression: + def visit_is_nan(self, term: BoundTerm) -> BooleanExpression: val = term.eval(self.struct) if isinstance(val, SupportsFloat) and math.isnan(val): return self.visit_true() else: return self.visit_false() - def visit_not_nan(self, term: BoundTerm[L]) -> BooleanExpression: + def visit_not_nan(self, term: BoundTerm) -> BooleanExpression: val = term.eval(self.struct) if isinstance(val, SupportsFloat) and not math.isnan(val): return self.visit_true() else: return self.visit_false() - def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + def visit_less_than(self, term: BoundTerm, literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) < literal.value: return self.visit_true() else: return self.visit_false() - def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) <= literal.value: return self.visit_true() else: return self.visit_false() - def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + def visit_greater_than(self, term: BoundTerm, literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) > literal.value: return self.visit_true() else: return self.visit_false() - def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) >= literal.value: return self.visit_true() else: return self.visit_false() - def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + def visit_equal(self, term: BoundTerm, literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) == literal.value: return self.visit_true() else: return self.visit_false() - def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + def visit_not_equal(self, term: BoundTerm, literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) != literal.value: return self.visit_true() else: return self.visit_false() - def visit_in(self, term: BoundTerm[L], literals: set[L]) -> BooleanExpression: + def visit_in(self, term: BoundTerm, literals: set[L]) -> BooleanExpression: if term.eval(self.struct) in literals: return self.visit_true() else: return self.visit_false() - def visit_not_in(self, term: BoundTerm[L], literals: set[L]) -> BooleanExpression: + def visit_not_in(self, term: BoundTerm, literals: set[L]) -> BooleanExpression: if term.eval(self.struct) not in literals: return self.visit_true() else: return self.visit_false() - def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + def visit_starts_with(self, term: BoundTerm, literal: Literal[L]) -> BooleanExpression: eval_res = term.eval(self.struct) if eval_res is not None and str(eval_res).startswith(str(literal.value)): return AlwaysTrue() else: return AlwaysFalse() - def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + def visit_not_starts_with(self, term: BoundTerm, literal: Literal[L]) -> BooleanExpression: if not self.visit_starts_with(term, literal): return AlwaysTrue() else: return AlwaysFalse() - def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> BooleanExpression: + def visit_bound_predicate(self, predicate: BoundPredicate) -> BooleanExpression: """ If there is no strict projection or if it evaluates to false, then return the predicate. @@ -1940,7 +1940,7 @@ def struct_to_schema(struct: StructType) -> Schema: return predicate - def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpression: + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> BooleanExpression: bound = predicate.bind(self.schema, case_sensitive=self.case_sensitive) if isinstance(bound, BoundPredicate): diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index cd19d43906..0210839df5 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -817,7 +817,7 @@ class _ConvertToArrowExpression(BoundBooleanExpressionVisitor[pc.Expression]): def __init__(self, schema: Schema | None = None): self._schema = schema - def _get_field_name(self, term: BoundTerm[Any]) -> str | tuple[str, ...]: + def _get_field_name(self, term: BoundTerm) -> str | tuple[str, ...]: """Get the field name or nested field path for a bound term. For nested struct fields, returns a tuple of field names (e.g., ("mazeMetadata", "run_id")). @@ -837,50 +837,50 @@ def _get_field_name(self, term: BoundTerm[Any]) -> str | tuple[str, ...]: # Fallback to just the field name if schema is not available return term.ref().field.name - def visit_in(self, term: BoundTerm[Any], literals: set[Any]) -> pc.Expression: + def visit_in(self, term: BoundTerm, literals: set[Any]) -> pc.Expression: pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type)) return pc.field(self._get_field_name(term)).isin(pyarrow_literals) - def visit_not_in(self, term: BoundTerm[Any], literals: set[Any]) -> pc.Expression: + def visit_not_in(self, term: BoundTerm, literals: set[Any]) -> pc.Expression: pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type)) return ~pc.field(self._get_field_name(term)).isin(pyarrow_literals) - def visit_is_nan(self, term: BoundTerm[Any]) -> pc.Expression: + def visit_is_nan(self, term: BoundTerm) -> pc.Expression: ref = pc.field(self._get_field_name(term)) return pc.is_nan(ref) - def visit_not_nan(self, term: BoundTerm[Any]) -> pc.Expression: + def visit_not_nan(self, term: BoundTerm) -> pc.Expression: ref = pc.field(self._get_field_name(term)) return ~pc.is_nan(ref) - def visit_is_null(self, term: BoundTerm[Any]) -> pc.Expression: + def visit_is_null(self, term: BoundTerm) -> pc.Expression: return pc.field(self._get_field_name(term)).is_null(nan_is_null=False) - def visit_not_null(self, term: BoundTerm[Any]) -> pc.Expression: + def visit_not_null(self, term: BoundTerm) -> pc.Expression: return pc.field(self._get_field_name(term)).is_valid() - def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: + def visit_equal(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression: return pc.field(self._get_field_name(term)) == _convert_scalar(literal.value, term.ref().field.field_type) - def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: + def visit_not_equal(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression: return pc.field(self._get_field_name(term)) != _convert_scalar(literal.value, term.ref().field.field_type) - def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: + def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression: return pc.field(self._get_field_name(term)) >= _convert_scalar(literal.value, term.ref().field.field_type) - def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: + def visit_greater_than(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression: return pc.field(self._get_field_name(term)) > _convert_scalar(literal.value, term.ref().field.field_type) - def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: + def visit_less_than(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression: return pc.field(self._get_field_name(term)) < _convert_scalar(literal.value, term.ref().field.field_type) - def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: + def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression: return pc.field(self._get_field_name(term)) <= _convert_scalar(literal.value, term.ref().field.field_type) - def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: + def visit_starts_with(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression: return pc.starts_with(pc.field(self._get_field_name(term)), literal.value) - def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: + def visit_not_starts_with(self, term: BoundTerm, literal: Literal[Any]) -> pc.Expression: return ~pc.starts_with(pc.field(self._get_field_name(term)), literal.value) def visit_true(self) -> pc.Expression: @@ -901,13 +901,13 @@ def visit_or(self, left_result: pc.Expression, right_result: pc.Expression) -> p class _NullNaNUnmentionedTermsCollector(BoundBooleanExpressionVisitor[None]): # BoundTerms which have either is_null or is_not_null appearing at least once in the boolean expr. - is_null_or_not_bound_terms: set[BoundTerm[Any]] + is_null_or_not_bound_terms: set[BoundTerm] # The remaining BoundTerms appearing in the boolean expr. - null_unmentioned_bound_terms: set[BoundTerm[Any]] + null_unmentioned_bound_terms: set[BoundTerm] # BoundTerms which have either is_nan or is_not_nan appearing at least once in the boolean expr. - is_nan_or_not_bound_terms: set[BoundTerm[Any]] + is_nan_or_not_bound_terms: set[BoundTerm] # The remaining BoundTerms appearing in the boolean expr. - nan_unmentioned_bound_terms: set[BoundTerm[Any]] + nan_unmentioned_bound_terms: set[BoundTerm] def __init__(self) -> None: super().__init__() @@ -916,81 +916,81 @@ def __init__(self) -> None: self.is_nan_or_not_bound_terms = set() self.nan_unmentioned_bound_terms = set() - def _handle_explicit_is_null_or_not(self, term: BoundTerm[Any]) -> None: + def _handle_explicit_is_null_or_not(self, term: BoundTerm) -> None: """Handle the predicate case where either is_null or is_not_null is included.""" if term in self.null_unmentioned_bound_terms: self.null_unmentioned_bound_terms.remove(term) self.is_null_or_not_bound_terms.add(term) - def _handle_null_unmentioned(self, term: BoundTerm[Any]) -> None: + def _handle_null_unmentioned(self, term: BoundTerm) -> None: """Handle the predicate case where neither is_null or is_not_null is included.""" if term not in self.is_null_or_not_bound_terms: self.null_unmentioned_bound_terms.add(term) - def _handle_explicit_is_nan_or_not(self, term: BoundTerm[Any]) -> None: + def _handle_explicit_is_nan_or_not(self, term: BoundTerm) -> None: """Handle the predicate case where either is_nan or is_not_nan is included.""" if term in self.nan_unmentioned_bound_terms: self.nan_unmentioned_bound_terms.remove(term) self.is_nan_or_not_bound_terms.add(term) - def _handle_nan_unmentioned(self, term: BoundTerm[Any]) -> None: + def _handle_nan_unmentioned(self, term: BoundTerm) -> None: """Handle the predicate case where neither is_nan or is_not_nan is included.""" if term not in self.is_nan_or_not_bound_terms: self.nan_unmentioned_bound_terms.add(term) - def visit_in(self, term: BoundTerm[Any], literals: set[Any]) -> None: + def visit_in(self, term: BoundTerm, literals: set[Any]) -> None: self._handle_null_unmentioned(term) self._handle_nan_unmentioned(term) - def visit_not_in(self, term: BoundTerm[Any], literals: set[Any]) -> None: + def visit_not_in(self, term: BoundTerm, literals: set[Any]) -> None: self._handle_null_unmentioned(term) self._handle_nan_unmentioned(term) - def visit_is_nan(self, term: BoundTerm[Any]) -> None: + def visit_is_nan(self, term: BoundTerm) -> None: self._handle_null_unmentioned(term) self._handle_explicit_is_nan_or_not(term) - def visit_not_nan(self, term: BoundTerm[Any]) -> None: + def visit_not_nan(self, term: BoundTerm) -> None: self._handle_null_unmentioned(term) self._handle_explicit_is_nan_or_not(term) - def visit_is_null(self, term: BoundTerm[Any]) -> None: + def visit_is_null(self, term: BoundTerm) -> None: self._handle_explicit_is_null_or_not(term) self._handle_nan_unmentioned(term) - def visit_not_null(self, term: BoundTerm[Any]) -> None: + def visit_not_null(self, term: BoundTerm) -> None: self._handle_explicit_is_null_or_not(term) self._handle_nan_unmentioned(term) - def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + def visit_equal(self, term: BoundTerm, literal: Literal[Any]) -> None: self._handle_null_unmentioned(term) self._handle_nan_unmentioned(term) - def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + def visit_not_equal(self, term: BoundTerm, literal: Literal[Any]) -> None: self._handle_null_unmentioned(term) self._handle_nan_unmentioned(term) - def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[Any]) -> None: self._handle_null_unmentioned(term) self._handle_nan_unmentioned(term) - def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + def visit_greater_than(self, term: BoundTerm, literal: Literal[Any]) -> None: self._handle_null_unmentioned(term) self._handle_nan_unmentioned(term) - def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + def visit_less_than(self, term: BoundTerm, literal: Literal[Any]) -> None: self._handle_null_unmentioned(term) self._handle_nan_unmentioned(term) - def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[Any]) -> None: self._handle_null_unmentioned(term) self._handle_nan_unmentioned(term) - def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + def visit_starts_with(self, term: BoundTerm, literal: Literal[Any]) -> None: self._handle_null_unmentioned(term) self._handle_nan_unmentioned(term) - def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + def visit_not_starts_with(self, term: BoundTerm, literal: Literal[Any]) -> None: self._handle_null_unmentioned(term) self._handle_nan_unmentioned(term) @@ -1040,10 +1040,10 @@ def _expression_to_complementary_pyarrow(expr: BooleanExpression, schema: Schema collector.collect(expr) # Convert the set of terms to a sorted list so that layout of the expression to build is deterministic. - null_unmentioned_bound_terms: list[BoundTerm[Any]] = sorted( + null_unmentioned_bound_terms: list[BoundTerm] = sorted( collector.null_unmentioned_bound_terms, key=lambda term: term.ref().field.name ) - nan_unmentioned_bound_terms: list[BoundTerm[Any]] = sorted( + nan_unmentioned_bound_terms: list[BoundTerm] = sorted( collector.nan_unmentioned_bound_terms, key=lambda term: term.ref().field.name ) diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index 98cfac1146..896a04527d 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -183,10 +183,10 @@ def result_type(self, source: IcebergType) -> IcebergType: ... @abstractmethod - def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: ... + def project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: ... @abstractmethod - def strict_project(self, name: str, pred: BoundPredicate[Any]) -> UnboundPredicate[Any] | None: ... + def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: ... @property def preserves_order(self) -> bool: @@ -269,7 +269,7 @@ def apply(self, value: S | None) -> int | None: def result_type(self, source: IcebergType) -> IcebergType: return IntegerType() - def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: + def project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: transformer = self.transform(pred.term.ref().field.field_type) if isinstance(pred.term, BoundTransform): @@ -286,7 +286,7 @@ def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | # For example, (x > 0) and (x < 3) can be turned into in({1, 2}) and projected. return None - def strict_project(self, name: str, pred: BoundPredicate[Any]) -> UnboundPredicate[Any] | None: + def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: transformer = self.transform(pred.term.ref().field.field_type) if isinstance(pred.term, BoundTransform): @@ -420,7 +420,7 @@ def result_type(self, source: IcebergType) -> IntegerType: @abstractmethod def transform(self, source: IcebergType) -> Callable[[Any | None], int | None]: ... - def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: + def project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: transformer = self.transform(pred.term.ref().field.field_type) if isinstance(pred.term, BoundTransform): return _project_transform_predicate(self, name, pred) @@ -433,7 +433,7 @@ def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | else: return None - def strict_project(self, name: str, pred: BoundPredicate[Any]) -> UnboundPredicate[Any] | None: + def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: transformer = self.transform(pred.term.ref().field.field_type) if isinstance(pred.term, BoundTransform): return _project_transform_predicate(self, name, pred) @@ -725,7 +725,7 @@ def can_transform(self, source: IcebergType) -> bool: def result_type(self, source: IcebergType) -> IcebergType: return source - def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: + def project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: if isinstance(pred.term, BoundTransform): return _project_transform_predicate(self, name, pred) elif isinstance(pred, BoundUnaryPredicate): @@ -737,7 +737,7 @@ def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | else: return None - def strict_project(self, name: str, pred: BoundPredicate[Any]) -> UnboundPredicate[Any] | None: + def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: if isinstance(pred, BoundUnaryPredicate): return pred.as_unbound(Reference(name)) elif isinstance(pred, BoundLiteralPredicate): @@ -801,7 +801,7 @@ def preserves_order(self) -> bool: def source_type(self) -> IcebergType: return self._source_type - def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: + def project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: field_type = pred.term.ref().field.field_type if isinstance(pred.term, BoundTransform): @@ -819,7 +819,7 @@ def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | return _truncate_array(name, pred, self.transform(field_type)) return None - def strict_project(self, name: str, pred: BoundPredicate[Any]) -> UnboundPredicate[Any] | None: + def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: field_type = pred.term.ref().field.field_type if isinstance(pred.term, BoundTransform): @@ -987,10 +987,10 @@ def can_transform(self, source: IcebergType) -> bool: def result_type(self, source: IcebergType) -> StringType: return StringType() - def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: + def project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: return None - def strict_project(self, name: str, pred: BoundPredicate[Any]) -> UnboundPredicate[Any] | None: + def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: return None def __repr__(self) -> str: @@ -1015,10 +1015,10 @@ def can_transform(self, _: IcebergType) -> bool: def result_type(self, source: IcebergType) -> IcebergType: return source - def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: + def project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: return None - def strict_project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: + def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: return None def to_human_string(self, _: IcebergType, value: S | None) -> str: @@ -1038,8 +1038,8 @@ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Arr def _truncate_number( - name: str, pred: BoundLiteralPredicate[L], transform: Callable[[L | None], L | None] -) -> UnboundPredicate[Any] | None: + name: str, pred: BoundLiteralPredicate, transform: Callable[[Any | None], Any | None] +) -> UnboundPredicate | None: boundary = pred.literal if not isinstance(boundary, (LongLiteral, DecimalLiteral, DateLiteral, TimestampLiteral)): @@ -1060,8 +1060,8 @@ def _truncate_number( def _truncate_number_strict( - name: str, pred: BoundLiteralPredicate[L], transform: Callable[[L | None], L | None] -) -> UnboundPredicate[Any] | None: + name: str, pred: BoundLiteralPredicate, transform: Callable[[Any | None], Any | None] +) -> UnboundPredicate | None: boundary = pred.literal if not isinstance(boundary, (LongLiteral, DecimalLiteral, DateLiteral, TimestampLiteral)): @@ -1086,8 +1086,8 @@ def _truncate_number_strict( def _truncate_array_strict( - name: str, pred: BoundLiteralPredicate[L], transform: Callable[[L | None], L | None] -) -> UnboundPredicate[Any] | None: + name: str, pred: BoundLiteralPredicate, transform: Callable[[Any | None], Any | None] +) -> UnboundPredicate | None: boundary = pred.literal if isinstance(pred, (BoundLessThan, BoundLessThanOrEqual)): @@ -1101,8 +1101,8 @@ def _truncate_array_strict( def _truncate_array( - name: str, pred: BoundLiteralPredicate[L], transform: Callable[[L | None], L | None] -) -> UnboundPredicate[Any] | None: + name: str, pred: BoundLiteralPredicate, transform: Callable[[Any | None], Any | None] +) -> UnboundPredicate | None: boundary = pred.literal if isinstance(pred, (BoundLessThan, BoundLessThanOrEqual)): @@ -1120,15 +1120,15 @@ def _truncate_array( def _project_transform_predicate( - transform: Transform[Any, Any], partition_name: str, pred: BoundPredicate[L] -) -> UnboundPredicate[Any] | None: + transform: Transform[Any, Any], partition_name: str, pred: BoundPredicate +) -> UnboundPredicate | None: term = pred.term if isinstance(term, BoundTransform) and transform == term.transform: return _remove_transform(partition_name, pred) return None -def _remove_transform(partition_name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any]: +def _remove_transform(partition_name: str, pred: BoundPredicate) -> UnboundPredicate: if isinstance(pred, BoundUnaryPredicate): return pred.as_unbound(Reference(partition_name)) elif isinstance(pred, BoundLiteralPredicate): @@ -1139,7 +1139,7 @@ def _remove_transform(partition_name: str, pred: BoundPredicate[L]) -> UnboundPr raise ValueError(f"Cannot replace transform in unknown predicate: {pred}") -def _set_apply_transform(name: str, pred: BoundSetPredicate[L], transform: Callable[[L], L]) -> UnboundPredicate[Any]: +def _set_apply_transform(name: str, pred: BoundSetPredicate, transform: Callable[[Any], Any]) -> UnboundPredicate: literals = pred.literals if isinstance(pred, BoundSetPredicate): transformed_literals = {_transform_literal(transform, literal) for literal in literals} @@ -1148,11 +1148,11 @@ def _set_apply_transform(name: str, pred: BoundSetPredicate[L], transform: Calla raise ValueError(f"Unknown BoundSetPredicate: {pred}") -class BoundTransform(BoundTerm[L]): +class BoundTransform(BoundTerm): """A transform expression.""" - transform: Transform[L, Any] + transform: Transform[Any, Any] - def __init__(self, term: BoundTerm[L], transform: Transform[L, Any]): - self.term: BoundTerm[L] = term + def __init__(self, term: BoundTerm, transform: Transform[Any, Any]): + self.term: BoundTerm = term self.transform = transform diff --git a/tests/conftest.py b/tests/conftest.py index 947fc00a83..fcd188f6a4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2558,17 +2558,17 @@ def table_v2_with_statistics(table_metadata_v2_with_statistics: dict[str, Any]) @pytest.fixture -def bound_reference_str() -> BoundReference[str]: +def bound_reference_str() -> BoundReference: return BoundReference(field=NestedField(1, "field", StringType(), required=False), accessor=Accessor(position=0, inner=None)) @pytest.fixture -def bound_reference_binary() -> BoundReference[str]: +def bound_reference_binary() -> BoundReference: return BoundReference(field=NestedField(1, "field", BinaryType(), required=False), accessor=Accessor(position=0, inner=None)) @pytest.fixture -def bound_reference_uuid() -> BoundReference[str]: +def bound_reference_uuid() -> BoundReference: return BoundReference(field=NestedField(1, "field", UUIDType(), required=False), accessor=Accessor(position=0, inner=None)) diff --git a/tests/expressions/test_evaluator.py b/tests/expressions/test_evaluator.py index 07888dd41e..5be3e92be8 100644 --- a/tests/expressions/test_evaluator.py +++ b/tests/expressions/test_evaluator.py @@ -685,7 +685,7 @@ def data_file_nan() -> DataFile: def test_inclusive_metrics_evaluator_less_than_and_less_than_equal(schema_data_file_nan: Schema, data_file_nan: DataFile) -> None: - operators: tuple[type[LiteralPredicate[Any]], ...] = (LessThan, LessThanOrEqual) + operators: tuple[type[LiteralPredicate], ...] = (LessThan, LessThanOrEqual) for operator in operators: should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) assert not should_read, "Should not match: all nan column doesn't contain number" @@ -714,7 +714,7 @@ def test_inclusive_metrics_evaluator_less_than_and_less_than_equal(schema_data_f def test_inclusive_metrics_evaluator_greater_than_and_greater_than_equal( schema_data_file_nan: Schema, data_file_nan: DataFile ) -> None: - operators: tuple[type[LiteralPredicate[Any]], ...] = (GreaterThan, GreaterThanOrEqual) + operators: tuple[type[LiteralPredicate], ...] = (GreaterThan, GreaterThanOrEqual) for operator in operators: should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) assert not should_read, "Should not match: all nan column doesn't contain number" diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index dbee2ca045..1fbe8d7a6c 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -19,7 +19,6 @@ import pickle import uuid from decimal import Decimal -from typing import Any import pytest from typing_extensions import assert_type @@ -65,7 +64,7 @@ from pyiceberg.expressions.literals import Literal, literal from pyiceberg.expressions.visitors import _from_byte_buffer from pyiceberg.schema import Accessor, Schema -from pyiceberg.typedef import L, Record +from pyiceberg.typedef import Record from pyiceberg.types import ( DecimalType, DoubleType, @@ -502,7 +501,7 @@ def test_less_than_or_equal_invert() -> None: LessThanOrEqual(Reference("foo"), "hello"), ], ) -def test_bind(pred: UnboundPredicate[Any], table_schema_simple: Schema) -> None: +def test_bind(pred: UnboundPredicate, table_schema_simple: Schema) -> None: assert pred.bind(table_schema_simple, case_sensitive=True).term.field == table_schema_simple.find_field( # type: ignore pred.term.name, # type: ignore case_sensitive=True, @@ -522,7 +521,7 @@ def test_bind(pred: UnboundPredicate[Any], table_schema_simple: Schema) -> None: LessThanOrEqual(Reference("Bar"), 5), ], ) -def test_bind_case_insensitive(pred: UnboundPredicate[Any], table_schema_simple: Schema) -> None: +def test_bind_case_insensitive(pred: UnboundPredicate, table_schema_simple: Schema) -> None: assert pred.bind(table_schema_simple, case_sensitive=False).term.field == table_schema_simple.find_field( # type: ignore pred.term.name, # type: ignore case_sensitive=False, @@ -683,7 +682,7 @@ def accessor() -> Accessor: @pytest.fixture -def term(field: NestedField, accessor: Accessor) -> BoundReference[Any]: +def term(field: NestedField, accessor: Accessor) -> BoundReference: return BoundReference( field=field, accessor=accessor, @@ -794,14 +793,14 @@ def test_bound_reference_field_property() -> None: assert bound_ref.field == NestedField(field_id=1, name="foo", field_type=StringType(), required=False) -def test_bound_is_null(term: BoundReference[Any]) -> None: +def test_bound_is_null(term: BoundReference) -> None: bound_is_null = BoundIsNull(term) assert str(bound_is_null) == f"BoundIsNull(term={str(term)})" assert repr(bound_is_null) == f"BoundIsNull(term={repr(term)})" assert bound_is_null == eval(repr(bound_is_null)) -def test_bound_is_not_null(term: BoundReference[Any]) -> None: +def test_bound_is_not_null(term: BoundReference) -> None: bound_not_null = BoundNotNull(term) assert str(bound_not_null) == f"BoundNotNull(term={str(term)})" assert repr(bound_not_null) == f"BoundNotNull(term={repr(term)})" @@ -838,7 +837,7 @@ def test_serialize_not_null() -> None: def test_bound_is_nan(accessor: Accessor) -> None: # We need a FloatType here - term = BoundReference[float]( + term = BoundReference( field=NestedField(field_id=1, name="foo", field_type=FloatType(), required=False), accessor=accessor, ) @@ -851,7 +850,7 @@ def test_bound_is_nan(accessor: Accessor) -> None: def test_bound_is_not_nan(accessor: Accessor) -> None: # We need a FloatType here - term = BoundReference[float]( + term = BoundReference( field=NestedField(field_id=1, name="foo", field_type=FloatType(), required=False), accessor=accessor, ) @@ -880,7 +879,7 @@ def test_not_nan() -> None: assert not_nan == pickle.loads(pickle.dumps(not_nan)) -def test_bound_in(term: BoundReference[Any]) -> None: +def test_bound_in(term: BoundReference) -> None: bound_in = BoundIn(term, {literal("a"), literal("b"), literal("c")}) assert str(bound_in) == f"BoundIn({str(term)}, {{a, b, c}})" assert repr(bound_in) == f"BoundIn({repr(term)}, {{literal('a'), literal('b'), literal('c')}})" @@ -888,7 +887,7 @@ def test_bound_in(term: BoundReference[Any]) -> None: assert bound_in == pickle.loads(pickle.dumps(bound_in)) -def test_bound_not_in(term: BoundReference[Any]) -> None: +def test_bound_not_in(term: BoundReference) -> None: bound_not_in = BoundNotIn(term, {literal("a"), literal("b"), literal("c")}) assert str(bound_not_in) == f"BoundNotIn({str(term)}, {{a, b, c}})" assert repr(bound_not_in) == f"BoundNotIn({repr(term)}, {{literal('a'), literal('b'), literal('c')}})" @@ -924,7 +923,7 @@ def test_serialize_not_in() -> None: assert pred.model_dump_json() == '{"term":"foo","type":"not-in","items":[1,2,3]}' -def test_bound_equal_to(term: BoundReference[Any]) -> None: +def test_bound_equal_to(term: BoundReference) -> None: bound_equal_to = BoundEqualTo(term, literal("a")) assert str(bound_equal_to) == f"BoundEqualTo(term={str(term)}, literal=literal('a'))" assert repr(bound_equal_to) == f"BoundEqualTo(term={repr(term)}, literal=literal('a'))" @@ -932,7 +931,7 @@ def test_bound_equal_to(term: BoundReference[Any]) -> None: assert bound_equal_to == pickle.loads(pickle.dumps(bound_equal_to)) -def test_bound_not_equal_to(term: BoundReference[Any]) -> None: +def test_bound_not_equal_to(term: BoundReference) -> None: bound_not_equal_to = BoundNotEqualTo(term, literal("a")) assert str(bound_not_equal_to) == f"BoundNotEqualTo(term={str(term)}, literal=literal('a'))" assert repr(bound_not_equal_to) == f"BoundNotEqualTo(term={repr(term)}, literal=literal('a'))" @@ -940,7 +939,7 @@ def test_bound_not_equal_to(term: BoundReference[Any]) -> None: assert bound_not_equal_to == pickle.loads(pickle.dumps(bound_not_equal_to)) -def test_bound_greater_than_or_equal_to(term: BoundReference[Any]) -> None: +def test_bound_greater_than_or_equal_to(term: BoundReference) -> None: bound_greater_than_or_equal_to = BoundGreaterThanOrEqual(term, literal("a")) assert str(bound_greater_than_or_equal_to) == f"BoundGreaterThanOrEqual(term={str(term)}, literal=literal('a'))" assert repr(bound_greater_than_or_equal_to) == f"BoundGreaterThanOrEqual(term={repr(term)}, literal=literal('a'))" @@ -948,7 +947,7 @@ def test_bound_greater_than_or_equal_to(term: BoundReference[Any]) -> None: assert bound_greater_than_or_equal_to == pickle.loads(pickle.dumps(bound_greater_than_or_equal_to)) -def test_bound_greater_than(term: BoundReference[Any]) -> None: +def test_bound_greater_than(term: BoundReference) -> None: bound_greater_than = BoundGreaterThan(term, literal("a")) assert str(bound_greater_than) == f"BoundGreaterThan(term={str(term)}, literal=literal('a'))" assert repr(bound_greater_than) == f"BoundGreaterThan(term={repr(term)}, literal=literal('a'))" @@ -956,7 +955,7 @@ def test_bound_greater_than(term: BoundReference[Any]) -> None: assert bound_greater_than == pickle.loads(pickle.dumps(bound_greater_than)) -def test_bound_less_than(term: BoundReference[Any]) -> None: +def test_bound_less_than(term: BoundReference) -> None: bound_less_than = BoundLessThan(term, literal("a")) assert str(bound_less_than) == f"BoundLessThan(term={str(term)}, literal=literal('a'))" assert repr(bound_less_than) == f"BoundLessThan(term={repr(term)}, literal=literal('a'))" @@ -964,7 +963,7 @@ def test_bound_less_than(term: BoundReference[Any]) -> None: assert bound_less_than == pickle.loads(pickle.dumps(bound_less_than)) -def test_bound_less_than_or_equal(term: BoundReference[Any]) -> None: +def test_bound_less_than_or_equal(term: BoundReference) -> None: bound_less_than_or_equal = BoundLessThanOrEqual(term, literal("a")) assert str(bound_less_than_or_equal) == f"BoundLessThanOrEqual(term={str(term)}, literal=literal('a'))" assert repr(bound_less_than_or_equal) == f"BoundLessThanOrEqual(term={repr(term)}, literal=literal('a'))" @@ -1092,37 +1091,37 @@ def below_int_min() -> Literal[int]: def test_above_int_bounds_equal_to(int_schema: Schema, above_int_max: Literal[int], below_int_min: Literal[int]) -> None: - assert EqualTo[int]("a", above_int_max).bind(int_schema) is AlwaysFalse() - assert EqualTo[int]("a", below_int_min).bind(int_schema) is AlwaysFalse() + assert EqualTo("a", above_int_max).bind(int_schema) is AlwaysFalse() + assert EqualTo("a", below_int_min).bind(int_schema) is AlwaysFalse() def test_above_int_bounds_not_equal_to(int_schema: Schema, above_int_max: Literal[int], below_int_min: Literal[int]) -> None: - assert NotEqualTo[int]("a", above_int_max).bind(int_schema) is AlwaysTrue() - assert NotEqualTo[int]("a", below_int_min).bind(int_schema) is AlwaysTrue() + assert NotEqualTo("a", above_int_max).bind(int_schema) is AlwaysTrue() + assert NotEqualTo("a", below_int_min).bind(int_schema) is AlwaysTrue() def test_above_int_bounds_less_than(int_schema: Schema, above_int_max: Literal[int], below_int_min: Literal[int]) -> None: - assert LessThan[int]("a", above_int_max).bind(int_schema) is AlwaysTrue() - assert LessThan[int]("a", below_int_min).bind(int_schema) is AlwaysFalse() + assert LessThan("a", above_int_max).bind(int_schema) is AlwaysTrue() + assert LessThan("a", below_int_min).bind(int_schema) is AlwaysFalse() def test_above_int_bounds_less_than_or_equal( int_schema: Schema, above_int_max: Literal[int], below_int_min: Literal[int] ) -> None: - assert LessThanOrEqual[int]("a", above_int_max).bind(int_schema) is AlwaysTrue() - assert LessThanOrEqual[int]("a", below_int_min).bind(int_schema) is AlwaysFalse() + assert LessThanOrEqual("a", above_int_max).bind(int_schema) is AlwaysTrue() + assert LessThanOrEqual("a", below_int_min).bind(int_schema) is AlwaysFalse() def test_above_int_bounds_greater_than(int_schema: Schema, above_int_max: Literal[int], below_int_min: Literal[int]) -> None: - assert GreaterThan[int]("a", above_int_max).bind(int_schema) is AlwaysFalse() - assert GreaterThan[int]("a", below_int_min).bind(int_schema) is AlwaysTrue() + assert GreaterThan("a", above_int_max).bind(int_schema) is AlwaysFalse() + assert GreaterThan("a", below_int_min).bind(int_schema) is AlwaysTrue() def test_above_int_bounds_greater_than_or_equal( int_schema: Schema, above_int_max: Literal[int], below_int_min: Literal[int] ) -> None: - assert GreaterThanOrEqual[int]("a", above_int_max).bind(int_schema) is AlwaysFalse() - assert GreaterThanOrEqual[int]("a", below_int_min).bind(int_schema) is AlwaysTrue() + assert GreaterThanOrEqual("a", above_int_max).bind(int_schema) is AlwaysFalse() + assert GreaterThanOrEqual("a", below_int_min).bind(int_schema) is AlwaysTrue() @pytest.fixture @@ -1143,43 +1142,43 @@ def below_float_min() -> Literal[float]: def test_above_float_bounds_equal_to( float_schema: Schema, above_float_max: Literal[float], below_float_min: Literal[float] ) -> None: - assert EqualTo[float]("a", above_float_max).bind(float_schema) is AlwaysFalse() - assert EqualTo[float]("a", below_float_min).bind(float_schema) is AlwaysFalse() + assert EqualTo("a", above_float_max).bind(float_schema) is AlwaysFalse() + assert EqualTo("a", below_float_min).bind(float_schema) is AlwaysFalse() def test_above_float_bounds_not_equal_to( float_schema: Schema, above_float_max: Literal[float], below_float_min: Literal[float] ) -> None: - assert NotEqualTo[float]("a", above_float_max).bind(float_schema) is AlwaysTrue() - assert NotEqualTo[float]("a", below_float_min).bind(float_schema) is AlwaysTrue() + assert NotEqualTo("a", above_float_max).bind(float_schema) is AlwaysTrue() + assert NotEqualTo("a", below_float_min).bind(float_schema) is AlwaysTrue() def test_above_float_bounds_less_than( float_schema: Schema, above_float_max: Literal[float], below_float_min: Literal[float] ) -> None: - assert LessThan[float]("a", above_float_max).bind(float_schema) is AlwaysTrue() - assert LessThan[float]("a", below_float_min).bind(float_schema) is AlwaysFalse() + assert LessThan("a", above_float_max).bind(float_schema) is AlwaysTrue() + assert LessThan("a", below_float_min).bind(float_schema) is AlwaysFalse() def test_above_float_bounds_less_than_or_equal( float_schema: Schema, above_float_max: Literal[float], below_float_min: Literal[float] ) -> None: - assert LessThanOrEqual[float]("a", above_float_max).bind(float_schema) is AlwaysTrue() - assert LessThanOrEqual[float]("a", below_float_min).bind(float_schema) is AlwaysFalse() + assert LessThanOrEqual("a", above_float_max).bind(float_schema) is AlwaysTrue() + assert LessThanOrEqual("a", below_float_min).bind(float_schema) is AlwaysFalse() def test_above_float_bounds_greater_than( float_schema: Schema, above_float_max: Literal[float], below_float_min: Literal[float] ) -> None: - assert GreaterThan[float]("a", above_float_max).bind(float_schema) is AlwaysFalse() - assert GreaterThan[float]("a", below_float_min).bind(float_schema) is AlwaysTrue() + assert GreaterThan("a", above_float_max).bind(float_schema) is AlwaysFalse() + assert GreaterThan("a", below_float_min).bind(float_schema) is AlwaysTrue() def test_above_float_bounds_greater_than_or_equal( float_schema: Schema, above_float_max: Literal[float], below_float_min: Literal[float] ) -> None: - assert GreaterThanOrEqual[float]("a", above_float_max).bind(float_schema) is AlwaysFalse() - assert GreaterThanOrEqual[float]("a", below_float_min).bind(float_schema) is AlwaysTrue() + assert GreaterThanOrEqual("a", above_float_max).bind(float_schema) is AlwaysFalse() + assert GreaterThanOrEqual("a", below_float_min).bind(float_schema) is AlwaysTrue() @pytest.fixture @@ -1198,40 +1197,40 @@ def below_long_min() -> Literal[float]: def test_above_long_bounds_equal_to(long_schema: Schema, above_long_max: Literal[int], below_long_min: Literal[int]) -> None: - assert EqualTo[int]("a", above_long_max).bind(long_schema) is AlwaysFalse() - assert EqualTo[int]("a", below_long_min).bind(long_schema) is AlwaysFalse() + assert EqualTo("a", above_long_max).bind(long_schema) is AlwaysFalse() + assert EqualTo("a", below_long_min).bind(long_schema) is AlwaysFalse() def test_above_long_bounds_not_equal_to(long_schema: Schema, above_long_max: Literal[int], below_long_min: Literal[int]) -> None: - assert NotEqualTo[int]("a", above_long_max).bind(long_schema) is AlwaysTrue() - assert NotEqualTo[int]("a", below_long_min).bind(long_schema) is AlwaysTrue() + assert NotEqualTo("a", above_long_max).bind(long_schema) is AlwaysTrue() + assert NotEqualTo("a", below_long_min).bind(long_schema) is AlwaysTrue() def test_above_long_bounds_less_than(long_schema: Schema, above_long_max: Literal[int], below_long_min: Literal[int]) -> None: - assert LessThan[int]("a", above_long_max).bind(long_schema) is AlwaysTrue() - assert LessThan[int]("a", below_long_min).bind(long_schema) is AlwaysFalse() + assert LessThan("a", above_long_max).bind(long_schema) is AlwaysTrue() + assert LessThan("a", below_long_min).bind(long_schema) is AlwaysFalse() def test_above_long_bounds_less_than_or_equal( long_schema: Schema, above_long_max: Literal[int], below_long_min: Literal[int] ) -> None: - assert LessThanOrEqual[int]("a", above_long_max).bind(long_schema) is AlwaysTrue() - assert LessThanOrEqual[int]("a", below_long_min).bind(long_schema) is AlwaysFalse() + assert LessThanOrEqual("a", above_long_max).bind(long_schema) is AlwaysTrue() + assert LessThanOrEqual("a", below_long_min).bind(long_schema) is AlwaysFalse() def test_above_long_bounds_greater_than(long_schema: Schema, above_long_max: Literal[int], below_long_min: Literal[int]) -> None: - assert GreaterThan[int]("a", above_long_max).bind(long_schema) is AlwaysFalse() - assert GreaterThan[int]("a", below_long_min).bind(long_schema) is AlwaysTrue() + assert GreaterThan("a", above_long_max).bind(long_schema) is AlwaysFalse() + assert GreaterThan("a", below_long_min).bind(long_schema) is AlwaysTrue() def test_above_long_bounds_greater_than_or_equal( long_schema: Schema, above_long_max: Literal[int], below_long_min: Literal[int] ) -> None: - assert GreaterThanOrEqual[int]("a", above_long_max).bind(long_schema) is AlwaysFalse() - assert GreaterThanOrEqual[int]("a", below_long_min).bind(long_schema) is AlwaysTrue() + assert GreaterThanOrEqual("a", above_long_max).bind(long_schema) is AlwaysFalse() + assert GreaterThanOrEqual("a", below_long_min).bind(long_schema) is AlwaysTrue() -def test_eq_bound_expression(bound_reference_str: BoundReference[str]) -> None: +def test_eq_bound_expression(bound_reference_str: BoundReference) -> None: assert BoundEqualTo(term=bound_reference_str, literal=literal("a")) != BoundGreaterThanOrEqual( term=bound_reference_str, literal=literal("a") ) @@ -1275,14 +1274,14 @@ def test_bind_ambiguous_name() -> None: # |__/ |__/ -def _assert_literal_predicate_type(expr: LiteralPredicate[L]) -> None: - assert_type(expr, LiteralPredicate[L]) +def _assert_literal_predicate_type(expr: LiteralPredicate) -> None: + assert_type(expr, LiteralPredicate) _assert_literal_predicate_type(EqualTo("a", "b")) _assert_literal_predicate_type(In("a", ("a", "b", "c"))) _assert_literal_predicate_type(In("a", (1, 2, 3))) _assert_literal_predicate_type(NotIn("a", ("a", "b", "c"))) -assert_type(In("a", ("a", "b", "c")), In[str]) -assert_type(In("a", (1, 2, 3)), In[int]) -assert_type(NotIn("a", ("a", "b", "c")), NotIn[str]) +assert_type(In("a", ("a", "b", "c")), In) +assert_type(In("a", (1, 2, 3)), In) +assert_type(NotIn("a", ("a", "b", "c")), NotIn) diff --git a/tests/expressions/test_visitors.py b/tests/expressions/test_visitors.py index 2847859db5..798e9f641e 100644 --- a/tests/expressions/test_visitors.py +++ b/tests/expressions/test_visitors.py @@ -121,11 +121,11 @@ def visit_or(self, left_result: list[str], right_result: list[str]) -> list[str] self.visit_history.append("OR") return self.visit_history - def visit_unbound_predicate(self, predicate: UnboundPredicate[Any]) -> list[str]: + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> list[str]: self.visit_history.append(str(predicate.__class__.__name__).upper()) return self.visit_history - def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> list[str]: + def visit_bound_predicate(self, predicate: BoundPredicate) -> list[str]: self.visit_history.append(str(predicate.__class__.__name__).upper()) return self.visit_history @@ -139,51 +139,51 @@ class FooBoundBooleanExpressionVisitor(BoundBooleanExpressionVisitor[list[str]]) def __init__(self) -> None: self.visit_history: list[str] = [] - def visit_in(self, term: BoundTerm[Any], literals: set[Any]) -> list[str]: + def visit_in(self, term: BoundTerm, literals: set[Any]) -> list[str]: self.visit_history.append("IN") return self.visit_history - def visit_not_in(self, term: BoundTerm[Any], literals: set[Any]) -> list[str]: + def visit_not_in(self, term: BoundTerm, literals: set[Any]) -> list[str]: self.visit_history.append("NOT_IN") return self.visit_history - def visit_is_nan(self, term: BoundTerm[Any]) -> list[str]: + def visit_is_nan(self, term: BoundTerm) -> list[str]: self.visit_history.append("IS_NAN") return self.visit_history - def visit_not_nan(self, term: BoundTerm[Any]) -> list[str]: + def visit_not_nan(self, term: BoundTerm) -> list[str]: self.visit_history.append("NOT_NAN") return self.visit_history - def visit_is_null(self, term: BoundTerm[Any]) -> list[str]: + def visit_is_null(self, term: BoundTerm) -> list[str]: self.visit_history.append("IS_NULL") return self.visit_history - def visit_not_null(self, term: BoundTerm[Any]) -> list[str]: + def visit_not_null(self, term: BoundTerm) -> list[str]: self.visit_history.append("NOT_NULL") return self.visit_history - def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name + def visit_equal(self, term: BoundTerm, literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name self.visit_history.append("EQUAL") return self.visit_history - def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name + def visit_not_equal(self, term: BoundTerm, literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name self.visit_history.append("NOT_EQUAL") return self.visit_history - def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name + def visit_greater_than_or_equal(self, term: BoundTerm, literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name self.visit_history.append("GREATER_THAN_OR_EQUAL") return self.visit_history - def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name + def visit_greater_than(self, term: BoundTerm, literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name self.visit_history.append("GREATER_THAN") return self.visit_history - def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name + def visit_less_than(self, term: BoundTerm, literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name self.visit_history.append("LESS_THAN") return self.visit_history - def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name + def visit_less_than_or_equal(self, term: BoundTerm, literal: Literal[Any]) -> list[str]: # pylint: disable=redefined-outer-name self.visit_history.append("LESS_THAN_OR_EQUAL") return self.visit_history @@ -207,11 +207,11 @@ def visit_or(self, left_result: list[str], right_result: list[str]) -> list[str] self.visit_history.append("OR") return self.visit_history - def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> list[str]: + def visit_starts_with(self, term: BoundTerm, literal: Literal[Any]) -> list[str]: self.visit_history.append("STARTS_WITH") return self.visit_history - def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> list[str]: + def visit_not_starts_with(self, term: BoundTerm, literal: Literal[Any]) -> list[str]: self.visit_history.append("NOT_STARTS_WITH") return self.visit_history @@ -253,7 +253,7 @@ def test_boolean_expression_visit_raise_not_implemented_error() -> None: def test_bind_visitor_already_bound(table_schema_simple: Schema) -> None: - bound = BoundEqualTo[str]( + bound = BoundEqualTo( term=BoundReference(table_schema_simple.find_field(1), table_schema_simple.accessor_for_field(1)), literal=literal("hello"), ) @@ -315,7 +315,7 @@ def test_always_false_or_always_true_expression_binding(table_schema_simple: Sch ), {literal("foo"), literal("bar")}, ), - BoundIn[int]( + BoundIn( BoundReference( field=NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), accessor=Accessor(position=1, inner=None), @@ -345,7 +345,7 @@ def test_always_false_or_always_true_expression_binding(table_schema_simple: Sch {literal("bar"), literal("baz")}, ), And( - BoundEqualTo[int]( + BoundEqualTo( BoundReference( field=NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), accessor=Accessor(position=1, inner=None), @@ -365,7 +365,7 @@ def test_always_false_or_always_true_expression_binding(table_schema_simple: Sch ], ) def test_and_expression_binding( - unbound_and_expression: UnboundPredicate[Any], expected_bound_expression: BoundPredicate[Any], table_schema_simple: Schema + unbound_and_expression: UnboundPredicate, expected_bound_expression: BoundPredicate, table_schema_simple: Schema ) -> None: """Test that visiting an unbound AND expression with a bind-visitor returns the expected bound expression""" bound_expression = visit(unbound_and_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True)) @@ -388,7 +388,7 @@ def test_and_expression_binding( ), {literal("foo"), literal("bar")}, ), - BoundIn[int]( + BoundIn( BoundReference( field=NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), accessor=Accessor(position=1, inner=None), @@ -459,7 +459,7 @@ def test_and_expression_binding( ], ) def test_or_expression_binding( - unbound_or_expression: UnboundPredicate[Any], expected_bound_expression: BoundPredicate[Any], table_schema_simple: Schema + unbound_or_expression: UnboundPredicate, expected_bound_expression: BoundPredicate, table_schema_simple: Schema ) -> None: """Test that visiting an unbound OR expression with a bind-visitor returns the expected bound expression""" bound_expression = visit(unbound_or_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True)) @@ -505,7 +505,7 @@ def test_or_expression_binding( ], ) def test_in_expression_binding( - unbound_in_expression: UnboundPredicate[Any], expected_bound_expression: BoundPredicate[Any], table_schema_simple: Schema + unbound_in_expression: UnboundPredicate, expected_bound_expression: BoundPredicate, table_schema_simple: Schema ) -> None: """Test that visiting an unbound IN expression with a bind-visitor returns the expected bound expression""" bound_expression = visit(unbound_in_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True)) @@ -556,7 +556,7 @@ def test_in_expression_binding( ], ) def test_not_expression_binding( - unbound_not_expression: UnboundPredicate[Any], expected_bound_expression: BoundPredicate[Any], table_schema_simple: Schema + unbound_not_expression: UnboundPredicate, expected_bound_expression: BoundPredicate, table_schema_simple: Schema ) -> None: """Test that visiting an unbound NOT expression with a bind-visitor returns the expected bound expression""" bound_expression = visit(unbound_not_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True)) @@ -1590,16 +1590,16 @@ def test_to_dnf_not_and() -> None: def test_dnf_to_dask(table_schema_simple: Schema) -> None: expr = ( - BoundGreaterThan[str]( + BoundGreaterThan( term=BoundReference(table_schema_simple.find_field(1), table_schema_simple.accessor_for_field(1)), literal=literal("hello"), ), And( - BoundIn[int]( + BoundIn( term=BoundReference(table_schema_simple.find_field(2), table_schema_simple.accessor_for_field(2)), literals={literal(1), literal(2), literal(3)}, ), - BoundEqualTo[bool]( + BoundEqualTo( term=BoundReference(table_schema_simple.find_field(3), table_schema_simple.accessor_for_field(3)), literal=literal(True), ), diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 5758dbe4e5..3977dc2143 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -633,12 +633,12 @@ def test_list_type_to_pyarrow() -> None: @pytest.fixture -def bound_reference(table_schema_simple: Schema) -> BoundReference[str]: +def bound_reference(table_schema_simple: Schema) -> BoundReference: return BoundReference(table_schema_simple.find_field(1), table_schema_simple.accessor_for_field(1)) @pytest.fixture -def bound_double_reference() -> BoundReference[float]: +def bound_double_reference() -> BoundReference: schema = Schema( NestedField(field_id=1, name="foo", field_type=DoubleType(), required=False), schema_id=1, @@ -647,68 +647,68 @@ def bound_double_reference() -> BoundReference[float]: return BoundReference(schema.find_field(1), schema.accessor_for_field(1)) -def test_expr_is_null_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_is_null_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(BoundIsNull(bound_reference))) == "" ) -def test_expr_not_null_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_not_null_to_pyarrow(bound_reference: BoundReference) -> None: assert repr(expression_to_pyarrow(BoundNotNull(bound_reference))) == "" -def test_expr_is_nan_to_pyarrow(bound_double_reference: BoundReference[str]) -> None: +def test_expr_is_nan_to_pyarrow(bound_double_reference: BoundReference) -> None: assert repr(expression_to_pyarrow(BoundIsNaN(bound_double_reference))) == "" -def test_expr_not_nan_to_pyarrow(bound_double_reference: BoundReference[str]) -> None: +def test_expr_not_nan_to_pyarrow(bound_double_reference: BoundReference) -> None: assert repr(expression_to_pyarrow(BoundNotNaN(bound_double_reference))) == "" -def test_expr_equal_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_equal_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(BoundEqualTo(bound_reference, literal("hello")))) == '' ) -def test_expr_not_equal_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_not_equal_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(BoundNotEqualTo(bound_reference, literal("hello")))) == '' ) -def test_expr_greater_than_or_equal_equal_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_greater_than_or_equal_equal_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(BoundGreaterThanOrEqual(bound_reference, literal("hello")))) == '= "hello")>' ) -def test_expr_greater_than_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_greater_than_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(BoundGreaterThan(bound_reference, literal("hello")))) == ' "hello")>' ) -def test_expr_less_than_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_less_than_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(BoundLessThan(bound_reference, literal("hello")))) == '' ) -def test_expr_less_than_or_equal_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_less_than_or_equal_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(BoundLessThanOrEqual(bound_reference, literal("hello")))) == '' ) -def test_expr_in_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_in_to_pyarrow(bound_reference: BoundReference) -> None: assert repr(expression_to_pyarrow(BoundIn(bound_reference, {literal("hello"), literal("world")}))) in ( """ None: ) -def test_expr_not_in_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_not_in_to_pyarrow(bound_reference: BoundReference) -> None: assert repr(expression_to_pyarrow(BoundNotIn(bound_reference, {literal("hello"), literal("world")}))) in ( """ None: ) -def test_expr_starts_with_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_starts_with_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(BoundStartsWith(bound_reference, literal("he")))) == '' ) -def test_expr_not_starts_with_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_expr_not_starts_with_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(BoundNotStartsWith(bound_reference, literal("he")))) == '' ) -def test_and_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_and_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(And(BoundEqualTo(bound_reference, literal("hello")), BoundIsNull(bound_reference)))) == '' ) -def test_or_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_or_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(Or(BoundEqualTo(bound_reference, literal("hello")), BoundIsNull(bound_reference)))) == '' ) -def test_not_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_not_to_pyarrow(bound_reference: BoundReference) -> None: assert ( repr(expression_to_pyarrow(Not(BoundEqualTo(bound_reference, literal("hello"))))) == '' ) -def test_always_true_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_always_true_to_pyarrow(bound_reference: BoundReference) -> None: assert repr(expression_to_pyarrow(AlwaysTrue())) == "" -def test_always_false_to_pyarrow(bound_reference: BoundReference[str]) -> None: +def test_always_false_to_pyarrow(bound_reference: BoundReference) -> None: assert repr(expression_to_pyarrow(AlwaysFalse())) == "" diff --git a/tests/io/test_pyarrow_visitor.py b/tests/io/test_pyarrow_visitor.py index 9d5772d01c..59a4857699 100644 --- a/tests/io/test_pyarrow_visitor.py +++ b/tests/io/test_pyarrow_visitor.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=protected-access,unused-argument,redefined-outer-name import re -from typing import Any import pyarrow as pa import pytest @@ -717,21 +716,21 @@ def test_pyarrow_schema_round_trip_ensure_large_types_and_then_small_types(pyarr @pytest.fixture -def bound_reference_str() -> BoundReference[Any]: +def bound_reference_str() -> BoundReference: return BoundReference( field=NestedField(1, "string_field", StringType(), required=False), accessor=Accessor(position=0, inner=None) ) @pytest.fixture -def bound_reference_float() -> BoundReference[Any]: +def bound_reference_float() -> BoundReference: return BoundReference( field=NestedField(2, "float_field", FloatType(), required=False), accessor=Accessor(position=1, inner=None) ) @pytest.fixture -def bound_reference_double() -> BoundReference[Any]: +def bound_reference_double() -> BoundReference: return BoundReference( field=NestedField(3, "double_field", DoubleType(), required=False), accessor=Accessor(position=2, inner=None), @@ -739,32 +738,32 @@ def bound_reference_double() -> BoundReference[Any]: @pytest.fixture -def bound_eq_str_field(bound_reference_str: BoundReference[Any]) -> BoundEqualTo[Any]: +def bound_eq_str_field(bound_reference_str: BoundReference) -> BoundEqualTo: return BoundEqualTo(term=bound_reference_str, literal=literal("hello")) @pytest.fixture -def bound_greater_than_float_field(bound_reference_float: BoundReference[Any]) -> BoundGreaterThan[Any]: +def bound_greater_than_float_field(bound_reference_float: BoundReference) -> BoundGreaterThan: return BoundGreaterThan(term=bound_reference_float, literal=literal(100)) @pytest.fixture -def bound_is_nan_float_field(bound_reference_float: BoundReference[Any]) -> BoundIsNaN[Any]: +def bound_is_nan_float_field(bound_reference_float: BoundReference) -> BoundIsNaN: return BoundIsNaN(bound_reference_float) @pytest.fixture -def bound_eq_double_field(bound_reference_double: BoundReference[Any]) -> BoundEqualTo[Any]: +def bound_eq_double_field(bound_reference_double: BoundReference) -> BoundEqualTo: return BoundEqualTo(term=bound_reference_double, literal=literal(False)) @pytest.fixture -def bound_is_null_double_field(bound_reference_double: BoundReference[Any]) -> BoundIsNull[Any]: +def bound_is_null_double_field(bound_reference_double: BoundReference) -> BoundIsNull: return BoundIsNull(bound_reference_double) def test_collect_null_nan_unmentioned_terms( - bound_eq_str_field: BoundEqualTo[Any], bound_is_nan_float_field: BoundIsNaN[Any], bound_is_null_double_field: BoundIsNull[Any] + bound_eq_str_field: BoundEqualTo, bound_is_nan_float_field: BoundIsNaN, bound_is_null_double_field: BoundIsNull ) -> None: bound_expr = And( Or(And(bound_eq_str_field, bound_is_nan_float_field), bound_is_null_double_field), Not(bound_is_nan_float_field) @@ -786,11 +785,11 @@ def test_collect_null_nan_unmentioned_terms( def test_collect_null_nan_unmentioned_terms_with_multiple_predicates_on_the_same_term( - bound_eq_str_field: BoundEqualTo[Any], - bound_greater_than_float_field: BoundGreaterThan[Any], - bound_is_nan_float_field: BoundIsNaN[Any], - bound_eq_double_field: BoundEqualTo[Any], - bound_is_null_double_field: BoundIsNull[Any], + bound_eq_str_field: BoundEqualTo, + bound_greater_than_float_field: BoundGreaterThan, + bound_is_nan_float_field: BoundIsNaN, + bound_eq_double_field: BoundEqualTo, + bound_is_null_double_field: BoundIsNull, ) -> None: """Test a single term appears multiple places in the expression tree""" bound_expr = And( @@ -818,11 +817,11 @@ def test_collect_null_nan_unmentioned_terms_with_multiple_predicates_on_the_same def test_expression_to_complementary_pyarrow( - bound_eq_str_field: BoundEqualTo[Any], - bound_greater_than_float_field: BoundGreaterThan[Any], - bound_is_nan_float_field: BoundIsNaN[Any], - bound_eq_double_field: BoundEqualTo[Any], - bound_is_null_double_field: BoundIsNull[Any], + bound_eq_str_field: BoundEqualTo, + bound_greater_than_float_field: BoundGreaterThan, + bound_is_nan_float_field: BoundIsNaN, + bound_eq_double_field: BoundEqualTo, + bound_is_null_double_field: BoundIsNull, ) -> None: bound_expr = And( Or( diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 3d9bfcb555..88f53c51b0 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -91,7 +91,7 @@ YearTransform, parse_transform, ) -from pyiceberg.typedef import UTF8, L +from pyiceberg.typedef import UTF8 from pyiceberg.types import ( BinaryType, BooleanType, @@ -650,146 +650,146 @@ def test_datetime_transform_repr(transform: TimeTransform[Any], transform_repr: @pytest.fixture -def bound_reference_date() -> BoundReference[int]: +def bound_reference_date() -> BoundReference: return BoundReference(field=NestedField(1, "field", DateType(), required=False), accessor=Accessor(position=0, inner=None)) @pytest.fixture -def bound_reference_timestamp() -> BoundReference[int]: +def bound_reference_timestamp() -> BoundReference: return BoundReference( field=NestedField(1, "field", TimestampType(), required=False), accessor=Accessor(position=0, inner=None) ) @pytest.fixture -def bound_reference_decimal() -> BoundReference[Decimal]: +def bound_reference_decimal() -> BoundReference: return BoundReference( field=NestedField(1, "field", DecimalType(8, 2), required=False), accessor=Accessor(position=0, inner=None) ) @pytest.fixture -def bound_reference_int() -> BoundReference[int]: +def bound_reference_int() -> BoundReference: return BoundReference(field=NestedField(1, "field", IntegerType(), required=False), accessor=Accessor(position=0, inner=None)) @pytest.fixture -def bound_reference_long() -> BoundReference[int]: +def bound_reference_long() -> BoundReference: return BoundReference(field=NestedField(1, "field", LongType(), required=False), accessor=Accessor(position=0, inner=None)) -def test_projection_bucket_unary(bound_reference_str: BoundReference[str]) -> None: +def test_projection_bucket_unary(bound_reference_str: BoundReference) -> None: assert BucketTransform(2).project("name", BoundNotNull(term=bound_reference_str)) == NotNull(term=Reference(name="name")) -def test_projection_bucket_literal(bound_reference_str: BoundReference[str]) -> None: +def test_projection_bucket_literal(bound_reference_str: BoundReference) -> None: assert BucketTransform(2).project("name", BoundEqualTo(term=bound_reference_str, literal=literal("data"))) == EqualTo( term="name", literal=1 ) -def test_projection_bucket_set_same_bucket(bound_reference_str: BoundReference[str]) -> None: +def test_projection_bucket_set_same_bucket(bound_reference_str: BoundReference) -> None: assert BucketTransform(2).project( "name", BoundIn(term=bound_reference_str, literals={literal("hello"), literal("world")}) ) == EqualTo(term="name", literal=1) -def test_projection_bucket_set_in(bound_reference_str: BoundReference[str]) -> None: +def test_projection_bucket_set_in(bound_reference_str: BoundReference) -> None: assert BucketTransform(3).project( "name", BoundIn(term=bound_reference_str, literals={literal("hello"), literal("world")}) ) == In(term="name", literals={1, 2}) -def test_projection_bucket_set_not_in(bound_reference_str: BoundReference[str]) -> None: +def test_projection_bucket_set_not_in(bound_reference_str: BoundReference) -> None: assert ( BucketTransform(3).project("name", BoundNotIn(term=bound_reference_str, literals={literal("hello"), literal("world")})) is None ) -def test_projection_year_unary(bound_reference_date: BoundReference[int]) -> None: +def test_projection_year_unary(bound_reference_date: BoundReference) -> None: assert YearTransform().project("name", BoundNotNull(term=bound_reference_date)) == NotNull(term="name") -def test_projection_year_literal(bound_reference_date: BoundReference[int]) -> None: +def test_projection_year_literal(bound_reference_date: BoundReference) -> None: assert YearTransform().project("name", BoundEqualTo(term=bound_reference_date, literal=DateLiteral(1925))) == EqualTo( term="name", literal=5 ) -def test_projection_year_set_same_year(bound_reference_date: BoundReference[int]) -> None: +def test_projection_year_set_same_year(bound_reference_date: BoundReference) -> None: assert YearTransform().project( "name", BoundIn(term=bound_reference_date, literals={DateLiteral(1925), DateLiteral(1926)}) ) == EqualTo(term="name", literal=5) -def test_projection_year_set_in(bound_reference_date: BoundReference[int]) -> None: +def test_projection_year_set_in(bound_reference_date: BoundReference) -> None: assert YearTransform().project( "name", BoundIn(term=bound_reference_date, literals={DateLiteral(1925), DateLiteral(2925)}) ) == In(term="name", literals={8, 5}) -def test_projection_year_set_not_in(bound_reference_date: BoundReference[int]) -> None: +def test_projection_year_set_not_in(bound_reference_date: BoundReference) -> None: assert ( YearTransform().project("name", BoundNotIn(term=bound_reference_date, literals={DateLiteral(1925), DateLiteral(2925)})) is None ) -def test_projection_month_unary(bound_reference_date: BoundReference[int]) -> None: +def test_projection_month_unary(bound_reference_date: BoundReference) -> None: assert MonthTransform().project("name", BoundNotNull(term=bound_reference_date)) == NotNull(term="name") -def test_projection_month_literal(bound_reference_date: BoundReference[int]) -> None: +def test_projection_month_literal(bound_reference_date: BoundReference) -> None: assert MonthTransform().project("name", BoundEqualTo(term=bound_reference_date, literal=DateLiteral(1925))) == EqualTo( term="name", literal=63 ) -def test_projection_month_set_same_month(bound_reference_date: BoundReference[int]) -> None: +def test_projection_month_set_same_month(bound_reference_date: BoundReference) -> None: assert MonthTransform().project( "name", BoundIn(term=bound_reference_date, literals={DateLiteral(1925), DateLiteral(1926)}) ) == EqualTo(term="name", literal=63) -def test_projection_month_set_in(bound_reference_date: BoundReference[int]) -> None: +def test_projection_month_set_in(bound_reference_date: BoundReference) -> None: assert MonthTransform().project( "name", BoundIn(term=bound_reference_date, literals={DateLiteral(1925), DateLiteral(2925)}) ) == In(term="name", literals={96, 63}) -def test_projection_day_month_not_in(bound_reference_date: BoundReference[int]) -> None: +def test_projection_day_month_not_in(bound_reference_date: BoundReference) -> None: assert ( MonthTransform().project("name", BoundNotIn(term=bound_reference_date, literals={DateLiteral(1925), DateLiteral(2925)})) is None ) -def test_projection_day_unary(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_day_unary(bound_reference_timestamp: BoundReference) -> None: assert DayTransform().project("name", BoundNotNull(term=bound_reference_timestamp)) == NotNull(term="name") -def test_projection_day_literal(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_day_literal(bound_reference_timestamp: BoundReference) -> None: assert DayTransform().project( "name", BoundEqualTo(term=bound_reference_timestamp, literal=TimestampLiteral(1667696874000)) ) == EqualTo(term="name", literal=19) -def test_projection_day_set_same_day(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_day_set_same_day(bound_reference_timestamp: BoundReference) -> None: assert DayTransform().project( "name", BoundIn(term=bound_reference_timestamp, literals={TimestampLiteral(1667696874001), TimestampLiteral(1667696874000)}), ) == EqualTo(term="name", literal=19) -def test_projection_day_set_in(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_day_set_in(bound_reference_timestamp: BoundReference) -> None: assert DayTransform().project( "name", BoundIn(term=bound_reference_timestamp, literals={TimestampLiteral(1667696874001), TimestampLiteral(1567696874000)}), ) == In(term="name", literals={18, 19}) -def test_projection_day_set_not_in(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_day_set_not_in(bound_reference_timestamp: BoundReference) -> None: assert ( DayTransform().project( "name", @@ -799,7 +799,7 @@ def test_projection_day_set_not_in(bound_reference_timestamp: BoundReference[int ) -def test_projection_day_human(bound_reference_date: BoundReference[int]) -> None: +def test_projection_day_human(bound_reference_date: BoundReference) -> None: date_literal = DateLiteral(17532) assert DayTransform().project("dt", BoundEqualTo(term=bound_reference_date, literal=date_literal)) == EqualTo( term="dt", literal=17532 @@ -822,7 +822,7 @@ def test_projection_day_human(bound_reference_date: BoundReference[int]) -> None ) # >= 2018, 1, 2 -def test_projection_hour_unary(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_hour_unary(bound_reference_timestamp: BoundReference) -> None: assert HourTransform().project("name", BoundNotNull(term=bound_reference_timestamp)) == NotNull(term="name") @@ -830,13 +830,13 @@ def test_projection_hour_unary(bound_reference_timestamp: BoundReference[int]) - HOUR_IN_MICROSECONDS = 60 * 60 * 1000 * 1000 -def test_projection_hour_literal(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_hour_literal(bound_reference_timestamp: BoundReference) -> None: assert HourTransform().project( "name", BoundEqualTo(term=bound_reference_timestamp, literal=TimestampLiteral(TIMESTAMP_EXAMPLE)) ) == EqualTo(term="name", literal=463249) -def test_projection_hour_set_same_hour(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_hour_set_same_hour(bound_reference_timestamp: BoundReference) -> None: assert HourTransform().project( "name", BoundIn( @@ -846,7 +846,7 @@ def test_projection_hour_set_same_hour(bound_reference_timestamp: BoundReference ) == EqualTo(term="name", literal=463249) -def test_projection_hour_set_in(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_hour_set_in(bound_reference_timestamp: BoundReference) -> None: assert HourTransform().project( "name", BoundIn( @@ -856,7 +856,7 @@ def test_projection_hour_set_in(bound_reference_timestamp: BoundReference[int]) ) == In(term="name", literals={463249, 463250}) -def test_projection_hour_set_not_in(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_hour_set_not_in(bound_reference_timestamp: BoundReference) -> None: assert ( HourTransform().project( "name", @@ -869,17 +869,17 @@ def test_projection_hour_set_not_in(bound_reference_timestamp: BoundReference[in ) -def test_projection_identity_unary(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_identity_unary(bound_reference_timestamp: BoundReference) -> None: assert IdentityTransform().project("name", BoundNotNull(term=bound_reference_timestamp)) == NotNull(term="name") -def test_projection_identity_literal(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_identity_literal(bound_reference_timestamp: BoundReference) -> None: assert IdentityTransform().project( "name", BoundEqualTo(term=bound_reference_timestamp, literal=TimestampLiteral(TIMESTAMP_EXAMPLE)) ) == EqualTo(term="name", literal=TimestampLiteral(TIMESTAMP_EXAMPLE)) -def test_projection_identity_set_in(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_identity_set_in(bound_reference_timestamp: BoundReference) -> None: assert IdentityTransform().project( "name", BoundIn( @@ -892,7 +892,7 @@ def test_projection_identity_set_in(bound_reference_timestamp: BoundReference[in ) -def test_projection_identity_set_not_in(bound_reference_timestamp: BoundReference[int]) -> None: +def test_projection_identity_set_not_in(bound_reference_timestamp: BoundReference) -> None: assert IdentityTransform().project( "name", BoundNotIn( @@ -905,78 +905,78 @@ def test_projection_identity_set_not_in(bound_reference_timestamp: BoundReferenc ) -def test_projection_truncate_string_unary(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_unary(bound_reference_str: BoundReference) -> None: assert TruncateTransform(2).project("name", BoundNotNull(term=bound_reference_str)) == NotNull(term="name") -def test_projection_truncate_string_literal_eq(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_literal_eq(bound_reference_str: BoundReference) -> None: assert TruncateTransform(2).project("name", BoundEqualTo(term=bound_reference_str, literal=literal("data"))) == EqualTo( term="name", literal=literal("da") ) -def test_projection_truncate_string_literal_gt(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_literal_gt(bound_reference_str: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundGreaterThan(term=bound_reference_str, literal=literal("data")) ) == GreaterThanOrEqual(term="name", literal=literal("da")) -def test_projection_truncate_string_literal_gte(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_literal_gte(bound_reference_str: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundGreaterThanOrEqual(term=bound_reference_str, literal=literal("data")) ) == GreaterThanOrEqual(term="name", literal=literal("da")) -def test_projection_truncate_string_literal_lt(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_literal_lt(bound_reference_str: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundLessThan(term=bound_reference_str, literal=literal("data")) ) == LessThanOrEqual(term="name", literal=literal("da")) -def test_projection_truncate_string_literal_lte(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_literal_lte(bound_reference_str: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundLessThanOrEqual(term=bound_reference_str, literal=literal("data")) ) == LessThanOrEqual(term="name", literal=literal("da")) -def test_projection_truncate_string_set_same_result(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_set_same_result(bound_reference_str: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundIn(term=bound_reference_str, literals={literal("hello"), literal("helloworld")}) ) == EqualTo(term="name", literal=literal("he")) -def test_projection_truncate_string_set_in(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_set_in(bound_reference_str: BoundReference) -> None: assert TruncateTransform(3).project( "name", BoundIn(term=bound_reference_str, literals={literal("hello"), literal("world")}) ) == In(term="name", literals={literal("hel"), literal("wor")}) # codespell:ignore hel -def test_projection_truncate_string_set_not_in(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_set_not_in(bound_reference_str: BoundReference) -> None: assert ( TruncateTransform(3).project("name", BoundNotIn(term=bound_reference_str, literals={literal("hello"), literal("world")})) is None ) -def test_projection_truncate_decimal_literal_eq(bound_reference_decimal: BoundReference[Decimal]) -> None: +def test_projection_truncate_decimal_literal_eq(bound_reference_decimal: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundEqualTo(term=bound_reference_decimal, literal=DecimalLiteral(Decimal(19.25))) ) == EqualTo(term="name", literal=Decimal("19.24")) -def test_projection_truncate_decimal_literal_gt(bound_reference_decimal: BoundReference[Decimal]) -> None: +def test_projection_truncate_decimal_literal_gt(bound_reference_decimal: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundGreaterThan(term=bound_reference_decimal, literal=DecimalLiteral(Decimal(19.25))) ) == GreaterThanOrEqual(term="name", literal=Decimal("19.26")) -def test_projection_truncate_decimal_literal_gte(bound_reference_decimal: BoundReference[Decimal]) -> None: +def test_projection_truncate_decimal_literal_gte(bound_reference_decimal: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundGreaterThanOrEqual(term=bound_reference_decimal, literal=DecimalLiteral(Decimal(19.25))) ) == GreaterThanOrEqual(term="name", literal=Decimal("19.24")) -def test_projection_truncate_decimal_in(bound_reference_decimal: BoundReference[Decimal]) -> None: +def test_projection_truncate_decimal_in(bound_reference_decimal: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundIn(term=bound_reference_decimal, literals={literal(Decimal(19.25)), literal(Decimal(18.15))}) ) == In( @@ -988,25 +988,25 @@ def test_projection_truncate_decimal_in(bound_reference_decimal: BoundReference[ ) -def test_projection_truncate_long_literal_eq(bound_reference_decimal: BoundReference[Decimal]) -> None: +def test_projection_truncate_long_literal_eq(bound_reference_decimal: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundEqualTo(term=bound_reference_decimal, literal=DecimalLiteral(Decimal(19.25))) ) == EqualTo(term="name", literal=Decimal("19.24")) -def test_projection_truncate_long_literal_gt(bound_reference_decimal: BoundReference[Decimal]) -> None: +def test_projection_truncate_long_literal_gt(bound_reference_decimal: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundGreaterThan(term=bound_reference_decimal, literal=DecimalLiteral(Decimal(19.25))) ) == GreaterThanOrEqual(term="name", literal=Decimal("19.26")) -def test_projection_truncate_long_literal_gte(bound_reference_decimal: BoundReference[Decimal]) -> None: +def test_projection_truncate_long_literal_gte(bound_reference_decimal: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundGreaterThanOrEqual(term=bound_reference_decimal, literal=DecimalLiteral(Decimal(19.25))) ) == GreaterThanOrEqual(term="name", literal=Decimal("19.24")) -def test_projection_truncate_long_in(bound_reference_decimal: BoundReference[Decimal]) -> None: +def test_projection_truncate_long_in(bound_reference_decimal: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundIn(term=bound_reference_decimal, literals={DecimalLiteral(Decimal(19.25)), DecimalLiteral(Decimal(18.15))}) ) == In( @@ -1018,19 +1018,19 @@ def test_projection_truncate_long_in(bound_reference_decimal: BoundReference[Dec ) -def test_projection_truncate_string_starts_with(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_starts_with(bound_reference_str: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundStartsWith(term=bound_reference_str, literal=literal("hello")) ) == StartsWith(term="name", literal=literal("he")) -def test_projection_truncate_string_not_starts_with(bound_reference_str: BoundReference[str]) -> None: +def test_projection_truncate_string_not_starts_with(bound_reference_str: BoundReference) -> None: assert TruncateTransform(2).project( "name", BoundNotStartsWith(term=bound_reference_str, literal=literal("hello")) ) == NotStartsWith(term="name", literal=literal("he")) -def _test_projection(lhs: UnboundPredicate[L] | None, rhs: UnboundPredicate[L] | None) -> None: +def _test_projection(lhs: UnboundPredicate | None, rhs: UnboundPredicate | None) -> None: assert type(lhs) is type(lhs), f"Different classes: {type(lhs)} != {type(rhs)}" if lhs is None and rhs is None: # Both null @@ -1068,7 +1068,7 @@ def _assert_projection_strict( assert actual_human_str == expected_human_str -def test_month_projection_strict_epoch(bound_reference_date: BoundReference[int]) -> None: +def test_month_projection_strict_epoch(bound_reference_date: BoundReference) -> None: date = literal("1970-01-01").to(DateType()) transform = MonthTransform() _assert_projection_strict(BoundLessThan(term=bound_reference_date, literal=date), transform, LessThan, "1970-01") @@ -1091,7 +1091,7 @@ def test_month_projection_strict_epoch(bound_reference_date: BoundReference[int] ) -def test_month_projection_strict_lower_bound(bound_reference_date: BoundReference[int]) -> None: +def test_month_projection_strict_lower_bound(bound_reference_date: BoundReference) -> None: date = literal("2017-01-01").to(DateType()) # == 564 months since epoch transform = MonthTransform() @@ -1115,7 +1115,7 @@ def test_month_projection_strict_lower_bound(bound_reference_date: BoundReferenc ) -def test_negative_month_projection_strict_lower_bound(bound_reference_date: BoundReference[int]) -> None: +def test_negative_month_projection_strict_lower_bound(bound_reference_date: BoundReference) -> None: date = literal("1969-01-01").to(DateType()) # == 564 months since epoch transform = MonthTransform() @@ -1140,7 +1140,7 @@ def test_negative_month_projection_strict_lower_bound(bound_reference_date: Boun ) -def test_month_projection_strict_upper_bound(bound_reference_date: BoundReference[int]) -> None: +def test_month_projection_strict_upper_bound(bound_reference_date: BoundReference) -> None: date = literal("2017-12-31").to(DateType()) # == 564 months since epoch transform = MonthTransform() @@ -1164,7 +1164,7 @@ def test_month_projection_strict_upper_bound(bound_reference_date: BoundReferenc ) -def test_negative_month_projection_strict_upper_bound(bound_reference_date: BoundReference[int]) -> None: +def test_negative_month_projection_strict_upper_bound(bound_reference_date: BoundReference) -> None: date = literal("1969-12-31").to(DateType()) # == 564 months since epoch transform = MonthTransform() @@ -1188,7 +1188,7 @@ def test_negative_month_projection_strict_upper_bound(bound_reference_date: Boun ) -def test_day_strict(bound_reference_date: BoundReference[int]) -> None: +def test_day_strict(bound_reference_date: BoundReference) -> None: date = literal("2017-01-01").to(DateType()) transform = DayTransform() @@ -1216,7 +1216,7 @@ def test_day_strict(bound_reference_date: BoundReference[int]) -> None: ) -def test_day_negative_strict(bound_reference_date: BoundReference[int]) -> None: +def test_day_negative_strict(bound_reference_date: BoundReference) -> None: date = literal("1969-12-30").to(DateType()) transform = DayTransform() @@ -1244,7 +1244,7 @@ def test_day_negative_strict(bound_reference_date: BoundReference[int]) -> None: ) -def test_year_strict_lower_bound(bound_reference_date: BoundReference[int]) -> None: +def test_year_strict_lower_bound(bound_reference_date: BoundReference) -> None: date = literal("2017-01-01").to(DateType()) transform = YearTransform() @@ -1265,7 +1265,7 @@ def test_year_strict_lower_bound(bound_reference_date: BoundReference[int]) -> N ) -def test_negative_year_strict_lower_bound(bound_reference_date: BoundReference[int]) -> None: +def test_negative_year_strict_lower_bound(bound_reference_date: BoundReference) -> None: date = literal("1970-01-01").to(DateType()) transform = YearTransform() @@ -1289,7 +1289,7 @@ def test_negative_year_strict_lower_bound(bound_reference_date: BoundReference[i ) -def test_year_strict_upper_bound(bound_reference_date: BoundReference[int]) -> None: +def test_year_strict_upper_bound(bound_reference_date: BoundReference) -> None: date = literal("2017-12-31").to(DateType()) transform = YearTransform() @@ -1313,7 +1313,7 @@ def test_year_strict_upper_bound(bound_reference_date: BoundReference[int]) -> N ) -def test_negative_year_strict_upper_bound(bound_reference_date: BoundReference[int]) -> None: +def test_negative_year_strict_upper_bound(bound_reference_date: BoundReference) -> None: date = literal("2017-12-31").to(DateType()) transform = YearTransform() @@ -1330,7 +1330,7 @@ def test_negative_year_strict_upper_bound(bound_reference_date: BoundReference[i _assert_projection_strict(BoundIn(term=bound_reference_date, literals={date, another_date}), transform, NotIn) -def test_strict_bucket_integer(bound_reference_int: BoundReference[int]) -> None: +def test_strict_bucket_integer(bound_reference_int: BoundReference) -> None: value = literal(100).to(IntegerType()) transform = BucketTransform(num_buckets=10) @@ -1346,7 +1346,7 @@ def test_strict_bucket_integer(bound_reference_int: BoundReference[int]) -> None _assert_projection_strict(BoundIn(term=bound_reference_int, literals=literals), transform, AlwaysFalse) -def test_strict_bucket_long(bound_reference_long: BoundReference[int]) -> None: +def test_strict_bucket_long(bound_reference_long: BoundReference) -> None: value = literal(100).to(LongType()) transform = BucketTransform(num_buckets=10) @@ -1362,7 +1362,7 @@ def test_strict_bucket_long(bound_reference_long: BoundReference[int]) -> None: _assert_projection_strict(BoundIn(term=bound_reference_long, literals=literals), transform, AlwaysFalse) -def test_strict_bucket_decimal(bound_reference_decimal: BoundReference[int]) -> None: +def test_strict_bucket_decimal(bound_reference_decimal: BoundReference) -> None: dec = DecimalType(9, 2) value = literal("100.00").to(dec) transform = BucketTransform(num_buckets=10) @@ -1379,7 +1379,7 @@ def test_strict_bucket_decimal(bound_reference_decimal: BoundReference[int]) -> _assert_projection_strict(BoundIn(term=bound_reference_decimal, literals=literals), transform, AlwaysFalse) -def test_strict_bucket_string(bound_reference_str: BoundReference[int]) -> None: +def test_strict_bucket_string(bound_reference_str: BoundReference) -> None: value = literal("abcdefg").to(StringType()) transform = BucketTransform(num_buckets=10) @@ -1395,7 +1395,7 @@ def test_strict_bucket_string(bound_reference_str: BoundReference[int]) -> None: _assert_projection_strict(BoundIn(term=bound_reference_str, literals={value, other_value}), transform, AlwaysFalse) -def test_strict_bucket_bytes(bound_reference_binary: BoundReference[int]) -> None: +def test_strict_bucket_bytes(bound_reference_binary: BoundReference) -> None: value = literal(str.encode("abcdefg")).to(BinaryType()) transform = BucketTransform(num_buckets=10) @@ -1411,7 +1411,7 @@ def test_strict_bucket_bytes(bound_reference_binary: BoundReference[int]) -> Non _assert_projection_strict(BoundIn(term=bound_reference_binary, literals={value, other_value}), transform, AlwaysFalse) -def test_strict_bucket_uuid(bound_reference_uuid: BoundReference[int]) -> None: +def test_strict_bucket_uuid(bound_reference_uuid: BoundReference) -> None: value = literal("00000000-0000-007b-0000-0000000001c8").to(UUIDType()) transform = BucketTransform(num_buckets=10) @@ -1427,7 +1427,7 @@ def test_strict_bucket_uuid(bound_reference_uuid: BoundReference[int]) -> None: _assert_projection_strict(BoundIn(term=bound_reference_uuid, literals={value, other_value}), transform, AlwaysFalse) -def test_strict_identity_projection(bound_reference_long: BoundReference[int]) -> None: +def test_strict_identity_projection(bound_reference_long: BoundReference) -> None: transform: Transform[Any, Any] = IdentityTransform() predicates = [ BoundNotNull(term=bound_reference_long), @@ -1458,7 +1458,7 @@ def test_strict_identity_projection(bound_reference_long: BoundReference[int]) - ) -def test_truncate_strict_integer_lower_bound(bound_reference_int: BoundReference[int]) -> None: +def test_truncate_strict_integer_lower_bound(bound_reference_int: BoundReference) -> None: value = literal(100).to(IntegerType()) transform = TruncateTransform(10) @@ -1476,7 +1476,7 @@ def test_truncate_strict_integer_lower_bound(bound_reference_int: BoundReference _assert_projection_strict(BoundIn(term=bound_reference_int, literals={value_dec, value, value_inc}), transform, NotIn) -def test_truncate_strict_integer_upper_bound(bound_reference_int: BoundReference[int]) -> None: +def test_truncate_strict_integer_upper_bound(bound_reference_int: BoundReference) -> None: value = literal(99).to(IntegerType()) transform = TruncateTransform(10) @@ -1492,7 +1492,7 @@ def test_truncate_strict_integer_upper_bound(bound_reference_int: BoundReference _assert_projection_strict(BoundIn(term=bound_reference_int, literals=literals), transform, NotIn) -def test_truncate_strict_long_lower_bound(bound_reference_long: BoundReference[int]) -> None: +def test_truncate_strict_long_lower_bound(bound_reference_long: BoundReference) -> None: value = literal(100).to(IntegerType()) transform = TruncateTransform(10) @@ -1510,7 +1510,7 @@ def test_truncate_strict_long_lower_bound(bound_reference_long: BoundReference[i _assert_projection_strict(BoundIn(term=bound_reference_long, literals={value_dec, value, value_inc}), transform, NotIn) -def test_truncate_strict_long_upper_bound(bound_reference_long: BoundReference[int]) -> None: +def test_truncate_strict_long_upper_bound(bound_reference_long: BoundReference) -> None: value = literal(99).to(IntegerType()) transform = TruncateTransform(10) @@ -1528,7 +1528,7 @@ def test_truncate_strict_long_upper_bound(bound_reference_long: BoundReference[i _assert_projection_strict(BoundIn(term=bound_reference_long, literals={value_dec, value, value_inc}), transform, NotIn) -def test_truncate_strict_decimal_lower_bound(bound_reference_decimal: BoundReference[Decimal]) -> None: +def test_truncate_strict_decimal_lower_bound(bound_reference_decimal: BoundReference) -> None: dec = DecimalType(9, 2) value = literal("100.00").to(dec) transform = TruncateTransform(10) @@ -1549,7 +1549,7 @@ def test_truncate_strict_decimal_lower_bound(bound_reference_decimal: BoundRefer _assert_projection_strict(BoundIn(term=bound_reference_decimal, literals=literals), transform, NotIn) -def test_truncate_strict_decimal_upper_bound(bound_reference_decimal: BoundReference[Decimal]) -> None: +def test_truncate_strict_decimal_upper_bound(bound_reference_decimal: BoundReference) -> None: dec = DecimalType(9, 2) value = literal("99.99").to(dec) transform = TruncateTransform(10) @@ -1570,7 +1570,7 @@ def test_truncate_strict_decimal_upper_bound(bound_reference_decimal: BoundRefer _assert_projection_strict(BoundIn(term=bound_reference_decimal, literals=literals), transform, NotIn) -def test_string_strict(bound_reference_str: BoundReference[str]) -> None: +def test_string_strict(bound_reference_str: BoundReference) -> None: value = literal("abcdefg").to(StringType()) transform: Transform[Any, Any] = TruncateTransform(width=5) @@ -1585,7 +1585,7 @@ def test_string_strict(bound_reference_str: BoundReference[str]) -> None: _assert_projection_strict(BoundIn(term=bound_reference_str, literals={value, other_value}), transform, NotIn) -def test_strict_binary(bound_reference_binary: BoundReference[str]) -> None: +def test_strict_binary(bound_reference_binary: BoundReference) -> None: value = literal(b"abcdefg").to(BinaryType()) transform: Transform[Any, Any] = TruncateTransform(width=5)