diff --git a/pyproject.toml b/pyproject.toml index 677ef57..c1ddd51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ authors = [{name = "Substrait contributors", email = "substrait@googlegroups.com license = {text = "Apache-2.0"} readme = "README.md" requires-python = ">=3.10" -dependencies = ["protobuf >=3.19.1,<6"] +dependencies = ["protobuf >=3.19.1,<6", "typing_extensions"] dynamic = ["version"] [tool.setuptools_scm] diff --git a/src/substrait/builders/plan.py b/src/substrait/builders/plan.py index a4a2180..caae375 100644 --- a/src/substrait/builders/plan.py +++ b/src/substrait/builders/plan.py @@ -5,8 +5,8 @@ See `examples/builder_example.py` for usage. """ -from typing import Iterable, Optional, Union, Callable +from typing import Callable, Iterable, Optional, TypedDict, Union import substrait.gen.proto.algebra_pb2 as stalg from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension import substrait.gen.proto.plan_pb2 as stp @@ -20,16 +20,23 @@ from substrait.type_inference import infer_plan_schema from substrait.utils import ( merge_extension_declarations, - merge_extension_urns, merge_extension_uris, + merge_extension_urns, ) UnboundPlan = Callable[[ExtensionRegistry], stp.Plan] PlanOrUnbound = Union[stp.Plan, UnboundPlan] +_ExtensionDict = TypedDict( + "_ExtensionDict", + {"extension_uris": list, "extension_urns": list, "extensions": list}, +) + -def _merge_extensions(*objs): +def _merge_extensions( + *objs, +) -> _ExtensionDict: """Merge extension URIs, URNs, and declarations from multiple plan/expression objects. During the URI -> URN migration period, we maintain both URI and URN references diff --git a/src/substrait/builders/type.py b/src/substrait/builders/type.py index 39ed5e6..c4877fe 100644 --- a/src/substrait/builders/type.py +++ b/src/substrait/builders/type.py @@ -221,6 +221,14 @@ def precision_timestamp_tz(precision: int, nullable=True) -> stt.Type: ) ) +def timestamp(nullable=True) -> stt.Type: + return stt.Type( + timestamp=stt.Type.Timestamp( + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED, + ) + ) def struct(types: Iterable[stt.Type], nullable=True) -> stt.Type: return stt.Type( diff --git a/src/substrait/derivation_expression.py b/src/substrait/derivation_expression.py index f4d18d7..d3dc0ac 100644 --- a/src/substrait/derivation_expression.py +++ b/src/substrait/derivation_expression.py @@ -1,5 +1,7 @@ from typing import Optional -from antlr4 import InputStream, CommonTokenStream + +from antlr4 import CommonTokenStream, InputStream + from substrait.gen.antlr.SubstraitTypeLexer import SubstraitTypeLexer from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser from substrait.gen.proto.type_pb2 import Type @@ -9,8 +11,9 @@ def _evaluate(x, values: dict): if isinstance(x, SubstraitTypeParser.BinaryExprContext): left = _evaluate(x.left, values) right = _evaluate(x.right, values) - - if x.op.text == "+": + if x.op is None: + raise Exception("Undefined operator op") + elif x.op.text == "+": return left + right elif x.op.text == "-": return left - right @@ -65,22 +68,122 @@ def _evaluate(x, values: dict): return Type(fp64=Type.FP64(nullability=nullability)) elif isinstance(scalar_type, SubstraitTypeParser.BooleanContext): return Type(bool=Type.Boolean(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.StringContext): + return Type(string=Type.String(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.TimestampContext): + return Type(timestamp=Type.Timestamp(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.DateContext): + return Type(date=Type.Date(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.IntervalYearContext): + return Type(interval_year=Type.IntervalYear(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.UuidContext): + return Type(uuid=Type.UUID(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.BinaryContext): + return Type(binary=Type.Binary(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.TimeContext): + return Type(time=Type.Time(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.TimestampTzContext): + return Type(timestamp_tz=Type.TimestampTZ(nullability=nullability)) else: raise Exception(f"Unknown scalar type {type(scalar_type)}") elif parametrized_type: + nullability = ( + Type.NULLABILITY_NULLABLE + if parametrized_type.isnull + else Type.NULLABILITY_REQUIRED + ) if isinstance(parametrized_type, SubstraitTypeParser.DecimalContext): precision = _evaluate(parametrized_type.precision, values) scale = _evaluate(parametrized_type.scale, values) - nullability = ( - Type.NULLABILITY_NULLABLE - if parametrized_type.isnull - else Type.NULLABILITY_REQUIRED - ) return Type( decimal=Type.Decimal( precision=precision, scale=scale, nullability=nullability ) ) + elif isinstance(parametrized_type, SubstraitTypeParser.VarCharContext): + length = _evaluate(parametrized_type.length, values) + return Type( + varchar=Type.VarChar( + length=length, + nullability=nullability, + ) + ) + elif isinstance(parametrized_type, SubstraitTypeParser.FixedCharContext): + length = _evaluate(parametrized_type.length, values) + return Type( + fixed_char=Type.FixedChar( + length=length, + nullability=nullability, + ) + ) + elif isinstance(parametrized_type, SubstraitTypeParser.FixedBinaryContext): + length = _evaluate(parametrized_type.length, values) + return Type( + fixed_binary=Type.FixedBinary( + length=length, + nullability=nullability, + ) + ) + elif isinstance( + parametrized_type, SubstraitTypeParser.PrecisionTimestampContext + ): + precision = _evaluate(parametrized_type.precision, values) + return Type( + precision_timestamp=Type.PrecisionTimestamp( + precision=precision, + nullability=nullability, + ) + ) + elif isinstance( + parametrized_type, SubstraitTypeParser.PrecisionTimestampTZContext + ): + precision = _evaluate(parametrized_type.precision, values) + return Type( + precision_timestamp_tz=Type.PrecisionTimestampTZ( + precision=precision, + nullability=nullability, + ) + ) + elif isinstance(parametrized_type, SubstraitTypeParser.IntervalYearContext): + return Type( + interval_year=Type.IntervalYear( + nullability=nullability, + ) + ) + elif isinstance(parametrized_type, SubstraitTypeParser.StructContext): + types = list( + map(lambda x: _evaluate(x, values), parametrized_type.expr()) + ) + return Type( + struct=Type.Struct( + types=types, + nullability=nullability, + ) + ) + elif isinstance(parametrized_type, SubstraitTypeParser.ListContext): + child_type = _evaluate(parametrized_type.expr(), values) + return Type( + list=Type.List( + type=child_type, + nullability=nullability, + ) + ) + + elif isinstance(parametrized_type, SubstraitTypeParser.MapContext): + return Type( + map=Type.Map( + key=_evaluate(parametrized_type.key, values), + value=_evaluate(parametrized_type.value, values), + nullability=nullability, + ) + ) + elif isinstance(parametrized_type, SubstraitTypeParser.NStructContext): + # it gives me a parser error i may have to update the parser + # string `evaluate("NSTRUCT")` from the docs https://substrait.io/types/type_classes/ + # line 1:17 extraneous input ':' + raise NotImplementedError("Named structure type not implemented yet") + # elif isinstance(parametrized_type, SubstraitTypeParser.UserDefinedContext): + raise Exception(f"Unknown parametrized type {type(parametrized_type)}") elif any_type: any_var = any_type.AnyVar() diff --git a/src/substrait/extension_registry.py b/src/substrait/extension_registry.py index c854c02..9791b3f 100644 --- a/src/substrait/extension_registry.py +++ b/src/substrait/extension_registry.py @@ -1,17 +1,19 @@ -import yaml import itertools import re -from substrait.gen.proto.type_pb2 import Type -from importlib.resources import files as importlib_files from collections import defaultdict +from importlib.resources import files as importlib_files from pathlib import Path from typing import Optional, Union -from .derivation_expression import evaluate, _evaluate, _parse + +import yaml + from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser from substrait.gen.json import simple_extensions as se +from substrait.gen.proto.type_pb2 import Type from substrait.simple_extension_utils import build_simple_extensions -from .bimap import UriUrnBiDiMap +from .bimap import UriUrnBiDiMap +from .derivation_expression import _evaluate, _parse, evaluate DEFAULT_URN_PREFIX = "https://github.com/substrait-io/substrait/blob/main/extensions" @@ -68,21 +70,24 @@ def normalize_substrait_type_names(typ: str) -> str: raise Exception(f"Unrecognized substrait type {typ}") -def violates_integer_option(actual: int, option, parameters: dict): +def violates_integer_option(actual: int, option, parameters: dict, subset=False): + option_numeric = None if isinstance(option, SubstraitTypeParser.NumericLiteralContext): - return actual != int(str(option.Number())) + option_numeric = int(str(option.Number())) elif isinstance(option, SubstraitTypeParser.NumericParameterNameContext): parameter_name = str(option.Identifier()) - if parameter_name in parameters and parameters[parameter_name] != actual: - return True - else: + + if parameter_name not in parameters: parameters[parameter_name] = actual + option_numeric = parameters[parameter_name] else: raise Exception( f"Input should be either NumericLiteralContext or NumericParameterNameContext, got {type(option)} instead" ) - - return False + if subset: + return actual < option_numeric + else: + return actual != option_numeric def types_equal(type1: Type, type2: Type, check_nullability=False): @@ -112,6 +117,27 @@ def handle_parameter_cover( return True +def _check_nullability(check_nullability, parameterized_type, covered, kind) -> bool: + if not check_nullability: + return True + # The ANTLR context stores a Token called ``isnull`` – it is + # present when the type is declared as nullable. + nullability = ( + Type.Nullability.NULLABILITY_NULLABLE + if getattr(parameterized_type, "isnull", None) is not None + else Type.Nullability.NULLABILITY_REQUIRED + ) + # if nullability == Type.Nullability.NULLABILITY_NULLABLE: + # return True # is still true even if the covered is required + # The protobuf message stores its own enum – we compare the two. + covered_nullability = getattr( + getattr(covered, kind), # e.g. covered.varchar + "nullability", + None, + ) + return nullability == covered_nullability + + def covers( covered: Type, covering: SubstraitTypeParser.TypeLiteralContext, @@ -123,7 +149,6 @@ def covers( return handle_parameter_cover( covered, parameter_name, parameters, check_nullability ) - covering: SubstraitTypeParser.TypeDefContext = covering.typeDef() any_type: SubstraitTypeParser.AnyTypeContext = covering.anyType() @@ -142,31 +167,99 @@ def covers( parameterized_type = covering.parameterizedType() if parameterized_type: - if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext): - if covered.WhichOneof("kind") != "decimal": + kind = covered.WhichOneof("kind") + if isinstance(parameterized_type, SubstraitTypeParser.VarCharContext): + if kind != "varchar": + return False + if hasattr(parameterized_type, "length") and violates_integer_option( + covered.varchar.length, parameterized_type.length, parameters + ): return False - nullability = ( - Type.NULLABILITY_NULLABLE - if parameterized_type.isnull - else Type.NULLABILITY_REQUIRED + return _check_nullability( + check_nullability, parameterized_type, covered, kind ) - - if ( - check_nullability - and nullability - != covered.__getattribute__(covered.WhichOneof("kind")).nullability + if isinstance(parameterized_type, SubstraitTypeParser.FixedCharContext): + if kind != "fixed_char": + return False + if hasattr(parameterized_type, "length") and violates_integer_option( + covered.fixed_char.length, parameterized_type.length, parameters ): return False + return _check_nullability( + check_nullability, parameterized_type, covered, kind + ) + if isinstance(parameterized_type, SubstraitTypeParser.FixedBinaryContext): + if kind != "fixed_binary": + return False + if hasattr(parameterized_type, "length") and violates_integer_option( + covered.fixed_binary.length, parameterized_type.length, parameters + ): + return False + # return True + return _check_nullability( + check_nullability, parameterized_type, covered, kind + ) + if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext): + if kind != "decimal": + return False + if not _check_nullability( + check_nullability, parameterized_type, covered, kind + ): + return False + # precision / scale are both optional – a missing value means “no limit”. + covered_scale = getattr(covered.decimal, "scale", 0) + param_scale = getattr(parameterized_type, "scale", 0) + covered_prec = getattr(covered.decimal, "precision", 0) + param_prec = getattr(parameterized_type, "precision", 0) return not ( - violates_integer_option( - covered.decimal.scale, parameterized_type.scale, parameters - ) - or violates_integer_option( - covered.decimal.precision, parameterized_type.precision, parameters - ) + violates_integer_option(covered_scale, param_scale, parameters) + or violates_integer_option(covered_prec, param_prec, parameters) ) + if isinstance( + parameterized_type, SubstraitTypeParser.PrecisionTimestampContext + ): + if kind != "precision_timestamp": + return False + if not _check_nullability( + check_nullability, parameterized_type, covered, kind + ): + return False + # return True + covered_prec = getattr(covered.precision_timestamp, "precision", 0) + param_prec = getattr(parameterized_type, "precision", 0) + return not violates_integer_option(covered_prec, param_prec, parameters) + + if isinstance( + parameterized_type, SubstraitTypeParser.PrecisionTimestampTZContext + ): + if kind != "precision_timestamp_tz": + return False + if not _check_nullability( + check_nullability, parameterized_type, covered, kind + ): + return False + # return True + covered_prec = getattr(covered.precision_timestamp_tz, "precision", 0) + param_prec = getattr(parameterized_type, "precision", 0) + return not violates_integer_option(covered_prec, param_prec, parameters) + + kind_mapping = { + SubstraitTypeParser.ListContext: "list", + SubstraitTypeParser.MapContext: "map", + SubstraitTypeParser.StructContext: "struct", + SubstraitTypeParser.UserDefinedContext: "user_defined", + SubstraitTypeParser.PrecisionIntervalDayContext: "interval_day", + } + + for ctx_cls, expected_kind in kind_mapping.items(): + if isinstance(parameterized_type, ctx_cls): + if kind != expected_kind: + return False + return _check_nullability( + check_nullability, parameterized_type, covered, kind + ) else: raise Exception(f"Unhandled type {type(parameterized_type)}") @@ -199,7 +292,7 @@ def __init__( def __repr__(self) -> str: return f"{self.name}:{'_'.join(self.normalized_inputs)}" - def satisfies_signature(self, signature: tuple) -> Optional[str]: + def satisfies_signature(self, signature: tuple | list) -> Optional[str]: if self.impl.variadic: min_args_allowed = self.impl.variadic.min or 0 if len(signature) < min_args_allowed: @@ -231,14 +324,12 @@ def satisfies_signature(self, signature: tuple) -> Optional[str]: output_type = evaluate(self.impl.return_, parameters) if self.nullability == se.NullabilityHandling.MIRROR: - sig_contains_nullable = any( - [ - p.__getattribute__(p.WhichOneof("kind")).nullability - == Type.NULLABILITY_NULLABLE - for p in signature - if isinstance(p, Type) - ] - ) + sig_contains_nullable = any([ + p.__getattribute__(p.WhichOneof("kind")).nullability + == Type.NULLABILITY_NULLABLE + for p in signature + if isinstance(p, Type) + ]) output_type.__getattribute__(output_type.WhichOneof("kind")).nullability = ( Type.NULLABILITY_NULLABLE if sig_contains_nullable @@ -326,7 +417,7 @@ def register_extension_dict(self, definitions: dict, uri: str) -> None: # TODO add an optional return type check def lookup_function( - self, urn: str, function_name: str, signature: tuple + self, urn: str, function_name: str, signature: tuple[Type] | list[Type] ) -> Optional[tuple[FunctionEntry, Type]]: if ( urn not in self._function_mapping diff --git a/src/substrait/gen/json/simple_extensions.py b/src/substrait/gen/json/simple_extensions.py index 2885bb4..e323ac5 100644 --- a/src/substrait/gen/json/simple_extensions.py +++ b/src/substrait/gen/json/simple_extensions.py @@ -7,13 +7,15 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union +from typing_extensions import TypeAlias + class Functions(Enum): INHERITS = 'INHERITS' SEPARATE = 'SEPARATE' -Type = Union[str, Dict[str, Any]] +Type: TypeAlias = Union[str, Dict[str, Any]] class Type1(Enum): @@ -24,7 +26,7 @@ class Type1(Enum): string = 'string' -EnumOptions = List[str] +EnumOptions: TypeAlias = List[str] @dataclass @@ -49,7 +51,7 @@ class TypeArg: description: Optional[str] = None -Arguments = List[Union[EnumerationArg, ValueArg, TypeArg]] +Arguments: TypeAlias = List[Union[EnumerationArg, ValueArg, TypeArg]] @dataclass @@ -58,7 +60,7 @@ class Options1: description: Optional[str] = None -Options = Dict[str, Options1] +Options: TypeAlias = Dict[str, Options1] class ParameterConsistency(Enum): @@ -73,10 +75,10 @@ class VariadicBehavior: parameterConsistency: Optional[ParameterConsistency] = None -Deterministic = bool +Deterministic: TypeAlias = bool -SessionDependent = bool +SessionDependent: TypeAlias = bool class NullabilityHandling(Enum): @@ -85,13 +87,13 @@ class NullabilityHandling(Enum): DISCRETE = 'DISCRETE' -ReturnValue = Type +ReturnValue: TypeAlias = Type -Implementation = Dict[str, str] +Implementation: TypeAlias = Dict[str, str] -Intermediate = Type +Intermediate: TypeAlias = Type class Decomposable(Enum): @@ -100,10 +102,10 @@ class Decomposable(Enum): MANY = 'MANY' -Maxset = float +Maxset: TypeAlias = float -Ordered = bool +Ordered: TypeAlias = bool @dataclass @@ -196,7 +198,7 @@ class TypeParamDef: optional: Optional[bool] = None -TypeParamDefs = List[TypeParamDef] +TypeParamDefs: TypeAlias = List[TypeParamDef] @dataclass diff --git a/tests/test_extension_registry.py b/tests/test_extension_registry.py index f9d63bd..342e8da 100644 --- a/tests/test_extension_registry.py +++ b/tests/test_extension_registry.py @@ -4,6 +4,11 @@ from substrait.gen.proto.type_pb2 import Type from substrait.extension_registry import ExtensionRegistry, covers from substrait.derivation_expression import _parse +from substrait.builders.type import ( + i8, + i16, + decimal, +) content = """%YAML 1.2 --- @@ -104,10 +109,19 @@ value: decimal nullability: DISCRETE return: decimal? + - name: "equal_test" + impls: + - args: + - name: x + value: any + - name: y + value: any + nullability: DISCRETE + return: any """ -registry = ExtensionRegistry() +registry = ExtensionRegistry(load_default_extensions=True) registry.register_extension_dict( yaml.safe_load(content), @@ -115,52 +129,11 @@ ) -def i8(nullable=False): - return Type( - i8=Type.I8( - nullability=Type.NULLABILITY_REQUIRED - if not nullable - else Type.NULLABILITY_NULLABLE - ) - ) - - -def i16(nullable=False): - return Type( - i16=Type.I16( - nullability=Type.NULLABILITY_REQUIRED - if not nullable - else Type.NULLABILITY_NULLABLE - ) - ) - - -def bool(nullable=False): - return Type( - bool=Type.Boolean( - nullability=Type.NULLABILITY_REQUIRED - if not nullable - else Type.NULLABILITY_NULLABLE - ) - ) - - -def decimal(precision, scale, nullable=False): - return Type( - decimal=Type.Decimal( - scale=scale, - precision=precision, - nullability=Type.NULLABILITY_REQUIRED - if not nullable - else Type.NULLABILITY_NULLABLE, - ) - ) - def test_non_existing_urn(): assert ( registry.lookup_function( - urn="non_existent", function_name="add", signature=[i8(), i8()] + urn="non_existent", function_name="add", signature=[i8(nullable=False), i8(nullable=False)] ) is None ) @@ -169,7 +142,8 @@ def test_non_existing_urn(): def test_non_existing_function(): assert ( registry.lookup_function( - urn="extension:test:functions", function_name="sub", signature=[i8(), i8()] + + urn="extension:test:functions", function_name="sub", signature=[i8(nullable=False), i8(nullable=False)] ) is None ) @@ -178,7 +152,7 @@ def test_non_existing_function(): def test_non_existing_function_signature(): assert ( registry.lookup_function( - urn="extension:test:functions", function_name="add", signature=[i8()] + urn="extension:test:functions", function_name="add", signature=[i8(nullable=False)] ) is None ) @@ -186,7 +160,7 @@ def test_non_existing_function_signature(): def test_exact_match(): assert registry.lookup_function( - urn="extension:test:functions", function_name="add", signature=[i8(), i8()] + urn="extension:test:functions", function_name="add", signature=[i8(nullable=False), i8(nullable=False)] )[1] == Type(i8=Type.I8(nullability=Type.NULLABILITY_REQUIRED)) @@ -194,7 +168,7 @@ def test_wildcard_match(): assert registry.lookup_function( urn="extension:test:functions", function_name="add", - signature=[i8(), i8(), bool()], + signature=[i8(nullable=False), i8(nullable=False), bool()], )[1] == Type(i16=Type.I16(nullability=Type.NULLABILITY_REQUIRED)) @@ -203,7 +177,7 @@ def test_wildcard_match_fails_with_constraits(): registry.lookup_function( urn="extension:test:functions", function_name="add", - signature=[i8(), i16(), i16()], + signature=[i8(nullable=False), i16(nullable=False), i16(nullable=False)], ) is None ) @@ -214,9 +188,9 @@ def test_wildcard_match_with_constraits(): registry.lookup_function( urn="extension:test:functions", function_name="add", - signature=[i16(), i16(), i8()], + signature=[i16(nullable=False), i16(nullable=False), i8(nullable=False)], )[1] - == i8() + == i8(nullable=False) ) @@ -225,9 +199,9 @@ def test_variadic(): registry.lookup_function( urn="extension:test:functions", function_name="test_fn", - signature=[i8(), i8(), i8()], + signature=[i8(nullable=False), i8(nullable=False), i8(nullable=False)], )[1] - == i8() + == i8(nullable=False) ) @@ -236,16 +210,17 @@ def test_variadic_any(): registry.lookup_function( urn="extension:test:functions", function_name="test_fn_variadic_any", - signature=[i16(), i16(), i16()], + signature=[i16(nullable=False), i16(nullable=False), i16(nullable=False)], )[1] - == i16() + == i16(nullable=False) ) def test_variadic_fails_min_constraint(): assert ( registry.lookup_function( - urn="extension:test:functions", function_name="test_fn", signature=[i8()] + + urn="extension:test:functions", function_name="test_fn", signature=[i8(nullable=False)] ) is None ) @@ -255,8 +230,8 @@ def test_decimal_happy_path(): assert registry.lookup_function( urn="extension:test:functions", function_name="test_decimal", - signature=[decimal(10, 8), decimal(8, 6)], - )[1] == decimal(11, 7) + signature=[decimal(8, 10, nullable=False), decimal(6, 8, nullable=False)], + )[1] == decimal(7, 11, nullable=False) def test_decimal_violates_constraint(): @@ -264,7 +239,7 @@ def test_decimal_violates_constraint(): registry.lookup_function( urn="extension:test:functions", function_name="test_decimal", - signature=[decimal(10, 8), decimal(12, 10)], + signature=[decimal(8, 10, nullable=False), decimal(10, 12, nullable=False)], ) is None ) @@ -274,8 +249,8 @@ def test_decimal_happy_path_discrete(): assert registry.lookup_function( urn="extension:test:functions", function_name="test_decimal_discrete", - signature=[decimal(10, 8, nullable=True), decimal(8, 6)], - )[1] == decimal(11, 7, nullable=True) + signature=[decimal(8, 10, nullable=True), decimal(6, 8, nullable=False)], + )[1] == decimal(7, 11, nullable=True) def test_enum_with_valid_option(): @@ -283,9 +258,9 @@ def test_enum_with_valid_option(): registry.lookup_function( urn="extension:test:functions", function_name="test_enum", - signature=["FLIP", i8()], + signature=["FLIP", i8(nullable=False)], )[1] - == i8() + == i8(nullable=False) ) @@ -294,7 +269,7 @@ def test_enum_with_nonexistent_option(): registry.lookup_function( urn="extension:test:functions", function_name="test_enum", - signature=["NONEXISTENT", i8()], + signature=["NONEXISTENT", i8(nullable=False)], ) is None ) @@ -304,7 +279,7 @@ def test_function_with_nullable_args(): assert registry.lookup_function( urn="extension:test:functions", function_name="add", - signature=[i8(nullable=True), i8()], + signature=[i8(nullable=True), i8(nullable=False)], )[1] == i8(nullable=True) @@ -312,7 +287,7 @@ def test_function_with_declared_output_nullability(): assert registry.lookup_function( urn="extension:test:functions", function_name="add_declared", - signature=[i8(), i8()], + signature=[i8(nullable=False), i8(nullable=False)], )[1] == i8(nullable=True) @@ -320,7 +295,7 @@ def test_function_with_discrete_nullability(): assert registry.lookup_function( urn="extension:test:functions", function_name="add_discrete", - signature=[i8(nullable=True), i8()], + signature=[i8(nullable=True), i8(nullable=False)], )[1] == i8(nullable=True) @@ -329,7 +304,7 @@ def test_function_with_discrete_nullability_nonexisting(): registry.lookup_function( urn="extension:test:functions", function_name="add_discrete", - signature=[i8(), i8()], + signature=[i8(nullable=False), i8(nullable=False)], ) is None ) @@ -337,7 +312,7 @@ def test_function_with_discrete_nullability_nonexisting(): def test_covers(): params = {} - assert covers(i8(), _parse("i8"), params) + assert covers(i8(nullable=False), _parse("i8"), params) assert params == {} @@ -346,18 +321,127 @@ def test_covers_nullability(): assert covers(i8(nullable=True), _parse("i8?"), {}, check_nullability=True) -def test_covers_decimal(): - assert not covers(decimal(10, 8), _parse("decimal<11, A>"), {}) +def test_covers_decimal(nullable=False): + assert not covers(decimal(8, 10), _parse("decimal<11, A>"), {}) def test_covers_decimal_happy_path(): params = {} - assert covers(decimal(10, 8), _parse("decimal<10, A>"), params) + assert covers(decimal(8, 10), _parse("decimal<10, A>"), params) assert params == {"A": 8} def test_covers_any(): - assert covers(decimal(10, 8), _parse("any"), {}) + assert covers(decimal(8, 10), _parse("any"), {}) + + +def test_covers_varchar_length_ok(): + covered = Type( + varchar=Type.VarChar(nullability=Type.NULLABILITY_REQUIRED, length=15) + ) + param_ctx = _parse("varchar<15>") + assert covers(covered, param_ctx, {}, check_nullability=True) + + +def test_covers_varchar_length_fail(): + covered = Type( + varchar=Type.VarChar(nullability=Type.NULLABILITY_REQUIRED, length=10) + ) + param_ctx = _parse("varchar<5>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_varchar_nullability(): + covered = Type( + varchar=Type.VarChar(nullability=Type.NULLABILITY_REQUIRED, length=10) + ) + param_tx = _parse("varchar?<10>") + assert covers(covered, param_tx, {}) + assert not covers(covered, param_tx, {}, True) + param_ctx2 = _parse("varchar<10>") + assert covers(covered, param_ctx2, {}, True) + + +def test_covers_fixed_char_length_ok(): + covered = Type( + fixed_char=Type.FixedChar(nullability=Type.NULLABILITY_REQUIRED, length=8) + ) + param_ctx = _parse("fixedchar<8>") + assert covers(covered, param_ctx, {}) + + +def test_covers_fixed_char_length_fail(): + covered = Type( + fixed_char=Type.FixedChar(nullability=Type.NULLABILITY_REQUIRED, length=8) + ) + param_ctx = _parse("fixedchar<4>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_fixed_binary_length_ok(): + covered = Type( + fixed_binary=Type.FixedBinary(nullability=Type.NULLABILITY_REQUIRED, length=16) + ) + param_ctx = _parse("fixedbinary<16>") + assert covers(covered, param_ctx, {}) + + +def test_covers_fixed_binary_length_fail(): + covered = Type( + fixed_binary=Type.FixedBinary(nullability=Type.NULLABILITY_REQUIRED, length=16) + ) + param_ctx = _parse("fixedbinary<10>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_decimal_precision_scale_fail(): + covered = decimal(8, 10, nullable=False) + param_ctx = _parse("decimal<6, 5>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_precision_timestamp_ok(): + covered = Type( + precision_timestamp=Type.PrecisionTimestamp( + nullability=Type.NULLABILITY_REQUIRED, precision=5 + ) + ) + param_ctx = _parse("precision_timestamp<5>") + assert covers(covered, param_ctx, {}) + param_ctx = _parse("precision_timestamp") + assert covers(covered, param_ctx, {}) + + +def test_covers_precision_timestamp_fail(): + covered = Type( + precision_timestamp=Type.PrecisionTimestamp( + nullability=Type.NULLABILITY_REQUIRED, precision=3 + ) + ) + param_ctx = _parse("precision_timestamp<2>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_precision_timestamp_tz_ok(): + covered = Type( + precision_timestamp_tz=Type.PrecisionTimestampTZ( + nullability=Type.NULLABILITY_REQUIRED, precision=4 + ) + ) + param_ctx = _parse("precision_timestamp_tz<4>") + assert covers(covered, param_ctx, {}) + param_ctx = _parse("precision_timestamp_tz") + assert covers(covered, param_ctx, {}) + + +def test_covers_precision_timestamp_tz_fail(): + covered = Type( + precision_timestamp_tz=Type.PrecisionTimestampTZ( + nullability=Type.NULLABILITY_REQUIRED, precision=4 + ) + ) + param_ctx = _parse("precision_timestamp_tz<3>") + assert not covers(covered, param_ctx, {}) def test_registry_uri_urn(): @@ -488,4 +572,4 @@ def test_register_requires_uri(): # During migration, URI is required - this should fail with TypeError with pytest.raises(TypeError): - registry.register_extension_dict(yaml.safe_load(content)) + registry.register_extension_dict(yaml.safe_load(content)) \ No newline at end of file