Skip to content

Commit 1fa0b77

Browse files
committed
refactor: add compile_aggregate
1 parent 07222bf commit 1fa0b77

File tree

5 files changed

+221
-4
lines changed

5 files changed

+221
-4
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import functools
17+
import typing
18+
19+
import sqlglot.expressions as sge
20+
21+
from bigframes import dtypes
22+
from bigframes.core import expression, window_spec
23+
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
24+
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
25+
import bigframes.operations as ops
26+
27+
28+
def compile_aggregate(
29+
aggregate: expression.Aggregation,
30+
order_by: tuple[sge.Expression, ...],
31+
) -> sge.Expression:
32+
"""Compiles BigFrames aggregation expression into SQLGlot expression."""
33+
# TODO: try to remove type: ignore
34+
if isinstance(aggregate, expression.NullaryAggregation):
35+
return compile_nullary_agg(aggregate.op)
36+
if isinstance(aggregate, expression.UnaryAggregation):
37+
column = scalar_compiler.compile_scalar_expression(aggregate.arg)
38+
if not aggregate.op.order_independent:
39+
return compile_ordered_unary_agg(aggregate.op, column, order_by=order_by) # type: ignore
40+
else:
41+
return compile_unary_agg(aggregate.op, column) # type: ignore
42+
elif isinstance(aggregate, expression.BinaryAggregation):
43+
left = scalar_compiler.compile_scalar_expression(aggregate.left)
44+
right = scalar_compiler.compile_scalar_expression(aggregate.right)
45+
return compile_binary_agg(aggregate.op, left, right) # type: ignore
46+
else:
47+
raise ValueError(f"Unexpected aggregation: {aggregate}")
48+
49+
50+
@functools.singledispatch
51+
def compile_nullary_agg(
52+
op: ops.aggregations.WindowOp,
53+
window: typing.Optional[window_spec.WindowSpec] = None,
54+
) -> sge.Expression:
55+
raise ValueError(f"Can't compile unrecognized operation: {op}")
56+
57+
58+
@functools.singledispatch
59+
def compile_binary_agg(
60+
op: ops.aggregations.WindowOp,
61+
left: sge.Expression,
62+
right: sge.Expression,
63+
window: typing.Optional[window_spec.WindowSpec] = None,
64+
) -> sge.Expression:
65+
raise ValueError(f"Can't compile unrecognized operation: {op}")
66+
67+
68+
@functools.singledispatch
69+
def compile_unary_agg(
70+
op: ops.aggregations.WindowOp,
71+
column: sge.Expression,
72+
window: typing.Optional[window_spec.WindowSpec] = None,
73+
) -> sge.Expression:
74+
raise ValueError(f"Can't compile unrecognized operation: {op}")
75+
76+
77+
@functools.singledispatch
78+
def compile_ordered_unary_agg(
79+
op: ops.aggregations.WindowOp,
80+
column: sge.Expression,
81+
window: typing.Optional[window_spec.WindowSpec] = None,
82+
) -> sge.Expression:
83+
raise ValueError(f"Can't compile unrecognized operation: {op}")
84+
85+
86+
# TODO: check @numeric_op
87+
@compile_unary_agg.register
88+
def _(
89+
op: ops.aggregations.SumOp,
90+
column: sge.Expression,
91+
window: typing.Optional[window_spec.WindowSpec] = None,
92+
) -> sge.Expression:
93+
# Will be null if all inputs are null. Pandas defaults to zero sum though.
94+
expr = _apply_window_if_present(sge.func("SUM", column), window)
95+
return sge.func("IFNULL", expr, ir._literal(0, dtypes.INT_DTYPE))
96+
97+
98+
def _apply_window_if_present(
99+
value: sge.Expression,
100+
window: typing.Optional[window_spec.WindowSpec] = None,
101+
) -> sge.Expression:
102+
if window is not None:
103+
raise NotImplementedError("Can't apply window to the expression.")
104+
return window

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite
2424
from bigframes.core.compile import configs
25+
import bigframes.core.compile.sqlglot.aggregate_compiler as aggregate_compiler
2526
from bigframes.core.compile.sqlglot.expressions import typed_expr
2627
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2728
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
@@ -267,6 +268,39 @@ def compile_random_sample(
267268
) -> ir.SQLGlotIR:
268269
return child.sample(node.fraction)
269270

