Skip to content

Commit 17b8c86

Browse files
committed
refactor: fix test_date_series_diff_agg on agg_ops.DiffOp
1 parent a634e97 commit 17b8c86

File tree

4 files changed

+37
-0
lines changed

4 files changed

+37
-0
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from bigframes.core import window_spec
2424
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
2525
from bigframes.core.compile.sqlglot.aggregations.windows import apply_window_if_present
26+
from bigframes.core.compile.sqlglot.expressions import constants
2627
import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr
2728
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
2829
from bigframes.operations import aggregations as agg_ops
@@ -326,6 +327,15 @@ def _(
326327
unit=sge.Identifier(this="MICROSECOND"),
327328
)
328329

330+
if column.dtype == dtypes.DATE_DTYPE:
331+
date_diff = sge.DateDiff(
332+
this=column.expr, expression=shifted, unit=sge.Identifier(this="DAY")
333+
)
334+
return sge.Cast(
335+
this=sge.Floor(this=date_diff * constants._DAY_TO_MICROSECONDS),
336+
to="INT64",
337+
)
338+
329339
raise TypeError(f"Cannot perform diff on type {column.dtype}")
330340

331341

bigframes/core/compile/sqlglot/expressions/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
_NAN = sge.Cast(this=sge.convert("NaN"), to="FLOAT64")
2121
_INF = sge.Cast(this=sge.convert("Infinity"), to="FLOAT64")
2222
_NEG_INF = sge.Cast(this=sge.convert("-Infinity"), to="FLOAT64")
23+
_DAY_TO_MICROSECONDS = sge.convert(86400000000)
2324

2425
# Approx Highest number you can pass in to EXP function and get a valid FLOAT64 result
2526
# FLOAT64 has 11 exponent bits, so max values is about 2**(2**10)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`date_col`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CAST(FLOOR(
9+
DATE_DIFF(`date_col`, LAG(`date_col`, 1) OVER (ORDER BY `date_col` ASC NULLS LAST), DAY) * 86400000000
10+
) AS INT64) AS `bfcol_1`
11+
FROM `bfcte_0`
12+
)
13+
SELECT
14+
`bfcol_1` AS `diff_date`
15+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,17 @@ def test_diff_w_datetime(scalar_types_df: bpd.DataFrame, snapshot):
247247
snapshot.assert_match(sql, "out.sql")
248248

249249

250+
def test_diff_w_date(scalar_types_df: bpd.DataFrame, snapshot):
251+
col_name = "date_col"
252+
bf_df_date = scalar_types_df[[col_name]]
253+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
254+
op = agg_exprs.UnaryAggregation(
255+
agg_ops.DiffOp(periods=1), expression.deref(col_name)
256+
)
257+
sql = _apply_unary_window_op(bf_df_date, op, window, "diff_date")
258+
snapshot.assert_match(sql, "out.sql")
259+
260+
250261
def test_diff_w_timestamp(scalar_types_df: bpd.DataFrame, snapshot):
251262
col_name = "timestamp_col"
252263
bf_df_timestamp = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)