Skip to content

Commit 4d5f0b5

Browse files
committed
generate more compact aggregation SQL
1 parent 094becc commit 4d5f0b5

File tree

4 files changed

+26
-26
lines changed

4 files changed

+26
-26
lines changed

bigframes/core/compile/sqlglot/aggregate_compiler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def compile_aggregate(
3030
order_by: tuple[sge.Expression, ...],
3131
) -> sge.Expression:
3232
"""Compiles BigFrames aggregation expression into SQLGlot expression."""
33-
# TODO: try to remove type: ignore
3433
if isinstance(aggregate, expression.NullaryAggregation):
3534
return compile_nullary_agg(aggregate.op)
3635
if isinstance(aggregate, expression.UnaryAggregation):
@@ -39,9 +38,9 @@ def compile_aggregate(
3938
aggregate.arg.output_type,
4039
)
4140
if not aggregate.op.order_independent:
42-
return compile_ordered_unary_agg(aggregate.op, column, order_by=order_by) # type: ignore
41+
return compile_ordered_unary_agg(aggregate.op, column, order_by=order_by)
4342
else:
44-
return compile_unary_agg(aggregate.op, column) # type: ignore
43+
return compile_unary_agg(aggregate.op, column)
4544
elif isinstance(aggregate, expression.BinaryAggregation):
4645
left = typed_expr.TypedExpr(
4746
scalar_compiler.compile_scalar_expression(aggregate.left),
@@ -51,7 +50,7 @@ def compile_aggregate(
5150
scalar_compiler.compile_scalar_expression(aggregate.right),
5251
aggregate.right.output_type,
5352
)
54-
return compile_binary_agg(aggregate.op, left, right) # type: ignore
53+
return compile_binary_agg(aggregate.op, left, right)
5554
else:
5655
raise ValueError(f"Unexpected aggregation: {aggregate}")
5756

@@ -88,6 +87,7 @@ def compile_ordered_unary_agg(
8887
op: ops.aggregations.WindowOp,
8988
column: typed_expr.TypedExpr,
9089
window: typing.Optional[window_spec.WindowSpec] = None,
90+
order_by: typing.Sequence[sge.Expression] = [],
9191
) -> sge.Expression:
9292
raise ValueError(f"Can't compile unrecognized operation: {op}")
9393

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import typing
1919

2020
from google.cloud import bigquery
21-
import sqlglot as sg
2221
import sqlglot.expressions as sge
2322

2423
from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite
@@ -279,7 +278,6 @@ def compile_aggregate(
279278
ordering.scalar_expression
280279
),
281280
desc=ordering.direction.is_ascending is False,
282-
# TODO: _convert_row_ordering_to_table_values for overwrite.
283281
nulls_first=ordering.na_last is False,
284282
)
285283
for ordering in node.order_by
@@ -293,16 +291,13 @@ def compile_aggregate(
293291
for by_col in node.by_column_ids
294292
)
295293

296-
result = child.aggregate(aggregations, by_cols)
294+
dropna_cols = []
297295
if node.dropna:
298-
conditions = []
299296
for key, by_col in zip(node.by_column_ids, by_cols):
300297
if node.child.field_by_id[key.id].nullable:
301-
conditions.append(
302-
sg.not_(sge.Is(this=by_col, expression=sge.Null()))
303-
)
304-
result = result.filter(tuple(conditions))
305-
return result
298+
dropna_cols.append(by_col)
299+
300+
return child.aggregate(aggregations, by_cols, tuple(dropna_cols))
306301

307302

308303
def _replace_unsupported_ops(node: nodes.BigFrameNode):

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -371,15 +371,15 @@ def sample(self, fraction: float) -> SQLGlotIR:
371371
def aggregate(
372372
self,
373373
aggregations: tuple[tuple[str, sge.Expression], ...],
374-
by_column_ids: tuple[sge.Expression, ...],
374+
by_cols: tuple[sge.Expression, ...],
375+
dropna_cols: tuple[sge.Expression, ...],
375376
) -> SQLGlotIR:
376377
"""Applies the aggregation expressions.
377378
378379
Args:
379380
aggregations: output_column_id, aggregation_expr tuples
380-
by_column_ids: column ids of the aggregation key, this is preserved through
381-
the transform
382-
dropna: whether null keys should be dropped
381+
by_cols: column expressions for aggregation
382+
dropna_cols: columns whether null keys should be dropped
383383
"""
384384
aggregations_expr = [
385385
sge.Alias(
@@ -395,9 +395,18 @@ def aggregate(
395395
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
396396
),
397397
)
398-
new_expr = new_expr.group_by(*by_column_ids).select(
399-
*[*by_column_ids, *aggregations_expr], append=False
398+
new_expr = new_expr.group_by(*by_cols).select(
399+
*[*by_cols, *aggregations_expr], append=False
400400
)
401+
402+
condition = _and(
403+
tuple(
404+
sg.not_(sge.Is(this=drop_col, expression=sge.Null()))
405+
for drop_col in dropna_cols
406+
)
407+
)
408+
if condition is not None:
409+
new_expr = new_expr.where(condition, append=False)
401410
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
402411

403412
def insert(

tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate/out.sql

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,14 @@ WITH `bfcte_0` AS (
1414
`bfcol_3`,
1515
COALESCE(SUM(`bfcol_2`), 0) AS `bfcol_6`
1616
FROM `bfcte_1`
17-
GROUP BY
18-
`bfcol_3`
19-
), `bfcte_3` AS (
20-
SELECT
21-
*
22-
FROM `bfcte_2`
2317
WHERE
2418
NOT `bfcol_3` IS NULL
19+
GROUP BY
20+
`bfcol_3`
2521
)
2622
SELECT
2723
`bfcol_3` AS `bool_col`,
2824
`bfcol_6` AS `int64_too`
29-
FROM `bfcte_3`
25+
FROM `bfcte_2`
3026
ORDER BY
3127
`bfcol_3` ASC NULLS LAST

0 commit comments

Comments
 (0)