Skip to content

Commit 876d2a4

Browse files
committed
refactor: support agg_ops.DenseRankOp and RankOp for sqlglot compiler
1 parent caa824a commit 876d2a4

File tree

4 files changed

+68
-1
lines changed

4 files changed

+68
-1
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@ def _(
4747
return apply_window_if_present(sge.func("COUNT", column.expr), window)
4848

4949

50+
@UNARY_OP_REGISTRATION.register(agg_ops.DenseRankOp)
51+
def _(
52+
op: agg_ops.DenseRankOp,
53+
column: typed_expr.TypedExpr,
54+
window: typing.Optional[window_spec.WindowSpec] = None,
55+
) -> sge.Expression:
56+
return apply_window_if_present(sge.func("DENSE_RANK"), window)
57+
58+
5059
@UNARY_OP_REGISTRATION.register(agg_ops.MaxOp)
5160
def _(
5261
op: agg_ops.MaxOp,

bigframes/operations/aggregations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,8 @@ def implicitly_inherits_order(self):
519519

520520
@dataclasses.dataclass(frozen=True)
521521
class DenseRankOp(UnaryWindowOp):
522+
name: ClassVar[str] = "dense_rank"
523+
522524
@property
523525
def skips_nulls(self):
524526
return False
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`,
4+
`rowindex` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
DENSE_RANK() OVER (
10+
ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST
11+
RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
12+
) AS `bfcol_4`
13+
FROM `bfcte_0`
14+
)
15+
SELECT
16+
`bfcol_1` AS `bfuid_col_1`,
17+
`bfcol_0` AS `int64_col`,
18+
`bfcol_4` AS `agg_int64`
19+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,14 @@
1717
import pytest
1818

1919
from bigframes.core import agg_expressions as agg_exprs
20-
from bigframes.core import array_value, identifiers, nodes
20+
from bigframes.core import (
21+
array_value,
22+
expression,
23+
identifiers,
24+
nodes,
25+
ordering,
26+
window_spec,
27+
)
2128
from bigframes.operations import aggregations as agg_ops
2229
import bigframes.pandas as bpd
2330

@@ -38,6 +45,24 @@ def _apply_unary_agg_ops(
3845
return sql
3946

4047

48+
def _apply_unary_window_op(
49+
obj: bpd.DataFrame,
50+
op: agg_exprs.UnaryAggregation,
51+
window_spec: window_spec.WindowSpec,
52+
new_name: str,
53+
) -> str:
54+
win_node = nodes.WindowOpNode(
55+
obj._block.expr.node,
56+
expression=op,
57+
window_spec=window_spec,
58+
output_name=identifiers.ColumnId(new_name),
59+
)
60+
result = array_value.ArrayValue(win_node)
61+
62+
sql = result.session._executor.to_sql(result, enable_cache=False)
63+
return sql
64+
65+
4166
def test_count(scalar_types_df: bpd.DataFrame, snapshot):
4267
col_name = "int64_col"
4368
bf_df = scalar_types_df[[col_name]]
@@ -47,6 +72,18 @@ def test_count(scalar_types_df: bpd.DataFrame, snapshot):
4772
snapshot.assert_match(sql, "out.sql")
4873

4974

75+
def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot):
76+
col_name = "int64_col"
77+
bf_df = scalar_types_df[[col_name]]
78+
agg_expr = agg_exprs.UnaryAggregation(
79+
agg_ops.DenseRankOp(), expression.deref("int64_col")
80+
)
81+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
82+
sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64")
83+
84+
snapshot.assert_match(sql, "out.sql")
85+
86+
5087
def test_max(scalar_types_df: bpd.DataFrame, snapshot):
5188
col_name = "int64_col"
5289
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)