Skip to content

Commit b55602e

Browse files
committed
generate more compact aggregation SQL
1 parent 094becc commit b55602e

File tree

4 files changed

+28
-28
lines changed

4 files changed

+28
-28
lines changed

bigframes/core/compile/sqlglot/aggregate_compiler.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import sqlglot.expressions as sge
2020

21+
from bigframes import dtypes, constants
2122
from bigframes.core import expression, window_spec
2223
from bigframes.core.compile.sqlglot.expressions import typed_expr
2324
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
@@ -30,7 +31,6 @@ def compile_aggregate(
3031
order_by: tuple[sge.Expression, ...],
3132
) -> sge.Expression:
3233
"""Compiles BigFrames aggregation expression into SQLGlot expression."""
33-
# TODO: try to remove type: ignore
3434
if isinstance(aggregate, expression.NullaryAggregation):
3535
return compile_nullary_agg(aggregate.op)
3636
if isinstance(aggregate, expression.UnaryAggregation):
@@ -39,9 +39,9 @@ def compile_aggregate(
3939
aggregate.arg.output_type,
4040
)
4141
if not aggregate.op.order_independent:
42-
return compile_ordered_unary_agg(aggregate.op, column, order_by=order_by) # type: ignore
42+
return compile_ordered_unary_agg(aggregate.op, column, order_by=order_by)
4343
else:
44-
return compile_unary_agg(aggregate.op, column) # type: ignore
44+
return compile_unary_agg(aggregate.op, column)
4545
elif isinstance(aggregate, expression.BinaryAggregation):
4646
left = typed_expr.TypedExpr(
4747
scalar_compiler.compile_scalar_expression(aggregate.left),
@@ -51,7 +51,7 @@ def compile_aggregate(
5151
scalar_compiler.compile_scalar_expression(aggregate.right),
5252
aggregate.right.output_type,
5353
)
54-
return compile_binary_agg(aggregate.op, left, right) # type: ignore
54+
return compile_binary_agg(aggregate.op, left, right)
5555
else:
5656
raise ValueError(f"Unexpected aggregation: {aggregate}")
5757

@@ -88,11 +88,11 @@ def compile_ordered_unary_agg(
8888
op: ops.aggregations.WindowOp,
8989
column: typed_expr.TypedExpr,
9090
window: typing.Optional[window_spec.WindowSpec] = None,
91+
order_by: typing.Sequence[sge.Expression] = [],
9192
) -> sge.Expression:
9293
raise ValueError(f"Can't compile unrecognized operation: {op}")
9394

9495

95-
# TODO: check @numeric_op
9696
@compile_unary_agg.register
9797
def _(
9898
op: ops.aggregations.SumOp,
@@ -101,7 +101,6 @@ def _(
101101
) -> sge.Expression:
102102
# Will be null if all inputs are null. Pandas defaults to zero sum though.
103103
expr = _apply_window_if_present(sge.func("SUM", column.expr), window)
104-
# TODO (b/430350912): check `column.dtype` works for all?
105104
return sge.func("IFNULL", expr, ir._literal(0, column.dtype))
106105

107106

@@ -112,3 +111,4 @@ def _apply_window_if_present(
112111
if window is not None:
113112
raise NotImplementedError("Can't apply window to the expression.")
114113
return value
114+

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import typing
1919

2020
from google.cloud import bigquery
21-
import sqlglot as sg
2221
import sqlglot.expressions as sge
2322

2423
from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite
@@ -279,7 +278,6 @@ def compile_aggregate(
279278
ordering.scalar_expression
280279
),
281280
desc=ordering.direction.is_ascending is False,
282-
# TODO: _convert_row_ordering_to_table_values for overwrite.
283281
nulls_first=ordering.na_last is False,
284282
)
285283
for ordering in node.order_by
@@ -293,16 +291,13 @@ def compile_aggregate(
293291
for by_col in node.by_column_ids
294292
)
295293

296-
result = child.aggregate(aggregations, by_cols)
294+
dropna_cols = []
297295
if node.dropna:
298-
conditions = []
299296
for key, by_col in zip(node.by_column_ids, by_cols):
300297
if node.child.field_by_id[key.id].nullable:
301-
conditions.append(
302-
sg.not_(sge.Is(this=by_col, expression=sge.Null()))
303-
)
304-
result = result.filter(tuple(conditions))
305-
return result
298+
dropna_cols.append(by_col)
299+
300+
return child.aggregate(aggregations, by_cols, tuple(dropna_cols))
306301

307302

308303
def _replace_unsupported_ops(node: nodes.BigFrameNode):

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -371,15 +371,15 @@ def sample(self, fraction: float) -> SQLGlotIR:
371371
def aggregate(
372372
self,
373373
aggregations: tuple[tuple[str, sge.Expression], ...],
374-
by_column_ids: tuple[sge.Expression, ...],
374+
by_cols: tuple[sge.Expression, ...],
375+
dropna_cols: tuple[sge.Expression, ...],
375376
) -> SQLGlotIR:
376377
"""Applies the aggregation expressions.
377378
378379
Args:
379380
aggregations: output_column_id, aggregation_expr tuples
380-
by_column_ids: column ids of the aggregation key, this is preserved through
381-
the transform
382-
dropna: whether null keys should be dropped
381+
by_cols: column expressions for aggregation
382+
dropna_cols: columns whether null keys should be dropped
383383
"""
384384
aggregations_expr = [
385385
sge.Alias(
@@ -395,9 +395,18 @@ def aggregate(
395395
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
396396
),
397397
)
398-
new_expr = new_expr.group_by(*by_column_ids).select(
399-
*[*by_column_ids, *aggregations_expr], append=False
398+
new_expr = new_expr.group_by(*by_cols).select(
399+
*[*by_cols, *aggregations_expr], append=False
400400
)
401+
402+
condition = _and(
403+
tuple(
404+
sg.not_(sge.Is(this=drop_col, expression=sge.Null()))
405+
for drop_col in dropna_cols
406+
)
407+
)
408+
if condition is not None:
409+
new_expr = new_expr.where(condition, append=False)
401410
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
402411

403412
def insert(

tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate/out.sql

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,14 @@ WITH `bfcte_0` AS (
1414
`bfcol_3`,
1515
COALESCE(SUM(`bfcol_2`), 0) AS `bfcol_6`
1616
FROM `bfcte_1`
17-
GROUP BY
18-
`bfcol_3`
19-
), `bfcte_3` AS (
20-
SELECT
21-
*
22-
FROM `bfcte_2`
2317
WHERE
2418
NOT `bfcol_3` IS NULL
19+
GROUP BY
20+
`bfcol_3`
2521
)
2622
SELECT
2723
`bfcol_3` AS `bool_col`,
2824
`bfcol_6` AS `int64_too`
29-
FROM `bfcte_3`
25+
FROM `bfcte_2`
3026
ORDER BY
3127
`bfcol_3` ASC NULLS LAST

0 commit comments

Comments
 (0)