From 5ed694a62c4e86b61984b3651d45f5d465141420 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 25 Sep 2025 21:42:56 +0000 Subject: [PATCH] refactor: support agg_ops.RowNumberOp for sqlglot compiler --- .../compile/sqlglot/aggregate_compiler.py | 2 +- .../sqlglot/aggregations/nullary_compiler.py | 12 +++ .../sqlglot/aggregations/unary_compiler.py | 10 +-- .../compile/sqlglot/aggregations/windows.py | 3 +- .../test_row_number/out.sql | 13 +++ .../test_row_number_with_window/out.sql | 13 +++ .../test_nullary_compiler/test_size/out.sql | 12 +++ .../aggregations/test_nullary_compiler.py | 85 +++++++++++++++++++ 8 files changed, 139 insertions(+), 11 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_row_number/out.sql create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_row_number_with_window/out.sql create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_size/out.sql create mode 100644 tests/unit/core/compile/sqlglot/aggregations/test_nullary_compiler.py diff --git a/bigframes/core/compile/sqlglot/aggregate_compiler.py b/bigframes/core/compile/sqlglot/aggregate_compiler.py index 08bca535a8..b86ae196f6 100644 --- a/bigframes/core/compile/sqlglot/aggregate_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregate_compiler.py @@ -63,7 +63,7 @@ def compile_analytic( window: window_spec.WindowSpec, ) -> sge.Expression: if isinstance(aggregate, agg_expressions.NullaryAggregation): - return nullary_compiler.compile(aggregate.op) + return nullary_compiler.compile(aggregate.op, window) if isinstance(aggregate, agg_expressions.UnaryAggregation): column = typed_expr.TypedExpr( scalar_compiler.scalar_op_compiler.compile_expression(aggregate.arg), diff --git a/bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py index 99e3562b42..c6418591ba 100644 --- a/bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py @@ -39,3 +39,15 @@ def _( window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window) + + +@NULLARY_OP_REGISTRATION.register(agg_ops.RowNumberOp) +def _( + op: agg_ops.RowNumberOp, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + result: sge.Expression = sge.func("ROW_NUMBER") + if window is None: + # ROW_NUMBER always needs an OVER clause. + return sge.Window(this=result) + return apply_window_if_present(result, window) diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 11d53cdd4c..e8baa15bce 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -84,10 +84,7 @@ def _( column: typed_expr.TypedExpr, window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: - # Ranking functions do not support window framing clauses. - return apply_window_if_present( - sge.func("DENSE_RANK"), window, include_framing_clauses=False - ) + return apply_window_if_present(sge.func("DENSE_RANK"), window) @UNARY_OP_REGISTRATION.register(agg_ops.MaxOp) @@ -165,10 +162,7 @@ def _( column: typed_expr.TypedExpr, window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: - # Ranking functions do not support window framing clauses. - return apply_window_if_present( - sge.func("RANK"), window, include_framing_clauses=False - ) + return apply_window_if_present(sge.func("RANK"), window) @UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp) diff --git a/bigframes/core/compile/sqlglot/aggregations/windows.py b/bigframes/core/compile/sqlglot/aggregations/windows.py index 1bfa72b878..5e38bf120e 100644 --- a/bigframes/core/compile/sqlglot/aggregations/windows.py +++ b/bigframes/core/compile/sqlglot/aggregations/windows.py @@ -25,7 +25,6 @@ def apply_window_if_present( value: sge.Expression, window: typing.Optional[window_spec.WindowSpec] = None, - include_framing_clauses: bool = True, ) -> sge.Expression: if window is None: return value @@ -65,7 +64,7 @@ def apply_window_if_present( if not window.bounds and not order: return sge.Window(this=value, partition_by=group_by) - if not window.bounds and not include_framing_clauses: + if not window.bounds: return sge.Window(this=value, partition_by=group_by, order=order) kind = ( diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_row_number/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_row_number/out.sql new file mode 100644 index 0000000000..d20a635e3d --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_row_number/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + ROW_NUMBER() OVER () AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `row_number` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_row_number_with_window/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_row_number_with_window/out.sql new file mode 100644 index 0000000000..2cee8a228f --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_row_number_with_window/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + ROW_NUMBER() OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `row_number` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_size/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_size/out.sql new file mode 100644 index 0000000000..19ae8aa3fd --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_size/out.sql @@ -0,0 +1,12 @@ +WITH `bfcte_0` AS ( + SELECT + `rowindex` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + COUNT(1) AS `bfcol_2` + FROM `bfcte_0` +) +SELECT + `bfcol_2` AS `size` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_nullary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_nullary_compiler.py new file mode 100644 index 0000000000..2348b95496 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/test_nullary_compiler.py @@ -0,0 +1,85 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +import pytest + +from bigframes.core import agg_expressions as agg_exprs +from bigframes.core import array_value, identifiers, nodes, ordering, window_spec +from bigframes.operations import aggregations as agg_ops +import bigframes.pandas as bpd + +pytest.importorskip("pytest_snapshot") + + +def _apply_nullary_agg_ops( + obj: bpd.DataFrame, + ops_list: typing.Sequence[agg_exprs.NullaryAggregation], + new_names: typing.Sequence[str], +) -> str: + aggs = [(op, identifiers.ColumnId(name)) for op, name in zip(ops_list, new_names)] + + agg_node = nodes.AggregateNode(obj._block.expr.node, aggregations=tuple(aggs)) + result = array_value.ArrayValue(agg_node) + + sql = result.session._executor.to_sql(result, enable_cache=False) + return sql + + +def _apply_nullary_window_op( + obj: bpd.DataFrame, + op: agg_exprs.NullaryAggregation, + window_spec: window_spec.WindowSpec, + new_name: str, +) -> str: + win_node = nodes.WindowOpNode( + obj._block.expr.node, + expression=op, + window_spec=window_spec, + output_name=identifiers.ColumnId(new_name), + ) + result = array_value.ArrayValue(win_node).select_columns([new_name]) + + sql = result.session._executor.to_sql(result, enable_cache=False) + return sql + + +def test_size(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df + agg_expr = agg_ops.SizeOp().as_expr() + sql = _apply_nullary_agg_ops(bf_df, [agg_expr], ["size"]) + + snapshot.assert_match(sql, "out.sql") + + +def test_row_number(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df + agg_expr = agg_exprs.NullaryAggregation(agg_ops.RowNumberOp()) + window = window_spec.WindowSpec() + sql = _apply_nullary_window_op(bf_df, agg_expr, window, "row_number") + + snapshot.assert_match(sql, "out.sql") + + +def test_row_number_with_window(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "int64_col" + bf_df = scalar_types_df[[col_name, "int64_too"]] + agg_expr = agg_exprs.NullaryAggregation(agg_ops.RowNumberOp()) + + window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),)) + # window = window_spec.unbound(ordering=(ordering.ascending_over(col_name),ordering.ascending_over("int64_too"))) + sql = _apply_nullary_window_op(bf_df, agg_expr, window, "row_number") + + snapshot.assert_match(sql, "out.sql")