Skip to content

Commit af06fdd

Browse files
committed
chore: Migrate IntegerLabelToDatetimeOp operator to SQLGlot
1 parent 719b278 commit af06fdd

File tree

3 files changed

+339
-0
lines changed

3 files changed

+339
-0
lines changed

bigframes/core/compile/sqlglot/expressions/datetime_ops.py

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2424

2525
register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
26+
register_binary_op = scalar_compiler.scalar_op_compiler.register_binary_op
2627

2728

2829
@register_unary_op(ops.FloorDtOp, pass_op=True)
@@ -51,6 +52,28 @@ def _(expr: TypedExpr, op: ops.FloorDtOp) -> sge.Expression:
5152
return sge.TimestampTrunc(this=expr.expr, unit=sge.Identifier(this=bq_freq))
5253

5354

55+
def _calculate_resample_first(y: TypedExpr, origin: str) -> sge.Expression:
56+
if origin == "epoch":
57+
return sge.convert(0)
58+
elif origin == "start_day":
59+
return sge.func(
60+
"UNIX_MICROS",
61+
sge.Cast(
62+
this=sge.Cast(
63+
this=y.expr, to=sge.DataType(this=sge.DataType.Type.DATE)
64+
),
65+
to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ),
66+
),
67+
)
68+
elif origin == "start":
69+
return sge.func(
70+
"UNIX_MICROS",
71+
sge.Cast(this=y.expr, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)),
72+
)
73+
else:
74+
raise ValueError(f"Origin {origin} not supported")
75+
76+
5477
@register_unary_op(ops.hour_op)
5578
def _(expr: TypedExpr) -> sge.Expression:
5679
return sge.Extract(this=sge.Identifier(this="HOUR"), expression=expr.expr)
@@ -170,3 +193,243 @@ def _(expr: TypedExpr, op: ops.UnixSeconds) -> sge.Expression:
170193
@register_unary_op(ops.year_op)
171194
def _(expr: TypedExpr) -> sge.Expression:
172195
return sge.Extract(this=sge.Identifier(this="YEAR"), expression=expr.expr)
196+
197+
198+
def _dtype_to_sql_string(dtype: dtypes.Dtype) -> str:
199+
if dtype == dtypes.TIMESTAMP_DTYPE:
200+
return "TIMESTAMP"
201+
if dtype == dtypes.DATETIME_DTYPE:
202+
return "DATETIME"
203+
if dtype == dtypes.DATE_DTYPE:
204+
return "DATE"
205+
if dtype == dtypes.TIME_DTYPE:
206+
return "TIME"
207+
# Should not be reached in this operator
208+
raise ValueError(f"Unsupported dtype for datetime conversion: {dtype}")
209+
210+
211+
@register_binary_op(ops.IntegerLabelToDatetimeOp, pass_op=True)
212+
def integer_label_to_datetime_op(
213+
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
214+
) -> sge.Expression:
215+
# Determine if the frequency is fixed by checking if 'op.freq.nanos' is defined.
216+
try:
217+
return _integer_label_to_datetime_op_fixed_frequency(x, y, op)
218+
except ValueError:
219+
return _integer_label_to_datetime_op_non_fixed_frequency(x, y, op)
220+
221+
222+
def _integer_label_to_datetime_op_fixed_frequency(
223+
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
224+
) -> sge.Expression:
225+
"""
226+
This function handles fixed frequency conversions where the unit can range
227+
from microseconds (us) to days.
228+
"""
229+
us = op.freq.nanos / 1000
230+
first = _calculate_resample_first(y, op.origin) # type: ignore
231+
x_label = sge.Cast(
232+
this=sge.func(
233+
"TIMESTAMP_MICROS",
234+
sge.Cast(
235+
this=sge.Add(
236+
this=sge.Mul(
237+
this=sge.Cast(this=x.expr, to=sge.DataType.build("BIGNUMERIC")),
238+
expression=sge.convert(int(us)),
239+
),
240+
expression=sge.Cast(
241+
this=first, to=sge.DataType.build("BIGNUMERIC")
242+
),
243+
),
244+
to=sge.DataType.build("INT64"),
245+
),
246+
),
247+
to=_dtype_to_sql_string(y.dtype), # type: ignore
248+
)
249+
return x_label
250+
251+
252+
def _integer_label_to_datetime_op_non_fixed_frequency(
253+
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
254+
) -> sge.Expression:
255+
"""
256+
This function handles non-fixed frequency conversions for units ranging
257+
from weeks to years.
258+
"""
259+
rule_code = op.freq.rule_code
260+
n = op.freq.n
261+
if rule_code == "W-SUN": # Weekly
262+
us = n * 7 * 24 * 60 * 60 * 1000000
263+
first = sge.func(
264+
"UNIX_MICROS",
265+
sge.Add(
266+
this=sge.TimestampTrunc(
267+
this=sge.Cast(
268+
this=y.expr,
269+
to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ),
270+
),
271+
unit=sge.Var(this="WEEK(MONDAY)"),
272+
),
273+
expression=sge.Interval(
274+
this=sge.convert(6), unit=sge.Identifier(this="DAY")
275+
),
276+
),
277+
)
278+
x_label = sge.Cast(
279+
this=sge.func(
280+
"TIMESTAMP_MICROS",
281+
sge.Cast(
282+
this=sge.Add(
283+
this=sge.Mul(
284+
this=sge.Cast(
285+
this=x.expr, to=sge.DataType.build("BIGNUMERIC")
286+
),
287+
expression=sge.convert(us),
288+
),
289+
expression=sge.Cast(
290+
this=first, to=sge.DataType.build("BIGNUMERIC")
291+
),
292+
),
293+
to=sge.DataType.build("INT64"),
294+
),
295+
),
296+
to=_dtype_to_sql_string(y.dtype), # type: ignore
297+
)
298+
elif rule_code == "ME": # Monthly
299+
one = sge.convert(1)
300+
twelve = sge.convert(12)
301+
first = sge.Sub( # type: ignore
302+
this=sge.Add(
303+
this=sge.Mul(
304+
this=sge.Extract(this="YEAR", expression=y.expr),
305+
expression=twelve,
306+
),
307+
expression=sge.Extract(this="MONTH", expression=y.expr),
308+
),
309+
expression=one,
310+
)
311+
x_val = sge.Add(
312+
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
313+
)
314+
year = sge.Cast(
315+
this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, twelve)),
316+
to=sge.DataType.build("INT64"),
317+
)
318+
month = sge.Add(this=sge.Mod(this=x_val, expression=twelve), expression=one)
319+
next_year = sge.Case(
320+
ifs=[
321+
sge.If(
322+
this=sge.EQ(this=month, expression=twelve),
323+
true=sge.Add(this=year, expression=one),
324+
)
325+
],
326+
default=year,
327+
)
328+
next_month = sge.Case(
329+
ifs=[
330+
sge.If(
331+
this=sge.EQ(this=month, expression=twelve),
332+
true=one,
333+
)
334+
],
335+
default=sge.Add(this=month, expression=one),
336+
)
337+
next_month_date = sge.func(
338+
"TIMESTAMP",
339+
sge.Anonymous(
340+
this="DATETIME",
341+
expressions=[
342+
next_year,
343+
next_month,
344+
one,
345+
sge.convert(0),
346+
sge.convert(0),
347+
sge.convert(0),
348+
],
349+
),
350+
)
351+
x_label = sge.Sub( # type: ignore
352+
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
353+
)
354+
elif rule_code == "QE-DEC": # Quarterly
355+
one = sge.convert(1)
356+
three = sge.convert(3)
357+
four = sge.convert(4)
358+
twelve = sge.convert(12)
359+
first = sge.Sub( # type: ignore
360+
this=sge.Add(
361+
this=sge.Mul(
362+
this=sge.Extract(this="YEAR", expression=y.expr),
363+
expression=four,
364+
),
365+
expression=sge.Extract(this="QUARTER", expression=y.expr),
366+
),
367+
expression=one,
368+
)
369+
x_val = sge.Add(
370+
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
371+
)
372+
year = sge.Cast(
373+
this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, four)),
374+
to=sge.DataType.build("INT64"),
375+
)
376+
month = sge.Mul( # type: ignore
377+
this=sge.Paren(
378+
this=sge.Add(this=sge.Mod(this=x_val, expression=four), expression=one)
379+
),
380+
expression=three,
381+
)
382+
next_year = sge.Case(
383+
ifs=[
384+
sge.If(
385+
this=sge.EQ(this=month, expression=twelve),
386+
true=sge.Add(this=year, expression=one),
387+
)
388+
],
389+
default=year,
390+
)
391+
next_month = sge.Case(
392+
ifs=[sge.If(this=sge.EQ(this=month, expression=twelve), true=one)],
393+
default=sge.Add(this=month, expression=one),
394+
)
395+
next_month_date = sge.Anonymous(
396+
this="DATETIME",
397+
expressions=[
398+
next_year,
399+
next_month,
400+
one,
401+
sge.convert(0),
402+
sge.convert(0),
403+
sge.convert(0),
404+
],
405+
)
406+
x_label = sge.Sub( # type: ignore
407+
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
408+
)
409+
elif rule_code == "YE-DEC": # Yearly
410+
one = sge.convert(1)
411+
first = sge.Extract(this="YEAR", expression=y.expr)
412+
x_val = sge.Add(
413+
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
414+
)
415+
next_year = sge.Add(this=x_val, expression=one) # type: ignore
416+
next_month_date = sge.func(
417+
"TIMESTAMP",
418+
sge.Anonymous(
419+
this="DATETIME",
420+
expressions=[
421+
next_year,
422+
one,
423+
one,
424+
sge.convert(0),
425+
sge.convert(0),
426+
sge.convert(0),
427+
],
428+
),
429+
)
430+
x_label = sge.Sub( # type: ignore
431+
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
432+
)
433+
else:
434+
raise ValueError(rule_code)
435+
return sge.Cast(this=x_label, to=_dtype_to_sql_string(y.dtype)) # type: ignore
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`rowindex`,
4+
`timestamp_col`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
CAST(TIMESTAMP_MICROS(
10+
CAST(CAST(`rowindex` AS BIGNUMERIC) * 86400000000 + CAST(UNIX_MICROS(CAST(`timestamp_col` AS TIMESTAMP)) AS BIGNUMERIC) AS INT64)
11+
) AS TIMESTAMP) AS `bfcol_2`,
12+
CAST(DATETIME(
13+
CASE
14+
WHEN (
15+
MOD(
16+
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
17+
4
18+
) + 1
19+
) * 3 = 12
20+
THEN CAST(FLOOR(
21+
IEEE_DIVIDE(
22+
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
23+
4
24+
)
25+
) AS INT64) + 1
26+
ELSE CAST(FLOOR(
27+
IEEE_DIVIDE(
28+
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
29+
4
30+
)
31+
) AS INT64)
32+
END,
33+
CASE
34+
WHEN (
35+
MOD(
36+
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
37+
4
38+
) + 1
39+
) * 3 = 12
40+
THEN 1
41+
ELSE (
42+
MOD(
43+
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
44+
4
45+
) + 1
46+
) * 3 + 1
47+
END,
48+
1,
49+
0,
50+
0,
51+
0
52+
) - INTERVAL 1 DAY AS TIMESTAMP) AS `bfcol_3`
53+
FROM `bfcte_0`
54+
)
55+
SELECT
56+
`bfcol_2` AS `fixed_freq`,
57+
`bfcol_3` AS `non_fixed_freq`
58+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,21 @@ def test_sub_timedelta(scalar_types_df: bpd.DataFrame, snapshot):
277277
bf_df["timedelta_sub_timedelta"] = bf_df["duration_col"] - bf_df["duration_col"]
278278

279279
snapshot.assert_match(bf_df.sql, "out.sql")
280+
281+
282+
def test_integer_label_to_datetime(scalar_types_df: bpd.DataFrame, snapshot):
283+
col_names = ["rowindex", "timestamp_col"]
284+
bf_df = scalar_types_df[col_names]
285+
ops_map = {
286+
"fixed_freq": ops.IntegerLabelToDatetimeOp(
287+
freq=pd.tseries.offsets.Day(), origin="start", label="left" # type: ignore
288+
).as_expr("rowindex", "timestamp_col"),
289+
"non_fixed_freq": ops.IntegerLabelToDatetimeOp(
290+
freq=pd.tseries.offsets.QuarterEnd(startingMonth=12), # type: ignore
291+
origin="start",
292+
label="left",
293+
).as_expr("rowindex", "timestamp_col"),
294+
}
295+
296+
sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys()))
297+
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)