Skip to content

Commit 73a47eb

Browse files
committed
refactor: add agg_ops.var_op for the sqlglot compiler
1 parent 012a04b commit 73a47eb

File tree

5 files changed

+65
-1
lines changed

5 files changed

+65
-1
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,3 +347,17 @@ def _(
347347
expression=shifted,
348348
unit=sge.Identifier(this="MICROSECOND"),
349349
)
350+
351+
352+
@UNARY_OP_REGISTRATION.register(agg_ops.VarOp)
353+
def _(
354+
op: agg_ops.VarOp,
355+
column: typed_expr.TypedExpr,
356+
window: typing.Optional[window_spec.WindowSpec] = None,
357+
) -> sge.Expression:
358+
expr = column.expr
359+
if column.dtype == dtypes.BOOL_DTYPE:
360+
expr = sge.Cast(this=expr, to="INT64")
361+
362+
expr = sge.func("VAR_SAMP", expr)
363+
return apply_window_if_present(expr, window)

tests/system/small/engines/test_aggregation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def test_engines_unary_aggregates(
111111
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)
112112

113113

114-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
114+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
115115
@pytest.mark.parametrize(
116116
"op",
117117
[agg_ops.std_op, agg_ops.var_op, agg_ops.PopVarOp()],
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
VARIANCE(`bfcol_1`) AS `bfcol_4`,
9+
VARIANCE(CAST(`bfcol_0` AS INT64)) AS `bfcol_5`
10+
FROM `bfcte_0`
11+
)
12+
SELECT
13+
`bfcol_4` AS `int64_col`,
14+
`bfcol_5` AS `bool_col`
15+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE WHEN `bfcol_0` IS NULL THEN NULL ELSE VARIANCE(`bfcol_0`) OVER () END AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `agg_int64`
13+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,3 +502,25 @@ def test_time_series_diff(scalar_types_df: bpd.DataFrame, snapshot):
502502
)
503503
sql = _apply_unary_window_op(bf_df, op, window, "diff_time")
504504
snapshot.assert_match(sql, "out.sql")
505+
506+
507+
def test_var(scalar_types_df: bpd.DataFrame, snapshot):
508+
col_names = ["int64_col", "bool_col"]
509+
bf_df = scalar_types_df[col_names]
510+
511+
agg_ops_map = {
512+
"int64_col": agg_ops.VarOp().as_expr("int64_col"),
513+
"bool_col": agg_ops.VarOp().as_expr("bool_col"),
514+
}
515+
sql = _apply_unary_agg_ops(
516+
bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys())
517+
)
518+
snapshot.assert_match(sql, "out.sql")
519+
520+
# Window tests
521+
col_name = "int64_col"
522+
bf_df_int = scalar_types_df[[col_name]]
523+
agg_expr = agg_ops.VarOp().as_expr(col_name)
524+
window = window_spec.WindowSpec(ordering=(ordering.descending_over(col_name),))
525+
sql_window = _apply_unary_window_op(bf_df_int, agg_expr, window, "agg_int64")
526+
snapshot.assert_match(sql_window, "window_out.sql")

0 commit comments

Comments
 (0)