Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/substrait/extension_registry/signature_checker_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,14 @@ def _handle_parameterized_type(
covered.decimal, parameterized_type, ["scale", "precision"], parameters
)

if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimeContext):
return kind == "precision_time" and check_integer_type_parameters(
covered.precision_time,
parameterized_type,
["precision"],
parameters,
)

if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampContext):
return kind == "precision_timestamp" and check_integer_type_parameters(
covered.precision_timestamp,
Expand All @@ -251,6 +259,14 @@ def _handle_parameterized_type(
parameters,
)

if isinstance(parameterized_type, SubstraitTypeParser.PrecisionIntervalDayContext):
return kind == "interval_day" and check_integer_type_parameters(
covered.interval_day,
parameterized_type,
["precision"],
parameters,
)

if isinstance(parameterized_type, SubstraitTypeParser.ListContext):
return kind == "list" and covers(
covered.list.type,
Expand Down
Empty file.
128 changes: 128 additions & 0 deletions tests/extension_registry/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Tests for parsing of a registry yaml and basic registry operations (lookup, registration)."""

import pytest
import yaml

from substrait.extension_registry import ExtensionRegistry

# Common test YAML content for testing basic functions
CONTENT = """%YAML 1.2
---
urn: extension:test:functions
scalar_functions:
- name: "test_fn"
description: ""
impls:
- args:
- value: i8
variadic:
min: 2
return: i8
- name: "test_fn_variadic_any"
description: ""
impls:
- args:
- value: any1
variadic:
min: 2
return: any1
- name: "add"
description: "Add two values."
impls:
- args:
- name: x
value: i8
- name: y
value: i8
options:
overflow:
values: [ SILENT, SATURATE, ERROR ]
return: i8
- args:
- name: x
value: i8
- name: y
value: i8
- name: z
value: any
options:
overflow:
values: [ SILENT, SATURATE, ERROR ]
return: i16
- args:
- name: x
value: any1
- name: y
value: any1
- name: z
value: any2
options:
overflow:
values: [ SILENT, SATURATE, ERROR ]
return: any2
- name: "test_decimal"
impls:
- args:
- name: x
value: decimal<P1,S1>
- name: y
value: decimal<S1,S2>
return: decimal<P1 + 1,S2 + 1>
- name: "test_enum"
impls:
- args:
- name: op
options: [ INTACT, FLIP ]
- name: x
value: i8
return: i8
- name: "add_declared"
description: "Add two values."
impls:
- args:
- name: x
value: i8
- name: y
value: i8
nullability: DECLARED_OUTPUT
return: i8?
- name: "add_discrete"
description: "Add two values."
impls:
- args:
- name: x
value: i8?
- name: y
value: i8
nullability: DISCRETE
return: i8?
- name: "test_decimal_discrete"
impls:
- args:
- name: x
value: decimal?<P1,S1>
- name: y
value: decimal<S1,S2>
nullability: DISCRETE
return: decimal?<P1 + 1,S2 + 1>
- name: "equal_test"
impls:
- args:
- name: x
value: any
- name: y
value: any
nullability: DISCRETE
return: any
"""


@pytest.fixture(scope="session")
def registry():
"""Create a registry with test functions loaded."""
reg = ExtensionRegistry(load_default_extensions=True)
reg.register_extension_dict(
yaml.safe_load(CONTENT),
uri="https://test.example.com/extension_test_functions.yaml",
)
return reg
176 changes: 176 additions & 0 deletions tests/extension_registry/test_function_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""Tests for function types (scalar, aggregate, window)."""

import textwrap

import pytest
import yaml

from substrait.builders.type import i8
from substrait.extension_registry import ExtensionRegistry


@pytest.mark.parametrize(
"test_case",
[
# Scalar functions
pytest.param(
{
"yaml_content": textwrap.dedent("""\
%YAML 1.2
---
urn: extension:test:scalar_funcs
scalar_functions:
- name: "add"
description: "Add two numbers"
impls:
- args:
- value: i8
- value: i8
return: i8
"""),
"urn": "extension:test:scalar_funcs",
"func_name": "add",
"signature": [i8(nullable=False), i8(nullable=False)],
"expected_type": "scalar",
},
id="scalar-add",
),
pytest.param(
{
"yaml_content": textwrap.dedent("""\
%YAML 1.2
---
urn: extension:test:scalar_funcs
scalar_functions:
- name: "test_fn"
description: ""
impls:
- args:
- value: i8
variadic:
min: 2
return: i8
"""),
"urn": "extension:test:scalar_funcs",
"func_name": "test_fn",
"signature": [i8(nullable=False), i8(nullable=False)],
"expected_type": "scalar",
},
id="scalar-test_fn",
),
# Aggregate functions
pytest.param(
{
"yaml_content": textwrap.dedent("""\
%YAML 1.2
---
urn: extension:test:agg_funcs
aggregate_functions:
- name: "count"
description: "Count non-null values"
impls:
- args:
- value: i8
return: i64
"""),
"urn": "extension:test:agg_funcs",
"func_name": "count",
"signature": [i8(nullable=False)],
"expected_type": "aggregate",
},
id="aggregate-count",
),
pytest.param(
{
"yaml_content": textwrap.dedent("""\
%YAML 1.2
---
urn: extension:test:agg_funcs
aggregate_functions:
- name: "sum"
description: "Sum values"
impls:
- args:
- value: i8
return: i64
"""),
"urn": "extension:test:agg_funcs",
"func_name": "sum",
"signature": [i8(nullable=False)],
"expected_type": "aggregate",
},
id="aggregate-sum",
),
# Window functions
pytest.param(
{
"yaml_content": textwrap.dedent("""\
%YAML 1.2
---
urn: extension:test:window_funcs
window_functions:
- name: "row_number"
description: "Assign row numbers"
impls:
- args: []
return: i64
"""),
"urn": "extension:test:window_funcs",
"func_name": "row_number",
"signature": [],
"expected_type": "window",
},
id="window-row_number",
),
pytest.param(
{
"yaml_content": textwrap.dedent("""\
%YAML 1.2
---
urn: extension:test:window_funcs
window_functions:
- name: "rank"
description: "Assign ranks"
impls:
- args: []
return: i64
"""),
"urn": "extension:test:window_funcs",
"func_name": "rank",
"signature": [],
"expected_type": "window",
},
id="window-rank",
),
],
)
def test_all_function_types_from_yaml(test_case):
"""Test that all functions in YAML are registered with correct function_type.value."""
test_registry = ExtensionRegistry(load_default_extensions=False)
test_registry.register_extension_dict(
yaml.safe_load(test_case["yaml_content"]),
uri=f"https://test.example.com/{test_case['urn'].replace(':', '_')}.yaml",
)

result = test_registry.lookup_function(
urn=test_case["urn"],
function_name=test_case["func_name"],
signature=test_case["signature"],
)
assert result is not None, (
f"Failed to lookup {test_case['func_name']} in {test_case['urn']}"
)
entry, _ = result
assert hasattr(entry, "function_type"), (
f"Entry for {test_case['func_name']} missing function_type attribute"
)
assert entry.function_type is not None, (
f"function_type is None for {test_case['func_name']}"
)
assert isinstance(entry.function_type.value, str), (
f"function_type.value is not a string for {test_case['func_name']}"
)
assert entry.function_type.value == test_case["expected_type"], (
f"Expected function_type.value '{test_case['expected_type']}' "
f"for {test_case['func_name']}, got '{entry.function_type.value}'"
)
Loading