Skip to content

Commit add4b23

Browse files
committed
fix a compatible issue where the sg.union cannot support multiple expressions in the older version of sqlglot
1 parent 67f1833 commit add4b23

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def from_union(
175175
), f"At least two select expressions must be provided, but got {selects}."
176176

177177
existing_ctes: list[sge.CTE] = []
178-
union_selects: list[sge.Select] = []
178+
union_selects: list[sge.Expression] = []
179179
for select in selects:
180180
assert isinstance(
181181
select, sge.Select
@@ -204,10 +204,14 @@ def from_union(
204204
sge.Select().select(*selections).from_(sge.Table(this=new_cte_name))
205205
)
206206

207-
union_expr = sg.union(
208-
*union_selects,
209-
distinct=False,
210-
copy=False,
207+
union_expr = typing.cast(
208+
sge.Select,
209+
functools.reduce(
210+
lambda x, y: sge.Union(
211+
this=x, expression=y, distinct=False, copy=False
212+
),
213+
union_selects,
214+
),
211215
)
212216
final_select_expr = sge.Select().select(sge.Star()).from_(union_expr.subquery())
213217
final_select_expr.set("with", sge.With(expressions=existing_ctes))

tests/unit/core/compile/sqlglot/snapshots/test_compile_concat/test_compile_concat_filter_sorted/out.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ WITH `bfcte_3` AS (
134134
)
135135
)
136136
SELECT
137-
`bfcol_42` AS `bfuid_col_12`,
138-
`bfcol_43` AS `bfuid_col_13`
137+
`bfcol_42` AS `float64_col`,
138+
`bfcol_43` AS `int64_col`
139139
FROM `bfcte_18`
140140
ORDER BY
141141
`bfcol_44` ASC NULLS LAST,

tests/unit/core/compile/sqlglot/test_compile_concat.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def test_compile_concat(scalar_types_df: bpd.DataFrame, snapshot):
2828

2929

3030
def test_compile_concat_filter_sorted(scalar_types_df: bpd.DataFrame, snapshot):
31+
3132
scalars_array_value = scalar_types_df._block.expr
3233
input_1 = scalars_array_value.select_columns(["float64_col", "int64_col"]).order_by(
3334
[ordering.ascending_over("int64_col")]
@@ -37,5 +38,12 @@ def test_compile_concat_filter_sorted(scalar_types_df: bpd.DataFrame, snapshot):
3738
)
3839

3940
result = input_1.concat([input_2, input_1, input_2])
41+
42+
new_names = ["float64_col", "int64_col"]
43+
col_ids = {
44+
old_name: new_name for old_name, new_name in zip(result.column_ids, new_names)
45+
}
46+
result = result.rename_columns(col_ids).select_columns(new_names)
47+
4048
sql = result.session._executor.to_sql(result, enable_cache=False)
4149
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)