Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ write_to = "src/substrait/_version.py"
[project.optional-dependencies]
extensions = ["antlr4-python3-runtime", "pyyaml"]
gen_proto = ["protobuf == 3.20.1", "protoletariat >= 2.0.0"]
test = ["pytest >= 7.0.0", "antlr4-python3-runtime", "pyyaml"]
sql = ["sqloxide", "deepdiff"]
test = ["pytest >= 7.0.0", "antlr4-python3-runtime", "pyyaml", "sqloxide", "deepdiff", "duckdb<=1.2.2", "datafusion"]

[tool.pytest.ini_options]
pythonpath = "src"
Expand Down
10 changes: 6 additions & 4 deletions src/substrait/builders/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@

def _merge_extensions(*objs):
return {
"extension_uris": merge_extension_uris(*[b.extension_uris for b in objs]),
"extensions": merge_extension_declarations(*[b.extensions for b in objs]),
"extension_uris": merge_extension_uris(*[b.extension_uris for b in objs if b]),
"extensions": merge_extension_declarations(*[b.extensions for b in objs if b]),
}


Expand Down Expand Up @@ -193,13 +193,15 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry)
ns = infer_plan_schema(bound_plan)

bound_offset = resolve_expression(offset, ns, registry)
bound_offset = resolve_expression(offset, ns, registry) if offset else None
bound_count = resolve_expression(count, ns, registry)

rel = stalg.Rel(
fetch=stalg.FetchRel(
input=bound_plan.relations[-1].root.input,
offset_expr=bound_offset.referred_expr[0].expression,
offset_expr=bound_offset.referred_expr[0].expression
if bound_offset
else None,
count_expr=bound_count.referred_expr[0].expression,
)
)
Expand Down
339 changes: 339 additions & 0 deletions src/substrait/sql/sql_to_substrait.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,339 @@
import random
import string
from sqloxide import parse_sql
from substrait.builders.extended_expression import (
UnboundExtendedExpression,
column,
scalar_function,
literal,
aggregate_function,
window_function,
)
from substrait.builders.plan import (
read_named_table,
project,
filter,
sort,
fetch,
set,
join,
aggregate,
)
from substrait.gen.proto import type_pb2 as stt
from substrait.gen.proto import algebra_pb2 as stalg
from substrait.extension_registry import ExtensionRegistry
from typing import Callable
from deepdiff import DeepDiff

SchemaResolver = Callable[[str], stt.NamedStruct]

function_mapping = {
"Plus": ("functions_arithmetic.yaml", "add"),
"Minus": ("functions_arithmetic.yaml", "subtract"),
"Gt": ("functions_comparison.yaml", "gt"),
"GtEq": ("functions_comparison.yaml", "gte"),
"Lt": ("functions_comparison.yaml", "lt"),
"Eq": ("functions_comparison.yaml", "equal"),
}

aggregate_function_mapping = {"SUM": ("functions_arithmetic.yaml", "sum")}

window_function_mapping = {
"row_number": ("functions_arithmetic.yaml", "row_number"),
}


def compare_dicts(dict1, dict2):
diff = DeepDiff(dict1, dict2, exclude_regex_paths=["span"])
return len(diff) == 0


def translate_expression(
ast: dict,
schema_resolver: SchemaResolver,
registry: ExtensionRegistry,
measures: list[UnboundExtendedExpression],
groupings: list[dict],
alias: str = None,
) -> UnboundExtendedExpression:
assert len(ast) == 1
op = list(ast.keys())[0]

if groupings:
# This means we are parsing a projection after a grouping
# Loop through used groupings for an identical ast and return it rather than recalculate
for i, f in enumerate(groupings):
if compare_dicts(ast, f):
return column(i, alias=alias)

ast = ast[op]

if op == "Identifier":
return column(ast["value"], alias=alias)
elif op == "UnnamedExpr" or op == "expr" or op == "Unnamed" or op == "Expr":
return translate_expression(
ast,
schema_resolver=schema_resolver,
registry=registry,
measures=measures,
groupings=groupings,
)
elif op == "ExprWithAlias":
return translate_expression(
ast["expr"],
schema_resolver=schema_resolver,
registry=registry,
measures=measures,
groupings=groupings,
alias=ast["alias"]["value"],
)
elif op == "BinaryOp":
expressions = [
translate_expression(
ast["left"],
schema_resolver=schema_resolver,
registry=registry,
measures=measures,
groupings=groupings,
),
translate_expression(
ast["right"],
schema_resolver=schema_resolver,
registry=registry,
measures=measures,
groupings=groupings,
),
]
func = function_mapping[ast["op"]]
return scalar_function(func[0], func[1], expressions=expressions, alias=alias)
elif op == "Value":
return literal(
int(ast["value"]["Number"][0]), stt.Type(i64=stt.Type.I64()), alias=alias
) # TODO infer type
elif op == "Function":
expressions = [
translate_expression(
e,
schema_resolver=schema_resolver,
registry=registry,
measures=measures,
groupings=groupings,
)
for e in ast["args"]["List"]["args"]
]
name = ast["name"][0]["Identifier"]["value"]

if name in function_mapping:
func = function_mapping[name]
return scalar_function(func[0], func[1], *expressions, alias=alias)
elif name in aggregate_function_mapping:
# All measures need to be extracted out because substrait calculates measures in a separate rel
# We generate a random name for the measure and return a column with that name for the projection to work
# Start by checking if multiple measures are identical and reuse previously generated name
for m in measures:
if compare_dicts(ast, m[1]):
return column(m[2], alias=alias)

