From b0162294884eebdec2cfa311b31ce6cb98263992 Mon Sep 17 00:00:00 2001 From: AGV Date: Wed, 10 Sep 2025 09:20:23 +0200 Subject: [PATCH 1/2] feat(write_table): added the write table builder and test Signed-off-by: AGV --- src/substrait/builders/plan.py | 36 +++++++++++++++++++++-- tests/builders/plan/test_write.py | 48 +++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 3 deletions(-) create mode 100644 tests/builders/plan/test_write.py diff --git a/src/substrait/builders/plan.py b/src/substrait/builders/plan.py index a4a2180..6b2ac6a 100644 --- a/src/substrait/builders/plan.py +++ b/src/substrait/builders/plan.py @@ -7,16 +7,16 @@ from typing import Iterable, Optional, Union, Callable -import substrait.gen.proto.algebra_pb2 as stalg from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.extended_expression_pb2 as stee import substrait.gen.proto.plan_pb2 as stp import substrait.gen.proto.type_pb2 as stt -import substrait.gen.proto.extended_expression_pb2 as stee -from substrait.extension_registry import ExtensionRegistry from substrait.builders.extended_expression import ( ExtendedExpressionOrUnbound, resolve_expression, ) +from substrait.extension_registry import ExtensionRegistry from substrait.type_inference import infer_plan_schema from substrait.utils import ( merge_extension_declarations, @@ -379,3 +379,33 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: ) return resolve + + +def write_table( + table_names: Union[str, Iterable[str]], + input: PlanOrUnbound, + create_mode: Union[stalg.WriteRel.CreateMode.ValueType, None] = None, +) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + bound_input = input if isinstance(input, stp.Plan) else input(registry) + ns = infer_plan_schema(bound_input) + _table_names = [table_names] if isinstance(table_names, str) else table_names + _create_mode = create_mode or stalg.WriteRel.CREATE_MODE_ERROR_IF_EXISTS + + write_rel = stalg.Rel( + write=stalg.WriteRel( + input=bound_input.relations[-1].root.input, + table_schema=ns, + op=stalg.WriteRel.WRITE_OP_CTAS, + create_mode=_create_mode, + named_table=stalg.NamedObjectWrite(names=_table_names), + ) + ) + return stp.Plan( + relations=[ + stp.PlanRel(root=stalg.RelRoot(input=write_rel, names=ns.names)) + ], + **_merge_extensions(bound_input), + ) + + return resolve diff --git a/tests/builders/plan/test_write.py b/tests/builders/plan/test_write.py new file mode 100644 index 0000000..dff830e --- /dev/null +++ b/tests/builders/plan/test_write.py @@ -0,0 +1,48 @@ +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.type_pb2 as stt +from substrait.builders.plan import read_named_table, write_table +from substrait.builders.type import boolean, i64 + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) + + +def test_write_rel(): + actual = write_table( + "example_table_write_test", + read_named_table("example_table", named_struct), + )(None) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + write=stalg.WriteRel( + input=stalg.Rel( + read=stalg.ReadRel( + common=stalg.RelCommon( + direct=stalg.RelCommon.Direct() + ), + base_schema=named_struct, + named_table=stalg.ReadRel.NamedTable( + names=["example_table"] + ), + ) + ), + op=stalg.WriteRel.WRITE_OP_CTAS, + table_schema=named_struct, + create_mode=stalg.WriteRel.CREATE_MODE_ERROR_IF_EXISTS, + named_table=stalg.NamedObjectWrite( + names=["example_table_write_test"] + ), + ) + ), + names=["id", "is_applicable"], + ) + ) + ] + ) + assert actual == expected From 1482701b88dda84d9168d064e1418ab8d0ac1960 Mon Sep 17 00:00:00 2001 From: Giovanni Spadaccini Date: Mon, 1 Dec 2025 15:49:28 +0100 Subject: [PATCH 2/2] feat: update proto codegen and add typing_extension dependency --- pyproject.toml | 2 +- src/substrait/gen/json/simple_extensions.py | 26 +++++++++++---------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 677ef57..c1ddd51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ authors = [{name = "Substrait contributors", email = "substrait@googlegroups.com license = {text = "Apache-2.0"} readme = "README.md" requires-python = ">=3.10" -dependencies = ["protobuf >=3.19.1,<6"] +dependencies = ["protobuf >=3.19.1,<6", "typing_extensions"] dynamic = ["version"] [tool.setuptools_scm] diff --git a/src/substrait/gen/json/simple_extensions.py b/src/substrait/gen/json/simple_extensions.py index 2885bb4..e323ac5 100644 --- a/src/substrait/gen/json/simple_extensions.py +++ b/src/substrait/gen/json/simple_extensions.py @@ -7,13 +7,15 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union +from typing_extensions import TypeAlias + class Functions(Enum): INHERITS = 'INHERITS' SEPARATE = 'SEPARATE' -Type = Union[str, Dict[str, Any]] +Type: TypeAlias = Union[str, Dict[str, Any]] class Type1(Enum): @@ -24,7 +26,7 @@ class Type1(Enum): string = 'string' -EnumOptions = List[str] +EnumOptions: TypeAlias = List[str] @dataclass @@ -49,7 +51,7 @@ class TypeArg: description: Optional[str] = None -Arguments = List[Union[EnumerationArg, ValueArg, TypeArg]] +Arguments: TypeAlias = List[Union[EnumerationArg, ValueArg, TypeArg]] @dataclass @@ -58,7 +60,7 @@ class Options1: description: Optional[str] = None -Options = Dict[str, Options1] +Options: TypeAlias = Dict[str, Options1] class ParameterConsistency(Enum): @@ -73,10 +75,10 @@ class VariadicBehavior: parameterConsistency: Optional[ParameterConsistency] = None -Deterministic = bool +Deterministic: TypeAlias = bool -SessionDependent = bool +SessionDependent: TypeAlias = bool class NullabilityHandling(Enum): @@ -85,13 +87,13 @@ class NullabilityHandling(Enum): DISCRETE = 'DISCRETE' -ReturnValue = Type +ReturnValue: TypeAlias = Type -Implementation = Dict[str, str] +Implementation: TypeAlias = Dict[str, str] -Intermediate = Type +Intermediate: TypeAlias = Type class Decomposable(Enum): @@ -100,10 +102,10 @@ class Decomposable(Enum): MANY = 'MANY' -Maxset = float +Maxset: TypeAlias = float -Ordered = bool +Ordered: TypeAlias = bool @dataclass @@ -196,7 +198,7 @@ class TypeParamDef: optional: Optional[bool] = None -TypeParamDefs = List[TypeParamDef] +TypeParamDefs: TypeAlias = List[TypeParamDef] @dataclass