Skip to content

Commit 3c22589

Browse files
committed
refactor: add agg_ops.TimeSeriesDiffOp
1 parent e0aa9cc commit 3c22589

File tree

3 files changed

+45
-0
lines changed

3 files changed

+45
-0
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,20 @@ def _(
293293
# Will be null if all inputs are null. Pandas defaults to zero sum though.
294294
zero = pd.to_timedelta(0) if column.dtype == dtypes.TIMEDELTA_DTYPE else 0
295295
return sge.func("IFNULL", expr, ir._literal(zero, column.dtype))
296+
297+
298+
@UNARY_OP_REGISTRATION.register(agg_ops.TimeSeriesDiffOp)
299+
def _(
300+
op: agg_ops.TimeSeriesDiffOp,
301+
column: typed_expr.TypedExpr,
302+
window: typing.Optional[window_spec.WindowSpec] = None,
303+
) -> sge.Expression:
304+
if column.dtype != dtypes.TIMESTAMP_DTYPE:
305+
raise TypeError(f"Cannot perform time series diff on type {column.dtype}")
306+
shift_op_impl = UNARY_OP_REGISTRATION[agg_ops.ShiftOp(0)]
307+
shifted = shift_op_impl(agg_ops.ShiftOp(op.periods), column, window)
308+
return sge.TimestampDiff(
309+
this=column.expr,
310+
expression=shifted,
311+
unit=sge.Identifier(this="MICROSECOND"),
312+
)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`timestamp_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
TIMESTAMP_DIFF(
9+
`bfcol_0`,
10+
LAG(`bfcol_0`, 1) OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST),
11+
MICROSECOND
12+
) AS `bfcol_1`
13+
FROM `bfcte_0`
14+
)
15+
SELECT
16+
`bfcol_1` AS `diff_time`
17+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,3 +331,14 @@ def test_sum(scalar_types_df: bpd.DataFrame, snapshot):
331331
)
332332

333333
snapshot.assert_match(sql, "out.sql")
334+
335+
336+
def test_time_series_diff(scalar_types_df: bpd.DataFrame, snapshot):
337+
col_name = "timestamp_col"
338+
bf_df = scalar_types_df[[col_name]]
339+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
340+
op = agg_exprs.UnaryAggregation(
341+
agg_ops.TimeSeriesDiffOp(periods=1), expression.deref(col_name)
342+
)
343+
sql = _apply_unary_window_op(bf_df, op, window, "diff_time")
344+
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)