Skip to content

Commit 4cab482

Browse files
committed
refactor: support agg_ops.LastOp, LastNonNullOp in the sqlglot compiler
1 parent bab9576 commit 4cab482

File tree

4 files changed

+85
-2
lines changed

4 files changed

+85
-2
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def _(
104104
column: typed_expr.TypedExpr,
105105
window: typing.Optional[window_spec.WindowSpec] = None,
106106
) -> sge.Expression:
107-
return apply_window_if_present(sge.func("DENSE_RANK"), window, include_framing_clauses=False)
107+
return apply_window_if_present(
108+
sge.func("DENSE_RANK"), window, include_framing_clauses=False
109+
)
108110

109111

110112
@UNARY_OP_REGISTRATION.register(agg_ops.FirstOp)
@@ -128,6 +130,27 @@ def _(
128130
)
129131

130132

133+
@UNARY_OP_REGISTRATION.register(agg_ops.LastOp)
134+
def _(
135+
op: agg_ops.LastOp,
136+
column: typed_expr.TypedExpr,
137+
window: typing.Optional[window_spec.WindowSpec] = None,
138+
) -> sge.Expression:
139+
# LAST_VALUE in BQ respects nulls by default.
140+
return apply_window_if_present(sge.LastValue(this=column.expr), window)
141+
142+
143+
@UNARY_OP_REGISTRATION.register(agg_ops.LastNonNullOp)
144+
def _(
145+
op: agg_ops.LastNonNullOp,
146+
column: typed_expr.TypedExpr,
147+
window: typing.Optional[window_spec.WindowSpec] = None,
148+
) -> sge.Expression:
149+
return apply_window_if_present(
150+
sge.IgnoreNulls(this=sge.LastValue(this=column.expr)), window
151+
)
152+
153+
131154
@UNARY_OP_REGISTRATION.register(agg_ops.MaxOp)
132155
def _(
133156
op: agg_ops.MaxOp,
@@ -203,7 +226,9 @@ def _(
203226
column: typed_expr.TypedExpr,
204227
window: typing.Optional[window_spec.WindowSpec] = None,
205228
) -> sge.Expression:
206-
return apply_window_if_present(sge.func("RANK"), window, include_framing_clauses=False)
229+
return apply_window_if_present(
230+
sge.func("RANK"), window, include_framing_clauses=False
231+
)
207232

208233

209234
@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE
9+
WHEN `bfcol_0` IS NULL
10+
THEN NULL
11+
ELSE LAST_VALUE(`bfcol_0`) OVER (
12+
ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST
13+
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
14+
)
15+
END AS `bfcol_1`
16+
FROM `bfcte_0`
17+
)
18+
SELECT
19+
`bfcol_1` AS `agg_int64`
20+
FROM `bfcte_1`
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
LAST_VALUE(`bfcol_0` IGNORE NULLS) OVER (
9+
ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST
10+
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
11+
) AS `bfcol_1`
12+
FROM `bfcte_0`
13+
)
14+
SELECT
15+
`bfcol_1` AS `agg_int64`
16+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,28 @@ def test_first_non_null(scalar_types_df: bpd.DataFrame, snapshot):
148148
snapshot.assert_match(sql, "out.sql")
149149

150150

151+
def test_last(scalar_types_df: bpd.DataFrame, snapshot):
152+
col_name = "int64_col"
153+
bf_df = scalar_types_df[[col_name]]
154+
agg_expr = agg_exprs.UnaryAggregation(agg_ops.LastOp(), expression.deref(col_name))
155+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
156+
sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64")
157+
158+
snapshot.assert_match(sql, "out.sql")
159+
160+
161+
def test_last_non_null(scalar_types_df: bpd.DataFrame, snapshot):
162+
col_name = "int64_col"
163+
bf_df = scalar_types_df[[col_name]]
164+
agg_expr = agg_exprs.UnaryAggregation(
165+
agg_ops.LastNonNullOp(), expression.deref(col_name)
166+
)
167+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
168+
sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64")
169+
170+
snapshot.assert_match(sql, "out.sql")
171+
172+
151173
def test_max(scalar_types_df: bpd.DataFrame, snapshot):
152174
col_name = "int64_col"
153175
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)