Skip to content

Commit c9963c1

Browse files
committed
refactor: add ops.clip_op to the sqlglot compiler
1 parent 9bb5d6f commit c9963c1

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,18 @@ def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
6767
return _cast(sg_expr, sg_to_type, op.safe)
6868

6969

70+
@register_ternary_op(ops.clip_op)
71+
def _(
72+
original: TypedExpr,
73+
lower: TypedExpr,
74+
upper: TypedExpr,
75+
) -> sge.Expression:
76+
return sge.Greatest(
77+
this=sge.Least(this=original.expr, expressions=[upper.expr]),
78+
expressions=[lower.expr],
79+
)
80+
81+
7082
@register_unary_op(ops.hash_op)
7183
def _(expr: TypedExpr) -> sge.Expression:
7284
return sge.func("FARM_FINGERPRINT", expr.expr)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`,
4+
`int64_too` AS `bfcol_1`,
5+
`rowindex` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
GREATEST(LEAST(`bfcol_2`, `bfcol_1`), `bfcol_0`) AS `bfcol_3`
11+
FROM `bfcte_0`
12+
)
13+
SELECT
14+
`bfcol_3` AS `result_col`
15+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,22 @@ def test_astype_json_invalid(
168168
)
169169

170170

171+
def test_clip(scalar_types_df: bpd.DataFrame, snapshot):
172+
op_expr = ops.clip_op.as_expr("rowindex", "int64_col", "int64_too")
173+
174+
array_value = scalar_types_df._block.expr
175+
result, col_ids = array_value.compute_values([op_expr])
176+
177+
# Rename columns for deterministic golden SQL results.
178+
assert len(col_ids) == 1
179+
result = result.rename_columns({col_ids[0]: "result_col"}).select_columns(
180+
["result_col"]
181+
)
182+
183+
sql = result.session._executor.to_sql(result, enable_cache=False)
184+
snapshot.assert_match(sql, "out.sql")
185+
186+
171187
def test_hash(scalar_types_df: bpd.DataFrame, snapshot):
172188
col_name = "string_col"
173189
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)