Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ authors = [{name = "Substrait contributors", email = "substrait@googlegroups.com
license = {text = "Apache-2.0"}
readme = "README.md"
requires-python = ">=3.10"
dependencies = ["protobuf >=3.19.1,<6"]
dependencies = ["protobuf >=3.19.1,<6", "typing_extensions"]
dynamic = ["version"]

[tool.setuptools_scm]
Expand Down
13 changes: 10 additions & 3 deletions src/substrait/builders/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
See `examples/builder_example.py` for usage.
"""

from typing import Iterable, Optional, Union, Callable

from typing import Callable, Iterable, Optional, TypedDict, Union
import substrait.gen.proto.algebra_pb2 as stalg
from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension
import substrait.gen.proto.plan_pb2 as stp
Expand All @@ -20,16 +20,23 @@
from substrait.type_inference import infer_plan_schema
from substrait.utils import (
merge_extension_declarations,
merge_extension_urns,
merge_extension_uris,
merge_extension_urns,
)

UnboundPlan = Callable[[ExtensionRegistry], stp.Plan]

PlanOrUnbound = Union[stp.Plan, UnboundPlan]

_ExtensionDict = TypedDict(
"_ExtensionDict",
{"extension_uris": list, "extension_urns": list, "extensions": list},
)


def _merge_extensions(*objs):
def _merge_extensions(
*objs,
) -> _ExtensionDict:
"""Merge extension URIs, URNs, and declarations from multiple plan/expression objects.

During the URI -> URN migration period, we maintain both URI and URN references
Expand Down
8 changes: 8 additions & 0 deletions src/substrait/builders/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,14 @@ def precision_timestamp_tz(precision: int, nullable=True) -> stt.Type:
)
)

def timestamp(nullable=True) -> stt.Type:
return stt.Type(
timestamp=stt.Type.Timestamp(
nullability=stt.Type.NULLABILITY_NULLABLE
if nullable
else stt.Type.NULLABILITY_REQUIRED,
)
)

def struct(types: Iterable[stt.Type], nullable=True) -> stt.Type:
return stt.Type(
Expand Down
119 changes: 111 additions & 8 deletions src/substrait/derivation_expression.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional
from antlr4 import InputStream, CommonTokenStream

from antlr4 import CommonTokenStream, InputStream

from substrait.gen.antlr.SubstraitTypeLexer import SubstraitTypeLexer
from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser
from substrait.gen.proto.type_pb2 import Type
Expand All @@ -9,8 +11,9 @@ def _evaluate(x, values: dict):
if isinstance(x, SubstraitTypeParser.BinaryExprContext):
left = _evaluate(x.left, values)
right = _evaluate(x.right, values)

if x.op.text == "+":
if x.op is None:
raise Exception("Undefined operator op")
elif x.op.text == "+":
return left + right
elif x.op.text == "-":
return left - right
Expand Down Expand Up @@ -65,22 +68,122 @@ def _evaluate(x, values: dict):
return Type(fp64=Type.FP64(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.BooleanContext):
return Type(bool=Type.Boolean(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.StringContext):
return Type(string=Type.String(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.TimestampContext):
return Type(timestamp=Type.Timestamp(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.DateContext):
return Type(date=Type.Date(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.IntervalYearContext):
return Type(interval_year=Type.IntervalYear(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.UuidContext):
return Type(uuid=Type.UUID(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.BinaryContext):
return Type(binary=Type.Binary(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.TimeContext):
return Type(time=Type.Time(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.TimestampTzContext):
return Type(timestamp_tz=Type.TimestampTZ(nullability=nullability))
else:
raise Exception(f"Unknown scalar type {type(scalar_type)}")
elif parametrized_type:
nullability = (
Type.NULLABILITY_NULLABLE
if parametrized_type.isnull
else Type.NULLABILITY_REQUIRED
)
if isinstance(parametrized_type, SubstraitTypeParser.DecimalContext):
precision = _evaluate(parametrized_type.precision, values)
scale = _evaluate(parametrized_type.scale, values)
nullability = (
Type.NULLABILITY_NULLABLE
if parametrized_type.isnull
else Type.NULLABILITY_REQUIRED
)
return Type(
decimal=Type.Decimal(
precision=precision, scale=scale, nullability=nullability
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.VarCharContext):
length = _evaluate(parametrized_type.length, values)
return Type(
varchar=Type.VarChar(
length=length,
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.FixedCharContext):
length = _evaluate(parametrized_type.length, values)
return Type(
fixed_char=Type.FixedChar(
length=length,
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.FixedBinaryContext):
length = _evaluate(parametrized_type.length, values)
return Type(
fixed_binary=Type.FixedBinary(
length=length,
nullability=nullability,
)
)
elif isinstance(
parametrized_type, SubstraitTypeParser.PrecisionTimestampContext
):
precision = _evaluate(parametrized_type.precision, values)
return Type(
precision_timestamp=Type.PrecisionTimestamp(
precision=precision,
nullability=nullability,
)
)
elif isinstance(
parametrized_type, SubstraitTypeParser.PrecisionTimestampTZContext
):
precision = _evaluate(parametrized_type.precision, values)
return Type(
precision_timestamp_tz=Type.PrecisionTimestampTZ(
precision=precision,
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.IntervalYearContext):
return Type(
interval_year=Type.IntervalYear(
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.StructContext):
types = list(
map(lambda x: _evaluate(x, values), parametrized_type.expr())
)
return Type(
struct=Type.Struct(
types=types,
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.ListContext):
child_type = _evaluate(parametrized_type.expr(), values)
return Type(
list=Type.List(
type=child_type,
nullability=nullability,
)
)

elif isinstance(parametrized_type, SubstraitTypeParser.MapContext):
return Type(
map=Type.Map(
key=_evaluate(parametrized_type.key, values),
value=_evaluate(parametrized_type.value, values),
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.NStructContext):
# it gives me a parser error i may have to update the parser
# string `evaluate("NSTRUCT<longitude: i32, latitude: i32>")` from the docs https://substrait.io/types/type_classes/
# line 1:17 extraneous input ':'
raise NotImplementedError("Named structure type not implemented yet")
# elif isinstance(parametrized_type, SubstraitTypeParser.UserDefinedContext):

raise Exception(f"Unknown parametrized type {type(parametrized_type)}")
elif any_type:
any_var = any_type.AnyVar()
Expand Down
Loading
Loading