Skip to content

Commit b17effb

Browse files
committed
refactor: add agg_ops.MeanOp for sqlglot compiler
1 parent fb81eea commit b17effb

File tree

4 files changed

+75
-1
lines changed

4 files changed

+75
-1
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,26 @@ def _(
5656
return apply_window_if_present(sge.func("MAX", column.expr), window)
5757

5858

59+
@UNARY_OP_REGISTRATION.register(agg_ops.MeanOp)
60+
def _(
61+
op: agg_ops.MeanOp,
62+
column: typed_expr.TypedExpr,
63+
window: typing.Optional[window_spec.WindowSpec] = None,
64+
) -> sge.Expression:
65+
expr = column.expr
66+
if column.dtype == dtypes.BOOL_DTYPE:
67+
expr = sge.Cast(this=expr, to="INT64")
68+
69+
expr = sge.func("AVG", expr)
70+
71+
should_floor_result = (
72+
op.should_floor_result or column.dtype == dtypes.TIMEDELTA_DTYPE
73+
)
74+
if should_floor_result:
75+
expr = sge.Cast(this=sge.func("FLOOR", expr), to="INT64")
76+
return apply_window_if_present(expr, window)
77+
78+
5979
@UNARY_OP_REGISTRATION.register(agg_ops.MinOp)
6080
def _(
6181
op: agg_ops.MinOp,

tests/system/small/engines/test_aggregation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_engines_aggregate_size(
7070
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)
7171

7272

73-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
73+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
7474
@pytest.mark.parametrize(
7575
"op",
7676
[agg_ops.min_op, agg_ops.max_op, agg_ops.mean_op, agg_ops.sum_op, agg_ops.count_op],
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`duration_col` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
`bfcol_1` AS `bfcol_6`,
11+
`bfcol_0` AS `bfcol_7`,
12+
`bfcol_2` AS `bfcol_8`
13+
FROM `bfcte_0`
14+
), `bfcte_2` AS (
15+
SELECT
16+
AVG(`bfcol_6`) AS `bfcol_12`,
17+
AVG(CAST(`bfcol_7` AS INT64)) AS `bfcol_13`,
18+
FLOOR(AVG(`bfcol_8`)) AS `bfcol_14`,
19+
FLOOR(AVG(`bfcol_6`)) AS `bfcol_15`
20+
FROM `bfcte_1`
21+
)
22+
SELECT
23+
`bfcol_12` AS `int64_col`,
24+
`bfcol_13` AS `bool_col`,
25+
`bfcol_14` AS `duration_col`,
26+
`bfcol_15` AS `int64_col_w_floor`
27+
FROM `bfcte_2`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,33 @@ def test_max(scalar_types_df: bpd.DataFrame, snapshot):
5656
snapshot.assert_match(sql, "out.sql")
5757

5858

59+
def test_mean(scalar_types_df: bpd.DataFrame, snapshot):
60+
col_names = ["int64_col", "bool_col", "duration_col"]
61+
bf_df = scalar_types_df[col_names]
62+
bf_df["duration_col"] = bpd.to_timedelta(bf_df["duration_col"], unit="us")
63+
64+
# The `to_timedelta` creates a new mapping for the column id.
65+
col_names.insert(0, "rowindex")
66+
name2id = {
67+
col_name: col_id
68+
for col_name, col_id in zip(col_names, bf_df._block.expr.column_ids)
69+
}
70+
71+
agg_ops_map = {
72+
"int64_col": agg_ops.MeanOp().as_expr(name2id["int64_col"]),
73+
"bool_col": agg_ops.MeanOp().as_expr(name2id["bool_col"]),
74+
"duration_col": agg_ops.MeanOp().as_expr(name2id["duration_col"]),
75+
"int64_col_w_floor": agg_ops.MeanOp(should_floor_result=True).as_expr(
76+
name2id["int64_col"]
77+
),
78+
}
79+
sql = _apply_unary_agg_ops(
80+
bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys())
81+
)
82+
83+
snapshot.assert_match(sql, "out.sql")
84+
85+
5986
def test_min(scalar_types_df: bpd.DataFrame, snapshot):
6087
col_name = "int64_col"
6188
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)