diff --git a/src/substrait/extension_registry/signature_checker_helpers.py b/src/substrait/extension_registry/signature_checker_helpers.py index 652e2b6..2e95085 100644 --- a/src/substrait/extension_registry/signature_checker_helpers.py +++ b/src/substrait/extension_registry/signature_checker_helpers.py @@ -235,6 +235,14 @@ def _handle_parameterized_type( covered.decimal, parameterized_type, ["scale", "precision"], parameters ) + if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimeContext): + return kind == "precision_time" and check_integer_type_parameters( + covered.precision_time, + parameterized_type, + ["precision"], + parameters, + ) + if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampContext): return kind == "precision_timestamp" and check_integer_type_parameters( covered.precision_timestamp, @@ -251,6 +259,14 @@ def _handle_parameterized_type( parameters, ) + if isinstance(parameterized_type, SubstraitTypeParser.PrecisionIntervalDayContext): + return kind == "interval_day" and check_integer_type_parameters( + covered.interval_day, + parameterized_type, + ["precision"], + parameters, + ) + if isinstance(parameterized_type, SubstraitTypeParser.ListContext): return kind == "list" and covers( covered.list.type, diff --git a/tests/extension_registry/__init__.py b/tests/extension_registry/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/extension_registry/conftest.py b/tests/extension_registry/conftest.py new file mode 100644 index 0000000..b6df3cc --- /dev/null +++ b/tests/extension_registry/conftest.py @@ -0,0 +1,128 @@ +"""Tests for parsing of a registry yaml and basic registry operations (lookup, registration).""" + +import pytest +import yaml + +from substrait.extension_registry import ExtensionRegistry + +# Common test YAML content for testing basic functions +CONTENT = """%YAML 1.2 +--- +urn: extension:test:functions +scalar_functions: + - name: "test_fn" + description: "" + impls: + - args: + - value: i8 + variadic: + min: 2 + return: i8 + - name: "test_fn_variadic_any" + description: "" + impls: + - args: + - value: any1 + variadic: + min: 2 + return: any1 + - name: "add" + description: "Add two values." + impls: + - args: + - name: x + value: i8 + - name: y + value: i8 + options: + overflow: + values: [ SILENT, SATURATE, ERROR ] + return: i8 + - args: + - name: x + value: i8 + - name: y + value: i8 + - name: z + value: any + options: + overflow: + values: [ SILENT, SATURATE, ERROR ] + return: i16 + - args: + - name: x + value: any1 + - name: y + value: any1 + - name: z + value: any2 + options: + overflow: + values: [ SILENT, SATURATE, ERROR ] + return: any2 + - name: "test_decimal" + impls: + - args: + - name: x + value: decimal + - name: y + value: decimal + return: decimal + - name: "test_enum" + impls: + - args: + - name: op + options: [ INTACT, FLIP ] + - name: x + value: i8 + return: i8 + - name: "add_declared" + description: "Add two values." + impls: + - args: + - name: x + value: i8 + - name: y + value: i8 + nullability: DECLARED_OUTPUT + return: i8? + - name: "add_discrete" + description: "Add two values." + impls: + - args: + - name: x + value: i8? + - name: y + value: i8 + nullability: DISCRETE + return: i8? + - name: "test_decimal_discrete" + impls: + - args: + - name: x + value: decimal? + - name: y + value: decimal + nullability: DISCRETE + return: decimal? + - name: "equal_test" + impls: + - args: + - name: x + value: any + - name: y + value: any + nullability: DISCRETE + return: any +""" + + +@pytest.fixture(scope="session") +def registry(): + """Create a registry with test functions loaded.""" + reg = ExtensionRegistry(load_default_extensions=True) + reg.register_extension_dict( + yaml.safe_load(CONTENT), + uri="https://test.example.com/extension_test_functions.yaml", + ) + return reg diff --git a/tests/extension_registry/test_function_types.py b/tests/extension_registry/test_function_types.py new file mode 100644 index 0000000..d4c3f3a --- /dev/null +++ b/tests/extension_registry/test_function_types.py @@ -0,0 +1,176 @@ +"""Tests for function types (scalar, aggregate, window).""" + +import textwrap + +import pytest +import yaml + +from substrait.builders.type import i8 +from substrait.extension_registry import ExtensionRegistry + + +@pytest.mark.parametrize( + "test_case", + [ + # Scalar functions + pytest.param( + { + "yaml_content": textwrap.dedent("""\ + %YAML 1.2 + --- + urn: extension:test:scalar_funcs + scalar_functions: + - name: "add" + description: "Add two numbers" + impls: + - args: + - value: i8 + - value: i8 + return: i8 + """), + "urn": "extension:test:scalar_funcs", + "func_name": "add", + "signature": [i8(nullable=False), i8(nullable=False)], + "expected_type": "scalar", + }, + id="scalar-add", + ), + pytest.param( + { + "yaml_content": textwrap.dedent("""\ + %YAML 1.2 + --- + urn: extension:test:scalar_funcs + scalar_functions: + - name: "test_fn" + description: "" + impls: + - args: + - value: i8 + variadic: + min: 2 + return: i8 + """), + "urn": "extension:test:scalar_funcs", + "func_name": "test_fn", + "signature": [i8(nullable=False), i8(nullable=False)], + "expected_type": "scalar", + }, + id="scalar-test_fn", + ), + # Aggregate functions + pytest.param( + { + "yaml_content": textwrap.dedent("""\ + %YAML 1.2 + --- + urn: extension:test:agg_funcs + aggregate_functions: + - name: "count" + description: "Count non-null values" + impls: + - args: + - value: i8 + return: i64 + """), + "urn": "extension:test:agg_funcs", + "func_name": "count", + "signature": [i8(nullable=False)], + "expected_type": "aggregate", + }, + id="aggregate-count", + ), + pytest.param( + { + "yaml_content": textwrap.dedent("""\ + %YAML 1.2 + --- + urn: extension:test:agg_funcs + aggregate_functions: + - name: "sum" + description: "Sum values" + impls: + - args: + - value: i8 + return: i64 + """), + "urn": "extension:test:agg_funcs", + "func_name": "sum", + "signature": [i8(nullable=False)], + "expected_type": "aggregate", + }, + id="aggregate-sum", + ), + # Window functions + pytest.param( + { + "yaml_content": textwrap.dedent("""\ + %YAML 1.2 + --- + urn: extension:test:window_funcs + window_functions: + - name: "row_number" + description: "Assign row numbers" + impls: + - args: [] + return: i64 + """), + "urn": "extension:test:window_funcs", + "func_name": "row_number", + "signature": [], + "expected_type": "window", + }, + id="window-row_number", + ), + pytest.param( + { + "yaml_content": textwrap.dedent("""\ + %YAML 1.2 + --- + urn: extension:test:window_funcs + window_functions: + - name: "rank" + description: "Assign ranks" + impls: + - args: [] + return: i64 + """), + "urn": "extension:test:window_funcs", + "func_name": "rank", + "signature": [], + "expected_type": "window", + }, + id="window-rank", + ), + ], +) +def test_all_function_types_from_yaml(test_case): + """Test that all functions in YAML are registered with correct function_type.value.""" + test_registry = ExtensionRegistry(load_default_extensions=False) + test_registry.register_extension_dict( + yaml.safe_load(test_case["yaml_content"]), + uri=f"https://test.example.com/{test_case['urn'].replace(':', '_')}.yaml", + ) + + result = test_registry.lookup_function( + urn=test_case["urn"], + function_name=test_case["func_name"], + signature=test_case["signature"], + ) + assert result is not None, ( + f"Failed to lookup {test_case['func_name']} in {test_case['urn']}" + ) + entry, _ = result + assert hasattr(entry, "function_type"), ( + f"Entry for {test_case['func_name']} missing function_type attribute" + ) + assert entry.function_type is not None, ( + f"function_type is None for {test_case['func_name']}" + ) + assert isinstance(entry.function_type.value, str), ( + f"function_type.value is not a string for {test_case['func_name']}" + ) + assert entry.function_type.value == test_case["expected_type"], ( + f"Expected function_type.value '{test_case['expected_type']}' " + f"for {test_case['func_name']}, got '{entry.function_type.value}'" + ) diff --git a/tests/extension_registry/test_registry_lookup.py b/tests/extension_registry/test_registry_lookup.py new file mode 100644 index 0000000..904d19f --- /dev/null +++ b/tests/extension_registry/test_registry_lookup.py @@ -0,0 +1,178 @@ +from substrait.builders.type import i8, decimal, i16 +from substrait.gen.proto.type_pb2 import Type + + +def test_non_existing_urn(registry): + assert ( + registry.lookup_function( + urn="non_existent", + function_name="add", + signature=[i8(nullable=False), i8(nullable=False)], + ) + is None + ) + + +def test_non_existing_function(registry): + assert ( + registry.lookup_function( + urn="extension:test:functions", + function_name="sub", + signature=[i8(nullable=False), i8(nullable=False)], + ) + is None + ) + + +def test_non_existing_function_signature(registry): + assert ( + registry.lookup_function( + urn="extension:test:functions", + function_name="add", + signature=[i8(nullable=False)], + ) + is None + ) + + +def test_exact_match(registry): + assert registry.lookup_function( + urn="extension:test:functions", + function_name="add", + signature=[i8(nullable=False), i8(nullable=False)], + )[1] == Type(i8=Type.I8(nullability=Type.NULLABILITY_REQUIRED)) + + +def test_wildcard_match(registry): + assert registry.lookup_function( + urn="extension:test:functions", + function_name="add", + signature=[i8(nullable=False), i8(nullable=False), bool()], + )[1] == Type(i16=Type.I16(nullability=Type.NULLABILITY_REQUIRED)) + + +def test_wildcard_match_fails_with_constraits(registry): + assert ( + registry.lookup_function( + urn="extension:test:functions", + function_name="add", + signature=[i8(nullable=False), i16(nullable=False), i16(nullable=False)], + ) + is None + ) + + +def test_wildcard_match_with_constraits(registry): + assert registry.lookup_function( + urn="extension:test:functions", + function_name="add", + signature=[i16(nullable=False), i16(nullable=False), i8(nullable=False)], + )[1] == i8(nullable=False) + + +def test_variadic(registry): + assert registry.lookup_function( + urn="extension:test:functions", + function_name="test_fn", + signature=[i8(nullable=False), i8(nullable=False), i8(nullable=False)], + )[1] == i8(nullable=False) + + +def test_variadic_any(registry): + assert registry.lookup_function( + urn="extension:test:functions", + function_name="test_fn_variadic_any", + signature=[i16(nullable=False), i16(nullable=False), i16(nullable=False)], + )[1] == i16(nullable=False) + + +def test_variadic_fails_min_constraint(registry): + assert ( + registry.lookup_function( + urn="extension:test:functions", + function_name="test_fn", + signature=[i8(nullable=False)], + ) + is None + ) + + +def test_decimal_happy_path(registry): + assert registry.lookup_function( + urn="extension:test:functions", + function_name="test_decimal", + signature=[decimal(8, 10, nullable=False), decimal(6, 8, nullable=False)], + )[1] == decimal(7, 11, nullable=False) + + +def test_decimal_violates_constraint(registry): + assert ( + registry.lookup_function( + urn="extension:test:functions", + function_name="test_decimal", + signature=[decimal(8, 10, nullable=False), decimal(10, 12, nullable=False)], + ) + is None + ) + + +def test_decimal_happy_path_discrete(registry): + assert registry.lookup_function( + urn="extension:test:functions", + function_name="test_decimal_discrete", + signature=[decimal(8, 10, nullable=True), decimal(6, 8, nullable=False)], + )[1] == decimal(7, 11, nullable=True) + + +def test_enum_with_valid_option(registry): + assert registry.lookup_function( + urn="extension:test:functions", + function_name="test_enum", + signature=["FLIP", i8(nullable=False)], + )[1] == i8(nullable=False) + + +def test_enum_with_nonexistent_option(registry): + assert ( + registry.lookup_function( + urn="extension:test:functions", + function_name="test_enum", + signature=["NONEXISTENT", i8(nullable=False)], + ) + is None + ) + + +def test_function_with_nullable_args(registry): + assert registry.lookup_function( + urn="extension:test:functions", + function_name="add", + signature=[i8(nullable=True), i8(nullable=False)], + )[1] == i8(nullable=True) + + +def test_function_with_declared_output_nullability(registry): + assert registry.lookup_function( + urn="extension:test:functions", + function_name="add_declared", + signature=[i8(nullable=False), i8(nullable=False)], + )[1] == i8(nullable=True) + + +def test_function_with_discrete_nullability(registry): + assert registry.lookup_function( + urn="extension:test:functions", + function_name="add_discrete", + signature=[i8(nullable=True), i8(nullable=False)], + )[1] == i8(nullable=True) + + +def test_function_with_discrete_nullability_nonexisting(registry): + assert ( + registry.lookup_function( + urn="extension:test:functions", + function_name="add_discrete", + signature=[i8(nullable=False), i8(nullable=False)], + ) + is None + ) diff --git a/tests/extension_registry/test_type_coverage.py b/tests/extension_registry/test_type_coverage.py new file mode 100644 index 0000000..e82f2c0 --- /dev/null +++ b/tests/extension_registry/test_type_coverage.py @@ -0,0 +1,538 @@ +"""Tests for the covers() function - type coverage and matching.""" + +from substrait.builders.type import decimal, i8, i16, i32, struct +from substrait.builders.type import list as list_ +from substrait.builders.type import map as map_ +from substrait.derivation_expression import _parse +from substrait.extension_registry import covers +from substrait.gen.proto.type_pb2 import Type + + +def test_covers(): + """Basic covers test for i8 type.""" + covered = i8(nullable=False) + param_ctx = _parse("i8") + assert covers(covered, param_ctx, {}) + + +def test_covers_nullability(): + """Test nullable type coverage with check_nullability flag.""" + covered = i8(nullable=True) + param_ctx = _parse("i8?") + assert covers(covered, param_ctx, {}, check_nullability=True) + covered = i8(nullable=True) + param_ctx = _parse("i8") + assert not covers(covered, param_ctx, {}, check_nullability=True) + + +def test_covers_decimal(): + """Test decimal precision/scale coverage with multiple assertions.""" + assert not covers(decimal(8, 10), _parse("decimal<11, A>"), {}) + assert covers(decimal(8, 10), _parse("decimal<10, A>"), {}) + assert covers(decimal(8, 10), _parse("decimal<10, 8>"), {}) + assert not covers(decimal(8, 10), _parse("decimal<10, 9>"), {}) + assert not covers(decimal(8, 10), _parse("decimal<11, 8>"), {}) + assert not covers(decimal(8, 10), _parse("decimal<11, 9>"), {}) + + +def test_covers_decimal_happy_path(): + """Test decimal coverage with parameter binding.""" + covered = decimal(precision=10, scale=2, nullable=False) + param_ctx = _parse("decimal") + params = {} + assert covers(covered, param_ctx, params) + assert params["P"] == 10 and params["S"] == 2 + + +def test_covers_decimal_happy_path_2(): + """Test decimal coverage with parameter binding.""" + params = {} + assert covers(decimal(8, 10), _parse("decimal<10, A>"), params) + assert params == {"A": 8} + + +def test_covers_any(): + """Test that any type can be covered by any concrete type.""" + covered = decimal(precision=10, scale=2, nullable=False) + param_ctx = _parse("any") + assert covers(covered, param_ctx, {}) + + +def test_covers_any_2(): + """Test that any type can be covered by any concrete type.""" + assert covers(decimal(8, 10), _parse("any"), {}) + + +def test_covers_varchar_length_ok(): + """Test varchar length coverage (success case).""" + covered = Type( + varchar=Type.VarChar(nullability=Type.NULLABILITY_REQUIRED, length=10) + ) + param_ctx = _parse("varchar<10>") + assert covers(covered, param_ctx, {}) + + +def test_covers_varchar_length_fail(): + """Test varchar length coverage (failure case).""" + covered = Type( + varchar=Type.VarChar(nullability=Type.NULLABILITY_REQUIRED, length=10) + ) + param_ctx = _parse("varchar<20>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_varchar_nullability(): + """Test varchar with nullability checks.""" + covered = Type( + varchar=Type.VarChar(nullability=Type.NULLABILITY_NULLABLE, length=10) + ) + param_ctx = _parse("varchar?<10>") + assert covers(covered, param_ctx, {}) + assert covers(covered, param_ctx, {}, check_nullability=True) + + +def test_covers_fixed_char_length_ok(): + """Test fixed char length coverage (success).""" + covered = Type( + fixed_char=Type.FixedChar(nullability=Type.NULLABILITY_REQUIRED, length=10) + ) + param_ctx = _parse("fixedchar<10>") + assert covers(covered, param_ctx, {}) + + +def test_covers_fixed_char_length_fail(): + """Test fixed char length coverage (failure).""" + covered = Type( + fixed_char=Type.FixedChar(nullability=Type.NULLABILITY_REQUIRED, length=10) + ) + param_ctx = _parse("fixedchar<20>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_fixed_binary_length_ok(): + """Test fixed binary length coverage (success).""" + covered = Type( + fixed_binary=Type.FixedBinary(nullability=Type.NULLABILITY_REQUIRED, length=10) + ) + param_ctx = _parse("fixedbinary<10>") + assert covers(covered, param_ctx, {}) + + +def test_covers_fixed_binary_length_fail(): + """Test fixed binary length coverage (failure).""" + covered = Type( + fixed_binary=Type.FixedBinary(nullability=Type.NULLABILITY_REQUIRED, length=10) + ) + param_ctx = _parse("fixedbinary<20>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_decimal_precision_scale_fail(): + """Test decimal coverage fails with mismatched precision/scale.""" + covered = decimal(precision=10, scale=2, nullable=False) + param_ctx = _parse("decimal<11, 2>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_precision_timestamp_ok(): + """Test precision timestamp coverage (success).""" + covered = Type( + precision_timestamp=Type.PrecisionTimestamp( + nullability=Type.NULLABILITY_REQUIRED, precision=6 + ) + ) + param_ctx = _parse("precision_timestamp<6>") + assert covers(covered, param_ctx, {}) + # Test with parameter binding + params = {} + param_ctx_with_param = _parse("precision_timestamp

") + assert covers(covered, param_ctx_with_param, params) + assert params["P"] == 6 + + +def test_covers_precision_timestamp_fail(): + """Test precision timestamp coverage (failure).""" + covered = Type( + precision_timestamp=Type.PrecisionTimestamp( + nullability=Type.NULLABILITY_REQUIRED, precision=6 + ) + ) + param_ctx = _parse("precision_timestamp<3>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_precision_timestamp_tz_ok(): + """Test precision timestamp with timezone (success).""" + covered = Type( + precision_timestamp_tz=Type.PrecisionTimestampTZ( + nullability=Type.NULLABILITY_REQUIRED, precision=6 + ) + ) + param_ctx = _parse("precision_timestamp_tz<6>") + assert covers(covered, param_ctx, {}) + # Test with parameter binding + params = {} + param_ctx_with_param = _parse("precision_timestamp_tz

") + assert covers(covered, param_ctx_with_param, params) + assert params["P"] == 6 + + +def test_covers_precision_timestamp_tz_fail(): + """Test precision timestamp with timezone (failure).""" + covered = Type( + precision_timestamp_tz=Type.PrecisionTimestampTZ( + nullability=Type.NULLABILITY_REQUIRED, precision=4 + ) + ) + param_ctx = _parse("precision_timestamp_tz<3>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_precision_time_ok(): + """Test precision time coverage (success).""" + covered = Type( + precision_time=Type.PrecisionTime( + nullability=Type.NULLABILITY_REQUIRED, precision=6 + ) + ) + param_ctx = _parse("precision_time<6>") + assert covers(covered, param_ctx, {}) + # Test with parameter binding + params = {} + param_ctx_with_param = _parse("precision_time

") + assert covers(covered, param_ctx_with_param, params) + assert params["P"] == 6 + + +def test_covers_precision_time_fail(): + """Test precision time coverage (failure).""" + covered = Type( + precision_time=Type.PrecisionTime( + nullability=Type.NULLABILITY_REQUIRED, precision=9 + ) + ) + param_ctx = _parse("precision_time<6>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_interval_day_ok(): + """Test interval_day coverage (success).""" + covered = Type( + interval_day=Type.IntervalDay( + nullability=Type.NULLABILITY_REQUIRED, precision=6 + ) + ) + param_ctx = _parse("interval_day<6>") + assert covers(covered, param_ctx, {}) + # Test with parameter binding + params = {} + param_ctx_with_param = _parse("interval_day

") + assert covers(covered, param_ctx_with_param, params) + assert params["P"] == 6 + + +def test_covers_interval_day_fail(): + """Test interval_day coverage (failure).""" + covered = Type( + interval_day=Type.IntervalDay( + nullability=Type.NULLABILITY_REQUIRED, precision=3 + ) + ) + param_ctx = _parse("interval_day<6>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_map_string_to_i8(): + """Test map type coverage with string keys and i8 values.""" + covered = map_(key=Type(string=Type.String()), value=i8(nullable=False)) + param_ctx = _parse("map") + assert covers(covered, param_ctx, {}) + + +def test_covers_struct_with_two_fields(): + """Test struct type coverage with two i8 fields.""" + covered = struct([i8(nullable=False), i8(nullable=False)]) + param_ctx = _parse("struct") + assert covers(covered, param_ctx, {}) + + +def test_covers_list_of_i16_fails_i8(): + """Test list type coverage failure (i16 vs i8).""" + covered = list_(i16(nullable=False)) + param_ctx = _parse("list") + assert not covers(covered, param_ctx, {}) + + +def test_covers_map_i8_to_i16_fails(): + """Test map type coverage failure (value type mismatch).""" + covered = map_(key=i8(nullable=False), value=i16(nullable=False)) + param_ctx = _parse("map") + assert not covers(covered, param_ctx, {}) + + +def test_covers_struct_mismatched_types_fails(): + """Test struct coverage failure (field type mismatch).""" + covered = struct([i8(nullable=False), i16(nullable=False)]) + param_ctx = _parse("struct") + assert not covers(covered, param_ctx, {}) + + +# Tests for basic Substrait types (non-parameterized) + + +def test_covers_boolean(): + """Test boolean type coverage.""" + covered = Type(bool=Type.Boolean(nullability=Type.NULLABILITY_REQUIRED)) + param_ctx = _parse("boolean") + assert covers(covered, param_ctx, {}) + + +def test_covers_i16(): + """Test i16 type coverage.""" + covered = i16(nullable=False) + param_ctx = _parse("i16") + assert covers(covered, param_ctx, {}) + + +def test_covers_i32(): + """Test i32 type coverage.""" + covered = i32(nullable=False) + param_ctx = _parse("i32") + assert covers(covered, param_ctx, {}) + + +def test_covers_i64(): + """Test i64 type coverage.""" + covered = Type(i64=Type.I64(nullability=Type.NULLABILITY_REQUIRED)) + param_ctx = _parse("i64") + assert covers(covered, param_ctx, {}) + + +def test_covers_fp32(): + """Test fp32 type coverage.""" + covered = Type(fp32=Type.FP32(nullability=Type.NULLABILITY_REQUIRED)) + param_ctx = _parse("fp32") + assert covers(covered, param_ctx, {}) + + +def test_covers_fp64(): + """Test fp64 type coverage.""" + covered = Type(fp64=Type.FP64(nullability=Type.NULLABILITY_REQUIRED)) + param_ctx = _parse("fp64") + assert covers(covered, param_ctx, {}) + + +def test_covers_string(): + """Test string type coverage.""" + covered = Type(string=Type.String(nullability=Type.NULLABILITY_REQUIRED)) + param_ctx = _parse("string") + assert covers(covered, param_ctx, {}) + + +def test_covers_binary(): + """Test binary type coverage.""" + covered = Type(binary=Type.Binary(nullability=Type.NULLABILITY_REQUIRED)) + param_ctx = _parse("binary") + assert covers(covered, param_ctx, {}) + + +def test_covers_timestamp(): + """Test timestamp type coverage.""" + covered = Type(timestamp=Type.Timestamp(nullability=Type.NULLABILITY_REQUIRED)) + param_ctx = _parse("timestamp") + assert covers(covered, param_ctx, {}) + + +def test_covers_timestamp_tz(): + """Test timestamp_tz type coverage.""" + covered = Type(timestamp_tz=Type.TimestampTZ(nullability=Type.NULLABILITY_REQUIRED)) + param_ctx = _parse("timestamp_tz") + assert covers(covered, param_ctx, {}) + + +def test_covers_date(): + """Test date type coverage.""" + covered = Type(date=Type.Date(nullability=Type.NULLABILITY_REQUIRED)) + param_ctx = _parse("date") + assert covers(covered, param_ctx, {}) + + +def test_covers_time(): + """Test time type coverage.""" + covered = Type(time=Type.Time(nullability=Type.NULLABILITY_REQUIRED)) + param_ctx = _parse("time") + assert covers(covered, param_ctx, {}) + + +def test_covers_interval_year(): + """Test interval_year type coverage.""" + covered = Type( + interval_year=Type.IntervalYear(nullability=Type.NULLABILITY_REQUIRED) + ) + param_ctx = _parse("interval_year") + assert covers(covered, param_ctx, {}) + + +def test_covers_interval_compound(): + """Test interval_compound type coverage.""" + covered = Type( + interval_compound=Type.IntervalCompound( + nullability=Type.NULLABILITY_REQUIRED, precision=6 + ) + ) + param_ctx = _parse("interval_compound") + assert covers(covered, param_ctx, {}) + + +def test_covers_uuid(): + """Test uuid type coverage.""" + covered = Type(uuid=Type.UUID(nullability=Type.NULLABILITY_REQUIRED)) + param_ctx = _parse("uuid") + assert covers(covered, param_ctx, {}) + + +# Additional comprehensive tests for parameterized types + + +def test_covers_fixedchar_with_parameter(): + """Test fixedchar with length parameter binding.""" + covered = Type( + fixed_char=Type.FixedChar(nullability=Type.NULLABILITY_REQUIRED, length=20) + ) + params = {} + param_ctx = _parse("fixedchar") + assert covers(covered, param_ctx, params) + assert params["L"] == 20 + + +def test_covers_varchar_with_parameter(): + """Test varchar with length parameter binding.""" + covered = Type( + varchar=Type.VarChar(nullability=Type.NULLABILITY_REQUIRED, length=100) + ) + params = {} + param_ctx = _parse("varchar") + assert covers(covered, param_ctx, params) + assert params["L"] == 100 + + +def test_covers_fixedbinary_with_parameter(): + """Test fixedbinary with length parameter binding.""" + covered = Type( + fixed_binary=Type.FixedBinary(nullability=Type.NULLABILITY_REQUIRED, length=16) + ) + params = {} + param_ctx = _parse("fixedbinary") + assert covers(covered, param_ctx, params) + assert params["L"] == 16 + + +def test_covers_decimal_with_both_parameters(): + """Test decimal with precision and scale parameter binding.""" + covered = decimal(precision=38, scale=10, nullable=False) + params = {} + param_ctx = _parse("decimal") + assert covers(covered, param_ctx, params) + assert params["P"] == 38 and params["S"] == 10 + + +def test_covers_list_with_type_parameter(): + """Test list with type parameter.""" + covered = list_(i32(nullable=False)) + param_ctx = _parse("list") + assert covers(covered, param_ctx, {}) + + +def test_covers_map_with_both_types(): + """Test map with key and value types.""" + covered = map_(key=i8(nullable=False), value=i32(nullable=False)) + param_ctx = _parse("map") + assert covers(covered, param_ctx, {}) + + +def test_covers_struct_with_three_fields(): + """Test struct with three fields of different types.""" + covered = struct([i8(nullable=False), i16(nullable=False), i32(nullable=False)]) + param_ctx = _parse("struct") + assert covers(covered, param_ctx, {}) + + +def test_covers_precision_time_with_parameter(): + """Test precision_time with parameter binding.""" + covered = Type( + precision_time=Type.PrecisionTime( + nullability=Type.NULLABILITY_REQUIRED, precision=9 + ) + ) + params = {} + param_ctx = _parse("precision_time

") + assert covers(covered, param_ctx, params) + assert params["P"] == 9 + + +def test_covers_precision_timestamp_with_parameter(): + """Test precision_timestamp with parameter binding.""" + covered = Type( + precision_timestamp=Type.PrecisionTimestamp( + nullability=Type.NULLABILITY_REQUIRED, precision=12 + ) + ) + params = {} + param_ctx = _parse("precision_timestamp

") + assert covers(covered, param_ctx, params) + assert params["P"] == 12 + + +def test_covers_precision_timestamp_tz_with_parameter(): + """Test precision_timestamp_tz with parameter binding.""" + covered = Type( + precision_timestamp_tz=Type.PrecisionTimestampTZ( + nullability=Type.NULLABILITY_REQUIRED, precision=9 + ) + ) + params = {} + param_ctx = _parse("precision_timestamp_tz

") + assert covers(covered, param_ctx, params) + assert params["P"] == 9 + + +def test_covers_interval_day_with_parameter(): + """Test interval_day with parameter binding.""" + covered = Type( + interval_day=Type.IntervalDay( + nullability=Type.NULLABILITY_REQUIRED, precision=9 + ) + ) + params = {} + param_ctx = _parse("interval_day

") + assert covers(covered, param_ctx, params) + assert params["P"] == 9 + + +def test_covers_interval_compound_with_precision(): + """Test interval_compound with precision parameter.""" + covered = Type( + interval_compound=Type.IntervalCompound( + nullability=Type.NULLABILITY_REQUIRED, precision=9 + ) + ) + # Note: interval_compound doesn't have a parameterized syntax in the grammar + # so we just test the basic type coverage + param_ctx = _parse("interval_compound") + assert covers(covered, param_ctx, {}) + + +def test_covers_nested_list(): + """Test nested list type (list of lists).""" + inner_list = list_(i8(nullable=False)) + covered = list_(inner_list) + param_ctx = _parse("list>") + assert covers(covered, param_ctx, {}) + + +def test_covers_map_with_struct_value(): + """Test map with struct as value type.""" + struct_type = struct([i8(nullable=False), i16(nullable=False)]) + covered = map_(key=Type(string=Type.String()), value=struct_type) + param_ctx = _parse("map>") + assert covers(covered, param_ctx, {}) diff --git a/tests/extension_registry/test_urn_uri_mapping.py b/tests/extension_registry/test_urn_uri_mapping.py new file mode 100644 index 0000000..3eb5be6 --- /dev/null +++ b/tests/extension_registry/test_urn_uri_mapping.py @@ -0,0 +1,176 @@ +"""Tests for URN/URI mapping and default extensions.""" + +import yaml + +from substrait.builders.type import i8 +from substrait.extension_registry import ExtensionRegistry + + +def test_registry_uri_urn(): + """Test that URI to URN conversion works via the bimap.""" + urn = "extension:test:bimap" + content_with_urn = f"""%YAML 1.2 +--- +urn: {urn} +scalar_functions: + - name: "test_func" + description: "" + impls: + - args: + - value: i8 + return: i8 +""" + uri = "https://test.example.com/bimap.yaml" + registry = ExtensionRegistry(load_default_extensions=False) + registry.register_extension_dict(yaml.safe_load(content_with_urn), uri=uri) + + assert registry._uri_urn_bimap.get_urn(uri) == urn + assert registry._uri_urn_bimap.get_uri(urn) == uri + + +def test_registry_uri_anchor_lookup(): + """Test that URI anchor lookup works.""" + content_with_urn = """%YAML 1.2 +--- +urn: extension:test:anchor +scalar_functions: [] +""" + uri = "https://test.example.com/anchor.yaml" + registry = ExtensionRegistry(load_default_extensions=False) + registry.register_extension_dict(yaml.safe_load(content_with_urn), uri=uri) + + anchor = registry.lookup_uri_anchor(uri) + assert anchor is not None + assert anchor > 0 + + +def test_registry_default_extensions_have_uri_mappings(): + """Test that default extensions have URI mappings.""" + registry = ExtensionRegistry(load_default_extensions=True) + + # Check that at least one default extension has a URI mapping + urn = "extension:io.substrait:functions_comparison" + uri = registry._uri_urn_bimap.get_uri(urn) + + assert uri is not None + assert "https://github.com/substrait-io/substrait/blob/main/extensions" in uri + assert "functions_comparison.yaml" in uri + + assert registry._uri_urn_bimap.get_urn(uri) == urn + + +def test_registry_default_extensions_lookup_function_multiply(): + """Test that default extensions are loaded and functions can be looked up.""" + registry = ExtensionRegistry(load_default_extensions=True) + + # Test looking up a function from the arithmetic extensions + urn = "extension:io.substrait:functions_arithmetic" + + # Look up a common arithmetic function (e.g., "multiply") + result = registry.lookup_function( + urn=urn, + function_name="multiply", + signature=[i8(nullable=False), i8(nullable=False)], + ) + + assert result is not None, ( + "Failed to lookup 'multiply' function from default extensions" + ) + entry, return_type = result + + # Verify the function entry + assert entry.name == "multiply" + assert entry.urn == urn + assert entry.function_type is not None + assert entry.function_type.value == "scalar" + assert isinstance(entry.anchor, int) + + # Verify the URI-URN mapping exists + uri = registry._uri_urn_bimap.get_uri(urn) + assert uri is not None + assert "https://github.com/substrait-io/substrait/blob/main/extensions" in uri + assert "functions_arithmetic.yaml" in uri + + # Test looking up a function across all URNs without specifying URN + results = registry.list_functions_across_urns( + function_name="multiply", + signature=[i8(nullable=False), i8(nullable=False)], + ) + + assert len(results) > 0, "Failed to find 'multiply' function across all URNs" + + # Verify we found the same function + found_entry = None + for entry, return_type in results: + if entry.urn == urn and entry.name == "multiply": + found_entry = entry + break + + assert found_entry is not None, "multiply function not found in cross-URN search" + assert found_entry.function_type.value == "scalar" + + +def test_registry_default_extensions_lookup_function(): + """Test that default extensions are loaded and functions can be looked up.""" + registry = ExtensionRegistry(load_default_extensions=True) + + # Test looking up a function from the comparison extensions + urn = "extension:io.substrait:functions_comparison" + + # Look up a common comparison function (e.g., "equal") + result = registry.lookup_function( + urn=urn, + function_name="equal", + signature=[i8(nullable=False), i8(nullable=False)], + ) + + assert result is not None, ( + "Failed to lookup 'equal' function from default extensions" + ) + entry, return_type = result + + # Verify the function entry + assert entry.name == "equal" + assert entry.urn == urn + assert entry.function_type is not None + assert entry.function_type.value == "scalar" + assert isinstance(entry.anchor, int) + + # Verify the URI-URN mapping exists + uri = registry._uri_urn_bimap.get_uri(urn) + assert uri is not None + assert "https://github.com/substrait-io/substrait/blob/main/extensions" in uri + assert "functions_comparison.yaml" in uri + + # Test looking up a function across all URNs without specifying URN + results = registry.list_functions_across_urns( + function_name="equal", + signature=[i8(nullable=False), i8(nullable=False)], + ) + + assert len(results) > 0, "Failed to find 'equal' function across all URNs" + + # Verify we found the same function + found_entry = None + for entry, return_type in results: + if entry.urn == urn and entry.name == "equal": + found_entry = entry + break + + assert found_entry is not None, "Equal function not found in cross-URN search" + assert found_entry.function_type.value == "scalar" + + +def test_register_requires_uri(): + """Test that registering requires URI parameter (migration requirement).""" + content = """%YAML 1.2 +--- +urn: extension:test:no_uri +scalar_functions: [] +""" + registry = ExtensionRegistry(load_default_extensions=False) + # This should work fine - URI is required + registry.register_extension_dict( + yaml.safe_load(content), uri="https://test.example.com/test.yaml" + ) + assert registry.lookup_urn("extension:test:no_uri") is not None diff --git a/tests/extension_registry/test_validation.py b/tests/extension_registry/test_validation.py new file mode 100644 index 0000000..1cdd88c --- /dev/null +++ b/tests/extension_registry/test_validation.py @@ -0,0 +1,70 @@ +"""Tests for URN format validation.""" + +import pytest +import yaml + +from substrait.extension_registry import ExtensionRegistry + + +def test_valid_urn_format(): + """Test that valid URN formats are accepted.""" + content = """%YAML 1.2 +--- +urn: extension:io.substrait:functions_test +scalar_functions: + - name: "test_func" + description: "Test function" + impls: + - args: + - value: i8 + return: i8 +""" + registry = ExtensionRegistry(load_default_extensions=False) + + registry.register_extension_dict( + yaml.safe_load(content), uri="https://test.example.com/functions_test.yaml" + ) # Should not raise + + +def test_invalid_urn_no_prefix(): + """Test that URN without 'extension:' prefix is rejected.""" + content = """%YAML 1.2 +--- +urn: io.substrait:functions_test +scalar_functions: [] +""" + registry = ExtensionRegistry(load_default_extensions=False) + + with pytest.raises(ValueError, match="Invalid URN format"): + registry.register_extension_dict( + yaml.safe_load(content), uri="https://test.example.com/invalid.yaml" + ) + + +def test_invalid_urn_too_short(): + """Test that URN with insufficient parts is rejected.""" + content = """%YAML 1.2 +--- +urn: extension:test +scalar_functions: [] +""" + registry = ExtensionRegistry(load_default_extensions=False) + + with pytest.raises(ValueError, match="Invalid URN format"): + registry.register_extension_dict( + yaml.safe_load(content), uri="https://test.example.com/invalid.yaml" + ) + + +def test_missing_urn(): + """Test that missing URN field raises ValueError.""" + content = """%YAML 1.2 +--- +scalar_functions: [] +""" + registry = ExtensionRegistry(load_default_extensions=False) + + with pytest.raises(ValueError, match="must contain a 'urn' field"): + registry.register_extension_dict( + yaml.safe_load(content), uri="https://test.example.com/missing_urn.yaml" + ) diff --git a/tests/test_extension_registry.py b/tests/test_extension_registry.py deleted file mode 100644 index 4f7b1da..0000000 --- a/tests/test_extension_registry.py +++ /dev/null @@ -1,795 +0,0 @@ -import textwrap - -import pytest -import yaml - -from substrait.builders.type import ( - decimal, - i8, - i16, - i32, - struct, -) -from substrait.builders.type import ( - list as list_, -) -from substrait.builders.type import ( - map as map_, -) -from substrait.derivation_expression import _parse -from substrait.extension_registry import ExtensionRegistry, covers -from substrait.gen.proto.type_pb2 import Type - -content = """%YAML 1.2 ---- -urn: extension:test:functions -scalar_functions: - - name: "test_fn" - description: "" - impls: - - args: - - value: i8 - variadic: - min: 2 - return: i8 - - name: "test_fn_variadic_any" - description: "" - impls: - - args: - - value: any1 - variadic: - min: 2 - return: any1 - - name: "add" - description: "Add two values." - impls: - - args: - - name: x - value: i8 - - name: y - value: i8 - options: - overflow: - values: [ SILENT, SATURATE, ERROR ] - return: i8 - - args: - - name: x - value: i8 - - name: y - value: i8 - - name: z - value: any - options: - overflow: - values: [ SILENT, SATURATE, ERROR ] - return: i16 - - args: - - name: x - value: any1 - - name: y - value: any1 - - name: z - value: any2 - options: - overflow: - values: [ SILENT, SATURATE, ERROR ] - return: any2 - - name: "test_decimal" - impls: - - args: - - name: x - value: decimal - - name: y - value: decimal - return: decimal - - name: "test_enum" - impls: - - args: - - name: op - options: [ INTACT, FLIP ] - - name: x - value: i8 - return: i8 - - name: "add_declared" - description: "Add two values." - impls: - - args: - - name: x - value: i8 - - name: y - value: i8 - nullability: DECLARED_OUTPUT - return: i8? - - name: "add_discrete" - description: "Add two values." - impls: - - args: - - name: x - value: i8? - - name: y - value: i8 - nullability: DISCRETE - return: i8? - - name: "test_decimal_discrete" - impls: - - args: - - name: x - value: decimal? - - name: y - value: decimal - nullability: DISCRETE - return: decimal? - - name: "equal_test" - impls: - - args: - - name: x - value: any - - name: y - value: any - nullability: DISCRETE - return: any -""" - - -registry = ExtensionRegistry(load_default_extensions=True) - -registry.register_extension_dict( - yaml.safe_load(content), - uri="https://test.example.com/extension_test_functions.yaml", -) - - -def test_non_existing_urn(): - assert ( - registry.lookup_function( - urn="non_existent", - function_name="add", - signature=[i8(nullable=False), i8(nullable=False)], - ) - is None - ) - - -def test_non_existing_function(): - assert ( - registry.lookup_function( - urn="extension:test:functions", - function_name="sub", - signature=[i8(nullable=False), i8(nullable=False)], - ) - is None - ) - - -def test_non_existing_function_signature(): - assert ( - registry.lookup_function( - urn="extension:test:functions", - function_name="add", - signature=[i8(nullable=False)], - ) - is None - ) - - -def test_exact_match(): - assert registry.lookup_function( - urn="extension:test:functions", - function_name="add", - signature=[i8(nullable=False), i8(nullable=False)], - )[1] == Type(i8=Type.I8(nullability=Type.NULLABILITY_REQUIRED)) - - -def test_wildcard_match(): - assert registry.lookup_function( - urn="extension:test:functions", - function_name="add", - signature=[i8(nullable=False), i8(nullable=False), bool()], - )[1] == Type(i16=Type.I16(nullability=Type.NULLABILITY_REQUIRED)) - - -def test_wildcard_match_fails_with_constraits(): - assert ( - registry.lookup_function( - urn="extension:test:functions", - function_name="add", - signature=[i8(nullable=False), i16(nullable=False), i16(nullable=False)], - ) - is None - ) - - -def test_wildcard_match_with_constraits(): - assert registry.lookup_function( - urn="extension:test:functions", - function_name="add", - signature=[i16(nullable=False), i16(nullable=False), i8(nullable=False)], - )[1] == i8(nullable=False) - - -def test_variadic(): - assert registry.lookup_function( - urn="extension:test:functions", - function_name="test_fn", - signature=[i8(nullable=False), i8(nullable=False), i8(nullable=False)], - )[1] == i8(nullable=False) - - -def test_variadic_any(): - assert registry.lookup_function( - urn="extension:test:functions", - function_name="test_fn_variadic_any", - signature=[i16(nullable=False), i16(nullable=False), i16(nullable=False)], - )[1] == i16(nullable=False) - - -def test_variadic_fails_min_constraint(): - assert ( - registry.lookup_function( - urn="extension:test:functions", - function_name="test_fn", - signature=[i8(nullable=False)], - ) - is None - ) - - -def test_decimal_happy_path(): - assert registry.lookup_function( - urn="extension:test:functions", - function_name="test_decimal", - signature=[decimal(8, 10, nullable=False), decimal(6, 8, nullable=False)], - )[1] == decimal(7, 11, nullable=False) - - -def test_decimal_violates_constraint(): - assert ( - registry.lookup_function( - urn="extension:test:functions", - function_name="test_decimal", - signature=[decimal(8, 10, nullable=False), decimal(10, 12, nullable=False)], - ) - is None - ) - - -def test_decimal_happy_path_discrete(): - assert registry.lookup_function( - urn="extension:test:functions", - function_name="test_decimal_discrete", - signature=[decimal(8, 10, nullable=True), decimal(6, 8, nullable=False)], - )[1] == decimal(7, 11, nullable=True) - - -def test_enum_with_valid_option(): - assert registry.lookup_function( - urn="extension:test:functions", - function_name="test_enum", - signature=["FLIP", i8(nullable=False)], - )[1] == i8(nullable=False) - - -def test_enum_with_nonexistent_option(): - assert ( - registry.lookup_function( - urn="extension:test:functions", - function_name="test_enum", - signature=["NONEXISTENT", i8(nullable=False)], - ) - is None - ) - - -def test_function_with_nullable_args(): - assert registry.lookup_function( - urn="extension:test:functions", - function_name="add", - signature=[i8(nullable=True), i8(nullable=False)], - )[1] == i8(nullable=True) - - -def test_function_with_declared_output_nullability(): - assert registry.lookup_function( - urn="extension:test:functions", - function_name="add_declared", - signature=[i8(nullable=False), i8(nullable=False)], - )[1] == i8(nullable=True) - - -def test_function_with_discrete_nullability(): - assert registry.lookup_function( - urn="extension:test:functions", - function_name="add_discrete", - signature=[i8(nullable=True), i8(nullable=False)], - )[1] == i8(nullable=True) - - -def test_function_with_discrete_nullability_nonexisting(): - assert ( - registry.lookup_function( - urn="extension:test:functions", - function_name="add_discrete", - signature=[i8(nullable=False), i8(nullable=False)], - ) - is None - ) - - -def test_covers(): - params = {} - assert covers(i8(nullable=False), _parse("i8"), params) - assert params == {} - - -def test_covers_nullability(): - assert not covers(i8(nullable=True), _parse("i8"), {}, check_nullability=True) - assert covers(i8(nullable=True), _parse("i8?"), {}, check_nullability=True) - - -def test_covers_decimal(nullable=False): - assert not covers(decimal(8, 10), _parse("decimal<11, A>"), {}) - assert covers(decimal(8, 10), _parse("decimal<10, A>"), {}) - assert covers(decimal(8, 10), _parse("decimal<10, 8>"), {}) - assert not covers(decimal(8, 10), _parse("decimal<10, 9>"), {}) - assert not covers(decimal(8, 10), _parse("decimal<11, 8>"), {}) - assert not covers(decimal(8, 10), _parse("decimal<11, 9>"), {}) - - -def test_covers_decimal_happy_path(): - params = {} - assert covers(decimal(8, 10), _parse("decimal<10, A>"), params) - assert params == {"A": 8} - - -def test_covers_any(): - assert covers(decimal(8, 10), _parse("any"), {}) - - -def test_covers_varchar_length_ok(): - covered = Type( - varchar=Type.VarChar(nullability=Type.NULLABILITY_REQUIRED, length=15) - ) - param_ctx = _parse("varchar<15>") - assert covers(covered, param_ctx, {}, check_nullability=True) - - -def test_covers_varchar_length_fail(): - covered = Type( - varchar=Type.VarChar(nullability=Type.NULLABILITY_REQUIRED, length=10) - ) - param_ctx = _parse("varchar<5>") - assert not covers(covered, param_ctx, {}) - - -def test_covers_varchar_nullability(): - covered = Type( - varchar=Type.VarChar(nullability=Type.NULLABILITY_REQUIRED, length=10) - ) - param_tx = _parse("varchar?<10>") - assert covers(covered, param_tx, {}) - assert not covers(covered, param_tx, {}, True) - param_ctx2 = _parse("varchar<10>") - assert covers(covered, param_ctx2, {}, True) - - -def test_covers_fixed_char_length_ok(): - covered = Type( - fixed_char=Type.FixedChar(nullability=Type.NULLABILITY_REQUIRED, length=8) - ) - param_ctx = _parse("fixedchar<8>") - assert covers(covered, param_ctx, {}) - - -def test_covers_fixed_char_length_fail(): - covered = Type( - fixed_char=Type.FixedChar(nullability=Type.NULLABILITY_REQUIRED, length=8) - ) - param_ctx = _parse("fixedchar<4>") - assert not covers(covered, param_ctx, {}) - - -def test_covers_fixed_binary_length_ok(): - covered = Type( - fixed_binary=Type.FixedBinary(nullability=Type.NULLABILITY_REQUIRED, length=16) - ) - param_ctx = _parse("fixedbinary<16>") - assert covers(covered, param_ctx, {}) - - -def test_covers_fixed_binary_length_fail(): - covered = Type( - fixed_binary=Type.FixedBinary(nullability=Type.NULLABILITY_REQUIRED, length=16) - ) - param_ctx = _parse("fixedbinary<10>") - assert not covers(covered, param_ctx, {}) - - -def test_covers_decimal_precision_scale_fail(): - covered = decimal(8, 10, nullable=False) - param_ctx = _parse("decimal<6, 5>") - assert not covers(covered, param_ctx, {}) - - -def test_covers_precision_timestamp_ok(): - covered = Type( - precision_timestamp=Type.PrecisionTimestamp( - nullability=Type.NULLABILITY_REQUIRED, precision=5 - ) - ) - param_ctx = _parse("precision_timestamp<5>") - assert covers(covered, param_ctx, {}) - param_ctx = _parse("precision_timestamp") - assert covers(covered, param_ctx, {}) - - -def test_covers_precision_timestamp_fail(): - covered = Type( - precision_timestamp=Type.PrecisionTimestamp( - nullability=Type.NULLABILITY_REQUIRED, precision=3 - ) - ) - param_ctx = _parse("precision_timestamp<2>") - assert not covers(covered, param_ctx, {}) - - -def test_covers_precision_timestamp_tz_ok(): - covered = Type( - precision_timestamp_tz=Type.PrecisionTimestampTZ( - nullability=Type.NULLABILITY_REQUIRED, precision=4 - ) - ) - param_ctx = _parse("precision_timestamp_tz<4>") - assert covers(covered, param_ctx, {}) - param_ctx = _parse("precision_timestamp_tz") - assert covers(covered, param_ctx, {}) - - -def test_covers_precision_timestamp_tz_fail(): - covered = Type( - precision_timestamp_tz=Type.PrecisionTimestampTZ( - nullability=Type.NULLABILITY_REQUIRED, precision=4 - ) - ) - param_ctx = _parse("precision_timestamp_tz<3>") - assert not covers(covered, param_ctx, {}) - - -def test_registry_uri_urn(): - """Test that URI to URN conversion works via the bimap.""" - urn = "extension:test:bimap" - content_with_urn = f"""%YAML 1.2 ---- -urn: {urn} -scalar_functions: - - name: "test_func" - description: "" - impls: - - args: - - value: i8 - return: i8 -""" - uri = "https://test.example.com/bimap.yaml" - registry = ExtensionRegistry(load_default_extensions=False) - registry.register_extension_dict(yaml.safe_load(content_with_urn), uri=uri) - - assert registry._uri_urn_bimap.get_urn(uri) == urn - assert registry._uri_urn_bimap.get_uri(urn) == uri - - -def test_registry_uri_anchor_lookup(): - """Test that URI anchor lookup works.""" - content_with_urn = """%YAML 1.2 ---- -urn: extension:test:anchor -scalar_functions: [] -""" - uri = "https://test.example.com/anchor.yaml" - registry = ExtensionRegistry(load_default_extensions=False) - registry.register_extension_dict(yaml.safe_load(content_with_urn), uri=uri) - - anchor = registry.lookup_uri_anchor(uri) - assert anchor is not None - assert anchor > 0 - - -def test_registry_default_extensions_have_uri_mappings(): - """Test that default extensions have URI mappings.""" - registry = ExtensionRegistry(load_default_extensions=True) - - # Check that at least one default extension has a URI mapping - urn = "extension:io.substrait:functions_comparison" - uri = registry._uri_urn_bimap.get_uri(urn) - - assert uri is not None - assert "https://github.com/substrait-io/substrait/blob/main/extensions" in uri - assert "functions_comparison.yaml" in uri - - assert registry._uri_urn_bimap.get_urn(uri) == urn - - -def test_valid_urn_format(): - """Test that valid URN formats are accepted.""" - content = """%YAML 1.2 ---- -urn: extension:io.substrait:functions_test -scalar_functions: - - name: "test_func" - description: "Test function" - impls: - - args: - - value: i8 - return: i8 -""" - registry = ExtensionRegistry(load_default_extensions=False) - - registry.register_extension_dict( - yaml.safe_load(content), uri="https://test.example.com/functions_test.yaml" - ) # Should not raise - - -def test_invalid_urn_no_prefix(): - """Test that URN without 'extension:' prefix is rejected.""" - content = """%YAML 1.2 ---- -urn: io.substrait:functions_test -scalar_functions: [] -""" - registry = ExtensionRegistry(load_default_extensions=False) - - with pytest.raises(ValueError, match="Invalid URN format"): - registry.register_extension_dict( - yaml.safe_load(content), uri="https://test.example.com/invalid.yaml" - ) - - -def test_invalid_urn_too_short(): - """Test that URN with insufficient parts is rejected.""" - content = """%YAML 1.2 ---- -urn: extension:test -scalar_functions: [] -""" - registry = ExtensionRegistry(load_default_extensions=False) - - with pytest.raises(ValueError, match="Invalid URN format"): - registry.register_extension_dict( - yaml.safe_load(content), uri="https://test.example.com/invalid.yaml" - ) - - -def test_missing_urn(): - """Test that missing URN field raises ValueError.""" - content = """%YAML 1.2 ---- -scalar_functions: [] -""" - registry = ExtensionRegistry(load_default_extensions=False) - - with pytest.raises(ValueError, match="must contain a 'urn' field"): - registry.register_extension_dict( - yaml.safe_load(content), uri="https://test.example.com/missing_urn.yaml" - ) - - -def test_register_requires_uri(): - """Test that registering an extension requires a URI during migration.""" - content = """%YAML 1.2 ---- -urn: extension:example:test -scalar_functions: [] -""" - registry = ExtensionRegistry(load_default_extensions=False) - - # During migration, URI is required - this should fail with TypeError - with pytest.raises(TypeError): - registry.register_extension_dict(yaml.safe_load(content)) - - -def test_covers_map_string_to_i8(): - """Test that a map with string keys and i8 values covers map.""" - covered = map_( - key=Type(string=Type.String(nullability=Type.NULLABILITY_REQUIRED)), - value=i8(nullable=False), - nullable=False, - ) - param_ctx = _parse("map") - assert covers(covered, param_ctx, {}) - - -def test_covers_struct_with_two_fields(): - """Test that a struct with two i8 fields covers struct.""" - covered = struct([i8(nullable=False), i8(nullable=False)], nullable=False) - param_ctx = _parse("struct") - assert covers(covered, param_ctx, {}) - - -def test_covers_list_of_i16_fails_i8(): - """Test that a list of i16 does not cover list.""" - covered = list_(i16(nullable=False), nullable=False) - param_ctx = _parse("list") - assert not covers(covered, param_ctx, {}) - - -def test_covers_map_i8_to_i16_fails(): - """Test that a map with i8 keys and i16 values does not cover map.""" - covered = map_( - key=i8(nullable=False), - value=i16(nullable=False), - nullable=False, - ) - param_ctx = _parse("map") - assert not covers(covered, param_ctx, {}) - - -def test_covers_struct_mismatched_types_fails(): - """Test that a struct with mismatched field types does not cover struct.""" - covered = struct([i32(nullable=False), i8(nullable=False)], nullable=False) - param_ctx = _parse("struct") - assert not covers(covered, param_ctx, {}) - - -@pytest.mark.parametrize( - "test_case", - [ - # Scalar functions - pytest.param( - { - "yaml_content": textwrap.dedent("""\ - %YAML 1.2 - --- - urn: extension:test:scalar_funcs - scalar_functions: - - name: "add" - description: "Add two numbers" - impls: - - args: - - value: i8 - - value: i8 - return: i8 - """), - "urn": "extension:test:scalar_funcs", - "func_name": "add", - "signature": [i8(nullable=False), i8(nullable=False)], - "expected_type": "scalar", - }, - id="scalar-add", - ), - pytest.param( - { - "yaml_content": textwrap.dedent("""\ - %YAML 1.2 - --- - urn: extension:test:scalar_funcs - scalar_functions: - - name: "test_fn" - description: "" - impls: - - args: - - value: i8 - variadic: - min: 2 - return: i8 - """), - "urn": "extension:test:scalar_funcs", - "func_name": "test_fn", - "signature": [i8(nullable=False), i8(nullable=False)], - "expected_type": "scalar", - }, - id="scalar-test_fn", - ), - # Aggregate functions - pytest.param( - { - "yaml_content": textwrap.dedent("""\ - %YAML 1.2 - --- - urn: extension:test:agg_funcs - aggregate_functions: - - name: "count" - description: "Count non-null values" - impls: - - args: - - value: i8 - return: i64 - """), - "urn": "extension:test:agg_funcs", - "func_name": "count", - "signature": [i8(nullable=False)], - "expected_type": "aggregate", - }, - id="aggregate-count", - ), - pytest.param( - { - "yaml_content": textwrap.dedent("""\ - %YAML 1.2 - --- - urn: extension:test:agg_funcs - aggregate_functions: - - name: "sum" - description: "Sum values" - impls: - - args: - - value: i8 - return: i64 - """), - "urn": "extension:test:agg_funcs", - "func_name": "sum", - "signature": [i8(nullable=False)], - "expected_type": "aggregate", - }, - id="aggregate-sum", - ), - # Window functions - pytest.param( - { - "yaml_content": textwrap.dedent("""\ - %YAML 1.2 - --- - urn: extension:test:window_funcs - window_functions: - - name: "row_number" - description: "Assign row numbers" - impls: - - args: [] - return: i64 - """), - "urn": "extension:test:window_funcs", - "func_name": "row_number", - "signature": [], - "expected_type": "window", - }, - id="window-row_number", - ), - pytest.param( - { - "yaml_content": textwrap.dedent("""\ - %YAML 1.2 - --- - urn: extension:test:window_funcs - window_functions: - - name: "rank" - description: "Assign ranks" - impls: - - args: [] - return: i64 - """), - "urn": "extension:test:window_funcs", - "func_name": "rank", - "signature": [], - "expected_type": "window", - }, - id="window-rank", - ), - ], -) -def test_all_function_types_from_yaml(test_case): - """Test that all functions in YAML are registered with correct function_type.value.""" - test_registry = ExtensionRegistry(load_default_extensions=False) - test_registry.register_extension_dict( - yaml.safe_load(test_case["yaml_content"]), - uri=f"https://test.example.com/{test_case['urn'].replace(':', '_')}.yaml", - ) - - result = test_registry.lookup_function( - urn=test_case["urn"], - function_name=test_case["func_name"], - signature=test_case["signature"], - ) - assert result is not None, ( - f"Failed to lookup {test_case['func_name']} in {test_case['urn']}" - ) - entry, _ = result - assert hasattr(entry, "function_type"), ( - f"Entry for {test_case['func_name']} missing function_type attribute" - ) - assert entry.function_type is not None, ( - f"function_type is None for {test_case['func_name']}" - ) - assert isinstance(entry.function_type.value, str), ( - f"function_type.value is not a string for {test_case['func_name']}" - ) - assert entry.function_type.value == test_case["expected_type"], ( - f"Expected function_type.value '{test_case['expected_type']}' " - f"for {test_case['func_name']}, got '{entry.function_type.value}'" - )