Skip to content

Commit 3e9156b

Browse files
committed
refactor: add agg_ops.DiffOp
1 parent 308fa1c commit 3e9156b

File tree

4 files changed

+65
-0
lines changed

4 files changed

+65
-0
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,23 @@ def _(
107107
return apply_window_if_present(sge.func("DENSE_RANK"), window)
108108

109109

110+
@UNARY_OP_REGISTRATION.register(agg_ops.DiffOp)
111+
def _(
112+
op: agg_ops.DiffOp,
113+
column: typed_expr.TypedExpr,
114+
window: typing.Optional[window_spec.WindowSpec] = None,
115+
) -> sge.Expression:
116+
shift_op_impl = UNARY_OP_REGISTRATION[agg_ops.ShiftOp(0)]
117+
shifted = shift_op_impl(agg_ops.ShiftOp(op.periods), column, window)
118+
if column.dtype in (dtypes.BOOL_DTYPE, dtypes.INT_DTYPE, dtypes.FLOAT_DTYPE):
119+
if column.dtype == dtypes.BOOL_DTYPE:
120+
return sge.NEQ(this=column.expr, expression=shifted)
121+
else:
122+
return sge.Sub(this=column.expr, expression=shifted)
123+
else:
124+
raise TypeError(f"Cannot perform diff on type {column.dtype}")
125+
126+
110127
@UNARY_OP_REGISTRATION.register(agg_ops.MaxOp)
111128
def _(
112129
op: agg_ops.MaxOp,
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
`bfcol_0` <> LAG(`bfcol_0`, 1) OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `diff_bool`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
`bfcol_0` - LAG(`bfcol_0`, 1) OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `diff_int`
13+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,28 @@ def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot):
126126
snapshot.assert_match(sql, "out.sql")
127127

128128

129+
def test_diff(scalar_types_df: bpd.DataFrame, snapshot):
130+
# Test integer
131+
int_col = "int64_col"
132+
bf_df_int = scalar_types_df[[int_col]]
133+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(int_col),))
134+
int_op = agg_exprs.UnaryAggregation(
135+
agg_ops.DiffOp(periods=1), expression.deref(int_col)
136+
)
137+
int_sql = _apply_unary_window_op(bf_df_int, int_op, window, "diff_int")
138+
snapshot.assert_match(int_sql, "diff_int.sql")
139+
140+
# Test boolean
141+
bool_col = "bool_col"
142+
bf_df_bool = scalar_types_df[[bool_col]]
143+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(bool_col),))
144+
bool_op = agg_exprs.UnaryAggregation(
145+
agg_ops.DiffOp(periods=1), expression.deref(bool_col)
146+
)
147+
bool_sql = _apply_unary_window_op(bf_df_bool, bool_op, window, "diff_bool")
148+
snapshot.assert_match(bool_sql, "diff_bool.sql")
149+
150+
129151
def test_max(scalar_types_df: bpd.DataFrame, snapshot):
130152
col_name = "int64_col"
131153
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)