Skip to content

Commit a6f87a0

Browse files
authored
refactor: add ops.clip_op and where_op to the sqlglot compiler (#2168)
1 parent 6353d6e commit a6f87a0

File tree

5 files changed

+83
-1
lines changed

5 files changed

+83
-1
lines changed

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2424

2525
register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
26+
register_ternary_op = scalar_compiler.scalar_op_compiler.register_ternary_op
2627

2728

2829
@register_unary_op(ops.AsTypeOp, pass_op=True)
@@ -66,6 +67,18 @@ def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
6667
return _cast(sg_expr, sg_to_type, op.safe)
6768

6869

70+
@register_ternary_op(ops.clip_op)
71+
def _(
72+
original: TypedExpr,
73+
lower: TypedExpr,
74+
upper: TypedExpr,
75+
) -> sge.Expression:
76+
return sge.Greatest(
77+
this=sge.Least(this=original.expr, expressions=[upper.expr]),
78+
expressions=[lower.expr],
79+
)
80+
81+
6982
@register_unary_op(ops.hash_op)
7083
def _(expr: TypedExpr) -> sge.Expression:
7184
return sge.func("FARM_FINGERPRINT", expr.expr)
@@ -94,6 +107,13 @@ def _(expr: TypedExpr) -> sge.Expression:
94107
return sge.Not(this=sge.Is(this=expr.expr, expression=sge.Null()))
95108

96109

110+
@register_ternary_op(ops.where_op)
111+
def _(
112+
original: TypedExpr, condition: TypedExpr, replacement: TypedExpr
113+
) -> sge.Expression:
114+
return sge.If(this=condition.expr, true=original.expr, false=replacement.expr)
115+
116+
97117
# Helper functions
98118
def _cast_to_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
99119
from_type = expr.dtype

tests/system/small/engines/test_generic_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def test_engines_astype_timedelta(scalars_array_value: array_value.ArrayValue, e
314314
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
315315

316316

317-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
317+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
318318
def test_engines_where_op(scalars_array_value: array_value.ArrayValue, engine):
319319
arr, _ = scalars_array_value.compute_values(
320320
[
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`,
4+
`int64_too` AS `bfcol_1`,
5+
`rowindex` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
GREATEST(LEAST(`bfcol_2`, `bfcol_1`), `bfcol_0`) AS `bfcol_3`
11+
FROM `bfcte_0`
12+
)
13+
SELECT
14+
`bfcol_3` AS `result_col`
15+
FROM `bfcte_1`
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`float64_col` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
IF(`bfcol_0`, `bfcol_1`, `bfcol_2`) AS `bfcol_3`
11+
FROM `bfcte_0`
12+
)
13+
SELECT
14+
`bfcol_3` AS `result_col`
15+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,22 @@ def test_astype_json_invalid(
168168
)
169169

170170

171+
def test_clip(scalar_types_df: bpd.DataFrame, snapshot):
172+
op_expr = ops.clip_op.as_expr("rowindex", "int64_col", "int64_too")
173+
174+
array_value = scalar_types_df._block.expr
175+
result, col_ids = array_value.compute_values([op_expr])
176+
177+
# Rename columns for deterministic golden SQL results.
178+
assert len(col_ids) == 1
179+
result = result.rename_columns({col_ids[0]: "result_col"}).select_columns(
180+
["result_col"]
181+
)
182+
183+
sql = result.session._executor.to_sql(result, enable_cache=False)
184+
snapshot.assert_match(sql, "out.sql")
185+
186+
171187
def test_hash(scalar_types_df: bpd.DataFrame, snapshot):
172188
col_name = "string_col"
173189
bf_df = scalar_types_df[[col_name]]
@@ -202,3 +218,19 @@ def test_map(scalar_types_df: bpd.DataFrame, snapshot):
202218
)
203219

204220
snapshot.assert_match(sql, "out.sql")
221+
222+
223+
def test_where(scalar_types_df: bpd.DataFrame, snapshot):
224+
op_expr = ops.where_op.as_expr("int64_col", "bool_col", "float64_col")
225+
226+
array_value = scalar_types_df._block.expr
227+
result, col_ids = array_value.compute_values([op_expr])
228+
229+
# Rename columns for deterministic golden SQL results.
230+
assert len(col_ids) == 1
231+
result = result.rename_columns({col_ids[0]: "result_col"}).select_columns(
232+
["result_col"]
233+
)
234+
235+
sql = result.session._executor.to_sql(result, enable_cache=False)
236+
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)