Skip to content

Commit 441aa05

Browse files
committed
refactor: add agg_ops.MeanOp for sqlglot compiler
1 parent 17b5d3e commit 441aa05

File tree

4 files changed

+60
-1
lines changed

4 files changed

+60
-1
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,26 @@ def _(
5555
return apply_window_if_present(sge.func("MAX", column.expr), window)
5656

5757

58+
@UNARY_OP_REGISTRATION.register(agg_ops.MeanOp)
59+
def _(
60+
op: agg_ops.MeanOp,
61+
column: typed_expr.TypedExpr,
62+
window: typing.Optional[window_spec.WindowSpec] = None,
63+
) -> sge.Expression:
64+
expr = column.expr
65+
if column.dtype == dtypes.BOOL_DTYPE:
66+
expr = sge.Cast(this=expr, to="INT64")
67+
68+
expr = sge.func("AVG", expr)
69+
70+
should_floor_result = (
71+
op.should_floor_result or column.dtype == dtypes.TIMEDELTA_DTYPE
72+
)
73+
if should_floor_result:
74+
expr = sge.func("FLOOR", expr)
75+
return apply_window_if_present(expr, window)
76+
77+
5878
@UNARY_OP_REGISTRATION.register(agg_ops.MinOp)
5979
def _(
6080
op: agg_ops.MinOp,

tests/system/small/engines/test_aggregation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_engines_aggregate_size(
7070
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)
7171

7272

73-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
73+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
7474
@pytest.mark.parametrize(
7575
"op",
7676
[agg_ops.min_op, agg_ops.max_op, agg_ops.mean_op, agg_ops.sum_op, agg_ops.count_op],
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`duration_col` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
AVG(`bfcol_1`) AS `bfcol_6`,
10+
AVG(CAST(`bfcol_0` AS INT64)) AS `bfcol_7`,
11+
AVG(`bfcol_2`) AS `bfcol_8`,
12+
FLOOR(AVG(`bfcol_1`)) AS `bfcol_9`
13+
FROM `bfcte_0`
14+
)
15+
SELECT
16+
`bfcol_6` AS `int64_col`,
17+
`bfcol_7` AS `bool_col`,
18+
`bfcol_8` AS `duration_col`,
19+
`bfcol_9` AS `int64_col_w_floor`
20+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,25 @@ def test_max(scalar_types_df: bpd.DataFrame, snapshot):
5656
snapshot.assert_match(sql, "out.sql")
5757

5858

59+
def test_mean(scalar_types_df: bpd.DataFrame, snapshot):
60+
bf_df = scalar_types_df[["int64_col", "bool_col", "duration_col"]]
61+
# bf_df["duration_col"] = bpd.to_timedelta(bf_df["duration_col"], unit="us")
62+
63+
agg_ops_map = {
64+
"int64_col": agg_ops.MeanOp().as_expr("int64_col"),
65+
"bool_col": agg_ops.MeanOp().as_expr("bool_col"),
66+
"duration_col": agg_ops.MeanOp().as_expr("duration_col"),
67+
"int64_col_w_floor": agg_ops.MeanOp(should_floor_result=True).as_expr(
68+
"int64_col"
69+
),
70+
}
71+
sql = _apply_unary_agg_ops(
72+
bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys())
73+
)
74+
75+
snapshot.assert_match(sql, "out.sql")
76+
77+
5978
def test_min(scalar_types_df: bpd.DataFrame, snapshot):
6079
col_name = "int64_col"
6180
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)