diff --git a/pyproject.toml b/pyproject.toml index 677ef57..c1ddd51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ authors = [{name = "Substrait contributors", email = "substrait@googlegroups.com license = {text = "Apache-2.0"} readme = "README.md" requires-python = ">=3.10" -dependencies = ["protobuf >=3.19.1,<6"] +dependencies = ["protobuf >=3.19.1,<6", "typing_extensions"] dynamic = ["version"] [tool.setuptools_scm] diff --git a/src/substrait/gen/json/simple_extensions.py b/src/substrait/gen/json/simple_extensions.py index 2885bb4..e323ac5 100644 --- a/src/substrait/gen/json/simple_extensions.py +++ b/src/substrait/gen/json/simple_extensions.py @@ -7,13 +7,15 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union +from typing_extensions import TypeAlias + class Functions(Enum): INHERITS = 'INHERITS' SEPARATE = 'SEPARATE' -Type = Union[str, Dict[str, Any]] +Type: TypeAlias = Union[str, Dict[str, Any]] class Type1(Enum): @@ -24,7 +26,7 @@ class Type1(Enum): string = 'string' -EnumOptions = List[str] +EnumOptions: TypeAlias = List[str] @dataclass @@ -49,7 +51,7 @@ class TypeArg: description: Optional[str] = None -Arguments = List[Union[EnumerationArg, ValueArg, TypeArg]] +Arguments: TypeAlias = List[Union[EnumerationArg, ValueArg, TypeArg]] @dataclass @@ -58,7 +60,7 @@ class Options1: description: Optional[str] = None -Options = Dict[str, Options1] +Options: TypeAlias = Dict[str, Options1] class ParameterConsistency(Enum): @@ -73,10 +75,10 @@ class VariadicBehavior: parameterConsistency: Optional[ParameterConsistency] = None -Deterministic = bool +Deterministic: TypeAlias = bool -SessionDependent = bool +SessionDependent: TypeAlias = bool class NullabilityHandling(Enum): @@ -85,13 +87,13 @@ class NullabilityHandling(Enum): DISCRETE = 'DISCRETE' -ReturnValue = Type +ReturnValue: TypeAlias = Type -Implementation = Dict[str, str] +Implementation: TypeAlias = Dict[str, str] -Intermediate = Type +Intermediate: TypeAlias = Type class Decomposable(Enum): @@ -100,10 +102,10 @@ class Decomposable(Enum): MANY = 'MANY' -Maxset = float +Maxset: TypeAlias = float -Ordered = bool +Ordered: TypeAlias = bool @dataclass @@ -196,7 +198,7 @@ class TypeParamDef: optional: Optional[bool] = None -TypeParamDefs = List[TypeParamDef] +TypeParamDefs: TypeAlias = List[TypeParamDef] @dataclass diff --git a/src/substrait/type_inference.py b/src/substrait/type_inference.py index d6a68e8..1d01b1c 100644 --- a/src/substrait/type_inference.py +++ b/src/substrait/type_inference.py @@ -191,9 +191,11 @@ def infer_expression_type( elif rex_type == "window_function": return expression.window_function.output_type elif rex_type == "if_then": - return infer_expression_type(expression.if_then.ifs[0].then) + return infer_expression_type(expression.if_then.ifs[0].then, parent_schema) elif rex_type == "switch_expression": - return infer_expression_type(expression.switch_expression.ifs[0].then) + return infer_expression_type( + expression.switch_expression.ifs[0].then, parent_schema + ) elif rex_type == "cast": return expression.cast.type elif rex_type == "singular_or_list" or rex_type == "multi_or_list":