diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index b28c5ede91..3473968450 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -174,7 +174,7 @@ def from_query_string( alias=cte_name, ) select_expr = sge.Select().select(sge.Star()).from_(sge.Table(this=cte_name)) - select_expr.set("with", sge.With(expressions=[cte])) + select_expr = _set_query_ctes(select_expr, [cte]) return cls(expr=select_expr, uid_gen=uid_gen) @classmethod @@ -197,7 +197,8 @@ def from_union( ), f"All provided expressions must be of type sge.Select, but got {type(select)}" select_expr = select.copy() - existing_ctes = [*existing_ctes, *select_expr.args.pop("with", [])] + select_expr, select_ctes = _pop_query_ctes(select_expr) + existing_ctes = [*existing_ctes, *select_ctes] new_cte_name = sge.to_identifier( next(uid_gen.get_uid_stream("bfcte_")), quoted=cls.quoted @@ -229,7 +230,7 @@ def from_union( ), ) final_select_expr = sge.Select().select(sge.Star()).from_(union_expr.subquery()) - final_select_expr.set("with", sge.With(expressions=existing_ctes)) + final_select_expr = _set_query_ctes(final_select_expr, existing_ctes) return cls(expr=final_select_expr, uid_gen=uid_gen) def select( @@ -336,8 +337,8 @@ def join( left_select = _select_to_cte(self.expr, left_cte_name) right_select = _select_to_cte(right.expr, right_cte_name) - left_ctes = left_select.args.pop("with", []) - right_ctes = right_select.args.pop("with", []) + left_select, left_ctes = _pop_query_ctes(left_select) + right_select, right_ctes = _pop_query_ctes(right_select) merged_ctes = [*left_ctes, *right_ctes] join_on = _and( @@ -353,7 +354,7 @@ def join( .from_(sge.Table(this=left_cte_name)) .join(sge.Table(this=right_cte_name), on=join_on, join_type=join_type_str) ) - new_expr.set("with", sge.With(expressions=merged_ctes)) + new_expr = _set_query_ctes(new_expr, merged_ctes) return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) @@ -373,8 +374,8 @@ def isin_join( # Prefer subquery over CTE for the IN clause's right side to improve SQL readability. right_select = right.expr - left_ctes = left_select.args.pop("with", []) - right_ctes = right_select.args.pop("with", []) + left_select, left_ctes = _pop_query_ctes(left_select) + right_select, right_ctes = _pop_query_ctes(right_select) merged_ctes = [*left_ctes, *right_ctes] left_condition = typed_expr.TypedExpr( @@ -415,7 +416,7 @@ def isin_join( .select(sge.Column(this=sge.Star(), table=left_cte_name), new_column) .from_(sge.Table(this=left_cte_name)) ) - new_expr.set("with", sge.With(expressions=merged_ctes)) + new_expr = _set_query_ctes(new_expr, merged_ctes) return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) @@ -625,14 +626,13 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select: into a new CTE and then generates a 'SELECT * FROM new_cte_name' for the new query.""" select_expr = expr.copy() - existing_ctes = select_expr.args.pop("with", []) + select_expr, existing_ctes = _pop_query_ctes(select_expr) new_cte = sge.CTE( this=select_expr, alias=cte_name, ) - new_with_clause = sge.With(expressions=[*existing_ctes, new_cte]) new_select_expr = sge.Select().select(sge.Star()).from_(sge.Table(this=cte_name)) - new_select_expr.set("with", new_with_clause) + new_select_expr = _set_query_ctes(new_select_expr, [*existing_ctes, new_cte]) return new_select_expr @@ -788,3 +788,34 @@ def _join_condition_for_numeric( this=sge.EQ(this=left_2, expression=right_2), expression=sge.EQ(this=left_3, expression=right_3), ) + + +def _set_query_ctes( + expr: sge.Select, + ctes: list[sge.CTE], +) -> sge.Select: + """Sets the CTEs of a given sge.Select expression.""" + new_expr = expr.copy() + with_expr = sge.With(expressions=ctes) if len(ctes) > 0 else None + + if "with" in new_expr.arg_types.keys(): + new_expr.set("with", with_expr) + elif "with_" in new_expr.arg_types.keys(): + new_expr.set("with_", with_expr) + else: + raise ValueError("The expression does not support CTEs.") + return new_expr + + +def _pop_query_ctes( + expr: sge.Select, +) -> tuple[sge.Select, list[sge.CTE]]: + """Pops the CTEs of a given sge.Select expression.""" + if "with" in expr.arg_types.keys(): + expr_ctes = expr.args.pop("with", []) + return expr, expr_ctes + elif "with_" in expr.arg_types.keys(): + expr_ctes = expr.args.pop("with_", []) + return expr, expr_ctes + else: + raise ValueError("The expression does not support CTEs.") diff --git a/setup.py b/setup.py index 28ae99a50d..fa663f66d5 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ "requests >=2.27.1", "shapely >=1.8.5", # 25.20.0 introduces this fix https://github.com/TobikoData/sqlmesh/issues/3095 for rtrim/ltrim. - "sqlglot >=25.20.0, <28.0.0", + "sqlglot >=25.20.0", "tabulate >=0.9", "ipywidgets >=7.7.1", "humanize >=4.6.0",