diff --git a/bigframes/core/compile/sqlglot/aggregations/ordered_unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/ordered_unary_compiler.py index dea30ec206..9024a9ec89 100644 --- a/bigframes/core/compile/sqlglot/aggregations/ordered_unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/ordered_unary_compiler.py @@ -14,11 +14,8 @@ from __future__ import annotations -import typing - import sqlglot.expressions as sge -from bigframes.core import window_spec import bigframes.core.compile.sqlglot.aggregations.op_registration as reg import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr from bigframes.operations import aggregations as agg_ops @@ -29,9 +26,35 @@ def compile( op: agg_ops.WindowOp, column: typed_expr.TypedExpr, - window: typing.Optional[window_spec.WindowSpec] = None, - order_by: typing.Sequence[sge.Expression] = [], + *, + order_by: tuple[sge.Expression, ...], +) -> sge.Expression: + return ORDERED_UNARY_OP_REGISTRATION[op](op, column, order_by=order_by) + + +@ORDERED_UNARY_OP_REGISTRATION.register(agg_ops.ArrayAggOp) +def _( + op: agg_ops.ArrayAggOp, + column: typed_expr.TypedExpr, + *, + order_by: tuple[sge.Expression, ...], ) -> sge.Expression: - return ORDERED_UNARY_OP_REGISTRATION[op]( - op, column, window=window, order_by=order_by - ) + expr = column.expr + if len(order_by) > 0: + expr = sge.Order(this=column.expr, expressions=list(order_by)) + return sge.IgnoreNulls(this=sge.ArrayAgg(this=expr)) + + +@ORDERED_UNARY_OP_REGISTRATION.register(agg_ops.StringAggOp) +def _( + op: agg_ops.StringAggOp, + column: typed_expr.TypedExpr, + *, + order_by: tuple[sge.Expression, ...], +) -> sge.Expression: + expr = column.expr + if len(order_by) > 0: + expr = sge.Order(this=expr, expressions=list(order_by)) + + expr = sge.GroupConcat(this=expr, separator=sge.convert(op.sep)) + return sge.func("COALESCE", expr, sge.convert("")) diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_ordered_unary_compiler/test_array_agg/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_ordered_unary_compiler/test_array_agg/out.sql new file mode 100644 index 0000000000..43e0a03db4 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_ordered_unary_compiler/test_array_agg/out.sql @@ -0,0 +1,12 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + ARRAY_AGG(`bfcol_0` IGNORE NULLS ORDER BY `bfcol_0` IS NULL ASC, `bfcol_0` ASC) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `int64_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_ordered_unary_compiler/test_string_agg/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_ordered_unary_compiler/test_string_agg/out.sql new file mode 100644 index 0000000000..115d7e37ee --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_ordered_unary_compiler/test_string_agg/out.sql @@ -0,0 +1,15 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + COALESCE(STRING_AGG(`bfcol_0`, ',' + ORDER BY + `bfcol_0` IS NULL ASC, + `bfcol_0` ASC), '') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_ordered_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_ordered_unary_compiler.py new file mode 100644 index 0000000000..2f88fb5d0c --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/test_ordered_unary_compiler.py @@ -0,0 +1,80 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import typing + +import pytest + +from bigframes.core import agg_expressions as agg_exprs +from bigframes.core import array_value, identifiers, nodes, ordering +from bigframes.operations import aggregations as agg_ops +import bigframes.pandas as bpd + +pytest.importorskip("pytest_snapshot") + + +def _apply_ordered_unary_agg_ops( + obj: bpd.DataFrame, + ops_list: typing.Sequence[agg_exprs.UnaryAggregation], + new_names: typing.Sequence[str], + ordering_args: typing.Sequence[str], +) -> str: + ordering_exprs = tuple(ordering.ascending_over(arg) for arg in ordering_args) + aggs = [(op, identifiers.ColumnId(name)) for op, name in zip(ops_list, new_names)] + + agg_node = nodes.AggregateNode( + obj._block.expr.node, + aggregations=tuple(aggs), + by_column_ids=(), + order_by=ordering_exprs, + ) + result = array_value.ArrayValue(agg_node) + + sql = result.session._executor.to_sql(result, enable_cache=False) + return sql + + +def test_array_agg(scalar_types_df: bpd.DataFrame, snapshot): + # TODO: Verify "NULL LAST" syntax issue on Python < 3.12 + if sys.version_info < (3, 12): + pytest.skip( + "Skipping test due to inconsistent SQL formatting on Python < 3.12.", + ) + + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_expr = agg_ops.ArrayAggOp().as_expr(col_name) + sql = _apply_ordered_unary_agg_ops( + bf_df, [agg_expr], [col_name], ordering_args=[col_name] + ) + + snapshot.assert_match(sql, "out.sql") + + +def test_string_agg(scalar_types_df: bpd.DataFrame, snapshot): + # TODO: Verify "NULL LAST" syntax issue on Python < 3.12 + if sys.version_info < (3, 12): + pytest.skip( + "Skipping test due to inconsistent SQL formatting on Python < 3.12.", + ) + + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + agg_expr = agg_ops.StringAggOp(sep=",").as_expr(col_name) + sql = _apply_ordered_unary_agg_ops( + bf_df, [agg_expr], [col_name], ordering_args=[col_name] + ) + + snapshot.assert_match(sql, "out.sql")