Skip to content

Commit 522b388

Browse files
committed
resolve the comments
1 parent 761ec5f commit 522b388

File tree

1 file changed

+15
-44
lines changed

1 file changed

+15
-44
lines changed

bigframes/core/compile/sqlglot/expressions/datetime_ops.py

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from bigframes import dtypes
2020
from bigframes import operations as ops
2121
from bigframes.core.compile.constants import UNIT_TO_US_CONVERSION_FACTORS
22+
from bigframes.core.compile.sqlglot import sqlglot_types
2223
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2324
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2425

@@ -301,18 +302,10 @@ def _calculate_resample_first(y: TypedExpr, origin: str) -> sge.Expression:
301302
elif origin == "start_day":
302303
return sge.func(
303304
"UNIX_MICROS",
304-
sge.Cast(
305-
this=sge.Cast(
306-
this=y.expr, to=sge.DataType(this=sge.DataType.Type.DATE)
307-
),
308-
to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ),
309-
),
305+
sge.Cast(this=sge.Cast(this=y.expr, to="DATE"), to="TIMESTAMP"),
310306
)
311307
elif origin == "start":
312-
return sge.func(
313-
"UNIX_MICROS",
314-
sge.Cast(this=y.expr, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)),
315-
)
308+
return sge.func("UNIX_MICROS", sge.Cast(this=y.expr, to="TIMESTAMP"))
316309
else:
317310
raise ValueError(f"Origin {origin} not supported")
318311

@@ -438,19 +431,6 @@ def _(expr: TypedExpr) -> sge.Expression:
438431
return sge.Extract(this=sge.Identifier(this="YEAR"), expression=expr.expr)
439432

440433

441-
def _dtype_to_sql_string(dtype: dtypes.Dtype) -> str:
442-
if dtype == dtypes.TIMESTAMP_DTYPE:
443-
return "TIMESTAMP"
444-
if dtype == dtypes.DATETIME_DTYPE:
445-
return "DATETIME"
446-
if dtype == dtypes.DATE_DTYPE:
447-
return "DATE"
448-
if dtype == dtypes.TIME_DTYPE:
449-
return "TIME"
450-
# Should not be reached in this operator
451-
raise ValueError(f"Unsupported dtype for datetime conversion: {dtype}")
452-
453-
454434
@register_binary_op(ops.IntegerLabelToDatetimeOp, pass_op=True)
455435
def integer_label_to_datetime_op(
456436
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
@@ -477,17 +457,15 @@ def _integer_label_to_datetime_op_fixed_frequency(
477457
sge.Cast(
478458
this=sge.Add(
479459
this=sge.Mul(
480-
this=sge.Cast(this=x.expr, to=sge.DataType.build("BIGNUMERIC")),
460+
this=sge.Cast(this=x.expr, to="BIGNUMERIC"),
481461
expression=sge.convert(int(us)),
482462
),
483-
expression=sge.Cast(
484-
this=first, to=sge.DataType.build("BIGNUMERIC")
485-
),
463+
expression=sge.Cast(this=first, to="BIGNUMERIC"),
486464
),
487-
to=sge.DataType.build("INT64"),
465+
to="INT64",
488466
),
489467
),
490-
to=_dtype_to_sql_string(y.dtype), # type: ignore
468+
to=sqlglot_types.from_bigframes_dtype(y.dtype),
491469
)
492470
return x_label
493471

@@ -507,10 +485,7 @@ def _integer_label_to_datetime_op_non_fixed_frequency(
507485
"UNIX_MICROS",
508486
sge.Add(
509487
this=sge.TimestampTrunc(
510-
this=sge.Cast(
511-
this=y.expr,
512-
to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ),
513-
),
488+
this=sge.Cast(this=y.expr, to="TIMESTAMP"),
514489
unit=sge.Var(this="WEEK(MONDAY)"),
515490
),
516491
expression=sge.Interval(
@@ -524,19 +499,15 @@ def _integer_label_to_datetime_op_non_fixed_frequency(
524499
sge.Cast(
525500
this=sge.Add(
526501
this=sge.Mul(
527-
this=sge.Cast(
528-
this=x.expr, to=sge.DataType.build("BIGNUMERIC")
529-
),
502+
this=sge.Cast(this=x.expr, to="BIGNUMERIC"),
530503
expression=sge.convert(us),
531504
),
532-
expression=sge.Cast(
533-
this=first, to=sge.DataType.build("BIGNUMERIC")
534-
),
505+
expression=sge.Cast(this=first, to="BIGNUMERIC"),
535506
),
536-
to=sge.DataType.build("INT64"),
507+
to="INT64",
537508
),
538509
),
539-
to=_dtype_to_sql_string(y.dtype), # type: ignore
510+
to=sqlglot_types.from_bigframes_dtype(y.dtype),
540511
)
541512
elif rule_code in ("ME", "M"): # Monthly
542513
one = sge.convert(1)
@@ -556,7 +527,7 @@ def _integer_label_to_datetime_op_non_fixed_frequency(
556527
)
557528
year = sge.Cast(
558529
this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, twelve)),
559-
to=sge.DataType.build("INT64"),
530+
to="INT64",
560531
)
561532
month = sge.Add(this=sge.Mod(this=x_val, expression=twelve), expression=one)
562533
next_year = sge.Case(
@@ -614,7 +585,7 @@ def _integer_label_to_datetime_op_non_fixed_frequency(
614585
)
615586
year = sge.Cast(
616587
this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, four)),
617-
to=sge.DataType.build("INT64"),
588+
to="INT64",
618589
)
619590
month = sge.Mul( # type: ignore
620591
this=sge.Paren(
@@ -675,4 +646,4 @@ def _integer_label_to_datetime_op_non_fixed_frequency(
675646
)
676647
else:
677648
raise ValueError(rule_code)
678-
return sge.Cast(this=x_label, to=_dtype_to_sql_string(y.dtype)) # type: ignore
649+
return sge.Cast(this=x_label, to=sqlglot_types.from_bigframes_dtype(y.dtype))

0 commit comments

Comments
 (0)