Skip to content

Commit 012a04b

Browse files
committed
refactor: add agg_ops.std_op for the sqlglot compiler
1 parent a0e1e50 commit 012a04b

File tree

5 files changed

+91
-1
lines changed

5 files changed

+91
-1
lines changed

bigframes/core/compile/sqlglot/aggregations/op_registration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,5 @@ def arg_checker(*args, **kwargs):
5252
def __getitem__(self, op: str | agg_ops.WindowOp) -> CompilationFunc:
5353
key = op if isinstance(op, type) else type(op)
5454
if str(key) not in self._registered_ops:
55-
raise ValueError(f"{key} is already not registered")
55+
raise ValueError(f"{key} is not registered")
5656
return self._registered_ops[str(key)]

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,22 @@ def _(
278278
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)
279279

280280

281+
@UNARY_OP_REGISTRATION.register(agg_ops.StdOp)
282+
def _(
283+
op: agg_ops.StdOp,
284+
column: typed_expr.TypedExpr,
285+
window: typing.Optional[window_spec.WindowSpec] = None,
286+
) -> sge.Expression:
287+
expr = column.expr
288+
if column.dtype == dtypes.BOOL_DTYPE:
289+
expr = sge.Cast(this=expr, to="INT64")
290+
291+
expr = sge.func("STDDEV", expr)
292+
if op.should_floor_result or column.dtype == dtypes.TIMEDELTA_DTYPE:
293+
expr = sge.Cast(this=sge.func("FLOOR", expr), to="INT64")
294+
return apply_window_if_present(expr, window)
295+
296+
281297
@UNARY_OP_REGISTRATION.register(agg_ops.ShiftOp)
282298
def _(
283299
op: agg_ops.ShiftOp,
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`duration_col` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
`bfcol_1` AS `bfcol_6`,
11+
`bfcol_0` AS `bfcol_7`,
12+
`bfcol_2` AS `bfcol_8`
13+
FROM `bfcte_0`
14+
), `bfcte_2` AS (
15+
SELECT
16+
STDDEV(`bfcol_6`) AS `bfcol_12`,
17+
STDDEV(CAST(`bfcol_7` AS INT64)) AS `bfcol_13`,
18+
CAST(FLOOR(STDDEV(`bfcol_8`)) AS INT64) AS `bfcol_14`,
19+
CAST(FLOOR(STDDEV(`bfcol_6`)) AS INT64) AS `bfcol_15`
20+
FROM `bfcte_1`
21+
)
22+
SELECT
23+
`bfcol_12` AS `int64_col`,
24+
`bfcol_13` AS `bool_col`,
25+
`bfcol_14` AS `duration_col`,
26+
`bfcol_15` AS `int64_col_w_floor`
27+
FROM `bfcte_2`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE WHEN `bfcol_0` IS NULL THEN NULL ELSE STDDEV(`bfcol_0`) OVER () END AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `agg_int64`
13+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,40 @@ def test_shift(scalar_types_df: bpd.DataFrame, snapshot):
428428
snapshot.assert_match(noop_sql, "noop.sql")
429429

430430

431+
def test_std(scalar_types_df: bpd.DataFrame, snapshot):
432+
col_names = ["int64_col", "bool_col", "duration_col"]
433+
bf_df = scalar_types_df[col_names]
434+
bf_df["duration_col"] = bpd.to_timedelta(bf_df["duration_col"], unit="us")
435+
436+
# The `to_timedelta` creates a new mapping for the column id.
437+
col_names.insert(0, "rowindex")
438+
name2id = {
439+
col_name: col_id
440+
for col_name, col_id in zip(col_names, bf_df._block.expr.column_ids)
441+
}
442+
443+
agg_ops_map = {
444+
"int64_col": agg_ops.StdOp().as_expr(name2id["int64_col"]),
445+
"bool_col": agg_ops.StdOp().as_expr(name2id["bool_col"]),
446+
"duration_col": agg_ops.StdOp().as_expr(name2id["duration_col"]),
447+
"int64_col_w_floor": agg_ops.StdOp(should_floor_result=True).as_expr(
448+
name2id["int64_col"]
449+
),
450+
}
451+
sql = _apply_unary_agg_ops(
452+
bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys())
453+
)
454+
snapshot.assert_match(sql, "out.sql")
455+
456+
# Window tests
457+
col_name = "int64_col"
458+
bf_df_int = scalar_types_df[[col_name]]
459+
agg_expr = agg_ops.StdOp().as_expr(col_name)
460+
window = window_spec.WindowSpec(ordering=(ordering.descending_over(col_name),))
461+
sql_window = _apply_unary_window_op(bf_df_int, agg_expr, window, "agg_int64")
462+
snapshot.assert_match(sql_window, "window_out.sql")
463+
464+
431465
def test_sum(scalar_types_df: bpd.DataFrame, snapshot):
432466
bf_df = scalar_types_df[["int64_col", "bool_col"]]
433467
agg_ops_map = {

0 commit comments

Comments
 (0)