Skip to content

Commit 6f83b54

Browse files
committed
implement eq_null_match_op
1 parent 26f8a77 commit 6f83b54

File tree

3 files changed

+40
-0
lines changed

3 files changed

+40
-0
lines changed

bigframes/core/compile/sqlglot/expressions/binary_compiler.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,26 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
8484
return sge.EQ(this=left_expr, expression=right_expr)
8585

8686

87+
@BINARY_OP_REGISTRATION.register(ops.eq_null_match_op)
88+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
89+
left_expr = left.expr
90+
if left.dtype == dtypes.BOOL_DTYPE and right.dtype != dtypes.BOOL_DTYPE:
91+
left_expr = sge.Cast(this=left_expr, to="INT64")
92+
93+
right_expr = right.expr
94+
if right.dtype == dtypes.BOOL_DTYPE and left.dtype != dtypes.BOOL_DTYPE:
95+
right_expr = sge.Cast(this=right_expr, to="INT64")
96+
97+
sentinel = sge.convert("$NULL_SENTINEL$")
98+
left_coalesce = sge.Coalesce(
99+
this=sge.Cast(this=left_expr, to="STRING"), expressions=[sentinel]
100+
)
101+
right_coalesce = sge.Coalesce(
102+
this=sge.Cast(this=right_expr, to="STRING"), expressions=[sentinel]
103+
)
104+
return sge.EQ(this=left_coalesce, expression=right_coalesce)
105+
106+
87107
@BINARY_OP_REGISTRATION.register(ops.div_op)
88108
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
89109
left_expr = left.expr
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
COALESCE(CAST(`bfcol_1` AS STRING), '$NULL_SENTINEL$') = COALESCE(CAST(CAST(`bfcol_0` AS INT64) AS STRING), '$NULL_SENTINEL$') AS `bfcol_4`
10+
FROM `bfcte_0`
11+
)
12+
SELECT
13+
`bfcol_4` AS `int64_col`
14+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ def test_div_timedelta(scalar_types_df: bpd.DataFrame, snapshot):
102102
snapshot.assert_match(bf_df.sql, "out.sql")
103103

104104

105+
def test_eq_null_match(scalar_types_df: bpd.DataFrame, snapshot):
106+
bf_df = scalar_types_df[["int64_col", "bool_col"]]
107+
sql = _apply_binary_op(bf_df, ops.eq_null_match_op, "int64_col", "bool_col")
108+
snapshot.assert_match(sql, "out.sql")
109+
110+
105111
def test_json_set(json_types_df: bpd.DataFrame, snapshot):
106112
bf_df = json_types_df[["json_col"]]
107113
sql = _apply_binary_op(

0 commit comments

Comments
 (0)