Skip to content

Commit 38bcf9a

Browse files
committed
refactor: support agg_ops.ShiftOp for the sqlglot compiler
1 parent 5e1e809 commit 38bcf9a

File tree

5 files changed

+83
-0
lines changed

5 files changed

+83
-0
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,23 @@ def _(
240240
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)
241241

242242

243+
@UNARY_OP_REGISTRATION.register(agg_ops.ShiftOp)
244+
def _(
245+
op: agg_ops.ShiftOp,
246+
column: typed_expr.TypedExpr,
247+
window: typing.Optional[window_spec.WindowSpec] = None,
248+
) -> sge.Expression:
249+
if op.periods == 0: # No-op
250+
return column.expr
251+
if op.periods > 0:
252+
return apply_window_if_present(
253+
sge.func("LAG", column.expr, sge.convert(op.periods)), window
254+
)
255+
return apply_window_if_present(
256+
sge.func("LEAD", column.expr, sge.convert(-op.periods)), window
257+
)
258+
259+
243260
@UNARY_OP_REGISTRATION.register(agg_ops.SumOp)
244261
def _(
245262
op: agg_ops.SumOp,
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
LAG(`bfcol_0`, 1) OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `lag`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
LEAD(`bfcol_0`, 1) OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `lead`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
`bfcol_0` AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `noop`
13+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,33 @@ def test_rank(scalar_types_df: bpd.DataFrame, snapshot):
271271
snapshot.assert_match(sql, "out.sql")
272272

273273

274+
def test_shift(scalar_types_df: bpd.DataFrame, snapshot):
275+
col_name = "int64_col"
276+
bf_df = scalar_types_df[[col_name]]
277+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
278+
279+
# Test lag
280+
lag_op = agg_exprs.UnaryAggregation(
281+
agg_ops.ShiftOp(periods=1), expression.deref(col_name)
282+
)
283+
lag_sql = _apply_unary_window_op(bf_df, lag_op, window, "lag")
284+
snapshot.assert_match(lag_sql, "lag.sql")
285+
286+
# Test lead
287+
lead_op = agg_exprs.UnaryAggregation(
288+
agg_ops.ShiftOp(periods=-1), expression.deref(col_name)
289+
)
290+
lead_sql = _apply_unary_window_op(bf_df, lead_op, window, "lead")
291+
snapshot.assert_match(lead_sql, "lead.sql")
292+
293+
# Test no-op
294+
noop_op = agg_exprs.UnaryAggregation(
295+
agg_ops.ShiftOp(periods=0), expression.deref(col_name)
296+
)
297+
noop_sql = _apply_unary_window_op(bf_df, noop_op, window, "noop")
298+
snapshot.assert_match(noop_sql, "noop.sql")
299+
300+
274301
def test_sum(scalar_types_df: bpd.DataFrame, snapshot):
275302
bf_df = scalar_types_df[["int64_col", "bool_col"]]
276303
agg_ops_map = {

0 commit comments

Comments
 (0)