diff --git a/bigframes/session/polars_executor.py b/bigframes/session/polars_executor.py index 8d415032fb..24acda35dc 100644 --- a/bigframes/session/polars_executor.py +++ b/bigframes/session/polars_executor.py @@ -20,6 +20,7 @@ from bigframes.core import array_value, bigframe_node, expression, local_data, nodes import bigframes.operations +from bigframes.operations import aggregations as agg_ops from bigframes.session import executor, semi_executor if TYPE_CHECKING: @@ -32,9 +33,11 @@ nodes.ReversedNode, nodes.SelectionNode, nodes.SliceNode, + nodes.AggregateNode, ) _COMPATIBLE_SCALAR_OPS = () +_COMPATIBLE_AGG_OPS = (agg_ops.SizeOp, agg_ops.SizeUnaryOp) def _get_expr_ops(expr: expression.Expression) -> set[bigframes.operations.ScalarOp]: @@ -48,7 +51,8 @@ def _is_node_polars_executable(node: nodes.BigFrameNode): return False for expr in node._node_expressions: if isinstance(expr, expression.Aggregation): - return False + if not type(expr.op) in _COMPATIBLE_AGG_OPS: + return False if isinstance(expr, expression.Expression): if not _get_expr_ops(expr).issubset(_COMPATIBLE_SCALAR_OPS): return False diff --git a/tests/system/small/engines/test_aggregation.py b/tests/system/small/engines/test_aggregation.py new file mode 100644 index 0000000000..2c323a5f28 --- /dev/null +++ b/tests/system/small/engines/test_aggregation.py @@ -0,0 +1,82 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from bigframes.core import array_value, expression, identifiers, nodes +import bigframes.operations.aggregations as agg_ops +from bigframes.session import polars_executor +from bigframes.testing.engine_utils import assert_equivalence_execution + +pytest.importorskip("polars") + +# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree. +REFERENCE_ENGINE = polars_executor.PolarsExecutor() + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_aggregate_size( + scalars_array_value: array_value.ArrayValue, + engine, +): + node = nodes.AggregateNode( + scalars_array_value.node, + aggregations=( + ( + expression.NullaryAggregation(agg_ops.SizeOp()), + identifiers.ColumnId("size_op"), + ), + ( + expression.UnaryAggregation( + agg_ops.SizeUnaryOp(), expression.deref("string_col") + ), + identifiers.ColumnId("unary_size_op"), + ), + ), + ) + assert_equivalence_execution(node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize( + "grouping_cols", + [ + ["bool_col"], + ["string_col", "int64_col"], + ["date_col"], + ["datetime_col"], + ["timestamp_col"], + ["bytes_col"], + ], +) +def test_engines_grouped_aggregate( + scalars_array_value: array_value.ArrayValue, engine, grouping_cols +): + node = nodes.AggregateNode( + scalars_array_value.node, + aggregations=( + ( + expression.NullaryAggregation(agg_ops.SizeOp()), + identifiers.ColumnId("size_op"), + ), + ( + expression.UnaryAggregation( + agg_ops.SizeUnaryOp(), expression.deref("string_col") + ), + identifiers.ColumnId("unary_size_op"), + ), + ), + by_column_ids=tuple(expression.deref(id) for id in grouping_cols), + ) + assert_equivalence_execution(node, REFERENCE_ENGINE, engine)