Skip to content

Commit 8fb35ff

Browse files
committed
refactor: fix remap variable error in InNode
1 parent 4c98c95 commit 8fb35ff

File tree

6 files changed

+52
-5
lines changed

6 files changed

+52
-5
lines changed

bigframes/core/nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1681,7 +1681,7 @@ class ResultNode(UnaryNode):
16811681
# TODO: CTE definitions
16821682

16831683
def _validate(self):
1684-
for ref, name in self.output_cols:
1684+
for ref, _ in self.output_cols:
16851685
assert ref.id in self.child.ids
16861686

16871687
@property

bigframes/core/rewrite/identifiers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,10 @@ def remap_variables(
5757
new_root = root.transform_children(lambda node: remapped_children[node])
5858

5959
# Step 3: Transform the current node using the mappings from its children.
60+
# "reversed" is required for InNode so that in case of a duplicate column ID,
61+
# the left child's mapping is the one that's kept.
6062
downstream_mappings: dict[identifiers.ColumnId, identifiers.ColumnId] = {
61-
k: v for mapping in new_child_mappings for k, v in mapping.items()
63+
k: v for mapping in reversed(new_child_mappings) for k, v in mapping.items()
6264
}
6365
if isinstance(new_root, nodes.InNode):
6466
new_root = typing.cast(nodes.InNode, new_root)

tests/system/small/engines/test_join.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_engines_join_on_coerced_key(
5555
assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)
5656

5757

58-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
58+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
5959
@pytest.mark.parametrize("join_type", ["left", "inner", "right", "outer"])
6060
def test_engines_join_multi_key(
6161
scalars_array_value: array_value.ArrayValue,
@@ -90,7 +90,7 @@ def test_engines_cross_join(
9090
assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)
9191

9292

93-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
93+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
9494
@pytest.mark.parametrize(
9595
("left_key", "right_key"),
9696
[

tests/system/small/engines/test_slicing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
REFERENCE_ENGINE = polars_executor.PolarsExecutor()
2525

2626

27-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
27+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
2828
@pytest.mark.parametrize(
2929
("start", "stop", "step"),
3030
[
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`,
4+
`float64_col` AS `bfcol_1`,
5+
`rowindex` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
`bfcol_2` AS `bfcol_3`,
10+
`bfcol_0` AS `bfcol_4`,
11+
`bfcol_1` AS `bfcol_5`
12+
FROM `bfcte_0`
13+
), `bfcte_2` AS (
14+
SELECT
15+
`bfcte_1`.*,
16+
EXISTS(
17+
SELECT
18+
1
19+
FROM (
20+
SELECT
21+
`float64_col` AS `bfcol_6`
22+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
23+
) AS `bft_0`
24+
WHERE
25+
COALESCE(`bfcte_1`.`bfcol_4`, 0) = COALESCE(`bft_0`.`bfcol_6`, 0)
26+
AND COALESCE(`bfcte_1`.`bfcol_4`, 1) = COALESCE(`bft_0`.`bfcol_6`, 1)
27+
) AS `bfcol_7`
28+
FROM `bfcte_1`
29+
)
30+
SELECT
31+
`bfcol_3` AS `bfuid_col_1`,
32+
`bfcol_4` AS `int64_col`,
33+
`bfcol_5` AS `float64_col`,
34+
`bfcol_7` AS `bfuid_col_2`
35+
FROM `bfcte_2`

tests/unit/core/compile/sqlglot/test_compile_isin.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,13 @@ def test_compile_isin_not_nullable(scalar_types_df: bpd.DataFrame, snapshot):
3737
scalar_types_df["rowindex_2"].isin(scalar_types_df["rowindex_2"]).to_frame()
3838
)
3939
snapshot.assert_match(bf_isin.sql, "out.sql")
40+
41+
42+
def test_compile_isin_for_array_value(scalar_types_df: bpd.DataFrame, snapshot):
43+
scalars_array_value = scalar_types_df[["int64_col", "float64_col"]]._block.expr
44+
result, _ = scalars_array_value.isin(
45+
scalars_array_value, lcol="int64_col", rcol="float64_col"
46+
)
47+
sql = result.session._executor.to_sql(result, enable_cache=False)
48+
49+
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)