diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 8ed5510ec2..598a89e4eb 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -47,6 +47,18 @@ def _( return apply_window_if_present(sge.func("COUNT", column.expr), window) +@UNARY_OP_REGISTRATION.register(agg_ops.DenseRankOp) +def _( + op: agg_ops.DenseRankOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + # Ranking functions do not support window framing clauses. + return apply_window_if_present( + sge.func("DENSE_RANK"), window, include_framing_clauses=False + ) + + @UNARY_OP_REGISTRATION.register(agg_ops.MaxOp) def _( op: agg_ops.MaxOp, @@ -106,6 +118,18 @@ def _( return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window) +@UNARY_OP_REGISTRATION.register(agg_ops.RankOp) +def _( + op: agg_ops.RankOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + # Ranking functions do not support window framing clauses. + return apply_window_if_present( + sge.func("RANK"), window, include_framing_clauses=False + ) + + @UNARY_OP_REGISTRATION.register(agg_ops.SumOp) def _( op: agg_ops.SumOp, diff --git a/bigframes/core/compile/sqlglot/aggregations/windows.py b/bigframes/core/compile/sqlglot/aggregations/windows.py index 4d7a3f7406..1bfa72b878 100644 --- a/bigframes/core/compile/sqlglot/aggregations/windows.py +++ b/bigframes/core/compile/sqlglot/aggregations/windows.py @@ -25,6 +25,7 @@ def apply_window_if_present( value: sge.Expression, window: typing.Optional[window_spec.WindowSpec] = None, + include_framing_clauses: bool = True, ) -> sge.Expression: if window is None: return value @@ -64,6 +65,9 @@ def apply_window_if_present( if not window.bounds and not order: return sge.Window(this=value, partition_by=group_by) + if not window.bounds and not include_framing_clauses: + return sge.Window(this=value, partition_by=group_by, order=order) + kind = ( "ROWS" if isinstance(window.bounds, window_spec.RowsWindowBounds) else "RANGE" ) diff --git a/bigframes/operations/aggregations.py b/bigframes/operations/aggregations.py index 7b6998b90e..f6e8600d42 100644 --- a/bigframes/operations/aggregations.py +++ b/bigframes/operations/aggregations.py @@ -519,6 +519,8 @@ def implicitly_inherits_order(self): @dataclasses.dataclass(frozen=True) class DenseRankOp(UnaryWindowOp): + name: ClassVar[str] = "dense_rank" + @property def skips_nulls(self): return False diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_dense_rank/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_dense_rank/out.sql new file mode 100644 index 0000000000..38b6ed9f5c --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_dense_rank/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + DENSE_RANK() OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `agg_int64` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_rank/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_rank/out.sql new file mode 100644 index 0000000000..5de2330ef6 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_rank/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + RANK() OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `agg_int64` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index a5ffda0e65..bf2523930f 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -17,7 +17,14 @@ import pytest from bigframes.core import agg_expressions as agg_exprs -from bigframes.core import array_value, identifiers, nodes +from bigframes.core import ( + array_value, + expression, + identifiers, + nodes, + ordering, + window_spec, +) from bigframes.operations import aggregations as agg_ops import bigframes.pandas as bpd @@ -38,6 +45,24 @@ def _apply_unary_agg_ops( return sql +def _apply_unary_window_op( + obj: bpd.DataFrame, + op: agg_exprs.UnaryAggregation, + window_spec: window_spec.WindowSpec, + new_name: str, +) -> str: + win_node = nodes.WindowOpNode( + obj._block.expr.node, + expression=op, + window_spec=window_spec, + output_name=identifiers.ColumnId(new_name), + ) + result = array_value.ArrayValue(win_node).select_columns([new_name]) + + sql = result.session._executor.to_sql(result, enable_cache=False) + return sql + + def test_count(scalar_types_df: bpd.DataFrame, snapshot): col_name = "int64_col" bf_df = scalar_types_df[[col_name]] @@ -47,6 +72,18 @@ def test_count(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_expr = agg_exprs.UnaryAggregation( + agg_ops.DenseRankOp(), expression.deref(col_name) + ) + window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),)) + sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64") + + snapshot.assert_match(sql, "out.sql") + + def test_max(scalar_types_df: bpd.DataFrame, snapshot): col_name = "int64_col" bf_df = scalar_types_df[[col_name]] @@ -104,6 +141,17 @@ def test_min(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_rank(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_expr = agg_exprs.UnaryAggregation(agg_ops.RankOp(), expression.deref(col_name)) + + window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),)) + sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64") + + snapshot.assert_match(sql, "out.sql") + + def test_sum(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["int64_col", "bool_col"]] agg_ops_map = {