Skip to content

Commit e729a25

Browse files
committed
refactor: support ops.case_when_op for the sqlglot compiler
1 parent f9e28fe commit e729a25

File tree

4 files changed

+111
-14
lines changed

4 files changed

+111
-14
lines changed

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2424

2525
register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
26+
register_nary_op = scalar_compiler.scalar_op_compiler.register_nary_op
2627
register_ternary_op = scalar_compiler.scalar_op_compiler.register_ternary_op
2728

2829

@@ -67,18 +68,6 @@ def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
6768
return _cast(sg_expr, sg_to_type, op.safe)
6869

6970

70-
@register_ternary_op(ops.clip_op)
71-
def _(
72-
original: TypedExpr,
73-
lower: TypedExpr,
74-
upper: TypedExpr,
75-
) -> sge.Expression:
76-
return sge.Greatest(
77-
this=sge.Least(this=original.expr, expressions=[upper.expr]),
78-
expressions=[lower.expr],
79-
)
80-
81-
8271
@register_unary_op(ops.hash_op)
8372
def _(expr: TypedExpr) -> sge.Expression:
8473
return sge.func("FARM_FINGERPRINT", expr.expr)
@@ -114,6 +103,44 @@ def _(
114103
return sge.If(this=condition.expr, true=original.expr, false=replacement.expr)
115104

116105

106+
@register_ternary_op(ops.clip_op)
107+
def _(
108+
original: TypedExpr,
109+
lower: TypedExpr,
110+
upper: TypedExpr,
111+
) -> sge.Expression:
112+
return sge.Greatest(
113+
this=sge.Least(this=original.expr, expressions=[upper.expr]),
114+
expressions=[lower.expr],
115+
)
116+
117+
118+
@register_nary_op(ops.case_when_op)
119+
def _(*cases_and_outputs: TypedExpr) -> sge.Expression:
120+
# Need to upcast BOOL to INT if any output is numeric
121+
result_values = cases_and_outputs[1::2]
122+
do_upcast_bool = any(
123+
dtypes.is_numeric(t.dtype, include_bool=False) for t in result_values
124+
)
125+
if do_upcast_bool:
126+
result_values = tuple(
127+
TypedExpr(
128+
sge.Cast(this=val.expr, to="INT64"),
129+
dtypes.INT_DTYPE,
130+
)
131+
if val.dtype == dtypes.BOOL_DTYPE
132+
else val
133+
for val in result_values
134+
)
135+
136+
return sge.Case(
137+
ifs=[
138+
sge.If(this=predicate.expr, true=output.expr)
139+
for predicate, output in zip(cases_and_outputs[::2], result_values)
140+
],
141+
)
142+
143+
117144
# Helper functions
118145
def _cast_to_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
119146
from_type = expr.dtype

tests/system/small/engines/test_generic_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def test_engines_fillna_op(scalars_array_value: array_value.ArrayValue, engine):
357357
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
358358

359359

360-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
360+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
361361
def test_engines_casewhen_op_single_case(
362362
scalars_array_value: array_value.ArrayValue, engine
363363
):
@@ -373,7 +373,7 @@ def test_engines_casewhen_op_single_case(
373373
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
374374

375375

376-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
376+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
377377
def test_engines_casewhen_op_double_case(
378378
scalars_array_value: array_value.ArrayValue, engine
379379
):
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`int64_too` AS `bfcol_2`,
6+
`float64_col` AS `bfcol_3`
7+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
8+
), `bfcte_1` AS (
9+
SELECT
10+
*,
11+
CASE WHEN `bfcol_0` THEN `bfcol_1` END AS `bfcol_4`,
12+
CASE WHEN `bfcol_0` THEN `bfcol_1` WHEN `bfcol_0` THEN `bfcol_2` END AS `bfcol_5`,
13+
CASE WHEN `bfcol_0` THEN `bfcol_0` WHEN `bfcol_0` THEN `bfcol_0` END AS `bfcol_6`,
14+
CASE
15+
WHEN `bfcol_0`
16+
THEN `bfcol_1`
17+
WHEN `bfcol_0`
18+
THEN CAST(`bfcol_0` AS INT64)
19+
WHEN `bfcol_0`
20+
THEN `bfcol_3`
21+
END AS `bfcol_7`
22+
FROM `bfcte_0`
23+
)
24+
SELECT
25+
`bfcol_4` AS `single_case`,
26+
`bfcol_5` AS `double_case`,
27+
`bfcol_6` AS `bool_types_case`,
28+
`bfcol_7` AS `mixed_types_cast`
29+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,47 @@ def test_astype_json_invalid(
168168
)
169169

170170

171+
def test_case_when_op(scalar_types_df: bpd.DataFrame, snapshot):
172+
ops_map = {
173+
"single_case": ops.case_when_op.as_expr(
174+
"bool_col",
175+
"int64_col",
176+
),
177+
"double_case": ops.case_when_op.as_expr(
178+
"bool_col",
179+
"int64_col",
180+
"bool_col",
181+
"int64_too",
182+
),
183+
"bool_types_case": ops.case_when_op.as_expr(
184+
"bool_col",
185+
"bool_col",
186+
"bool_col",
187+
"bool_col",
188+
),
189+
"mixed_types_cast": ops.case_when_op.as_expr(
190+
"bool_col",
191+
"int64_col",
192+
"bool_col",
193+
"bool_col",
194+
"bool_col",
195+
"float64_col",
196+
),
197+
}
198+
199+
array_value = scalar_types_df._block.expr
200+
result, col_ids = array_value.compute_values(list(ops_map.values()))
201+
202+
# Rename columns for deterministic golden SQL results.
203+
assert len(col_ids) == len(ops_map.keys())
204+
result = result.rename_columns(
205+
{col_id: key for col_id, key in zip(col_ids, ops_map.keys())}
206+
).select_columns(list(ops_map.keys()))
207+
208+
sql = result.session._executor.to_sql(result, enable_cache=False)
209+
snapshot.assert_match(sql, "out.sql")
210+
211+
171212
def test_clip(scalar_types_df: bpd.DataFrame, snapshot):
172213
op_expr = ops.clip_op.as_expr("rowindex", "int64_col", "int64_too")
173214

0 commit comments

Comments
 (0)