diff --git a/src/substrait/builders/type.py b/src/substrait/builders/type.py index 92b560a..9e715a4 100644 --- a/src/substrait/builders/type.py +++ b/src/substrait/builders/type.py @@ -223,6 +223,7 @@ def precision_timestamp_tz(precision: int, nullable=True) -> stt.Type: ) + def struct(types: Iterable[stt.Type], nullable=True) -> stt.Type: return stt.Type( struct=stt.Type.Struct( diff --git a/src/substrait/dataframe/__init__.py b/src/substrait/dataframe/__init__.py index 8f2271b..34b5a4f 100644 --- a/src/substrait/dataframe/__init__.py +++ b/src/substrait/dataframe/__init__.py @@ -11,6 +11,7 @@ def col(name: str) -> Expression: """Column selection.""" return Expression(column(name)) + # TODO handle str_as_lit argument def parse_into_expr(expr, str_as_lit: bool): return expr._to_compliant_expr(substrait.dataframe) diff --git a/src/substrait/dataframe/dataframe.py b/src/substrait/dataframe/dataframe.py index f5d25d2..916c0e7 100644 --- a/src/substrait/dataframe/dataframe.py +++ b/src/substrait/dataframe/dataframe.py @@ -29,7 +29,7 @@ def select( expressions = [e.expr for e in exprs] + [ expr.alias(alias).expr for alias, expr in named_exprs.items() ] - return DataFrame(select(self.plan, expressions=expressions)) + return DataFrame(select(self.plan, expressions=expressions)) # TODO handle version def _with_version(self, version): diff --git a/src/substrait/dataframe/expression.py b/src/substrait/dataframe/expression.py index 011b625..00196b2 100644 --- a/src/substrait/dataframe/expression.py +++ b/src/substrait/dataframe/expression.py @@ -2,7 +2,7 @@ UnboundExtendedExpression, ExtendedExpressionOrUnbound, resolve_expression, - scalar_function + scalar_function, ) import substrait.gen.proto.type_pb2 as stp import substrait.gen.proto.extended_expression_pb2 as stee @@ -30,7 +30,9 @@ def __init__(self, expr: UnboundExtendedExpression): def alias(self, alias: str): self.expr = _alias(self.expr, alias) return self - + def abs(self): - self.expr = scalar_function("functions_arithmetic.yaml", "abs", expressions=[self.expr]) + self.expr = scalar_function( + "functions_arithmetic.yaml", "abs", expressions=[self.expr] + ) return self diff --git a/src/substrait/extension_registry.py b/src/substrait/extension_registry.py deleted file mode 100644 index 2770efd..0000000 --- a/src/substrait/extension_registry.py +++ /dev/null @@ -1,490 +0,0 @@ -import itertools -import re -from collections import defaultdict -from importlib.resources import files as importlib_files -from pathlib import Path -from typing import Optional, Union - -import yaml - -from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser -from substrait.gen.json import simple_extensions as se -from substrait.gen.proto.type_pb2 import Type -from substrait.simple_extension_utils import build_simple_extensions - -from .bimap import UriUrnBiDiMap -from .derivation_expression import _evaluate, _parse, evaluate - -DEFAULT_URN_PREFIX = "https://github.com/substrait-io/substrait/blob/main/extensions" - -# Format: extension:: -# Example: extension:io.substrait:functions_arithmetic -URN_PATTERN = re.compile(r"^extension:[^:]+:[^:]+$") - - -# mapping from argument types to shortened signature names: https://substrait.io/extensions/#function-signature-compound-names -_normalized_key_names = { - "i8": "i8", - "i16": "i16", - "i32": "i32", - "i64": "i64", - "fp32": "fp32", - "fp64": "fp64", - "string": "str", - "binary": "vbin", - "boolean": "bool", - "timestamp": "ts", - "timestamp_tz": "tstz", - "date": "date", - "time": "time", - "interval_year": "iyear", - "interval_day": "iday", - "interval_compound": "icompound", - "uuid": "uuid", - "fixedchar": "fchar", - "varchar": "vchar", - "fixedbinary": "fbin", - "decimal": "dec", - "precision_time": "pt", - "precision_timestamp": "pts", - "precision_timestamp_tz": "ptstz", - "struct": "struct", - "list": "list", - "map": "map", -} - - -def normalize_substrait_type_names(typ: str) -> str: - # Strip type specifiers - typ = typ.split("<")[0] - # First strip nullability marker - typ = typ.strip("?").lower() - - if typ.startswith("any"): - return "any" - elif typ.startswith("u!"): - return typ - elif typ in _normalized_key_names: - return _normalized_key_names[typ] - else: - raise Exception(f"Unrecognized substrait type {typ}") - - -def violates_integer_option(actual: int, option, parameters: dict): - option_numeric = None - if isinstance(option, SubstraitTypeParser.NumericLiteralContext): - option_numeric = int(str(option.Number())) - elif isinstance(option, SubstraitTypeParser.NumericParameterNameContext): - parameter_name = str(option.Identifier()) - - if parameter_name not in parameters: - parameters[parameter_name] = actual - option_numeric = parameters[parameter_name] - else: - raise Exception( - f"Input should be either NumericLiteralContext or NumericParameterNameContext, got {type(option)} instead" - ) - return actual != option_numeric - - -def types_equal(type1: Type, type2: Type, check_nullability=False): - if check_nullability: - return type1 == type2 - else: - x, y = Type(), Type() - x.CopyFrom(type1) - y.CopyFrom(type2) - x.__getattribute__( - x.WhichOneof("kind") - ).nullability = Type.Nullability.NULLABILITY_UNSPECIFIED - y.__getattribute__( - y.WhichOneof("kind") - ).nullability = Type.Nullability.NULLABILITY_UNSPECIFIED - return x == y - - -def handle_parameter_cover( - covered: Type, parameter_name: str, parameters: dict, check_nullability: bool -): - if parameter_name in parameters: - covering = parameters[parameter_name] - return types_equal(covering, covered, check_nullability) - else: - parameters[parameter_name] = covered - return True - - -def _check_nullability(check_nullability, parameterized_type, covered, kind) -> bool: - if not check_nullability: - return True - # The ANTLR context stores a Token called ``isnull`` – it is - # present when the type is declared as nullable. - nullability = ( - Type.Nullability.NULLABILITY_NULLABLE - if getattr(parameterized_type, "isnull", None) is not None - else Type.Nullability.NULLABILITY_REQUIRED - ) - # if nullability == Type.Nullability.NULLABILITY_NULLABLE: - # return True # is still true even if the covered is required - # The protobuf message stores its own enum – we compare the two. - covered_nullability = getattr( - getattr(covered, kind), # e.g. covered.varchar - "nullability", - None, - ) - return nullability == covered_nullability - - -def covers( - covered: Type, - covering: SubstraitTypeParser.TypeLiteralContext, - parameters: dict, - check_nullability=False, -): - if isinstance(covering, SubstraitTypeParser.ParameterNameContext): - parameter_name = str(covering.Identifier()) - return handle_parameter_cover( - covered, parameter_name, parameters, check_nullability - ) - covering: SubstraitTypeParser.TypeDefContext = covering.typeDef() - - any_type: SubstraitTypeParser.AnyTypeContext = covering.anyType() - if any_type: - if any_type.AnyVar(): - return handle_parameter_cover( - covered, any_type.AnyVar().symbol.text, parameters, check_nullability - ) - else: - return True - - scalar_type = covering.scalarType() - if scalar_type: - covering = _evaluate(covering, {}) - return types_equal(covering, covered, check_nullability) - - parameterized_type = covering.parameterizedType() - if parameterized_type: - return _cover_parametrized_type( - covered, parameterized_type, parameters, check_nullability - ) - - -def check_violates_integer_option_parameters( - covered, parameterized_type, attributes, parameters -): - for attr in attributes: - if not hasattr(covered, attr) and not hasattr(parameterized_type, attr): - return False - covered_attr = getattr(covered, attr) - param_attr = getattr(parameterized_type, attr) - if violates_integer_option(covered_attr, param_attr, parameters): - return True - return False - - -def _cover_parametrized_type( - covered: Type, - parameterized_type: SubstraitTypeParser.ParameterizedTypeContext, - parameters: dict, - check_nullability=False, -): - kind = covered.WhichOneof("kind") - - if not _check_nullability(check_nullability, parameterized_type, covered, kind): - return False - - if isinstance(parameterized_type, SubstraitTypeParser.VarCharContext): - return kind == "varchar" and not check_violates_integer_option_parameters( - covered.varchar, parameterized_type, ["length"], parameters - ) - - if isinstance(parameterized_type, SubstraitTypeParser.FixedCharContext): - return kind == "fixed_char" and not check_violates_integer_option_parameters( - covered.fixed_char, parameterized_type, ["length"], parameters - ) - - if isinstance(parameterized_type, SubstraitTypeParser.FixedBinaryContext): - return kind == "fixed_binary" and not check_violates_integer_option_parameters( - covered.fixed_binary, parameterized_type, ["length"], parameters - ) - - if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext): - return kind == "decimal" and not check_violates_integer_option_parameters( - covered.decimal, parameterized_type, ["scale", "precision"], parameters - ) - - if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampContext): - return ( - kind == "precision_timestamp" - and not check_violates_integer_option_parameters( - covered.precision_timestamp, - parameterized_type, - ["precision"], - parameters, - ) - ) - - if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampTZContext): - return ( - kind == "precision_timestamp_tz" - and not check_violates_integer_option_parameters( - covered.precision_timestamp_tz, - parameterized_type, - ["precision"], - parameters, - ) - ) - - if isinstance(parameterized_type, SubstraitTypeParser.ListContext): - return kind == "list" and covers( - covered.list.type, - parameterized_type.expr(), - parameters, - check_nullability, - ) - - if isinstance(parameterized_type, SubstraitTypeParser.MapContext): - return ( - kind == "map" - and covers( - covered.map.key, parameterized_type.key, parameters, check_nullability - ) - and covers( - covered.map.value, - parameterized_type.value, - parameters, - check_nullability, - ) - ) - - if isinstance(parameterized_type, SubstraitTypeParser.StructContext): - if kind != "struct": - return False - covered_types = covered.struct.types - param_types = parameterized_type.expr() or [] - if not isinstance(param_types, list): - param_types = [param_types] - if len(covered_types) != len(param_types): - return False - for covered_field, param_field_ctx in zip(covered_types, param_types): - if not covers( - covered_field, - param_field_ctx, - parameters, - check_nullability, # type: ignore - ): - return False - return True - - raise Exception(f"Unhandled type {type(parameterized_type)}") - - -class FunctionEntry: - def __init__( - self, urn: str, name: str, impl: Union[se.Impl, se.Impl1, se.Impl2], anchor: int - ) -> None: - self.name = name - self.impl = impl - self.normalized_inputs: list = [] - self.urn: str = urn - self.anchor = anchor - self.arguments = [] - self.nullability = ( - impl.nullability if impl.nullability else se.NullabilityHandling.MIRROR - ) - - if impl.args: - for arg in impl.args: - if isinstance(arg, se.ValueArg): - self.arguments.append(_parse(arg.value)) - self.normalized_inputs.append( - normalize_substrait_type_names(arg.value) - ) - elif isinstance(arg, se.EnumerationArg): - self.arguments.append(arg.options) - self.normalized_inputs.append("req") - - def __repr__(self) -> str: - return f"{self.name}:{'_'.join(self.normalized_inputs)}" - - def satisfies_signature(self, signature: tuple) -> Optional[str]: - if self.impl.variadic: - min_args_allowed = self.impl.variadic.min or 0 - if len(signature) < min_args_allowed: - return None - inputs = [self.arguments[0]] * len(signature) - else: - inputs = self.arguments - if len(inputs) != len(signature): - return None - - zipped_args = list(zip(inputs, signature)) - - parameters = {} - - for x, y in zipped_args: - if isinstance(y, str): - if y not in x: - return None - else: - if not covers( - y, - x, - parameters, - check_nullability=self.nullability - == se.NullabilityHandling.DISCRETE, - ): - return None - - output_type = evaluate(self.impl.return_, parameters) - - if self.nullability == se.NullabilityHandling.MIRROR: - sig_contains_nullable = any( - [ - p.__getattribute__(p.WhichOneof("kind")).nullability - == Type.NULLABILITY_NULLABLE - for p in signature - if isinstance(p, Type) - ] - ) - output_type.__getattribute__(output_type.WhichOneof("kind")).nullability = ( - Type.NULLABILITY_NULLABLE - if sig_contains_nullable - else Type.NULLABILITY_REQUIRED - ) - - return output_type - - -class ExtensionRegistry: - def __init__(self, load_default_extensions=True) -> None: - self._urn_mapping: dict = defaultdict(dict) # URN -> anchor ID - # NOTE: during the URI -> URN migration, we only need an id generator for URN. We can use the same anchor for plan construction for URIs. - self._urn_id_generator = itertools.count(1) - - self._function_mapping: dict = defaultdict(dict) - self._id_generator = itertools.count(1) - - # Bidirectional URI <-> URN mapping (temporary during migration) - self._uri_urn_bimap = UriUrnBiDiMap() - - if load_default_extensions: - for fpath in importlib_files("substrait.extensions").glob( # type: ignore - "functions*.yaml" - ): - # Derive URI from DEFAULT_URN_PREFIX and filename - uri = f"{DEFAULT_URN_PREFIX}/{fpath.name}" - self.register_extension_yaml(fpath, uri=uri) - - def register_extension_yaml( - self, - fname: Union[str, Path], - uri: str, - ) -> None: - """Register extensions from a YAML file. - - Args: - fname: Path to the YAML file - uri: URI for the extension (this is required during the URI -> URN migration) - """ - fname = Path(fname) - with open(fname) as f: # type: ignore - extension_definitions = yaml.safe_load(f) - - self.register_extension_dict(extension_definitions, uri=uri) - - def register_extension_dict(self, definitions: dict, uri: str) -> None: - """Register extensions from a dictionary (parsed YAML). - - Args: - definitions: The extension definitions dictionary - uri: URI for the extension (for URI/URN bimap) - """ - urn = definitions.get("urn") - if not urn: - raise ValueError("Extension definitions must contain a 'urn' field") - - self._validate_urn_format(urn) - - self._urn_mapping[urn] = next(self._urn_id_generator) - - self._uri_urn_bimap.put(uri, urn) - - simple_extensions = build_simple_extensions(definitions) - - functions = ( - (simple_extensions.scalar_functions or []) - + (simple_extensions.aggregate_functions or []) - + (simple_extensions.window_functions or []) - ) - - if functions: - for function in functions: - for impl in function.impls: - func = FunctionEntry( - urn, function.name, impl, next(self._id_generator) - ) - if ( - func.urn in self._function_mapping - and function.name in self._function_mapping[func.urn] - ): - self._function_mapping[func.urn][function.name].append(func) - else: - self._function_mapping[func.urn][function.name] = [func] - - # TODO add an optional return type check - def lookup_function( - self, urn: str, function_name: str, signature: tuple - ) -> Optional[tuple[FunctionEntry, Type]]: - if ( - urn not in self._function_mapping - or function_name not in self._function_mapping[urn] - ): - return None - functions = self._function_mapping[urn][function_name] - for f in functions: - assert isinstance(f, FunctionEntry) - rtn = f.satisfies_signature(signature) - if rtn is not None: - return (f, rtn) - - return None - - def lookup_urn(self, urn: str) -> Optional[int]: - return self._urn_mapping.get(urn, None) - - def lookup_uri_anchor(self, uri: str) -> Optional[int]: - """Look up the anchor ID for a URI. - - During the migration period, URI and URN share the same anchor. - This method converts the URI to its URN and returns the URN's anchor. - - Args: - uri: The extension URI to look up - - Returns: - The anchor ID for the URI (same as its corresponding URN), or None if not found - """ - urn = self._uri_urn_bimap.get_urn(uri) - if urn: - return self._urn_mapping.get(urn) - return None - - def _validate_urn_format(self, urn: str) -> None: - """Validate that a URN follows the expected format. - - Expected format: extension:: - Example: extension:io.substrait:functions_arithmetic - - Args: - urn: The URN to validate - - Raises: - ValueError: If the URN format is invalid - """ - if not URN_PATTERN.match(urn): - raise ValueError( - f"Invalid URN format: '{urn}'. " - f"Expected format: extension:: " - f"(e.g., 'extension:io.substrait:functions_arithmetic')" - ) diff --git a/src/substrait/extension_registry/__init__.py b/src/substrait/extension_registry/__init__.py new file mode 100644 index 0000000..6024a0f --- /dev/null +++ b/src/substrait/extension_registry/__init__.py @@ -0,0 +1,25 @@ +"""Extension Registry module.""" + +from .registry import ExtensionRegistry +from .function_entry import FunctionEntry, FunctionType +from .signature_checker_helpers import ( + normalize_substrait_type_names, + _check_integer_constraint, + types_equal, + _bind_type_parameter, + covers, +) +from .exceptions import UnrecognizedSubstraitTypeError, UnhandledParameterizedTypeError + +__all__ = [ + "ExtensionRegistry", + "FunctionEntry", + "FunctionType", + "normalize_substrait_type_names", + "_check_integer_constraint", + "types_equal", + "_bind_type_parameter", + "covers", + "UnrecognizedSubstraitTypeError", + "UnhandledParameterizedTypeError", +] diff --git a/src/substrait/extension_registry/exceptions.py b/src/substrait/extension_registry/exceptions.py new file mode 100644 index 0000000..37b05be --- /dev/null +++ b/src/substrait/extension_registry/exceptions.py @@ -0,0 +1,12 @@ +# Custom exceptions +# +class UnrecognizedSubstraitTypeError(Exception): + """Raised when an unrecognized Substrait type is encountered.""" + + pass + + +class UnhandledParameterizedTypeError(Exception): + """Raised when an unhandled ANTLR parameterized type context is encountered.""" + + pass diff --git a/src/substrait/extension_registry/function_entry.py b/src/substrait/extension_registry/function_entry.py new file mode 100644 index 0000000..aae66ab --- /dev/null +++ b/src/substrait/extension_registry/function_entry.py @@ -0,0 +1,96 @@ +"""Function entry class for extension registry.""" + +from enum import Enum +from typing import Optional, Union + +from substrait.gen.json import simple_extensions as se +from substrait.derivation_expression import _parse, evaluate +from substrait.gen.proto.type_pb2 import Type + +from .signature_checker_helpers import normalize_substrait_type_names, covers + + +class FunctionType(Enum): + SCALAR = "scalar" + AGGREGATE = "aggregate" + WINDOW = "window" + + +class FunctionEntry: + def __init__( + self, + urn: str, + name: str, + impl: Union[se.Impl, se.Impl1, se.Impl2], + anchor: int, + function_type: FunctionType = FunctionType.SCALAR, + ) -> None: + self.name = name + self.impl = impl + self.normalized_inputs: list = [] + self.urn: str = urn + self.anchor = anchor + self.function_type = function_type + self.arguments = [] + self.nullability = ( + impl.nullability if impl.nullability else se.NullabilityHandling.MIRROR + ) + if impl.args: + for arg in impl.args: + if isinstance(arg, se.ValueArg): + self.arguments.append(_parse(arg.value)) + self.normalized_inputs.append( + normalize_substrait_type_names(arg.value) + ) + elif isinstance(arg, se.EnumerationArg): + self.arguments.append(arg.options) + self.normalized_inputs.append("req") + + def __repr__(self) -> str: + return f"{self.name}:{'_'.join(self.normalized_inputs)}" + + def satisfies_signature(self, signature: tuple | list) -> Optional[str]: + if self.impl.variadic: + min_args_allowed = self.impl.variadic.min or 0 + if len(signature) < min_args_allowed: + return None + inputs = [self.arguments[0]] * len(signature) + else: + inputs = self.arguments + if len(inputs) != len(signature): + return None + zipped_args = list(zip(inputs, signature)) + parameters = {} + for x, y in zipped_args: + if isinstance(y, str): + if y not in x: + return None + else: + if not covers( + y, + x, + parameters, + check_nullability=self.nullability + == se.NullabilityHandling.DISCRETE, + ): + return None + output_type = evaluate(self.impl.return_, parameters) + if self.nullability == se.NullabilityHandling.MIRROR and isinstance( + output_type, Type + ): + sig_contains_nullable = any( + [ + p.__getattribute__(p.WhichOneof("kind")).nullability + == Type.NULLABILITY_NULLABLE + for p in signature + if isinstance(p, Type) + ] + ) + kind = output_type.WhichOneof("kind") + if kind is not None: + output_type.__getattribute__(kind).nullability = ( + Type.NULLABILITY_NULLABLE + if sig_contains_nullable + else Type.NULLABILITY_REQUIRED + ) + return output_type diff --git a/src/substrait/extension_registry/registry.py b/src/substrait/extension_registry/registry.py new file mode 100644 index 0000000..e8bbc67 --- /dev/null +++ b/src/substrait/extension_registry/registry.py @@ -0,0 +1,184 @@ +"""Extension Registry class.""" + +import re +import itertools +from collections import defaultdict +from importlib.resources import files as importlib_files +from pathlib import Path +from typing import Optional, Union + +import yaml + +from substrait.bimap import UriUrnBiDiMap +from substrait.gen.proto.type_pb2 import Type +from substrait.simple_extension_utils import build_simple_extensions + +from .function_entry import FunctionEntry, FunctionType + + +# Constants +DEFAULT_URN_PREFIX = "https://github.com/substrait-io/substrait/blob/main/extensions" +# Format: extension:: +# Example: extension:io.substrait:functions_arithmetic +URN_PATTERN = re.compile(r"^extension:[^:]+:[^:]+$") + + +class ExtensionRegistry: + def __init__(self, load_default_extensions=True) -> None: + self._urn_mapping: dict = defaultdict(dict) # URN -> anchor ID + # NOTE: during the URI -> URN migration, we only need an id generator for URN. We can use the same anchor for plan construction for URIs. + self._urn_id_generator = itertools.count(1) + self._function_mapping: dict = defaultdict(lambda: defaultdict(list)) + self._id_generator = itertools.count(1) + # Bidirectional URI <-> URN mapping (temporary during migration) + self._uri_urn_bimap = UriUrnBiDiMap() + if load_default_extensions: + for fpath in importlib_files("substrait.extensions").glob( # type: ignore + "functions*.yaml" + ): + # Derive URI from DEFAULT_URN_PREFIX and filename + uri = f"{DEFAULT_URN_PREFIX}/{fpath.name}" + self.register_extension_yaml(fpath, uri=uri) + + def register_extension_yaml( + self, + fname: Union[str, Path], + uri: str, + ) -> None: + """Register extensions from a YAML file. + Args: + fname: Path to the YAML file + uri: URI for the extension (this is required during the URI -> URN migration) + """ + fname = Path(fname) + with open(fname) as f: # type: ignore + extension_definitions = yaml.safe_load(f) + self.register_extension_dict(extension_definitions, uri=uri) + + def register_extension_dict(self, definitions: dict, uri: str) -> None: + """Register extensions from a dictionary (parsed YAML). + Args: + definitions: The extension definitions dictionary + uri: URI for the extension (for URI/URN bimap) + """ + unverified_urn = definitions.get("urn") + if not unverified_urn: + raise ValueError("Extension definitions must contain a 'urn' field") + urn = validate_urn_format(unverified_urn) + self._urn_mapping[urn] = next(self._urn_id_generator) + self._uri_urn_bimap.put(uri, urn) + simple_extensions = build_simple_extensions(definitions) + + # Helper to register functions by type + def register_functions_by_type( + functions_list: list, func_type: FunctionType + ) -> None: + if not functions_list: + return + + for function in functions_list: + self._function_mapping[urn][function.name].extend( + [ + FunctionEntry( + urn=urn, + name=function.name, + impl=impl, + anchor=next(self._id_generator), + function_type=func_type, + ) + for impl in function.impls + ] + ) + + # Register each function type + register_functions_by_type( + simple_extensions.scalar_functions or [], FunctionType.SCALAR + ) + register_functions_by_type( + simple_extensions.aggregate_functions or [], FunctionType.AGGREGATE + ) + register_functions_by_type( + simple_extensions.window_functions or [], FunctionType.WINDOW + ) + + def _find_matching_functions( + self, + function_name: str, + signature: tuple[Type] | list[Type], + urns: list[str] | None = None, + ) -> list[tuple[FunctionEntry, Type]]: + """Helper method to find matching functions across specified URNs.""" + matches = [] + urns_to_search = ( + urns if urns is not None else list(self._function_mapping.keys()) + ) + for urn in urns_to_search: + if ( + urn not in self._function_mapping + or function_name not in self._function_mapping[urn] + ): + continue + functions = self._function_mapping[urn][function_name] + for f in functions: + rtn = f.satisfies_signature(signature) + if rtn is not None: + matches.append((f, rtn)) + return matches + + # TODO add an optional return type check + def lookup_function( + self, + urn: str, + function_name: str, + signature: tuple[Type] | list[Type], + ) -> Optional[tuple[FunctionEntry, Type]]: + """Look up a function within a specific URN.""" + matches = self._find_matching_functions(function_name, signature, [urn]) + return matches[0] if matches else None + + def list_functions( + self, urn: str, function_name: str, signature: tuple[Type] | list[Type] + ) -> list[tuple[FunctionEntry, Type]]: + """List all matching functions within a specific URN.""" + return self._find_matching_functions(function_name, signature, [urn]) + + def list_functions_across_urns( + self, function_name: str, signature: tuple[Type] | list[Type] + ) -> list[tuple[FunctionEntry, Type]]: + """List all matching functions across all URNs.""" + return self._find_matching_functions(function_name, signature) + + def lookup_urn(self, urn: str) -> Optional[int]: + return self._urn_mapping.get(urn, None) + + def lookup_uri_anchor(self, uri: str) -> Optional[int]: + """Look up the anchor ID for a URI. + During the migration period, URI and URN share the same anchor. + This method converts the URI to its URN and returns the URN's anchor. + Args: + uri: The extension URI to look up + Returns: + The anchor ID for the URI (same as its corresponding URN), or None if not found + """ + urn = self._uri_urn_bimap.get_urn(uri) + if urn: + return self._urn_mapping.get(urn) + return None + + +def validate_urn_format(urn: str) -> str: + """Validate that a URN follows the expected format. + Expected format: extension:: + Example: extension:io.substrait:functions_arithmetic + Args: + urn: The URN to validate + Raises: + ValueError: If the URN format is invalid + """ + if not URN_PATTERN.match(urn): + raise ValueError( + f"Invalid URN format: '{urn}'. " + f"Expected format: extension:: " + f"(e.g., 'extension:io.substrait:functions_arithmetic')" + ) + return urn diff --git a/src/substrait/extension_registry/signature_checker_helpers.py b/src/substrait/extension_registry/signature_checker_helpers.py new file mode 100644 index 0000000..652e2b6 --- /dev/null +++ b/src/substrait/extension_registry/signature_checker_helpers.py @@ -0,0 +1,353 @@ +"""Helper functions for extension registry.""" + +from typing import Dict + +from substrait.derivation_expression import _evaluate +from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser +from substrait.gen.proto.type_pb2 import Type +from .exceptions import UnrecognizedSubstraitTypeError, UnhandledParameterizedTypeError + +# Type aliases +TypeParameterMapping = Dict[str, object] + + +# mapping from argument types to shortened signature names: https://substrait.io/extensions/#function-signature-compound-names +_normalized_key_names = { + "i8": "i8", + "i16": "i16", + "i32": "i32", + "i64": "i64", + "fp32": "fp32", + "fp64": "fp64", + "string": "str", + "binary": "vbin", + "boolean": "bool", + "timestamp": "ts", + "timestamp_tz": "tstz", + "date": "date", + "time": "time", + "interval_year": "iyear", + "interval_day": "iday", + "interval_compound": "icompound", + "uuid": "uuid", + "fixedchar": "fchar", + "varchar": "vchar", + "fixedbinary": "fbin", + "decimal": "dec", + "precision_time": "pt", + "precision_timestamp": "pts", + "precision_timestamp_tz": "ptstz", + "struct": "struct", + "list": "list", + "map": "map", +} + + +def normalize_substrait_type_names(typ: str) -> str: + """Normalize Substrait type names to their canonical short forms. + + Args: + typ: The type string to normalize (may include type specifiers and nullability markers) + + Returns: + The normalized type name + + Raises: + UnrecognizedSubstraitTypeError: If the type is not recognized + """ + # Strip type specifiers + typ = typ.split("<")[0] + # First strip nullability marker + typ = typ.strip("?").lower() + if typ.startswith("any"): + return "any" + elif typ.startswith("u!"): + return typ + elif typ in _normalized_key_names: + return _normalized_key_names[typ] + else: + raise UnrecognizedSubstraitTypeError(f"Unrecognized substrait type {typ}") + + +def _check_integer_constraint( + actual: int, constraint, parameters: TypeParameterMapping, subset: bool = False +) -> bool: + """Check if an actual integer value matches a constraint. + + Args: + actual: The actual integer value to check + constraint: An ANTLR context for either a numeric literal or parameter name + parameters: Mapping of parameter names to their resolved values + subset: If True, checks if actual < constraint (for subset relationships). + If False, checks if actual == constraint (for exact match). + + Returns: + True if the constraint is satisfied, False otherwise + + Raises: + TypeError: If constraint is not NumericLiteralContext or NumericParameterNameContext + """ + constraint_numeric: int | None = None + if isinstance(constraint, SubstraitTypeParser.NumericLiteralContext): + constraint_numeric = int(str(constraint.Number())) + elif isinstance(constraint, SubstraitTypeParser.NumericParameterNameContext): + parameter_name = str(constraint.Identifier()) + if parameter_name not in parameters: + parameters[parameter_name] = actual + if isinstance(parameters[parameter_name], int): + constraint_numeric = parameters[parameter_name] # type:ignore + else: + raise TypeError( + f"Constraint must be either NumericLiteralContext or NumericParameterNameContext, " + f"got {type(constraint).__name__} instead" + ) + if constraint_numeric is None: + return False + elif subset: + return actual < constraint_numeric + else: + return actual == constraint_numeric + + +def types_equal(type1: Type, type2: Type, check_nullability=False): + if check_nullability: + return type1 == type2 + else: + x, y = Type(), Type() + x.CopyFrom(type1) + y.CopyFrom(type2) + x.__getattribute__( + x.WhichOneof("kind") # type:ignore + ).nullability = Type.Nullability.NULLABILITY_UNSPECIFIED + y.__getattribute__( + y.WhichOneof("kind") # type:ignore + ).nullability = Type.Nullability.NULLABILITY_UNSPECIFIED + return x == y + + +def _bind_type_parameter( + covered: Type, + parameter_name: str, + parameters: TypeParameterMapping, + check_nullability: bool, +) -> bool: + """Bind a type parameter to a concrete type or verify consistency. + + If the parameter is already bound, verify that the new type is consistent with the + previously bound type. Otherwise, bind the parameter to this concrete type. + + Args: + covered: The concrete type to bind or verify against + parameter_name: The name of the type parameter (e.g., 'T', 'L') + parameters: Mapping of parameter names to their resolved types + check_nullability: Whether to consider nullability when comparing types + + Returns: + True if the parameter is successfully bound or verified, False if types are inconsistent + """ + if parameter_name in parameters: + bound_type = parameters[parameter_name] + return types_equal(bound_type, covered, check_nullability) + else: + parameters[parameter_name] = covered + return True + + +def _nullability_matches( + check_nullability: bool, parameterized_type, covered: Type, kind: str +) -> bool: + """Check if nullability constraints are satisfied. + + When check_nullability is False, any nullability is acceptable. + + When check_nullability is True, the nullability declared in the ANTLR parameterized type + (via the ``isnull`` token) must match the nullability of the covered type's protobuf enum. + + Args: + check_nullability: If False, return True immediately (no constraint checking) + parameterized_type: ANTLR context that may have an ``isnull`` token attribute + covered: The protobuf Type message to check + kind: The field name on the covered type (e.g., 'varchar', 'list') + + Returns: + True if nullability constraints are satisfied, False otherwise + """ + if not check_nullability: + return True + + # The ANTLR context stores a Token called ``isnull`` – it is + # present when the type is declared as nullable. + parameterized_nullability = ( + Type.Nullability.NULLABILITY_NULLABLE + if getattr(parameterized_type, "isnull", None) is not None + else Type.Nullability.NULLABILITY_REQUIRED + ) + + # The protobuf message stores its own enum – we compare the two. + covered_nullability = getattr( + getattr(covered, kind), # e.g. covered.varchar + "nullability", + None, + ) + + return parameterized_nullability == covered_nullability + + +def check_integer_type_parameters(covered, parameterized_type, attributes, parameters): + for attr in attributes: + if not hasattr(covered, attr) and not hasattr(parameterized_type, attr): + return True + covered_attr = getattr(covered, attr) + param_attr = getattr(parameterized_type, attr) + if not _check_integer_constraint(covered_attr, param_attr, parameters): + return False + return True + + +def _handle_parameterized_type( + parameterized_type: SubstraitTypeParser.ParameterizedTypeContext, + covered: Type, + parameters: dict, + check_nullability=False, +): + kind = covered.WhichOneof("kind") + + if not _nullability_matches(check_nullability, parameterized_type, covered, kind): + return False + + if isinstance(parameterized_type, SubstraitTypeParser.VarCharContext): + return kind == "varchar" and check_integer_type_parameters( + covered.varchar, parameterized_type, ["length"], parameters + ) + + if isinstance(parameterized_type, SubstraitTypeParser.FixedCharContext): + return kind == "fixed_char" and check_integer_type_parameters( + covered.fixed_char, parameterized_type, ["length"], parameters + ) + + if isinstance(parameterized_type, SubstraitTypeParser.FixedBinaryContext): + return kind == "fixed_binary" and check_integer_type_parameters( + covered.fixed_binary, parameterized_type, ["length"], parameters + ) + + if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext): + return kind == "decimal" and check_integer_type_parameters( + covered.decimal, parameterized_type, ["scale", "precision"], parameters + ) + + if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampContext): + return kind == "precision_timestamp" and check_integer_type_parameters( + covered.precision_timestamp, + parameterized_type, + ["precision"], + parameters, + ) + + if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampTZContext): + return kind == "precision_timestamp_tz" and check_integer_type_parameters( + covered.precision_timestamp_tz, + parameterized_type, + ["precision"], + parameters, + ) + + if isinstance(parameterized_type, SubstraitTypeParser.ListContext): + return kind == "list" and covers( + covered.list.type, + parameterized_type.expr(), + parameters, + check_nullability, + ) + + if isinstance(parameterized_type, SubstraitTypeParser.MapContext): + return ( + kind == "map" + and covers( + covered.map.key, parameterized_type.key, parameters, check_nullability + ) + and covers( + covered.map.value, + parameterized_type.value, + parameters, + check_nullability, + ) + ) + + if isinstance(parameterized_type, SubstraitTypeParser.StructContext): + if kind != "struct": + return False + covered_types = covered.struct.types + param_types = parameterized_type.expr() or [] + if not isinstance(param_types, list): + param_types = [param_types] + if len(covered_types) != len(param_types): + return False + for covered_field, param_field_ctx in zip(covered_types, param_types): + if not covers( + covered_field, + param_field_ctx, + parameters, + check_nullability, # type: ignore + ): + return False + return True + + raise UnhandledParameterizedTypeError(f"Unhandled type {type(parameterized_type)}") + + +def covers( + covered: Type, + covering: SubstraitTypeParser.TypeLiteralContext, + parameters: TypeParameterMapping, + check_nullability: bool = False, +) -> bool: + """Check if a concrete type is covered by a parameterized type signature. + + A type is "covered" if it satisfies the constraints specified in the parameterized type + and is consistent with any type parameters encountered. + + Args: + covered: The concrete type being checked + covering: The parameterized type signature to check against + parameters: Mapping of type parameter names to their bound types + check_nullability: If True, nullability must match exactly. If False, nullability is ignored. + + Returns: + True if the covered type satisfies the covering type's constraints, False otherwise + """ + # Handle parameter names + if isinstance(covering, SubstraitTypeParser.ParameterNameContext): + parameter_name = str(covering.Identifier()) + return _bind_type_parameter( + covered, parameter_name, parameters, check_nullability + ) + + covering_typedef: SubstraitTypeParser.TypeDefContext = covering.typeDef() # type:ignore + + # Handle any types + any_type: SubstraitTypeParser.AnyTypeContext = covering_typedef.anyType() # type:ignore + if any_type: + if any_type.AnyVar(): + return _bind_type_parameter( + covered, + any_type.AnyVar().symbol.text, + parameters, + check_nullability, # type:ignore + ) + else: + return True + + # Handle scalar types + scalar_type = covering_typedef.scalarType() + if scalar_type: + covering_resolved = _evaluate(covering_typedef, {}) + return types_equal(covering_resolved, covered, check_nullability) + + # Handle parameterized types using singledispatch + parameterized_type = covering_typedef.parameterizedType() + if parameterized_type: + return _handle_parameterized_type( + parameterized_type, covered, parameters, check_nullability + ) + + return False diff --git a/tests/test_extension_registry.py b/tests/test_extension_registry.py index 1b17f5c..c46842d 100644 --- a/tests/test_extension_registry.py +++ b/tests/test_extension_registry.py @@ -1,3 +1,5 @@ +import textwrap + import pytest import yaml @@ -583,13 +585,6 @@ def test_register_requires_uri(): registry.register_extension_dict(yaml.safe_load(content)) -def test_covers_list_of_i8(): - """Test that a list of i8 covers list.""" - covered = list_(i8(nullable=False), nullable=False) - param_ctx = _parse("list") - assert covers(covered, param_ctx, {}) - - def test_covers_map_string_to_i8(): """Test that a map with string keys and i8 values covers map.""" covered = map_( @@ -631,3 +626,171 @@ def test_covers_struct_mismatched_types_fails(): covered = struct([i32(nullable=False), i8(nullable=False)], nullable=False) param_ctx = _parse("struct") assert not covers(covered, param_ctx, {}) + + + +@pytest.mark.parametrize( + "test_case", + [ + # Scalar functions + pytest.param( + { + "yaml_content": textwrap.dedent("""\ + %YAML 1.2 + --- + urn: extension:test:scalar_funcs + scalar_functions: + - name: "add" + description: "Add two numbers" + impls: + - args: + - value: i8 + - value: i8 + return: i8 + """), + "urn": "extension:test:scalar_funcs", + "func_name": "add", + "signature": [i8(nullable=False), i8(nullable=False)], + "expected_type": "scalar", + }, + id="scalar-add", + ), + pytest.param( + { + "yaml_content": textwrap.dedent("""\ + %YAML 1.2 + --- + urn: extension:test:scalar_funcs + scalar_functions: + - name: "test_fn" + description: "" + impls: + - args: + - value: i8 + variadic: + min: 2 + return: i8 + """), + "urn": "extension:test:scalar_funcs", + "func_name": "test_fn", + "signature": [i8(nullable=False), i8(nullable=False)], + "expected_type": "scalar", + }, + id="scalar-test_fn", + ), + # Aggregate functions + pytest.param( + { + "yaml_content": textwrap.dedent("""\ + %YAML 1.2 + --- + urn: extension:test:agg_funcs + aggregate_functions: + - name: "count" + description: "Count non-null values" + impls: + - args: + - value: i8 + return: i64 + """), + "urn": "extension:test:agg_funcs", + "func_name": "count", + "signature": [i8(nullable=False)], + "expected_type": "aggregate", + }, + id="aggregate-count", + ), + pytest.param( + { + "yaml_content": textwrap.dedent("""\ + %YAML 1.2 + --- + urn: extension:test:agg_funcs + aggregate_functions: + - name: "sum" + description: "Sum values" + impls: + - args: + - value: i8 + return: i64 + """), + "urn": "extension:test:agg_funcs", + "func_name": "sum", + "signature": [i8(nullable=False)], + "expected_type": "aggregate", + }, + id="aggregate-sum", + ), + # Window functions + pytest.param( + { + "yaml_content": textwrap.dedent("""\ + %YAML 1.2 + --- + urn: extension:test:window_funcs + window_functions: + - name: "row_number" + description: "Assign row numbers" + impls: + - args: [] + return: i64 + """), + "urn": "extension:test:window_funcs", + "func_name": "row_number", + "signature": [], + "expected_type": "window", + }, + id="window-row_number", + ), + pytest.param( + { + "yaml_content": textwrap.dedent("""\ + %YAML 1.2 + --- + urn: extension:test:window_funcs + window_functions: + - name: "rank" + description: "Assign ranks" + impls: + - args: [] + return: i64 + """), + "urn": "extension:test:window_funcs", + "func_name": "rank", + "signature": [], + "expected_type": "window", + }, + id="window-rank", + ), + ], +) +def test_all_function_types_from_yaml(test_case): + """Test that all functions in YAML are registered with correct function_type.value.""" + test_registry = ExtensionRegistry(load_default_extensions=False) + test_registry.register_extension_dict( + yaml.safe_load(test_case["yaml_content"]), + uri=f"https://test.example.com/{test_case['urn'].replace(':', '_')}.yaml", + ) + + result = test_registry.lookup_function( + urn=test_case["urn"], + function_name=test_case["func_name"], + signature=test_case["signature"], + ) + assert result is not None, ( + f"Failed to lookup {test_case['func_name']} in {test_case['urn']}" + ) + entry, _ = result + assert hasattr(entry, "function_type"), ( + f"Entry for {test_case['func_name']} missing function_type attribute" + ) + assert entry.function_type is not None, ( + f"function_type is None for {test_case['func_name']}" + ) + assert isinstance(entry.function_type.value, str), ( + f"function_type.value is not a string for {test_case['func_name']}" + ) + assert entry.function_type.value == test_case["expected_type"], ( + f"Expected function_type.value '{test_case['expected_type']}' " + f"for {test_case['func_name']}, got '{entry.function_type.value}'" + ) diff --git a/uv.lock b/uv.lock index d433f17..717676d 100644 --- a/uv.lock +++ b/uv.lock @@ -659,6 +659,7 @@ name = "substrait" source = { editable = "." } dependencies = [ { name = "protobuf" }, + { name = "typing-extensions" }, ] [package.optional-dependencies]