Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ RUN cd ~ && curl -LO https://github.com/protocolbuffers/protobuf/releases/downlo
unzip protoc-25.1-linux-x86_64.zip -d ~/.local && \
rm protoc-25.1-linux-x86_64.zip
RUN curl -sSL "https://github.com/bufbuild/buf/releases/download/v1.50.0/buf-$(uname -s)-$(uname -m)" -o ~/.local/bin/buf && chmod +x ~/.local/bin/buf
RUN curl -LsSf https://astral.sh/uv/0.7.11/install.sh | sh
USER root
13 changes: 13 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,16 @@ antlr:
&& java -jar ${ANTLR_JAR} -o ../../../src/substrait/gen/antlr -Dlanguage=Python3 SubstraitType.g4 \
&& rm ../../../src/substrait/gen/antlr/*.tokens \
&& rm ../../../src/substrait/gen/antlr/*.interp

codegen-extensions:
uv run --with datamodel-code-generator datamodel-codegen \
--input-file-type jsonschema \
--input third_party/substrait/text/simple_extensions_schema.yaml \
--output src/substrait/gen/json/simple_extensions.py \
--output-model-type dataclasses.dataclass

lint:
uvx ruff@0.11.11 check

format:
uvx ruff@0.11.11 format
73 changes: 45 additions & 28 deletions src/substrait/extension_registry.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import yaml
import itertools
from substrait.gen.proto.type_pb2 import Type
from importlib.resources import files as importlib_files
import itertools
from collections import defaultdict
from collections.abc import Mapping
from pathlib import Path
from typing import Any, Optional, Union
from typing import Optional, Union
from .derivation_expression import evaluate, _evaluate, _parse

import yaml
from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser
from substrait.gen.json import simple_extensions as se
from substrait.simple_extension_utils import build_simple_extensions


DEFAULT_URI_PREFIX = "https://github.com/substrait-io/substrait/blob/main/extensions"

Expand Down Expand Up @@ -166,31 +167,35 @@ def covers(

class FunctionEntry:
def __init__(
self, uri: str, name: str, impl: Mapping[str, Any], anchor: int
self, uri: str, name: str, impl: Union[se.Impl, se.Impl1, se.Impl2], anchor: int
) -> None:
self.name = name
self.impl = impl
self.normalized_inputs: list = []
self.uri: str = uri
self.anchor = anchor
self.arguments = []
self.rtn = impl["return"]
self.nullability = impl.get("nullability", "MIRROR")
self.variadic = impl.get("variadic", False)
if input_args := impl.get("args", []):
for val in input_args:
if typ := val.get("value"):
self.arguments.append(_parse(typ))
self.normalized_inputs.append(normalize_substrait_type_names(typ))
elif _ := val.get("name", None):
self.arguments.append(val.get("options"))
self.nullability = (
impl.nullability if impl.nullability else se.NullabilityHandling.MIRROR
)

if impl.args:
for arg in impl.args:
if isinstance(arg, se.ValueArg):
self.arguments.append(_parse(arg.value))
self.normalized_inputs.append(
normalize_substrait_type_names(arg.value)
)
elif isinstance(arg, se.EnumerationArg):
self.arguments.append(arg.options)
self.normalized_inputs.append("req")

def __repr__(self) -> str:
return f"{self.name}:{'_'.join(self.normalized_inputs)}"

def satisfies_signature(self, signature: tuple) -> Optional[str]:
if self.variadic:
min_args_allowed = self.variadic.get("min", 0)
if self.impl.variadic:
min_args_allowed = self.impl.variadic.min or 0
if len(signature) < min_args_allowed:
return None
inputs = [self.arguments[0]] * len(signature)
Expand All @@ -209,13 +214,17 @@ def satisfies_signature(self, signature: tuple) -> Optional[str]:
return None
else:
if not covers(
y, x, parameters, check_nullability=self.nullability == "DISCRETE"
y,
x,
parameters,
check_nullability=self.nullability
== se.NullabilityHandling.DISCRETE,
):
return None

output_type = evaluate(self.rtn, parameters)
output_type = evaluate(self.impl.return_, parameters)

if self.nullability == "MIRROR":
if self.nullability == se.NullabilityHandling.MIRROR:
sig_contains_nullable = any(
[
p.__getattribute__(p.WhichOneof("kind")).nullability
Expand Down Expand Up @@ -265,19 +274,27 @@ def register_extension_yaml(
def register_extension_dict(self, definitions: dict, uri: str) -> None:
self._uri_mapping[uri] = next(self._uri_id_generator)

for named_functions in definitions.values():
for function in named_functions:
for impl in function.get("impls", []):
simple_extensions = build_simple_extensions(definitions)

functions = (
(simple_extensions.scalar_functions or [])
+ (simple_extensions.aggregate_functions or [])
+ (simple_extensions.window_functions or [])
)

if functions:
for function in functions:
for impl in function.impls:
func = FunctionEntry(
uri, function["name"], impl, next(self._id_generator)
uri, function.name, impl, next(self._id_generator)
)
if (
func.uri in self._function_mapping
and function["name"] in self._function_mapping[func.uri]
and function.name in self._function_mapping[func.uri]
):
self._function_mapping[func.uri][function["name"]].append(func)
self._function_mapping[func.uri][function.name].append(func)
else:
self._function_mapping[func.uri][function["name"]] = [func]
self._function_mapping[func.uri][function.name] = [func]

# TODO add an optional return type check
def lookup_function(
Expand Down
218 changes: 218 additions & 0 deletions src/substrait/gen/json/simple_extensions.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading