diff --git a/bigframes/core/compile/sqlglot/aggregate_compiler.py b/bigframes/core/compile/sqlglot/aggregate_compiler.py new file mode 100644 index 0000000000..888b3756b5 --- /dev/null +++ b/bigframes/core/compile/sqlglot/aggregate_compiler.py @@ -0,0 +1,112 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import functools +import typing + +import sqlglot.expressions as sge + +from bigframes.core import expression, window_spec +from bigframes.core.compile.sqlglot.expressions import typed_expr +import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler +import bigframes.core.compile.sqlglot.sqlglot_ir as ir +import bigframes.operations as ops + + +def compile_aggregate( + aggregate: expression.Aggregation, + order_by: tuple[sge.Expression, ...], +) -> sge.Expression: + """Compiles BigFrames aggregation expression into SQLGlot expression.""" + if isinstance(aggregate, expression.NullaryAggregation): + return compile_nullary_agg(aggregate.op) + if isinstance(aggregate, expression.UnaryAggregation): + column = typed_expr.TypedExpr( + scalar_compiler.compile_scalar_expression(aggregate.arg), + aggregate.arg.output_type, + ) + if not aggregate.op.order_independent: + return compile_ordered_unary_agg(aggregate.op, column, order_by=order_by) + else: + return compile_unary_agg(aggregate.op, column) + elif isinstance(aggregate, expression.BinaryAggregation): + left = typed_expr.TypedExpr( + scalar_compiler.compile_scalar_expression(aggregate.left), + aggregate.left.output_type, + ) + right = typed_expr.TypedExpr( + scalar_compiler.compile_scalar_expression(aggregate.right), + aggregate.right.output_type, + ) + return compile_binary_agg(aggregate.op, left, right) + else: + raise ValueError(f"Unexpected aggregation: {aggregate}") + + +@functools.singledispatch +def compile_nullary_agg( + op: ops.aggregations.WindowOp, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + raise ValueError(f"Can't compile unrecognized operation: {op}") + + +@functools.singledispatch +def compile_binary_agg( + op: ops.aggregations.WindowOp, + left: typed_expr.TypedExpr, + right: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + raise ValueError(f"Can't compile unrecognized operation: {op}") + + +@functools.singledispatch +def compile_unary_agg( + op: ops.aggregations.WindowOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + raise ValueError(f"Can't compile unrecognized operation: {op}") + + +@functools.singledispatch +def compile_ordered_unary_agg( + op: ops.aggregations.WindowOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, + order_by: typing.Sequence[sge.Expression] = [], +) -> sge.Expression: + raise ValueError(f"Can't compile unrecognized operation: {op}") + + +@compile_unary_agg.register +def _( + op: ops.aggregations.SumOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + # Will be null if all inputs are null. Pandas defaults to zero sum though. + expr = _apply_window_if_present(sge.func("SUM", column.expr), window) + return sge.func("IFNULL", expr, ir._literal(0, column.dtype)) + + +def _apply_window_if_present( + value: sge.Expression, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + if window is not None: + raise NotImplementedError("Can't apply window to the expression.") + return value diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index 93f072973c..1c5aaf50a8 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -22,6 +22,7 @@ from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite from bigframes.core.compile import configs +import bigframes.core.compile.sqlglot.aggregate_compiler as aggregate_compiler from bigframes.core.compile.sqlglot.expressions import typed_expr import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler import bigframes.core.compile.sqlglot.sqlglot_ir as ir @@ -217,7 +218,7 @@ def compile_filter( self, node: nodes.FilterNode, child: ir.SQLGlotIR ) -> ir.SQLGlotIR: condition = scalar_compiler.compile_scalar_expression(node.predicate) - return child.filter(condition) + return child.filter(tuple([condition])) @_compile_node.register def compile_join( @@ -267,6 +268,37 @@ def compile_random_sample( ) -> ir.SQLGlotIR: return child.sample(node.fraction) + @_compile_node.register + def compile_aggregate( + self, node: nodes.AggregateNode, child: ir.SQLGlotIR + ) -> ir.SQLGlotIR: + ordering_cols = tuple( + sge.Ordered( + this=scalar_compiler.compile_scalar_expression( + ordering.scalar_expression + ), + desc=ordering.direction.is_ascending is False, + nulls_first=ordering.na_last is False, + ) + for ordering in node.order_by + ) + aggregations: tuple[tuple[str, sge.Expression], ...] = tuple( + (id.sql, aggregate_compiler.compile_aggregate(agg, order_by=ordering_cols)) + for agg, id in node.aggregations + ) + by_cols: tuple[sge.Expression, ...] = tuple( + scalar_compiler.compile_scalar_expression(by_col) + for by_col in node.by_column_ids + ) + + dropna_cols = [] + if node.dropna: + for key, by_col in zip(node.by_column_ids, by_cols): + if node.child.field_by_id[key.id].nullable: + dropna_cols.append(by_col) + + return child.aggregate(aggregations, by_cols, tuple(dropna_cols)) + def _replace_unsupported_ops(node: nodes.BigFrameNode): node = nodes.bottom_up(node, rewrite.rewrite_slice) diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index c0bed4090c..b194fe9e5d 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -15,6 +15,7 @@ from __future__ import annotations import dataclasses +import functools import typing from google.cloud import bigquery @@ -25,11 +26,9 @@ import sqlglot.expressions as sge from bigframes import dtypes -from bigframes.core import guid, utils +from bigframes.core import guid, local_data, schema, utils from bigframes.core.compile.sqlglot.expressions import typed_expr import bigframes.core.compile.sqlglot.sqlglot_types as sgt -import bigframes.core.local_data as local_data -import bigframes.core.schema as bf_schema # shapely.wkt.dumps was moved to shapely.io.to_wkt in 2.0. try: @@ -68,7 +67,7 @@ def sql(self) -> str: def from_pyarrow( cls, pa_table: pa.Table, - schema: bf_schema.ArraySchema, + schema: schema.ArraySchema, uid_gen: guid.SequentialUIDGenerator, ) -> SQLGlotIR: """Builds SQLGlot expression from a pyarrow table. @@ -280,9 +279,13 @@ def limit( def filter( self, - condition: sge.Expression, + conditions: tuple[sge.Expression, ...], ) -> SQLGlotIR: """Filters the query by adding a WHERE clause.""" + condition = _and(conditions) + if condition is None: + return SQLGlotIR(expr=self.expr.copy(), uid_gen=self.uid_gen) + new_expr = _select_to_cte( self.expr, sge.to_identifier( @@ -316,10 +319,11 @@ def join( right_ctes = right_select.args.pop("with", []) merged_ctes = [*left_ctes, *right_ctes] - join_conditions = [ - _join_condition(left, right, joins_nulls) for left, right in conditions - ] - join_on = sge.And(expressions=join_conditions) if join_conditions else None + join_on = _and( + tuple( + _join_condition(left, right, joins_nulls) for left, right in conditions + ) + ) join_type_str = join_type if join_type != "outer" else "full outer" new_expr = ( @@ -364,6 +368,47 @@ def sample(self, fraction: float) -> SQLGlotIR: ).where(condition, append=False) return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + def aggregate( + self, + aggregations: tuple[tuple[str, sge.Expression], ...], + by_cols: tuple[sge.Expression, ...], + dropna_cols: tuple[sge.Expression, ...], + ) -> SQLGlotIR: + """Applies the aggregation expressions. + + Args: + aggregations: output_column_id, aggregation_expr tuples + by_cols: column expressions for aggregation + dropna_cols: columns whether null keys should be dropped + """ + aggregations_expr = [ + sge.Alias( + this=expr, + alias=sge.to_identifier(id, quoted=self.quoted), + ) + for id, expr in aggregations + ] + + new_expr = _select_to_cte( + self.expr, + sge.to_identifier( + next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + ), + ) + new_expr = new_expr.group_by(*by_cols).select( + *[*by_cols, *aggregations_expr], append=False + ) + + condition = _and( + tuple( + sg.not_(sge.Is(this=drop_col, expression=sge.Null())) + for drop_col in dropna_cols + ) + ) + if condition is not None: + new_expr = new_expr.where(condition, append=False) + return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + def insert( self, destination: bigquery.TableReference, @@ -552,6 +597,16 @@ def _table(table: bigquery.TableReference) -> sge.Table: ) +def _and(conditions: tuple[sge.Expression, ...]) -> typing.Optional[sge.Expression]: + """Chains multiple expressions together using a logical AND.""" + if not conditions: + return None + + return functools.reduce( + lambda left, right: sge.And(this=left, expression=right), conditions + ) + + def _join_condition( left: typed_expr.TypedExpr, right: typed_expr.TypedExpr, diff --git a/bigframes/core/rewrite/schema_binding.py b/bigframes/core/rewrite/schema_binding.py index af0593211c..40a00ff8f6 100644 --- a/bigframes/core/rewrite/schema_binding.py +++ b/bigframes/core/rewrite/schema_binding.py @@ -13,6 +13,7 @@ # limitations under the License. import dataclasses +import typing from bigframes.core import bigframe_node from bigframes.core import expression as ex @@ -65,4 +66,49 @@ def bind_schema_to_node( conditions=conditions, ) + if isinstance(node, nodes.AggregateNode): + aggregations = [] + for aggregation, id in node.aggregations: + if isinstance(aggregation, ex.UnaryAggregation): + replaced = typing.cast( + ex.Aggregation, + dataclasses.replace( + aggregation, + arg=typing.cast( + ex.RefOrConstant, + ex.bind_schema_fields( + aggregation.arg, node.child.field_by_id + ), + ), + ), + ) + aggregations.append((replaced, id)) + elif isinstance(aggregation, ex.BinaryAggregation): + replaced = typing.cast( + ex.Aggregation, + dataclasses.replace( + aggregation, + left=typing.cast( + ex.RefOrConstant, + ex.bind_schema_fields( + aggregation.left, node.child.field_by_id + ), + ), + right=typing.cast( + ex.RefOrConstant, + ex.bind_schema_fields( + aggregation.right, node.child.field_by_id + ), + ), + ), + ) + aggregations.append((replaced, id)) + else: + aggregations.append((aggregation, id)) + + return dataclasses.replace( + node, + aggregations=tuple(aggregations), + ) + return node diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate/out.sql new file mode 100644 index 0000000000..02bba41a22 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate/out.sql @@ -0,0 +1,27 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_too` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `bfcol_1` AS `bfcol_2`, + `bfcol_0` AS `bfcol_3` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + `bfcol_3`, + COALESCE(SUM(`bfcol_2`), 0) AS `bfcol_6` + FROM `bfcte_1` + WHERE + NOT `bfcol_3` IS NULL + GROUP BY + `bfcol_3` +) +SELECT + `bfcol_3` AS `bool_col`, + `bfcol_6` AS `int64_too` +FROM `bfcte_2` +ORDER BY + `bfcol_3` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate_wo_dropna/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate_wo_dropna/out.sql new file mode 100644 index 0000000000..b8e127eb77 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate_wo_dropna/out.sql @@ -0,0 +1,25 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_too` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `bfcol_1` AS `bfcol_2`, + `bfcol_0` AS `bfcol_3` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + `bfcol_3`, + COALESCE(SUM(`bfcol_2`), 0) AS `bfcol_6` + FROM `bfcte_1` + GROUP BY + `bfcol_3` +) +SELECT + `bfcol_3` AS `bool_col`, + `bfcol_6` AS `int64_too` +FROM `bfcte_2` +ORDER BY + `bfcol_3` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql index 85eab4487a..04ee767f8a 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql @@ -23,8 +23,8 @@ WITH `bfcte_1` AS ( * FROM `bfcte_2` LEFT JOIN `bfcte_3` - ON COALESCE(`bfcol_2`, 0) = COALESCE(`bfcol_6`, 0) - AND COALESCE(`bfcol_2`, 1) = COALESCE(`bfcol_6`, 1) + ON COALESCE(`bfcol_2`, 0) = COALESCE(`bfcol_6`, 0) + AND COALESCE(`bfcol_2`, 1) = COALESCE(`bfcol_6`, 1) ) SELECT `bfcol_3` AS `int64_col`, diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/bool_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/bool_col/out.sql index a073e35c69..05d5fd0695 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/bool_col/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/bool_col/out.sql @@ -23,8 +23,8 @@ WITH `bfcte_1` AS ( * FROM `bfcte_2` INNER JOIN `bfcte_3` - ON COALESCE(CAST(`bfcol_3` AS STRING), '0') = COALESCE(CAST(`bfcol_7` AS STRING), '0') - AND COALESCE(CAST(`bfcol_3` AS STRING), '1') = COALESCE(CAST(`bfcol_7` AS STRING), '1') + ON COALESCE(CAST(`bfcol_3` AS STRING), '0') = COALESCE(CAST(`bfcol_7` AS STRING), '0') + AND COALESCE(CAST(`bfcol_3` AS STRING), '1') = COALESCE(CAST(`bfcol_7` AS STRING), '1') ) SELECT `bfcol_2` AS `rowindex_x`, diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/float64_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/float64_col/out.sql index 1d04343f31..9e6a4094b2 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/float64_col/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/float64_col/out.sql @@ -23,8 +23,8 @@ WITH `bfcte_1` AS ( * FROM `bfcte_2` INNER JOIN `bfcte_3` - ON IF(IS_NAN(`bfcol_3`), 2, COALESCE(`bfcol_3`, 0)) = IF(IS_NAN(`bfcol_7`), 2, COALESCE(`bfcol_7`, 0)) - AND IF(IS_NAN(`bfcol_3`), 3, COALESCE(`bfcol_3`, 1)) = IF(IS_NAN(`bfcol_7`), 3, COALESCE(`bfcol_7`, 1)) + ON IF(IS_NAN(`bfcol_3`), 2, COALESCE(`bfcol_3`, 0)) = IF(IS_NAN(`bfcol_7`), 2, COALESCE(`bfcol_7`, 0)) + AND IF(IS_NAN(`bfcol_3`), 3, COALESCE(`bfcol_3`, 1)) = IF(IS_NAN(`bfcol_7`), 3, COALESCE(`bfcol_7`, 1)) ) SELECT `bfcol_2` AS `rowindex_x`, diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/int64_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/int64_col/out.sql index 80ec5d19d1..bd03e05cba 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/int64_col/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/int64_col/out.sql @@ -23,8 +23,8 @@ WITH `bfcte_1` AS ( * FROM `bfcte_2` INNER JOIN `bfcte_3` - ON COALESCE(`bfcol_3`, 0) = COALESCE(`bfcol_7`, 0) - AND COALESCE(`bfcol_3`, 1) = COALESCE(`bfcol_7`, 1) + ON COALESCE(`bfcol_3`, 0) = COALESCE(`bfcol_7`, 0) + AND COALESCE(`bfcol_3`, 1) = COALESCE(`bfcol_7`, 1) ) SELECT `bfcol_2` AS `rowindex_x`, diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/numeric_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/numeric_col/out.sql index 22ce6f5b29..6b77ead97c 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/numeric_col/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/numeric_col/out.sql @@ -23,8 +23,8 @@ WITH `bfcte_1` AS ( * FROM `bfcte_2` INNER JOIN `bfcte_3` - ON COALESCE(`bfcol_3`, CAST(0 AS NUMERIC)) = COALESCE(`bfcol_7`, CAST(0 AS NUMERIC)) - AND COALESCE(`bfcol_3`, CAST(1 AS NUMERIC)) = COALESCE(`bfcol_7`, CAST(1 AS NUMERIC)) + ON COALESCE(`bfcol_3`, CAST(0 AS NUMERIC)) = COALESCE(`bfcol_7`, CAST(0 AS NUMERIC)) + AND COALESCE(`bfcol_3`, CAST(1 AS NUMERIC)) = COALESCE(`bfcol_7`, CAST(1 AS NUMERIC)) ) SELECT `bfcol_2` AS `rowindex_x`, diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/string_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/string_col/out.sql index 5e8d072d46..1903d5fc22 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/string_col/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/string_col/out.sql @@ -18,8 +18,8 @@ WITH `bfcte_1` AS ( * FROM `bfcte_1` INNER JOIN `bfcte_2` - ON COALESCE(CAST(`bfcol_1` AS STRING), '0') = COALESCE(CAST(`bfcol_5` AS STRING), '0') - AND COALESCE(CAST(`bfcol_1` AS STRING), '1') = COALESCE(CAST(`bfcol_5` AS STRING), '1') + ON COALESCE(CAST(`bfcol_1` AS STRING), '0') = COALESCE(CAST(`bfcol_5` AS STRING), '0') + AND COALESCE(CAST(`bfcol_1` AS STRING), '1') = COALESCE(CAST(`bfcol_5` AS STRING), '1') ) SELECT `bfcol_0` AS `rowindex_x`, diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/time_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/time_col/out.sql index b0df619f25..9e3477d4a9 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/time_col/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/time_col/out.sql @@ -18,8 +18,8 @@ WITH `bfcte_1` AS ( * FROM `bfcte_1` INNER JOIN `bfcte_2` - ON COALESCE(CAST(`bfcol_1` AS STRING), '0') = COALESCE(CAST(`bfcol_5` AS STRING), '0') - AND COALESCE(CAST(`bfcol_1` AS STRING), '1') = COALESCE(CAST(`bfcol_5` AS STRING), '1') + ON COALESCE(CAST(`bfcol_1` AS STRING), '0') = COALESCE(CAST(`bfcol_5` AS STRING), '0') + AND COALESCE(CAST(`bfcol_1` AS STRING), '1') = COALESCE(CAST(`bfcol_5` AS STRING), '1') ) SELECT `bfcol_0` AS `rowindex_x`, diff --git a/tests/unit/core/compile/sqlglot/test_compile_aggregate.py b/tests/unit/core/compile/sqlglot/test_compile_aggregate.py new file mode 100644 index 0000000000..d59c5e5068 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/test_compile_aggregate.py @@ -0,0 +1,33 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import bigframes.pandas as bpd + +pytest.importorskip("pytest_snapshot") + + +def test_compile_aggregate(scalar_types_df: bpd.DataFrame, snapshot): + result = scalar_types_df["int64_too"].groupby(scalar_types_df["bool_col"]).sum() + snapshot.assert_match(result.to_frame().sql, "out.sql") + + +def test_compile_aggregate_wo_dropna(scalar_types_df: bpd.DataFrame, snapshot): + result = ( + scalar_types_df["int64_too"] + .groupby(scalar_types_df["bool_col"], dropna=False) + .sum() + ) + snapshot.assert_match(result.to_frame().sql, "out.sql")