Skip to content

Commit 8a7b85d

Browse files
committed
refactor: add agg_ops.DateSeriesDiffOp
1 parent 3c22589 commit 8a7b85d

File tree

3 files changed

+53
-4
lines changed

3 files changed

+53
-4
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,27 @@ def _(
9898
return apply_window_if_present(sge.func("COUNT", column.expr), window)
9999

100100

101+
@UNARY_OP_REGISTRATION.register(agg_ops.DateSeriesDiffOp)
102+
def _(
103+
op: agg_ops.DateSeriesDiffOp,
104+
column: typed_expr.TypedExpr,
105+
window: typing.Optional[window_spec.WindowSpec] = None,
106+
) -> sge.Expression:
107+
if column.dtype != dtypes.DATE_DTYPE:
108+
raise TypeError(f"Cannot perform date series diff on type {column.dtype}")
109+
shift_op_impl = UNARY_OP_REGISTRATION[agg_ops.ShiftOp(0)]
110+
shifted = shift_op_impl(agg_ops.ShiftOp(op.periods), column, window)
111+
# Conversion factor from days to microseconds
112+
conversion_factor = 24 * 60 * 60 * 1_000_000
113+
return sge.Cast(
114+
this=sge.DateDiff(
115+
this=column.expr, expression=shifted, unit=sge.Identifier(this="DAY")
116+
)
117+
* sge.convert(conversion_factor),
118+
to="INT64",
119+
)
120+
121+
101122
@UNARY_OP_REGISTRATION.register(agg_ops.DenseRankOp)
102123
def _(
103124
op: agg_ops.DenseRankOp,
@@ -306,7 +327,7 @@ def _(
306327
shift_op_impl = UNARY_OP_REGISTRATION[agg_ops.ShiftOp(0)]
307328
shifted = shift_op_impl(agg_ops.ShiftOp(op.periods), column, window)
308329
return sge.TimestampDiff(
309-
this=column.expr,
310-
expression=shifted,
311-
unit=sge.Identifier(this="MICROSECOND"),
312-
)
330+
this=column.expr,
331+
expression=shifted,
332+
unit=sge.Identifier(this="MICROSECOND"),
333+
)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`date_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CAST(DATE_DIFF(
9+
`bfcol_0`,
10+
LAG(`bfcol_0`, 1) OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST),
11+
day
12+
) * 86400000000 AS INT64) AS `bfcol_1`
13+
FROM `bfcte_0`
14+
)
15+
SELECT
16+
`bfcol_1` AS `diff_date`
17+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,17 @@ def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot):
127127
snapshot.assert_match(sql, "out.sql")
128128

129129

130+
def test_date_series_diff(scalar_types_df: bpd.DataFrame, snapshot):
131+
col_name = "date_col"
132+
bf_df = scalar_types_df[[col_name]]
133+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
134+
op = agg_exprs.UnaryAggregation(
135+
agg_ops.DateSeriesDiffOp(periods=1), expression.deref(col_name)
136+
)
137+
sql = _apply_unary_window_op(bf_df, op, window, "diff_date")
138+
snapshot.assert_match(sql, "out.sql")
139+
140+
130141
def test_diff(scalar_types_df: bpd.DataFrame, snapshot):
131142
# Test integer
132143
int_col = "int64_col"

0 commit comments

Comments
 (0)