Skip to content

Commit 305e57d

Browse files
authored
refactor: support agg_ops.ArrayAggOp and StringAggOp to sqlglot compiler (#2163)
1 parent bbfdb20 commit 305e57d

File tree

4 files changed

+138
-8
lines changed

4 files changed

+138
-8
lines changed

bigframes/core/compile/sqlglot/aggregations/ordered_unary_compiler.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,8 @@
1414

1515
from __future__ import annotations
1616

17-
import typing
18-
1917
import sqlglot.expressions as sge
2018

21-
from bigframes.core import window_spec
2219
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
2320
import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr
2421
from bigframes.operations import aggregations as agg_ops
@@ -29,9 +26,35 @@
2926
def compile(
3027
op: agg_ops.WindowOp,
3128
column: typed_expr.TypedExpr,
32-
window: typing.Optional[window_spec.WindowSpec] = None,
33-
order_by: typing.Sequence[sge.Expression] = [],
29+
*,
30+
order_by: tuple[sge.Expression, ...],
31+
) -> sge.Expression:
32+
return ORDERED_UNARY_OP_REGISTRATION[op](op, column, order_by=order_by)
33+
34+
35+
@ORDERED_UNARY_OP_REGISTRATION.register(agg_ops.ArrayAggOp)
36+
def _(
37+
op: agg_ops.ArrayAggOp,
38+
column: typed_expr.TypedExpr,
39+
*,
40+
order_by: tuple[sge.Expression, ...],
3441
) -> sge.Expression:
35-
return ORDERED_UNARY_OP_REGISTRATION[op](
36-
op, column, window=window, order_by=order_by
37-
)
42+
expr = column.expr
43+
if len(order_by) > 0:
44+
expr = sge.Order(this=column.expr, expressions=list(order_by))
45+
return sge.IgnoreNulls(this=sge.ArrayAgg(this=expr))
46+
47+
48+
@ORDERED_UNARY_OP_REGISTRATION.register(agg_ops.StringAggOp)
49+
def _(
50+
op: agg_ops.StringAggOp,
51+
column: typed_expr.TypedExpr,
52+
*,
53+
order_by: tuple[sge.Expression, ...],
54+
) -> sge.Expression:
55+
expr = column.expr
56+
if len(order_by) > 0:
57+
expr = sge.Order(this=expr, expressions=list(order_by))
58+
59+
expr = sge.GroupConcat(this=expr, separator=sge.convert(op.sep))
60+
return sge.func("COALESCE", expr, sge.convert(""))
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
ARRAY_AGG(`bfcol_0` IGNORE NULLS ORDER BY `bfcol_0` IS NULL ASC, `bfcol_0` ASC) AS `bfcol_1`
8+
FROM `bfcte_0`
9+
)
10+
SELECT
11+
`bfcol_1` AS `int64_col`
12+
FROM `bfcte_1`
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
COALESCE(STRING_AGG(`bfcol_0`, ','
8+
ORDER BY
9+
`bfcol_0` IS NULL ASC,
10+
`bfcol_0` ASC), '') AS `bfcol_1`
11+
FROM `bfcte_0`
12+
)
13+
SELECT
14+
`bfcol_1` AS `string_col`
15+
FROM `bfcte_1`
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sys
16+
import typing
17+
18+
import pytest
19+
20+
from bigframes.core import agg_expressions as agg_exprs
21+
from bigframes.core import array_value, identifiers, nodes, ordering
22+
from bigframes.operations import aggregations as agg_ops
23+
import bigframes.pandas as bpd
24+
25+
pytest.importorskip("pytest_snapshot")
26+
27+
28+
def _apply_ordered_unary_agg_ops(
29+
obj: bpd.DataFrame,
30+
ops_list: typing.Sequence[agg_exprs.UnaryAggregation],
31+
new_names: typing.Sequence[str],
32+
ordering_args: typing.Sequence[str],
33+
) -> str:
34+
ordering_exprs = tuple(ordering.ascending_over(arg) for arg in ordering_args)
35+
aggs = [(op, identifiers.ColumnId(name)) for op, name in zip(ops_list, new_names)]
36+
37+
agg_node = nodes.AggregateNode(
38+
obj._block.expr.node,
39+
aggregations=tuple(aggs),
40+
by_column_ids=(),
41+
order_by=ordering_exprs,
42+
)
43+
result = array_value.ArrayValue(agg_node)
44+
45+
sql = result.session._executor.to_sql(result, enable_cache=False)
46+
return sql
47+
48+
49+
def test_array_agg(scalar_types_df: bpd.DataFrame, snapshot):
50+
# TODO: Verify "NULL LAST" syntax issue on Python < 3.12
51+
if sys.version_info < (3, 12):
52+
pytest.skip(
53+
"Skipping test due to inconsistent SQL formatting on Python < 3.12.",
54+
)
55+
56+
col_name = "int64_col"
57+
bf_df = scalar_types_df[[col_name]]
58+
agg_expr = agg_ops.ArrayAggOp().as_expr(col_name)
59+
sql = _apply_ordered_unary_agg_ops(
60+
bf_df, [agg_expr], [col_name], ordering_args=[col_name]
61+
)
62+
63+
snapshot.assert_match(sql, "out.sql")
64+
65+
66+
def test_string_agg(scalar_types_df: bpd.DataFrame, snapshot):
67+
# TODO: Verify "NULL LAST" syntax issue on Python < 3.12
68+
if sys.version_info < (3, 12):
69+
pytest.skip(
70+
"Skipping test due to inconsistent SQL formatting on Python < 3.12.",
71+
)
72+
73+
col_name = "string_col"
74+
bf_df = scalar_types_df[[col_name]]
75+
agg_expr = agg_ops.StringAggOp(sep=",").as_expr(col_name)
76+
sql = _apply_ordered_unary_agg_ops(
77+
bf_df, [agg_expr], [col_name], ordering_args=[col_name]
78+
)
79+
80+
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)