func = aggregate_function_mapping[name]
random_name = "".join(
random.choices(string.ascii_uppercase + string.digits, k=5)
) # TODO make this deterministic
aggr = aggregate_function(func[0], func[1], expressions, alias=random_name)
measures.append((aggr, ast, random_name))
return column(random_name, alias=alias)
elif name in window_function_mapping:
func = window_function_mapping[name]

partitions = [
translate_expression(
e,
schema_resolver=schema_resolver,
registry=registry,
measures=measures,
groupings=groupings,
)
for e in ast["over"]["WindowSpec"]["partition_by"]
]

return window_function(
func[0], func[1], expressions, partitions=partitions, alias=alias
)

else:
raise Exception(f"Unknown function {name}")
# elif op == "Wildcard":
# return wildcard()
else:
raise Exception(f"Unknown op {op}")


def translate(ast: dict, schema_resolver: SchemaResolver, registry: ExtensionRegistry):
assert len(ast) == 1
op = list(ast.keys())[0]
ast = ast[op]

if op == "Query":
relation = translate(
ast["body"], schema_resolver=schema_resolver, registry=registry
)

if ast["order_by"]:
expressions = [
translate_expression(
e["expr"],
schema_resolver=schema_resolver,
registry=registry,
measures=None,
groupings=None,
)
for e in ast["order_by"]["kind"]["Expressions"]
]
relation = sort(relation, expressions)(registry)

if ast["limit_clause"]:
limit_expression = translate_expression(
ast["limit_clause"]["LimitOffset"]["limit"],
schema_resolver=schema_resolver,
registry=registry,
measures=None,
groupings=None,
)

if ast["limit_clause"]["LimitOffset"]["offset"]:
offset_expression = translate_expression(
ast["limit_clause"]["LimitOffset"]["offset"]["value"],
schema_resolver=schema_resolver,
registry=registry,
measures=None,
groupings=None,
)
else:
offset_expression = None

relation = fetch(relation, offset_expression, limit_expression)(registry)

return relation
elif op == "Select":
relation = translate(
ast["from"][0]["relation"],
schema_resolver=schema_resolver,
registry=registry,
)

if ast["from"][0]["joins"]:
for _join in ast["from"][0]["joins"]:
join_type_mapping = {
"Inner": stalg.JoinRel.JOIN_TYPE_INNER,
"Left": stalg.JoinRel.JOIN_TYPE_LEFT,
"LeftOuter": stalg.JoinRel.JOIN_TYPE_LEFT,
"RightOuter": stalg.JoinRel.JOIN_TYPE_RIGHT,
"Right": stalg.JoinRel.JOIN_TYPE_RIGHT,
}
right = translate(
_join["relation"],
schema_resolver=schema_resolver,
registry=registry,
)

join_type = list(_join["join_operator"].keys())[0]

expression = translate_expression(
_join["join_operator"][join_type]["On"],
schema_resolver=schema_resolver,
registry=registry,
measures=None,
groupings=None,
)

relation = join(
relation, right, expression, join_type_mapping[join_type]
)(registry)

if "selection" in ast and ast["selection"]:
where_expression = translate_expression(
ast["selection"],
schema_resolver=schema_resolver,
registry=registry,
measures=None,
groupings=None,
)
relation = filter(relation, where_expression)(registry)

if ast["group_by"] and ast["group_by"]["Expressions"][0]:
groupings = ast["group_by"]["Expressions"][0]
grouping_expressions = [
translate_expression(
e,
schema_resolver=schema_resolver,
registry=registry,
measures=None,
groupings=None,
)
for e in groupings
]
else:
groupings = []
grouping_expressions = []

measures = []

projection = [
translate_expression(
p,
schema_resolver=schema_resolver,
registry=registry,
measures=measures,
groupings=groupings,
)
for p in ast["projection"]
]

if ast["having"]:
having_predicate = translate_expression(
ast["having"],
schema_resolver=schema_resolver,
registry=registry,
measures=measures,
groupings=[],
)
else:
having_predicate = None

if measures or groupings:
relation = aggregate(
relation, grouping_expressions, [e[0] for e in measures]
)(registry)

if having_predicate:
relation = filter(relation, having_predicate)(registry)

return project(relation, expressions=projection)(registry)
elif op == "Table":
name = ast["name"][0]["Identifier"]["value"]
return read_named_table(name, schema_resolver(name))
elif op == "SetOperation":
# TODO more than 2 inputs to a set operation
left = translate(
ast["left"], schema_resolver=schema_resolver, registry=registry
)
right = translate(
ast["right"], schema_resolver=schema_resolver, registry=registry
)
if ast["op"] == "Union":
set_op = (
stalg.SetRel.SET_OP_UNION_ALL
if ast["set_quantifier"] == "All"
else stalg.SetRel.SET_OP_UNION_DISTINCT
)
else:
raise Exception("")

return set([left, right], set_op)(registry)
else:
raise Exception(f"Unknown op {op}")


def convert(query: str, dialect: str, schema_resolver: SchemaResolver):
ast = parse_sql(sql=query, dialect=dialect)[0]
registry = ExtensionRegistry(load_default_extensions=True)
return translate(ast, schema_resolver=schema_resolver, registry=registry)
Loading