Skip to content

Commit 5ed694a

Browse files
committed
refactor: support agg_ops.RowNumberOp for sqlglot compiler
1 parent a3c2522 commit 5ed694a

File tree

8 files changed

+139
-11
lines changed

8 files changed

+139
-11
lines changed

bigframes/core/compile/sqlglot/aggregate_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def compile_analytic(
6363
window: window_spec.WindowSpec,
6464
) -> sge.Expression:
6565
if isinstance(aggregate, agg_expressions.NullaryAggregation):
66-
return nullary_compiler.compile(aggregate.op)
66+
return nullary_compiler.compile(aggregate.op, window)
6767
if isinstance(aggregate, agg_expressions.UnaryAggregation):
6868
column = typed_expr.TypedExpr(
6969
scalar_compiler.scalar_op_compiler.compile_expression(aggregate.arg),

bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,15 @@ def _(
3939
window: typing.Optional[window_spec.WindowSpec] = None,
4040
) -> sge.Expression:
4141
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)
42+
43+
44+
@NULLARY_OP_REGISTRATION.register(agg_ops.RowNumberOp)
45+
def _(
46+
op: agg_ops.RowNumberOp,
47+
window: typing.Optional[window_spec.WindowSpec] = None,
48+
) -> sge.Expression:
49+
result: sge.Expression = sge.func("ROW_NUMBER")
50+
if window is None:
51+
# ROW_NUMBER always needs an OVER clause.
52+
return sge.Window(this=result)
53+
return apply_window_if_present(result, window)

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,7 @@ def _(
8484
column: typed_expr.TypedExpr,
8585
window: typing.Optional[window_spec.WindowSpec] = None,
8686
) -> sge.Expression:
87-
# Ranking functions do not support window framing clauses.
88-
return apply_window_if_present(
89-
sge.func("DENSE_RANK"), window, include_framing_clauses=False
90-
)
87+
return apply_window_if_present(sge.func("DENSE_RANK"), window)
9188

9289

9390
@UNARY_OP_REGISTRATION.register(agg_ops.MaxOp)
@@ -165,10 +162,7 @@ def _(
165162
column: typed_expr.TypedExpr,
166163
window: typing.Optional[window_spec.WindowSpec] = None,
167164
) -> sge.Expression:
168-
# Ranking functions do not support window framing clauses.
169-
return apply_window_if_present(
170-
sge.func("RANK"), window, include_framing_clauses=False
171-
)
165+
return apply_window_if_present(sge.func("RANK"), window)
172166

173167

174168
@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp)

bigframes/core/compile/sqlglot/aggregations/windows.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
def apply_window_if_present(
2626
value: sge.Expression,
2727
window: typing.Optional[window_spec.WindowSpec] = None,
28-
include_framing_clauses: bool = True,
2928
) -> sge.Expression:
3029
if window is None:
3130
return value
@@ -65,7 +64,7 @@ def apply_window_if_present(
6564
if not window.bounds and not order:
6665
return sge.Window(this=value, partition_by=group_by)
6766

68-
if not window.bounds and not include_framing_clauses:
67+
if not window.bounds:
6968
return sge.Window(this=value, partition_by=group_by, order=order)
7069

7170
kind = (
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
ROW_NUMBER() OVER () AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `row_number`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
ROW_NUMBER() OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `row_number`
13+
FROM `bfcte_1`
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`rowindex` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
COUNT(1) AS `bfcol_2`
8+
FROM `bfcte_0`
9+
)
10+
SELECT
11+
`bfcol_2` AS `size`
12+
FROM `bfcte_1`
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import typing
16+
17+
import pytest
18+
19+
from bigframes.core import agg_expressions as agg_exprs
20+
from bigframes.core import array_value, identifiers, nodes, ordering, window_spec
21+
from bigframes.operations import aggregations as agg_ops
22+
import bigframes.pandas as bpd
23+
24+
pytest.importorskip("pytest_snapshot")
25+
26+
27+
def _apply_nullary_agg_ops(
28+
obj: bpd.DataFrame,
29+
ops_list: typing.Sequence[agg_exprs.NullaryAggregation],
30+
new_names: typing.Sequence[str],
31+
) -> str:
32+
aggs = [(op, identifiers.ColumnId(name)) for op, name in zip(ops_list, new_names)]
33+
34+
agg_node = nodes.AggregateNode(obj._block.expr.node, aggregations=tuple(aggs))
35+
result = array_value.ArrayValue(agg_node)
36+
37+
sql = result.session._executor.to_sql(result, enable_cache=False)
38+
return sql
39+
40+
41+
def _apply_nullary_window_op(
42+
obj: bpd.DataFrame,
43+
op: agg_exprs.NullaryAggregation,
44+
window_spec: window_spec.WindowSpec,
45+
new_name: str,
46+
) -> str:
47+
win_node = nodes.WindowOpNode(
48+
obj._block.expr.node,
49+
expression=op,
50+
window_spec=window_spec,
51+
output_name=identifiers.ColumnId(new_name),
52+
)
53+
result = array_value.ArrayValue(win_node).select_columns([new_name])
54+
55+
sql = result.session._executor.to_sql(result, enable_cache=False)
56+
return sql
57+
58+
59+
def test_size(scalar_types_df: bpd.DataFrame, snapshot):
60+
bf_df = scalar_types_df
61+
agg_expr = agg_ops.SizeOp().as_expr()
62+
sql = _apply_nullary_agg_ops(bf_df, [agg_expr], ["size"])
63+
64+
snapshot.assert_match(sql, "out.sql")
65+
66+
67+
def test_row_number(scalar_types_df: bpd.DataFrame, snapshot):
68+
bf_df = scalar_types_df
69+
agg_expr = agg_exprs.NullaryAggregation(agg_ops.RowNumberOp())
70+
window = window_spec.WindowSpec()
71+
sql = _apply_nullary_window_op(bf_df, agg_expr, window, "row_number")
72+
73+
snapshot.assert_match(sql, "out.sql")
74+
75+
76+
def test_row_number_with_window(scalar_types_df: bpd.DataFrame, snapshot):
77+
col_name = "int64_col"
78+
bf_df = scalar_types_df[[col_name, "int64_too"]]
79+
agg_expr = agg_exprs.NullaryAggregation(agg_ops.RowNumberOp())
80+
81+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
82+
# window = window_spec.unbound(ordering=(ordering.ascending_over(col_name),ordering.ascending_over("int64_too")))
83+
sql = _apply_nullary_window_op(bf_df, agg_expr, window, "row_number")
84+
85+
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)