From 1fa0b77bb42d97e9a162b77ecf204d3dbf6937c1 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 9 Jul 2025 00:02:59 +0000 Subject: [PATCH 1/3] refactor: add compile_aggregate --- .../compile/sqlglot/aggregate_compiler.py | 104 ++++++++++++++++++ bigframes/core/compile/sqlglot/compiler.py | 34 ++++++ bigframes/core/compile/sqlglot/sqlglot_ir.py | 38 ++++++- .../test_compile_aggregate/out.sql | 25 +++++ .../compile/sqlglot/test_compile_aggregate.py | 24 ++++ 5 files changed, 221 insertions(+), 4 deletions(-) create mode 100644 bigframes/core/compile/sqlglot/aggregate_compiler.py create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate/out.sql create mode 100644 tests/unit/core/compile/sqlglot/test_compile_aggregate.py diff --git a/bigframes/core/compile/sqlglot/aggregate_compiler.py b/bigframes/core/compile/sqlglot/aggregate_compiler.py new file mode 100644 index 0000000000..dae8f2485b --- /dev/null +++ b/bigframes/core/compile/sqlglot/aggregate_compiler.py @@ -0,0 +1,104 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import functools +import typing + +import sqlglot.expressions as sge + +from bigframes import dtypes +from bigframes.core import expression, window_spec +import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler +import bigframes.core.compile.sqlglot.sqlglot_ir as ir +import bigframes.operations as ops + + +def compile_aggregate( + aggregate: expression.Aggregation, + order_by: tuple[sge.Expression, ...], +) -> sge.Expression: + """Compiles BigFrames aggregation expression into SQLGlot expression.""" + # TODO: try to remove type: ignore + if isinstance(aggregate, expression.NullaryAggregation): + return compile_nullary_agg(aggregate.op) + if isinstance(aggregate, expression.UnaryAggregation): + column = scalar_compiler.compile_scalar_expression(aggregate.arg) + if not aggregate.op.order_independent: + return compile_ordered_unary_agg(aggregate.op, column, order_by=order_by) # type: ignore + else: + return compile_unary_agg(aggregate.op, column) # type: ignore + elif isinstance(aggregate, expression.BinaryAggregation): + left = scalar_compiler.compile_scalar_expression(aggregate.left) + right = scalar_compiler.compile_scalar_expression(aggregate.right) + return compile_binary_agg(aggregate.op, left, right) # type: ignore + else: + raise ValueError(f"Unexpected aggregation: {aggregate}") + + +@functools.singledispatch +def compile_nullary_agg( + op: ops.aggregations.WindowOp, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + raise ValueError(f"Can't compile unrecognized operation: {op}") + + +@functools.singledispatch +def compile_binary_agg( + op: ops.aggregations.WindowOp, + left: sge.Expression, + right: sge.Expression, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + raise ValueError(f"Can't compile unrecognized operation: {op}") + + +@functools.singledispatch +def compile_unary_agg( + op: ops.aggregations.WindowOp, + column: sge.Expression, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + raise ValueError(f"Can't compile unrecognized operation: {op}") + + +@functools.singledispatch +def compile_ordered_unary_agg( + op: ops.aggregations.WindowOp, + column: sge.Expression, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + raise ValueError(f"Can't compile unrecognized operation: {op}") + + +# TODO: check @numeric_op +@compile_unary_agg.register +def _( + op: ops.aggregations.SumOp, + column: sge.Expression, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + # Will be null if all inputs are null. Pandas defaults to zero sum though. + expr = _apply_window_if_present(sge.func("SUM", column), window) + return sge.func("IFNULL", expr, ir._literal(0, dtypes.INT_DTYPE)) + + +def _apply_window_if_present( + value: sge.Expression, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + if window is not None: + raise NotImplementedError("Can't apply window to the expression.") + return window diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index 93f072973c..329defd9a8 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -22,6 +22,7 @@ from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite from bigframes.core.compile import configs +import bigframes.core.compile.sqlglot.aggregate_compiler as aggregate_compiler from bigframes.core.compile.sqlglot.expressions import typed_expr import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler import bigframes.core.compile.sqlglot.sqlglot_ir as ir @@ -267,6 +268,39 @@ def compile_random_sample( ) -> ir.SQLGlotIR: return child.sample(node.fraction) + @_compile_node.register + def compile_aggregate( + self, node: nodes.AggregateNode, child: ir.SQLGlotIR + ) -> ir.SQLGlotIR: + ordering_cols = tuple( + sge.Ordered( + this=scalar_compiler.compile_scalar_expression( + ordering.scalar_expression + ), + desc=ordering.direction.is_ascending is False, + # TODO: _convert_row_ordering_to_table_values for overwrite. + nulls_first=ordering.na_last is False, + ) + for ordering in node.order_by + ) + aggregations: tuple[tuple[str, sge.Expression], ...] = tuple( + (id.sql, aggregate_compiler.compile_aggregate(agg, order_by=ordering_cols)) + for agg, id in node.aggregations + ) + by_cols: tuple[sge.Expression, ...] = tuple( + scalar_compiler.compile_scalar_expression(by_col) + for by_col in node.by_column_ids + ) + + result = child.aggregate(aggregations, by_cols) + # TODO(chelsealin): Support dropna + # TODO: Remove dropna field and use filter node instead + # if node.dropna: + # for key in node.by_column_ids: + # if node.child.field_by_id[key.id].nullable: + # result = result.filter(operations.notnull_op.as_expr(key)) + return result + def _replace_unsupported_ops(node: nodes.BigFrameNode): node = nodes.bottom_up(node, rewrite.rewrite_slice) diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index c0bed4090c..4ee62cbd68 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -25,11 +25,9 @@ import sqlglot.expressions as sge from bigframes import dtypes -from bigframes.core import guid, utils +from bigframes.core import guid, local_data, schema, utils from bigframes.core.compile.sqlglot.expressions import typed_expr import bigframes.core.compile.sqlglot.sqlglot_types as sgt -import bigframes.core.local_data as local_data -import bigframes.core.schema as bf_schema # shapely.wkt.dumps was moved to shapely.io.to_wkt in 2.0. try: @@ -68,7 +66,7 @@ def sql(self) -> str: def from_pyarrow( cls, pa_table: pa.Table, - schema: bf_schema.ArraySchema, + schema: schema.ArraySchema, uid_gen: guid.SequentialUIDGenerator, ) -> SQLGlotIR: """Builds SQLGlot expression from a pyarrow table. @@ -364,6 +362,38 @@ def sample(self, fraction: float) -> SQLGlotIR: ).where(condition, append=False) return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + def aggregate( + self, + aggregations: tuple[tuple[str, sge.Expression], ...], + by_column_ids: tuple[sge.Expression, ...], + ) -> SQLGlotIR: + """Applies the aggregation expressions. + + Args: + aggregations: output_column_id, aggregation_expr tuples + by_column_ids: column ids of the aggregation key, this is preserved through + the transform + dropna: whether null keys should be dropped + """ + aggregations_expr = [ + sge.Alias( + this=expr, + alias=sge.to_identifier(id, quoted=self.quoted), + ) + for id, expr in aggregations + ] + + new_expr = _select_to_cte( + self.expr, + sge.to_identifier( + next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + ), + ) + new_expr = new_expr.group_by(*by_column_ids).select( + *[*by_column_ids, *aggregations_expr], append=False + ) + return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + def insert( self, destination: bigquery.TableReference, diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate/out.sql new file mode 100644 index 0000000000..b8e127eb77 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate/out.sql @@ -0,0 +1,25 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_too` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `bfcol_1` AS `bfcol_2`, + `bfcol_0` AS `bfcol_3` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + `bfcol_3`, + COALESCE(SUM(`bfcol_2`), 0) AS `bfcol_6` + FROM `bfcte_1` + GROUP BY + `bfcol_3` +) +SELECT + `bfcol_3` AS `bool_col`, + `bfcol_6` AS `int64_too` +FROM `bfcte_2` +ORDER BY + `bfcol_3` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/test_compile_aggregate.py b/tests/unit/core/compile/sqlglot/test_compile_aggregate.py new file mode 100644 index 0000000000..bce5b26b3a --- /dev/null +++ b/tests/unit/core/compile/sqlglot/test_compile_aggregate.py @@ -0,0 +1,24 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import bigframes.pandas as bpd + +pytest.importorskip("pytest_snapshot") + + +def test_compile_aggregate(scalar_types_df: bpd.DataFrame, snapshot): + result = scalar_types_df["int64_too"].groupby(scalar_types_df["bool_col"]).sum() + snapshot.assert_match(result.to_frame().sql, "out.sql") From 094becc0b2a629f234e783d0542bc818d9da7b9c Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Fri, 11 Jul 2025 20:58:12 +0000 Subject: [PATCH 2/3] resolve aggregation nodes for dtype and support dropna --- .../compile/sqlglot/aggregate_compiler.py | 34 +++++++++----- bigframes/core/compile/sqlglot/compiler.py | 17 ++++--- bigframes/core/compile/sqlglot/sqlglot_ir.py | 26 +++++++++-- bigframes/core/rewrite/schema_binding.py | 46 +++++++++++++++++++ .../test_compile_aggregate/out.sql | 8 +++- .../test_compile_aggregate_wo_dropna/out.sql | 25 ++++++++++ .../test_compile_join/out.sql | 4 +- .../test_compile_join_w_on/bool_col/out.sql | 4 +- .../float64_col/out.sql | 4 +- .../test_compile_join_w_on/int64_col/out.sql | 4 +- .../numeric_col/out.sql | 4 +- .../test_compile_join_w_on/string_col/out.sql | 4 +- .../test_compile_join_w_on/time_col/out.sql | 4 +- .../compile/sqlglot/test_compile_aggregate.py | 9 ++++ 14 files changed, 154 insertions(+), 39 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate_wo_dropna/out.sql diff --git a/bigframes/core/compile/sqlglot/aggregate_compiler.py b/bigframes/core/compile/sqlglot/aggregate_compiler.py index dae8f2485b..befbcb8ce3 100644 --- a/bigframes/core/compile/sqlglot/aggregate_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregate_compiler.py @@ -18,8 +18,8 @@ import sqlglot.expressions as sge -from bigframes import dtypes from bigframes.core import expression, window_spec +from bigframes.core.compile.sqlglot.expressions import typed_expr import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler import bigframes.core.compile.sqlglot.sqlglot_ir as ir import bigframes.operations as ops @@ -34,14 +34,23 @@ def compile_aggregate( if isinstance(aggregate, expression.NullaryAggregation): return compile_nullary_agg(aggregate.op) if isinstance(aggregate, expression.UnaryAggregation): - column = scalar_compiler.compile_scalar_expression(aggregate.arg) + column = typed_expr.TypedExpr( + scalar_compiler.compile_scalar_expression(aggregate.arg), + aggregate.arg.output_type, + ) if not aggregate.op.order_independent: return compile_ordered_unary_agg(aggregate.op, column, order_by=order_by) # type: ignore else: return compile_unary_agg(aggregate.op, column) # type: ignore elif isinstance(aggregate, expression.BinaryAggregation): - left = scalar_compiler.compile_scalar_expression(aggregate.left) - right = scalar_compiler.compile_scalar_expression(aggregate.right) + left = typed_expr.TypedExpr( + scalar_compiler.compile_scalar_expression(aggregate.left), + aggregate.left.output_type, + ) + right = typed_expr.TypedExpr( + scalar_compiler.compile_scalar_expression(aggregate.right), + aggregate.right.output_type, + ) return compile_binary_agg(aggregate.op, left, right) # type: ignore else: raise ValueError(f"Unexpected aggregation: {aggregate}") @@ -58,8 +67,8 @@ def compile_nullary_agg( @functools.singledispatch def compile_binary_agg( op: ops.aggregations.WindowOp, - left: sge.Expression, - right: sge.Expression, + left: typed_expr.TypedExpr, + right: typed_expr.TypedExpr, window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: raise ValueError(f"Can't compile unrecognized operation: {op}") @@ -68,7 +77,7 @@ def compile_binary_agg( @functools.singledispatch def compile_unary_agg( op: ops.aggregations.WindowOp, - column: sge.Expression, + column: typed_expr.TypedExpr, window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: raise ValueError(f"Can't compile unrecognized operation: {op}") @@ -77,7 +86,7 @@ def compile_unary_agg( @functools.singledispatch def compile_ordered_unary_agg( op: ops.aggregations.WindowOp, - column: sge.Expression, + column: typed_expr.TypedExpr, window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: raise ValueError(f"Can't compile unrecognized operation: {op}") @@ -87,12 +96,13 @@ def compile_ordered_unary_agg( @compile_unary_agg.register def _( op: ops.aggregations.SumOp, - column: sge.Expression, + column: typed_expr.TypedExpr, window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: # Will be null if all inputs are null. Pandas defaults to zero sum though. - expr = _apply_window_if_present(sge.func("SUM", column), window) - return sge.func("IFNULL", expr, ir._literal(0, dtypes.INT_DTYPE)) + expr = _apply_window_if_present(sge.func("SUM", column.expr), window) + # TODO (b/430350912): check `column.dtype` works for all? + return sge.func("IFNULL", expr, ir._literal(0, column.dtype)) def _apply_window_if_present( @@ -101,4 +111,4 @@ def _apply_window_if_present( ) -> sge.Expression: if window is not None: raise NotImplementedError("Can't apply window to the expression.") - return window + return value diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index 329defd9a8..fe16dc8f74 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -18,6 +18,7 @@ import typing from google.cloud import bigquery +import sqlglot as sg import sqlglot.expressions as sge from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite @@ -218,7 +219,7 @@ def compile_filter( self, node: nodes.FilterNode, child: ir.SQLGlotIR ) -> ir.SQLGlotIR: condition = scalar_compiler.compile_scalar_expression(node.predicate) - return child.filter(condition) + return child.filter(tuple([condition])) @_compile_node.register def compile_join( @@ -293,12 +294,14 @@ def compile_aggregate( ) result = child.aggregate(aggregations, by_cols) - # TODO(chelsealin): Support dropna - # TODO: Remove dropna field and use filter node instead - # if node.dropna: - # for key in node.by_column_ids: - # if node.child.field_by_id[key.id].nullable: - # result = result.filter(operations.notnull_op.as_expr(key)) + if node.dropna: + conditions = [] + for key, by_col in zip(node.by_column_ids, by_cols): + if node.child.field_by_id[key.id].nullable: + conditions.append( + sg.not_(sge.Is(this=by_col, expression=sge.Null())) + ) + result = result.filter(tuple(conditions)) return result diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index 4ee62cbd68..d272ee489a 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -15,6 +15,7 @@ from __future__ import annotations import dataclasses +import functools import typing from google.cloud import bigquery @@ -278,9 +279,13 @@ def limit( def filter( self, - condition: sge.Expression, + conditions: tuple[sge.Expression, ...], ) -> SQLGlotIR: """Filters the query by adding a WHERE clause.""" + condition = _and(conditions) + if condition is None: + return SQLGlotIR(expr=self.expr.copy(), uid_gen=self.uid_gen) + new_expr = _select_to_cte( self.expr, sge.to_identifier( @@ -314,10 +319,11 @@ def join( right_ctes = right_select.args.pop("with", []) merged_ctes = [*left_ctes, *right_ctes] - join_conditions = [ - _join_condition(left, right, joins_nulls) for left, right in conditions - ] - join_on = sge.And(expressions=join_conditions) if join_conditions else None + join_on = _and( + tuple( + _join_condition(left, right, joins_nulls) for left, right in conditions + ) + ) join_type_str = join_type if join_type != "outer" else "full outer" new_expr = ( @@ -582,6 +588,16 @@ def _table(table: bigquery.TableReference) -> sge.Table: ) +def _and(conditions: tuple[sge.Expression, ...]) -> typing.Optional[sge.Expression]: + """Chains multiple expressions together using a logical AND.""" + if not conditions: + return None + + return functools.reduce( + lambda left, right: sge.And(this=left, expression=right), conditions + ) + + def _join_condition( left: typed_expr.TypedExpr, right: typed_expr.TypedExpr, diff --git a/bigframes/core/rewrite/schema_binding.py b/bigframes/core/rewrite/schema_binding.py index af0593211c..40a00ff8f6 100644 --- a/bigframes/core/rewrite/schema_binding.py +++ b/bigframes/core/rewrite/schema_binding.py @@ -13,6 +13,7 @@ # limitations under the License. import dataclasses +import typing from bigframes.core import bigframe_node from bigframes.core import expression as ex @@ -65,4 +66,49 @@ def bind_schema_to_node( conditions=conditions, ) + if isinstance(node, nodes.AggregateNode): + aggregations = [] + for aggregation, id in node.aggregations: + if isinstance(aggregation, ex.UnaryAggregation): + replaced = typing.cast( + ex.Aggregation, + dataclasses.replace( + aggregation, + arg=typing.cast( + ex.RefOrConstant, + ex.bind_schema_fields( + aggregation.arg, node.child.field_by_id + ), + ), + ), + ) + aggregations.append((replaced, id)) + elif isinstance(aggregation, ex.BinaryAggregation): + replaced = typing.cast( + ex.Aggregation, + dataclasses.replace( + aggregation, + left=typing.cast( + ex.RefOrConstant, + ex.bind_schema_fields( + aggregation.left, node.child.field_by_id + ), + ), + right=typing.cast( + ex.RefOrConstant, + ex.bind_schema_fields( + aggregation.right, node.child.field_by_id + ), + ), + ), + ) + aggregations.append((replaced, id)) + else: + aggregations.append((aggregation, id)) + + return dataclasses.replace( + node, + aggregations=tuple(aggregations), + ) + return node diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate/out.sql index b8e127eb77..db413b4c79 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate/out.sql @@ -16,10 +16,16 @@ WITH `bfcte_0` AS ( FROM `bfcte_1` GROUP BY `bfcol_3` +), `bfcte_3` AS ( + SELECT + * + FROM `bfcte_2` + WHERE + NOT `bfcol_3` IS NULL ) SELECT `bfcol_3` AS `bool_col`, `bfcol_6` AS `int64_too` -FROM `bfcte_2` +FROM `bfcte_3` ORDER BY `bfcol_3` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate_wo_dropna/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate_wo_dropna/out.sql new file mode 100644 index 0000000000..b8e127eb77 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate_wo_dropna/out.sql @@ -0,0 +1,25 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_too` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `bfcol_1` AS `bfcol_2`, + `bfcol_0` AS `bfcol_3` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + `bfcol_3`, + COALESCE(SUM(`bfcol_2`), 0) AS `bfcol_6` + FROM `bfcte_1` + GROUP BY + `bfcol_3` +) +SELECT + `bfcol_3` AS `bool_col`, + `bfcol_6` AS `int64_too` +FROM `bfcte_2` +ORDER BY + `bfcol_3` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql index 85eab4487a..04ee767f8a 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql @@ -23,8 +23,8 @@ WITH `bfcte_1` AS ( * FROM `bfcte_2` LEFT JOIN `bfcte_3` - ON COALESCE(`bfcol_2`, 0) = COALESCE(`bfcol_6`, 0) - AND COALESCE(`bfcol_2`, 1) = COALESCE(`bfcol_6`, 1) + ON COALESCE(`bfcol_2`, 0) = COALESCE(`bfcol_6`, 0) + AND COALESCE(`bfcol_2`, 1) = COALESCE(`bfcol_6`, 1) ) SELECT `bfcol_3` AS `int64_col`, diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/bool_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/bool_col/out.sql index a073e35c69..05d5fd0695 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/bool_col/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/bool_col/out.sql @@ -23,8 +23,8 @@ WITH `bfcte_1` AS ( * FROM `bfcte_2` INNER JOIN `bfcte_3` - ON COALESCE(CAST(`bfcol_3` AS STRING), '0') = COALESCE(CAST(`bfcol_7` AS STRING), '0') - AND COALESCE(CAST(`bfcol_3` AS STRING), '1') = COALESCE(CAST(`bfcol_7` AS STRING), '1') + ON COALESCE(CAST(`bfcol_3` AS STRING), '0') = COALESCE(CAST(`bfcol_7` AS STRING), '0') + AND COALESCE(CAST(`bfcol_3` AS STRING), '1') = COALESCE(CAST(`bfcol_7` AS STRING), '1') ) SELECT `bfcol_2` AS `rowindex_x`, diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/float64_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/float64_col/out.sql index 1d04343f31..9e6a4094b2 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/float64_col/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/float64_col/out.sql @@ -23,8 +23,8 @@ WITH `bfcte_1` AS ( * FROM `bfcte_2` INNER JOIN `bfcte_3` - ON IF(IS_NAN(`bfcol_3`), 2, COALESCE(`bfcol_3`, 0)) = IF(IS_NAN(`bfcol_7`), 2, COALESCE(`bfcol_7`, 0)) - AND IF(IS_NAN(`bfcol_3`), 3, COALESCE(`bfcol_3`, 1)) = IF(IS_NAN(`bfcol_7`), 3, COALESCE(`bfcol_7`, 1)) + ON IF(IS_NAN(`bfcol_3`), 2, COALESCE(`bfcol_3`, 0)) = IF(IS_NAN(`bfcol_7`), 2, COALESCE(`bfcol_7`, 0)) + AND IF(IS_NAN(`bfcol_3`), 3, COALESCE(`bfcol_3`, 1)) = IF(IS_NAN(`bfcol_7`), 3, COALESCE(`bfcol_7`, 1)) ) SELECT `bfcol_2` AS `rowindex_x`, diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/int64_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/int64_col/out.sql index 80ec5d19d1..bd03e05cba 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/int64_col/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/int64_col/out.sql @@ -23,8 +23,8 @@ WITH `bfcte_1` AS ( * FROM `bfcte_2` INNER JOIN `bfcte_3` - ON COALESCE(`bfcol_3`, 0) = COALESCE(`bfcol_7`, 0) - AND COALESCE(`bfcol_3`, 1) = COALESCE(`bfcol_7`, 1) + ON COALESCE(`bfcol_3`, 0) = COALESCE(`bfcol_7`, 0) + AND COALESCE(`bfcol_3`, 1) = COALESCE(`bfcol_7`, 1) ) SELECT `bfcol_2` AS `rowindex_x`, diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/numeric_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/numeric_col/out.sql index 22ce6f5b29..6b77ead97c 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/numeric_col/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/numeric_col/out.sql @@ -23,8 +23,8 @@ WITH `bfcte_1` AS ( * FROM `bfcte_2` INNER JOIN `bfcte_3` - ON COALESCE(`bfcol_3`, CAST(0 AS NUMERIC)) = COALESCE(`bfcol_7`, CAST(0 AS NUMERIC)) - AND COALESCE(`bfcol_3`, CAST(1 AS NUMERIC)) = COALESCE(`bfcol_7`, CAST(1 AS NUMERIC)) + ON COALESCE(`bfcol_3`, CAST(0 AS NUMERIC)) = COALESCE(`bfcol_7`, CAST(0 AS NUMERIC)) + AND COALESCE(`bfcol_3`, CAST(1 AS NUMERIC)) = COALESCE(`bfcol_7`, CAST(1 AS NUMERIC)) ) SELECT `bfcol_2` AS `rowindex_x`, diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/string_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/string_col/out.sql index 5e8d072d46..1903d5fc22 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/string_col/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/string_col/out.sql @@ -18,8 +18,8 @@ WITH `bfcte_1` AS ( * FROM `bfcte_1` INNER JOIN `bfcte_2` - ON COALESCE(CAST(`bfcol_1` AS STRING), '0') = COALESCE(CAST(`bfcol_5` AS STRING), '0') - AND COALESCE(CAST(`bfcol_1` AS STRING), '1') = COALESCE(CAST(`bfcol_5` AS STRING), '1') + ON COALESCE(CAST(`bfcol_1` AS STRING), '0') = COALESCE(CAST(`bfcol_5` AS STRING), '0') + AND COALESCE(CAST(`bfcol_1` AS STRING), '1') = COALESCE(CAST(`bfcol_5` AS STRING), '1') ) SELECT `bfcol_0` AS `rowindex_x`, diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/time_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/time_col/out.sql index b0df619f25..9e3477d4a9 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/time_col/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/time_col/out.sql @@ -18,8 +18,8 @@ WITH `bfcte_1` AS ( * FROM `bfcte_1` INNER JOIN `bfcte_2` - ON COALESCE(CAST(`bfcol_1` AS STRING), '0') = COALESCE(CAST(`bfcol_5` AS STRING), '0') - AND COALESCE(CAST(`bfcol_1` AS STRING), '1') = COALESCE(CAST(`bfcol_5` AS STRING), '1') + ON COALESCE(CAST(`bfcol_1` AS STRING), '0') = COALESCE(CAST(`bfcol_5` AS STRING), '0') + AND COALESCE(CAST(`bfcol_1` AS STRING), '1') = COALESCE(CAST(`bfcol_5` AS STRING), '1') ) SELECT `bfcol_0` AS `rowindex_x`, diff --git a/tests/unit/core/compile/sqlglot/test_compile_aggregate.py b/tests/unit/core/compile/sqlglot/test_compile_aggregate.py index bce5b26b3a..d59c5e5068 100644 --- a/tests/unit/core/compile/sqlglot/test_compile_aggregate.py +++ b/tests/unit/core/compile/sqlglot/test_compile_aggregate.py @@ -22,3 +22,12 @@ def test_compile_aggregate(scalar_types_df: bpd.DataFrame, snapshot): result = scalar_types_df["int64_too"].groupby(scalar_types_df["bool_col"]).sum() snapshot.assert_match(result.to_frame().sql, "out.sql") + + +def test_compile_aggregate_wo_dropna(scalar_types_df: bpd.DataFrame, snapshot): + result = ( + scalar_types_df["int64_too"] + .groupby(scalar_types_df["bool_col"], dropna=False) + .sum() + ) + snapshot.assert_match(result.to_frame().sql, "out.sql") From c8e14a08a9a05fcdba0ff26d9c2acc8087c64997 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Fri, 11 Jul 2025 21:17:27 +0000 Subject: [PATCH 3/3] generate more compact aggregation SQL --- .../compile/sqlglot/aggregate_compiler.py | 10 ++++----- bigframes/core/compile/sqlglot/compiler.py | 13 ++++-------- bigframes/core/compile/sqlglot/sqlglot_ir.py | 21 +++++++++++++------ .../test_compile_aggregate/out.sql | 10 +++------ 4 files changed, 26 insertions(+), 28 deletions(-) diff --git a/bigframes/core/compile/sqlglot/aggregate_compiler.py b/bigframes/core/compile/sqlglot/aggregate_compiler.py index befbcb8ce3..888b3756b5 100644 --- a/bigframes/core/compile/sqlglot/aggregate_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregate_compiler.py @@ -30,7 +30,6 @@ def compile_aggregate( order_by: tuple[sge.Expression, ...], ) -> sge.Expression: """Compiles BigFrames aggregation expression into SQLGlot expression.""" - # TODO: try to remove type: ignore if isinstance(aggregate, expression.NullaryAggregation): return compile_nullary_agg(aggregate.op) if isinstance(aggregate, expression.UnaryAggregation): @@ -39,9 +38,9 @@ def compile_aggregate( aggregate.arg.output_type, ) if not aggregate.op.order_independent: - return compile_ordered_unary_agg(aggregate.op, column, order_by=order_by) # type: ignore + return compile_ordered_unary_agg(aggregate.op, column, order_by=order_by) else: - return compile_unary_agg(aggregate.op, column) # type: ignore + return compile_unary_agg(aggregate.op, column) elif isinstance(aggregate, expression.BinaryAggregation): left = typed_expr.TypedExpr( scalar_compiler.compile_scalar_expression(aggregate.left), @@ -51,7 +50,7 @@ def compile_aggregate( scalar_compiler.compile_scalar_expression(aggregate.right), aggregate.right.output_type, ) - return compile_binary_agg(aggregate.op, left, right) # type: ignore + return compile_binary_agg(aggregate.op, left, right) else: raise ValueError(f"Unexpected aggregation: {aggregate}") @@ -88,11 +87,11 @@ def compile_ordered_unary_agg( op: ops.aggregations.WindowOp, column: typed_expr.TypedExpr, window: typing.Optional[window_spec.WindowSpec] = None, + order_by: typing.Sequence[sge.Expression] = [], ) -> sge.Expression: raise ValueError(f"Can't compile unrecognized operation: {op}") -# TODO: check @numeric_op @compile_unary_agg.register def _( op: ops.aggregations.SumOp, @@ -101,7 +100,6 @@ def _( ) -> sge.Expression: # Will be null if all inputs are null. Pandas defaults to zero sum though. expr = _apply_window_if_present(sge.func("SUM", column.expr), window) - # TODO (b/430350912): check `column.dtype` works for all? return sge.func("IFNULL", expr, ir._literal(0, column.dtype)) diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index fe16dc8f74..1c5aaf50a8 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -18,7 +18,6 @@ import typing from google.cloud import bigquery -import sqlglot as sg import sqlglot.expressions as sge from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite @@ -279,7 +278,6 @@ def compile_aggregate( ordering.scalar_expression ), desc=ordering.direction.is_ascending is False, - # TODO: _convert_row_ordering_to_table_values for overwrite. nulls_first=ordering.na_last is False, ) for ordering in node.order_by @@ -293,16 +291,13 @@ def compile_aggregate( for by_col in node.by_column_ids ) - result = child.aggregate(aggregations, by_cols) + dropna_cols = [] if node.dropna: - conditions = [] for key, by_col in zip(node.by_column_ids, by_cols): if node.child.field_by_id[key.id].nullable: - conditions.append( - sg.not_(sge.Is(this=by_col, expression=sge.Null())) - ) - result = result.filter(tuple(conditions)) - return result + dropna_cols.append(by_col) + + return child.aggregate(aggregations, by_cols, tuple(dropna_cols)) def _replace_unsupported_ops(node: nodes.BigFrameNode): diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index d272ee489a..b194fe9e5d 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -371,15 +371,15 @@ def sample(self, fraction: float) -> SQLGlotIR: def aggregate( self, aggregations: tuple[tuple[str, sge.Expression], ...], - by_column_ids: tuple[sge.Expression, ...], + by_cols: tuple[sge.Expression, ...], + dropna_cols: tuple[sge.Expression, ...], ) -> SQLGlotIR: """Applies the aggregation expressions. Args: aggregations: output_column_id, aggregation_expr tuples - by_column_ids: column ids of the aggregation key, this is preserved through - the transform - dropna: whether null keys should be dropped + by_cols: column expressions for aggregation + dropna_cols: columns whether null keys should be dropped """ aggregations_expr = [ sge.Alias( @@ -395,9 +395,18 @@ def aggregate( next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted ), ) - new_expr = new_expr.group_by(*by_column_ids).select( - *[*by_column_ids, *aggregations_expr], append=False + new_expr = new_expr.group_by(*by_cols).select( + *[*by_cols, *aggregations_expr], append=False ) + + condition = _and( + tuple( + sg.not_(sge.Is(this=drop_col, expression=sge.Null())) + for drop_col in dropna_cols + ) + ) + if condition is not None: + new_expr = new_expr.where(condition, append=False) return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) def insert( diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate/out.sql index db413b4c79..02bba41a22 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate/out.sql @@ -14,18 +14,14 @@ WITH `bfcte_0` AS ( `bfcol_3`, COALESCE(SUM(`bfcol_2`), 0) AS `bfcol_6` FROM `bfcte_1` - GROUP BY - `bfcol_3` -), `bfcte_3` AS ( - SELECT - * - FROM `bfcte_2` WHERE NOT `bfcol_3` IS NULL + GROUP BY + `bfcol_3` ) SELECT `bfcol_3` AS `bool_col`, `bfcol_6` AS `int64_too` -FROM `bfcte_3` +FROM `bfcte_2` ORDER BY `bfcol_3` ASC NULLS LAST \ No newline at end of file