271+
@_compile_node.register
272+
def compile_aggregate(
273+
self, node: nodes.AggregateNode, child: ir.SQLGlotIR
274+
) -> ir.SQLGlotIR:
275+
ordering_cols = tuple(
276+
sge.Ordered(
277+
this=scalar_compiler.compile_scalar_expression(
278+
ordering.scalar_expression
279+
),
280+
desc=ordering.direction.is_ascending is False,
281+
# TODO: _convert_row_ordering_to_table_values for overwrite.
282+
nulls_first=ordering.na_last is False,
283+
)
284+
for ordering in node.order_by
285+
)
286+
aggregations: tuple[tuple[str, sge.Expression], ...] = tuple(
287+
(id.sql, aggregate_compiler.compile_aggregate(agg, order_by=ordering_cols))
288+
for agg, id in node.aggregations
289+
)
290+
by_cols: tuple[sge.Expression, ...] = tuple(
291+
scalar_compiler.compile_scalar_expression(by_col)
292+
for by_col in node.by_column_ids
293+
)
294+
295+
result = child.aggregate(aggregations, by_cols)
296+
# TODO(chelsealin): Support dropna
297+
# TODO: Remove dropna field and use filter node instead
298+
# if node.dropna:
299+
# for key in node.by_column_ids:
300+
# if node.child.field_by_id[key.id].nullable:
301+
# result = result.filter(operations.notnull_op.as_expr(key))
302+
return result
303+
270304

271305
def _replace_unsupported_ops(node: nodes.BigFrameNode):
272306
node = nodes.bottom_up(node, rewrite.rewrite_slice)

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,9 @@
2525
import sqlglot.expressions as sge
2626

2727
from bigframes import dtypes
28-
from bigframes.core import guid, utils
28+
from bigframes.core import guid, local_data, schema, utils
2929
from bigframes.core.compile.sqlglot.expressions import typed_expr
3030
import bigframes.core.compile.sqlglot.sqlglot_types as sgt
31-
import bigframes.core.local_data as local_data
32-
import bigframes.core.schema as bf_schema
3331

3432
# shapely.wkt.dumps was moved to shapely.io.to_wkt in 2.0.
3533
try:
@@ -68,7 +66,7 @@ def sql(self) -> str:
6866
def from_pyarrow(
6967
cls,
7068
pa_table: pa.Table,
71-
schema: bf_schema.ArraySchema,
69+
schema: schema.ArraySchema,
7270
uid_gen: guid.SequentialUIDGenerator,
7371
) -> SQLGlotIR:
7472
"""Builds SQLGlot expression from a pyarrow table.
@@ -364,6 +362,38 @@ def sample(self, fraction: float) -> SQLGlotIR:
364362
).where(condition, append=False)
365363
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
366364

365+
def aggregate(
366+
self,
367+
aggregations: tuple[tuple[str, sge.Expression], ...],
368+
by_column_ids: tuple[sge.Expression, ...],
369+
) -> SQLGlotIR:
370+
"""Applies the aggregation expressions.
371+
372+
Args:
373+
aggregations: output_column_id, aggregation_expr tuples
374+
by_column_ids: column ids of the aggregation key, this is preserved through
375+
the transform
376+
dropna: whether null keys should be dropped
377+
"""
378+
aggregations_expr = [
379+
sge.Alias(
380+
this=expr,
381+
alias=sge.to_identifier(id, quoted=self.quoted),
382+
)
383+
for id, expr in aggregations
384+
]
385+
386+
new_expr = _select_to_cte(
387+
self.expr,
388+
sge.to_identifier(
389+
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
390+
),
391+
)
392+
new_expr = new_expr.group_by(*by_column_ids).select(
393+
*[*by_column_ids, *aggregations_expr], append=False
394+
)
395+
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
396+
367397
def insert(
368398
self,
369399
destination: bigquery.TableReference,
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_too` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
`bfcol_1` AS `bfcol_2`,
10+
`bfcol_0` AS `bfcol_3`
11+
FROM `bfcte_0`
12+
), `bfcte_2` AS (
13+
SELECT
14+
`bfcol_3`,
15+
COALESCE(SUM(`bfcol_2`), 0) AS `bfcol_6`
16+
FROM `bfcte_1`
17+
GROUP BY
18+
`bfcol_3`
19+
)
20+
SELECT
21+
`bfcol_3` AS `bool_col`,
22+
`bfcol_6` AS `int64_too`
23+
FROM `bfcte_2`
24+
ORDER BY
25+
`bfcol_3` ASC NULLS LAST
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
import bigframes.pandas as bpd
18+
19+
pytest.importorskip("pytest_snapshot")
20+
21+
22+
def test_compile_aggregate(scalar_types_df: bpd.DataFrame, snapshot):
23+
result = scalar_types_df["int64_too"].groupby(scalar_types_df["bool_col"]).sum()
24+
snapshot.assert_match(result.to_frame().sql, "out.sql")

0 commit comments

Comments
 (0)