Skip to content

Commit 7d3532f

Browse files
committed
chore: add compile_join for new compiler
1 parent 0709f17 commit 7d3532f

File tree

10 files changed

+312
-9
lines changed

10 files changed

+312
-9
lines changed

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,21 @@ def compile_filter(
218218
condition = scalar_compiler.compile_scalar_expression(node.predicate)
219219
return child.filter(condition)
220220

221+
@_compile_node.register
222+
def compile_join(
223+
self, node: nodes.JoinNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR
224+
) -> ir.SQLGlotIR:
225+
conditions = tuple(
226+
(left.id.sql, right.id.sql) for left, right in node.conditions
227+
)
228+
229+
return left.join(
230+
right,
231+
join_type=node.type,
232+
conditions=conditions,
233+
join_nulls=node.joins_nulls,
234+
)
235+
221236
@_compile_node.register
222237
def compile_concat(
223238
self, node: nodes.ConcatNode, *children: ir.SQLGlotIR

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,8 @@ def select(
212212
for id, expr in selected_cols
213213
]
214214

215-
new_expr = self._encapsulate_as_cte().select(*selections, append=False)
215+
new_expr, _ = self._encapsulate_as_cte()
216+
new_expr = new_expr.select(*selections, append=False)
216217
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
217218

218219
def order_by(
@@ -247,19 +248,68 @@ def project(
247248
)
248249
for id, expr in projected_cols
249250
]
250-
new_expr = self._encapsulate_as_cte().select(*projected_cols_expr, append=True)
251+
new_expr, _ = self._encapsulate_as_cte()
252+
new_expr = new_expr.select(*projected_cols_expr, append=True)
251253
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
252254

253255
def filter(
254256
self,
255257
condition: sge.Expression,
256258
) -> SQLGlotIR:
257259
"""Filters the query with the given condition."""
258-
new_expr = self._encapsulate_as_cte()
260+
new_expr, _ = self._encapsulate_as_cte()
259261
return SQLGlotIR(
260262
expr=new_expr.where(condition, append=False), uid_gen=self.uid_gen
261263
)
262264

265+
def join(
266+
self,
267+
right: SQLGlotIR,
268+
join_type: typing.Literal["inner", "outer", "left", "right", "cross"],
269+
conditions: tuple[tuple[str, str], ...],
270+
*,
271+
join_nulls: bool = True,
272+
) -> SQLGlotIR:
273+
"""Joins the current query with another SQLGlotIR instance."""
274+
# TODO: add join_nulls support
275+
left_select, left_table = self._encapsulate_as_cte()
276+
right_select, right_table = right._encapsulate_as_cte()
277+
278+
left_ctes = left_select.args.pop("with", [])
279+
right_ctes = right_select.args.pop("with", [])
280+
merged_ctes = [*left_ctes, *right_ctes]
281+
282+
join_on = (
283+
sge.And(
284+
expressions=[
285+
sge.EQ(
286+
this=sge.Column(
287+
this=sge.to_identifier(left_id, quoted=self.quoted),
288+
table=left_table,
289+
),
290+
expression=sge.Column(
291+
this=sge.to_identifier(right_id, quoted=self.quoted),
292+
table=right_table,
293+
),
294+
)
295+
for left_id, right_id in conditions
296+
]
297+
)
298+
if len(list(conditions)) > 0
299+
else None
300+
)
301+
302+
join_type_str = join_type if join_type != "outer" else "full outer"
303+
new_expr = (
304+
sge.Select()
305+
.select(sge.Star())
306+
.from_(left_table)
307+
.join(right_table, on=join_on, join_type=join_type_str)
308+
)
309+
new_expr.set("with", sge.With(expressions=merged_ctes))
310+
311+
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
312+
263313
def insert(
264314
self,
265315
destination: bigquery.TableReference,
@@ -292,7 +342,7 @@ def replace(
292342

293343
def _encapsulate_as_cte(
294344
self,
295-
) -> sge.Select:
345+
) -> typing.Tuple[sge.Select, sge.Table]:
296346
"""Transforms a given sge.Select query by pushing its main SELECT statement
297347
into a new CTE and then generates a 'SELECT * FROM new_cte_name'
298348
for the new query."""
@@ -307,11 +357,10 @@ def _encapsulate_as_cte(
307357
alias=new_cte_name,
308358
)
309359
new_with_clause = sge.With(expressions=[*existing_ctes, new_cte])
310-
new_select_expr = (
311-
sge.Select().select(sge.Star()).from_(sge.Table(this=new_cte_name))
312-
)
360+
new_table_expr = sge.Table(this=new_cte_name)
361+
new_select_expr = sge.Select().select(sge.Star()).from_(new_table_expr)
313362
new_select_expr.set("with", new_with_clause)
314-
return new_select_expr
363+
return new_select_expr, new_table_expr
315364

316365

317366
def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
WITH `bfcte_1` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`,
4+
`rowindex` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_2` AS (
7+
SELECT
8+
`bfcol_1` AS `bfcol_2`,
9+
`bfcol_0` AS `bfcol_3`
10+
FROM `bfcte_1`
11+
), `bfcte_0` AS (
12+
SELECT
13+
`int64_col` AS `bfcol_4`,
14+
`int64_too` AS `bfcol_5`
15+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
16+
), `bfcte_3` AS (
17+
SELECT
18+
`bfcol_4` AS `bfcol_6`,
19+
`bfcol_5` AS `bfcol_7`
20+
FROM `bfcte_0`
21+
), `bfcte_4` AS (
22+
SELECT
23+
*
24+
FROM `bfcte_2`
25+
INNER JOIN `bfcte_3`
26+
ON `bfcte_2`.`bfcol_2` = `bfcte_3`.`bfcol_6`
27+
)
28+
SELECT
29+
`bfcol_3` AS `int64_col`,
30+
`bfcol_7` AS `int64_too`
31+
FROM `bfcte_4`
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
WITH `bfcte_1` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`,
4+
`rowindex` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_2` AS (
7+
SELECT
8+
`bfcol_1` AS `bfcol_2`,
9+
`bfcol_0` AS `bfcol_3`
10+
FROM `bfcte_1`
11+
), `bfcte_0` AS (
12+
SELECT
13+
`int64_col` AS `bfcol_4`,
14+
`int64_too` AS `bfcol_5`
15+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
16+
), `bfcte_3` AS (
17+
SELECT
18+
`bfcol_4` AS `bfcol_6`,
19+
`bfcol_5` AS `bfcol_7`
20+
FROM `bfcte_0`
21+
), `bfcte_4` AS (
22+
SELECT
23+
*
24+
FROM `bfcte_2`
25+
LEFT JOIN `bfcte_3`
26+
ON `bfcte_2`.`bfcol_2` = `bfcte_3`.`bfcol_6`
27+
)
28+
SELECT
29+
`bfcol_3` AS `int64_col`,
30+
`bfcol_7` AS `int64_too`
31+
FROM `bfcte_4`
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
WITH `bfcte_1` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`,
4+
`rowindex` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_2` AS (
7+
SELECT
8+
`bfcol_1` AS `bfcol_2`,
9+
`bfcol_0` AS `bfcol_3`
10+
FROM `bfcte_1`
11+
), `bfcte_0` AS (
12+
SELECT
13+
`int64_col` AS `bfcol_4`,
14+
`int64_too` AS `bfcol_5`
15+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
16+
), `bfcte_3` AS (
17+
SELECT
18+
`bfcol_4` AS `bfcol_6`,
19+
`bfcol_5` AS `bfcol_7`
20+
FROM `bfcte_0`
21+
), `bfcte_4` AS (
22+
SELECT
23+
*
24+
FROM `bfcte_2`
25+
FULL OUTER JOIN `bfcte_3`
26+
ON `bfcte_2`.`bfcol_2` = `bfcte_3`.`bfcol_6`
27+
)
28+
SELECT
29+
`bfcol_3` AS `int64_col`,
30+
`bfcol_7` AS `int64_too`
31+
FROM `bfcte_4`
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
WITH `bfcte_1` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`,
4+
`rowindex` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_2` AS (
7+
SELECT
8+
`bfcol_1` AS `bfcol_2`,
9+
`bfcol_0` AS `bfcol_3`
10+
FROM `bfcte_1`
11+
), `bfcte_0` AS (
12+
SELECT
13+
`int64_col` AS `bfcol_4`,
14+
`int64_too` AS `bfcol_5`
15+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
16+
), `bfcte_3` AS (
17+
SELECT
18+
`bfcol_4` AS `bfcol_6`,
19+
`bfcol_5` AS `bfcol_7`
20+
FROM `bfcte_0`
21+
), `bfcte_4` AS (
22+
SELECT
23+
*
24+
FROM `bfcte_2`
25+
RIGHT JOIN `bfcte_3`
26+
ON `bfcte_2`.`bfcol_2` = `bfcte_3`.`bfcol_6`
27+
)
28+
SELECT
29+
`bfcol_3` AS `int64_col`,
30+
`bfcol_7` AS `int64_too`
31+
FROM `bfcte_4`
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
WITH `bfcte_1` AS (
2+
SELECT
3+
*
4+
FROM UNNEST(ARRAY<STRUCT<`bfcol_0` STRING, `bfcol_1` INT64>>[STRUCT('foo', 1), STRUCT('bar', 2), STRUCT('baz', 3), STRUCT('foo', 5)])
5+
), `bfcte_0` AS (
6+
SELECT
7+
*
8+
FROM UNNEST(ARRAY<STRUCT<`bfcol_2` STRING, `bfcol_3` INT64>>[STRUCT('foo', 5), STRUCT('bar', 6), STRUCT('baz', 7), STRUCT('foo', 8)])
9+
), `bfcte_2` AS (
10+
SELECT
11+
`bfcol_2` AS `bfcol_4`,
12+
`bfcol_3` AS `bfcol_5`
13+
FROM `bfcte_0`
14+
), `bfcte_3` AS (
15+
SELECT
16+
*
17+
FROM `bfcte_1`
18+
CROSS JOIN `bfcte_2`
19+
)
20+
SELECT
21+
`bfcol_0` AS `lkey`,
22+
`bfcol_1` AS `value_x`,
23+
`bfcol_4` AS `rkey`,
24+
`bfcol_5` AS `value_y`
25+
FROM `bfcte_3`
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
WITH `bfcte_1` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`,
4+
`int64_too` AS `bfcol_1`,
5+
`rowindex` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_2` AS (
8+
SELECT
9+
`bfcol_1` AS `bfcol_3`,
10+
`bfcol_2` AS `bfcol_4`,
11+
`bfcol_0` AS `bfcol_5`
12+
FROM `bfcte_1`
13+
), `bfcte_0` AS (
14+
SELECT
15+
`int64_col` AS `bfcol_6`,
16+
`int64_too` AS `bfcol_7`
17+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
18+
), `bfcte_3` AS (
19+
SELECT
20+
`bfcol_7` AS `bfcol_8`,
21+
`bfcol_6` AS `bfcol_9`
22+
FROM `bfcte_0`
23+
), `bfcte_4` AS (
24+
SELECT
25+
*
26+
FROM `bfcte_2`
27+
LEFT JOIN `bfcte_3`
28+
ON `bfcte_2`.`bfcol_3` = `bfcte_3`.`bfcol_8`
29+
)
30+
SELECT
31+
`bfcol_4` AS `rowindex`,
32+
`bfcol_5` AS `int64_col`,
33+
`bfcol_3` AS `int64_too`,
34+
`bfcol_9` AS `col1`
35+
FROM `bfcte_4`
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
import bigframes.pandas as bpd
18+
19+
pytest.importorskip("pytest_snapshot")
20+
21+
22+
@pytest.mark.parametrize(
23+
("how"),
24+
["left", "right", "outer", "inner"],
25+
)
26+
def test_compile_join(scalars_types_df: bpd.DataFrame, how, snapshot):
27+
left = scalars_types_df[["int64_col"]]
28+
right = scalars_types_df.set_index("int64_col")[["int64_too"]]
29+
join = left.join(right, how=how)
30+
snapshot.assert_match(join.sql, "out.sql")
31+
32+
33+
def test_compile_join_w_on(scalars_types_df: bpd.DataFrame, snapshot):
34+
selected_cols = ["int64_col", "int64_too"]
35+
left = scalars_types_df[selected_cols]
36+
right = (
37+
scalars_types_df[selected_cols]
38+
.rename(columns={"int64_col": "col1", "int64_too": "col2"})
39+
.set_index("col2")
40+
)
41+
join = left.join(right, on="int64_too")
42+
snapshot.assert_match(join.sql, "out.sql")
43+
44+
45+
def test_compile_join_by_cross(compiler_session, snapshot):
46+
df1 = bpd.DataFrame(
47+
{"lkey": ["foo", "bar", "baz", "foo"], "value": [1, 2, 3, 5]},
48+
session=compiler_session,
49+
)
50+
df2 = bpd.DataFrame(
51+
{"rkey": ["foo", "bar", "baz", "foo"], "value": [5, 6, 7, 8]},
52+
session=compiler_session,
53+
)
54+
merge = df1.merge(df2, left_on="lkey", right_on="rkey", how="cross")
55+
snapshot.assert_match(merge.sql, "out.sql")

third_party/bigframes_vendored/pandas/core/frame.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4747,7 +4747,7 @@ def merge(
47474747
right:
47484748
Object to merge with.
47494749
how:
4750-
``{'left', 'right', 'outer', 'inner'}, default 'inner'``
4750+
``{'left', 'right', 'outer', 'inner', 'cross'}, default 'inner'``
47514751
Type of merge to be performed.
47524752
``left``: use only keys from left frame, similar to a SQL left outer join;
47534753
preserve key order.

0 commit comments

Comments
 (0)