Skip to content

Commit 858df9c

Browse files
committed
refactor: add ops.where_op to the sqlglot compiler
1 parent 6353d6e commit 858df9c

File tree

4 files changed

+40
-1
lines changed

4 files changed

+40
-1
lines changed

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2424

2525
register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
26+
register_ternary_op = scalar_compiler.scalar_op_compiler.register_ternary_op
2627

2728

2829
@register_unary_op(ops.AsTypeOp, pass_op=True)
@@ -94,6 +95,13 @@ def _(expr: TypedExpr) -> sge.Expression:
9495
return sge.Not(this=sge.Is(this=expr.expr, expression=sge.Null()))
9596

9697

98+
@register_ternary_op(ops.where_op)
99+
def _(
100+
original: TypedExpr, condition: TypedExpr, replacement: TypedExpr
101+
) -> sge.Expression:
102+
return sge.If(this=condition.expr, true=original.expr, false=replacement.expr)
103+
104+
97105
# Helper functions
98106
def _cast_to_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
99107
from_type = expr.dtype

tests/system/small/engines/test_generic_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def test_engines_astype_timedelta(scalars_array_value: array_value.ArrayValue, e
314314
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
315315

316316

317-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
317+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
318318
def test_engines_where_op(scalars_array_value: array_value.ArrayValue, engine):
319319
arr, _ = scalars_array_value.compute_values(
320320
[
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`float64_col` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
IF(`bfcol_0`, `bfcol_1`, `bfcol_2`) AS `bfcol_3`
11+
FROM `bfcte_0`
12+
)
13+
SELECT
14+
`bfcol_3` AS `result_col`
15+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,19 @@ def test_map(scalar_types_df: bpd.DataFrame, snapshot):
202202
)
203203

204204
snapshot.assert_match(sql, "out.sql")
205+
206+
207+
def test_where(scalar_types_df: bpd.DataFrame, snapshot):
208+
op_expr = ops.where_op.as_expr("int64_col", "bool_col", "float64_col")
209+
210+
array_value = scalar_types_df._block.expr
211+
result, col_ids = array_value.compute_values([op_expr])
212+
213+
# Rename columns for deterministic golden SQL results.
214+
assert len(col_ids) == 1
215+
result = result.rename_columns({col_ids[0]: "result_col"}).select_columns(
216+
["result_col"]
217+
)
218+
219+
sql = result.session._executor.to_sql(result, enable_cache=False)
220+
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)