Skip to content

Commit 1be0f8c

Browse files
committed
refactor: add agg_ops.QcutOp to the sqlglot compiler
1 parent 0e3f2a4 commit 1be0f8c

File tree

4 files changed

+94
-0
lines changed

4 files changed

+94
-0
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,39 @@ def _(
253253
return apply_window_if_present(expr, window)
254254

255255

256+
@UNARY_OP_REGISTRATION.register(agg_ops.QcutOp)
257+
def _(
258+
op: agg_ops.QcutOp,
259+
column: typed_expr.TypedExpr,
260+
window: typing.Optional[window_spec.WindowSpec] = None,
261+
) -> sge.Expression:
262+
percent_ranks = apply_window_if_present(
263+
sge.func("PERCENT_RANK"), window, include_framing_clauses=False
264+
)
265+
if isinstance(op.quantiles, int):
266+
quantiles_sql = ir._literal(op.quantiles, dtypes.INT_DTYPE)
267+
float_bucket = percent_ranks * quantiles_sql
268+
# We need to clip the result to be between 1 and quantiles, so we use LEAST.
269+
ceil_val = sge.func("CEIL", float_bucket)
270+
clipped = sge.func("LEAST", ceil_val, quantiles_sql)
271+
return sge.Sub(this=clipped, expression=sge.convert(1))
272+
else:
273+
case = sge.Case()
274+
first_quantile = ir._literal(
275+
op.quantiles[0], dtypes.infer_literal_type(op.quantiles[0])
276+
)
277+
case = case.when(
278+
sge.LT(this=percent_ranks, expression=first_quantile), sge.Null()
279+
)
280+
for i in range(len(op.quantiles) - 1):
281+
quantile = ir._literal(
282+
op.quantiles[i + 1], dtypes.infer_literal_type(op.quantiles[i + 1])
283+
)
284+
bucket = ir._literal(i, dtypes.INT_DTYPE)
285+
case = case.when(sge.LTE(this=percent_ranks, expression=quantile), bucket)
286+
return case.else_(sge.Null())
287+
288+
256289
@UNARY_OP_REGISTRATION.register(agg_ops.QuantileOp)
257290
def _(
258291
op: agg_ops.QuantileOp,
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
LEAST(CEIL(PERCENT_RANK() OVER () * 4), 4) - 1 AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `int_quantiles`
13+
FROM `bfcte_1`
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE
9+
WHEN PERCENT_RANK() OVER () < 0
10+
THEN NULL
11+
WHEN PERCENT_RANK() OVER () <= 0.25
12+
THEN 0
13+
WHEN PERCENT_RANK() OVER () <= 0.5
14+
THEN 1
15+
WHEN PERCENT_RANK() OVER () <= 0.75
16+
THEN 2
17+
WHEN PERCENT_RANK() OVER () <= 1
18+
THEN 3
19+
ELSE NULL
20+
END AS `bfcol_1`
21+
FROM `bfcte_0`
22+
)
23+
SELECT
24+
`bfcol_1` AS `list_quantiles`
25+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,29 @@ def test_pop_var(scalar_types_df: bpd.DataFrame, snapshot):
392392
snapshot.assert_match(sql_window, "window_out.sql")
393393

394394

395+
def test_qcut(scalar_types_df: bpd.DataFrame, snapshot):
396+
col_name = "int64_col"
397+
bf_df = scalar_types_df[[col_name]]
398+
agg_ops_map = {
399+
"int_quantiles": agg_exprs.UnaryAggregation(
400+
agg_ops.QcutOp(quantiles=4), expression.deref(col_name)
401+
),
402+
"list_quantiles": agg_exprs.UnaryAggregation(
403+
agg_ops.QcutOp(quantiles=tuple([0, 0.25, 0.5, 0.75, 1])),
404+
expression.deref(col_name),
405+
),
406+
}
407+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
408+
sql = _apply_unary_window_op(
409+
bf_df, agg_ops_map["int_quantiles"], window, "int_quantiles"
410+
)
411+
snapshot.assert_match(sql, "int_quantiles.sql")
412+
sql = _apply_unary_window_op(
413+
bf_df, agg_ops_map["list_quantiles"], window, "list_quantiles"
414+
)
415+
snapshot.assert_match(sql, "list_quantiles.sql")
416+
417+
395418
def test_quantile(scalar_types_df: bpd.DataFrame, snapshot):
396419
col_name = "int64_col"
397420
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)