Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 43 additions & 12 deletions bigframes/core/compile/sqlglot/sqlglot_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def from_query_string(
alias=cte_name,
)
select_expr = sge.Select().select(sge.Star()).from_(sge.Table(this=cte_name))
select_expr.set("with", sge.With(expressions=[cte]))
select_expr = _set_query_ctes(select_expr, [cte])
return cls(expr=select_expr, uid_gen=uid_gen)

@classmethod
Expand All @@ -197,7 +197,8 @@ def from_union(
), f"All provided expressions must be of type sge.Select, but got {type(select)}"

select_expr = select.copy()
existing_ctes = [*existing_ctes, *select_expr.args.pop("with", [])]
select_expr, select_ctes = _pop_query_ctes(select_expr)
existing_ctes = [*existing_ctes, *select_ctes]

new_cte_name = sge.to_identifier(
next(uid_gen.get_uid_stream("bfcte_")), quoted=cls.quoted
Expand Down Expand Up @@ -229,7 +230,7 @@ def from_union(
),
)
final_select_expr = sge.Select().select(sge.Star()).from_(union_expr.subquery())
final_select_expr.set("with", sge.With(expressions=existing_ctes))
final_select_expr = _set_query_ctes(final_select_expr, existing_ctes)
return cls(expr=final_select_expr, uid_gen=uid_gen)

def select(
Expand Down Expand Up @@ -336,8 +337,8 @@ def join(
left_select = _select_to_cte(self.expr, left_cte_name)
right_select = _select_to_cte(right.expr, right_cte_name)

left_ctes = left_select.args.pop("with", [])
right_ctes = right_select.args.pop("with", [])
left_select, left_ctes = _pop_query_ctes(left_select)
right_select, right_ctes = _pop_query_ctes(right_select)
merged_ctes = [*left_ctes, *right_ctes]

join_on = _and(
Expand All @@ -353,7 +354,7 @@ def join(
.from_(sge.Table(this=left_cte_name))
.join(sge.Table(this=right_cte_name), on=join_on, join_type=join_type_str)
)
new_expr.set("with", sge.With(expressions=merged_ctes))
new_expr = _set_query_ctes(new_expr, merged_ctes)

return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)

Expand All @@ -373,8 +374,8 @@ def isin_join(
# Prefer subquery over CTE for the IN clause's right side to improve SQL readability.
right_select = right.expr

left_ctes = left_select.args.pop("with", [])
right_ctes = right_select.args.pop("with", [])
left_select, left_ctes = _pop_query_ctes(left_select)
right_select, right_ctes = _pop_query_ctes(right_select)
merged_ctes = [*left_ctes, *right_ctes]

left_condition = typed_expr.TypedExpr(
Expand Down Expand Up @@ -415,7 +416,7 @@ def isin_join(
.select(sge.Column(this=sge.Star(), table=left_cte_name), new_column)
.from_(sge.Table(this=left_cte_name))
)
new_expr.set("with", sge.With(expressions=merged_ctes))
new_expr = _set_query_ctes(new_expr, merged_ctes)

return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)

Expand Down Expand Up @@ -625,14 +626,13 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select:
into a new CTE and then generates a 'SELECT * FROM new_cte_name'
for the new query."""
select_expr = expr.copy()
existing_ctes = select_expr.args.pop("with", [])
select_expr, existing_ctes = _pop_query_ctes(select_expr)
new_cte = sge.CTE(
this=select_expr,
alias=cte_name,
)
new_with_clause = sge.With(expressions=[*existing_ctes, new_cte])
new_select_expr = sge.Select().select(sge.Star()).from_(sge.Table(this=cte_name))
new_select_expr.set("with", new_with_clause)
new_select_expr = _set_query_ctes(new_select_expr, [*existing_ctes, new_cte])
return new_select_expr


Expand Down Expand Up @@ -788,3 +788,34 @@ def _join_condition_for_numeric(
this=sge.EQ(this=left_2, expression=right_2),
expression=sge.EQ(this=left_3, expression=right_3),
)


def _set_query_ctes(
expr: sge.Select,
ctes: list[sge.CTE],
) -> sge.Select:
"""Sets the CTEs of a given sge.Select expression."""
new_expr = expr.copy()
with_expr = sge.With(expressions=ctes) if len(ctes) > 0 else None

if "with" in new_expr.arg_types.keys():
new_expr.set("with", with_expr)
elif "with_" in new_expr.arg_types.keys():
new_expr.set("with_", with_expr)
else:
raise ValueError("The expression does not support CTEs.")
return new_expr


def _pop_query_ctes(
expr: sge.Select,
) -> tuple[sge.Select, list[sge.CTE]]:
"""Pops the CTEs of a given sge.Select expression."""
if "with" in expr.arg_types.keys():
expr_ctes = expr.args.pop("with", [])
return expr, expr_ctes
elif "with_" in expr.arg_types.keys():
expr_ctes = expr.args.pop("with_", [])
return expr, expr_ctes
else:
raise ValueError("The expression does not support CTEs.")
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
"requests >=2.27.1",
"shapely >=1.8.5",
# 25.20.0 introduces this fix https://github.com/TobikoData/sqlmesh/issues/3095 for rtrim/ltrim.
"sqlglot >=25.20.0, <28.0.0",
"sqlglot >=25.20.0",
"tabulate >=0.9",
"ipywidgets >=7.7.1",
"humanize >=4.6.0",
Expand Down