Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions vortex-python/python/vortex/arrow/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pyarrow as pa
import pyarrow.compute as pc
from substrait.proto import ( # pyright: ignore[reportMissingTypeStubs]
ExtendedExpression, # pyright: ignore[reportAttributeAccessIssue, reportUnknownVariableType]
ExtendedExpression,
)

from vortex._lib.expr import Expr # pyright: ignore[reportMissingModuleSource]
Expand Down Expand Up @@ -46,10 +46,10 @@ def _schema_for_substrait(schema: pa.Schema) -> pa.Schema:

def arrow_to_vortex(arrow_expression: pc.Expression, schema: pa.Schema) -> Expr:
compat_schema = _schema_for_substrait(schema)
substrait_object = ExtendedExpression() # pyright: ignore[reportUnknownVariableType]
substrait_object.ParseFromString(arrow_expression.to_substrait(compat_schema)) # pyright: ignore[reportUnknownMemberType]
substrait_object = ExtendedExpression()
substrait_object.ParseFromString(bytes(arrow_expression.to_substrait(compat_schema))) # pyright: ignore[reportUnusedCallResult]

expressions = extended_expression(substrait_object) # pyright: ignore[reportUnknownArgumentType]
expressions = extended_expression(substrait_object)

if len(expressions) < 0 or len(expressions) > 1:
raise ValueError("arrow_to_vortex: extended expression must have exactly one child")
Expand Down
29 changes: 24 additions & 5 deletions vortex-python/python/vortex/substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,31 @@

import operator
from collections.abc import Callable
from typing import TYPE_CHECKING

from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
from substrait.gen.proto.algebra_pb2 import Expression, FunctionArgument
from substrait.gen.proto.extended_expression_pb2 import ExpressionReference, ExtendedExpression
from substrait.gen.proto.extensions.extensions_pb2 import SimpleExtensionDeclaration, SimpleExtensionURI
from substrait.gen.proto.type_pb2 import NamedStruct

if TYPE_CHECKING:
from substrait.algebra_pb2 import Expression, FunctionArgument
from substrait.extended_expression_pb2 import ExpressionReference, ExtendedExpression
from substrait.extensions.extensions_pb2 import (
SimpleExtensionDeclaration,
SimpleExtensionURI, # pyright: ignore[reportDeprecated]
)
from substrait.type_pb2 import NamedStruct
else:
try:
# substrait >= 0.27
from substrait.algebra_pb2 import Expression, FunctionArgument
from substrait.extended_expression_pb2 import ExpressionReference, ExtendedExpression
from substrait.extensions.extensions_pb2 import SimpleExtensionDeclaration, SimpleExtensionURI
from substrait.type_pb2 import NamedStruct
except ImportError:
# substrait < 0.27
from substrait.gen.proto.algebra_pb2 import Expression, FunctionArgument
from substrait.gen.proto.extended_expression_pb2 import ExpressionReference, ExtendedExpression
from substrait.gen.proto.extensions.extensions_pb2 import SimpleExtensionDeclaration, SimpleExtensionURI
from substrait.gen.proto.type_pb2 import NamedStruct

from ._lib import dtype as _dtype # pyright: ignore[reportMissingModuleSource]
from ._lib import expr as _expr # pyright: ignore[reportMissingModuleSource]
Expand Down Expand Up @@ -150,7 +169,7 @@ def function_argument(

def extension_function(
substrait_object: SimpleExtensionDeclaration.ExtensionFunction,
extension_uris: RepeatedCompositeFieldContainer[SimpleExtensionURI],
extension_uris: RepeatedCompositeFieldContainer[SimpleExtensionURI], # pyright: ignore[reportDeprecated]
) -> Callable[..., _expr.Expr]:
# https://github.com/substrait-io/substrait/blob/main/proto/substrait/extensions/extensions.proto#L57
match extension_uris[substrait_object.extension_uri_reference].uri:
Expand Down
Loading