1919from bigframes import dtypes
2020from bigframes import operations as ops
2121from bigframes .core .compile .constants import UNIT_TO_US_CONVERSION_FACTORS
22+ from bigframes .core .compile .sqlglot import sqlglot_types
2223from bigframes .core .compile .sqlglot .expressions .typed_expr import TypedExpr
2324import bigframes .core .compile .sqlglot .scalar_compiler as scalar_compiler
2425
@@ -301,18 +302,10 @@ def _calculate_resample_first(y: TypedExpr, origin: str) -> sge.Expression:
301302 elif origin == "start_day" :
302303 return sge .func (
303304 "UNIX_MICROS" ,
304- sge .Cast (
305- this = sge .Cast (
306- this = y .expr , to = sge .DataType (this = sge .DataType .Type .DATE )
307- ),
308- to = sge .DataType (this = sge .DataType .Type .TIMESTAMPTZ ),
309- ),
305+ sge .Cast (this = sge .Cast (this = y .expr , to = "DATE" ), to = "TIMESTAMP" ),
310306 )
311307 elif origin == "start" :
312- return sge .func (
313- "UNIX_MICROS" ,
314- sge .Cast (this = y .expr , to = sge .DataType (this = sge .DataType .Type .TIMESTAMPTZ )),
315- )
308+ return sge .func ("UNIX_MICROS" , sge .Cast (this = y .expr , to = "TIMESTAMP" ))
316309 else :
317310 raise ValueError (f"Origin { origin } not supported" )
318311
@@ -438,19 +431,6 @@ def _(expr: TypedExpr) -> sge.Expression:
438431 return sge .Extract (this = sge .Identifier (this = "YEAR" ), expression = expr .expr )
439432
440433
441- def _dtype_to_sql_string (dtype : dtypes .Dtype ) -> str :
442- if dtype == dtypes .TIMESTAMP_DTYPE :
443- return "TIMESTAMP"
444- if dtype == dtypes .DATETIME_DTYPE :
445- return "DATETIME"
446- if dtype == dtypes .DATE_DTYPE :
447- return "DATE"
448- if dtype == dtypes .TIME_DTYPE :
449- return "TIME"
450- # Should not be reached in this operator
451- raise ValueError (f"Unsupported dtype for datetime conversion: { dtype } " )
452-
453-
454434@register_binary_op (ops .IntegerLabelToDatetimeOp , pass_op = True )
455435def integer_label_to_datetime_op (
456436 x : TypedExpr , y : TypedExpr , op : ops .IntegerLabelToDatetimeOp
@@ -477,17 +457,15 @@ def _integer_label_to_datetime_op_fixed_frequency(
477457 sge .Cast (
478458 this = sge .Add (
479459 this = sge .Mul (
480- this = sge .Cast (this = x .expr , to = sge . DataType . build ( "BIGNUMERIC" ) ),
460+ this = sge .Cast (this = x .expr , to = "BIGNUMERIC" ),
481461 expression = sge .convert (int (us )),
482462 ),
483- expression = sge .Cast (
484- this = first , to = sge .DataType .build ("BIGNUMERIC" )
485- ),
463+ expression = sge .Cast (this = first , to = "BIGNUMERIC" ),
486464 ),
487- to = sge . DataType . build ( "INT64" ) ,
465+ to = "INT64" ,
488466 ),
489467 ),
490- to = _dtype_to_sql_string (y .dtype ), # type: ignore
468+ to = sqlglot_types . from_bigframes_dtype (y .dtype ),
491469 )
492470 return x_label
493471
@@ -507,10 +485,7 @@ def _integer_label_to_datetime_op_non_fixed_frequency(
507485 "UNIX_MICROS" ,
508486 sge .Add (
509487 this = sge .TimestampTrunc (
510- this = sge .Cast (
511- this = y .expr ,
512- to = sge .DataType (this = sge .DataType .Type .TIMESTAMPTZ ),
513- ),
488+ this = sge .Cast (this = y .expr , to = "TIMESTAMP" ),
514489 unit = sge .Var (this = "WEEK(MONDAY)" ),
515490 ),
516491 expression = sge .Interval (
@@ -524,19 +499,15 @@ def _integer_label_to_datetime_op_non_fixed_frequency(
524499 sge .Cast (
525500 this = sge .Add (
526501 this = sge .Mul (
527- this = sge .Cast (
528- this = x .expr , to = sge .DataType .build ("BIGNUMERIC" )
529- ),
502+ this = sge .Cast (this = x .expr , to = "BIGNUMERIC" ),
530503 expression = sge .convert (us ),
531504 ),
532- expression = sge .Cast (
533- this = first , to = sge .DataType .build ("BIGNUMERIC" )
534- ),
505+ expression = sge .Cast (this = first , to = "BIGNUMERIC" ),
535506 ),
536- to = sge . DataType . build ( "INT64" ) ,
507+ to = "INT64" ,
537508 ),
538509 ),
539- to = _dtype_to_sql_string (y .dtype ), # type: ignore
510+ to = sqlglot_types . from_bigframes_dtype (y .dtype ),
540511 )
541512 elif rule_code in ("ME" , "M" ): # Monthly
542513 one = sge .convert (1 )
@@ -556,7 +527,7 @@ def _integer_label_to_datetime_op_non_fixed_frequency(
556527 )
557528 year = sge .Cast (
558529 this = sge .Floor (this = sge .func ("IEEE_DIVIDE" , x_val , twelve )),
559- to = sge . DataType . build ( "INT64" ) ,
530+ to = "INT64" ,
560531 )
561532 month = sge .Add (this = sge .Mod (this = x_val , expression = twelve ), expression = one )
562533 next_year = sge .Case (
@@ -614,7 +585,7 @@ def _integer_label_to_datetime_op_non_fixed_frequency(
614585 )
615586 year = sge .Cast (
616587 this = sge .Floor (this = sge .func ("IEEE_DIVIDE" , x_val , four )),
617- to = sge . DataType . build ( "INT64" ) ,
588+ to = "INT64" ,
618589 )
619590 month = sge .Mul ( # type: ignore
620591 this = sge .Paren (
@@ -675,4 +646,4 @@ def _integer_label_to_datetime_op_non_fixed_frequency(
675646 )
676647 else :
677648 raise ValueError (rule_code )
678- return sge .Cast (this = x_label , to = _dtype_to_sql_string (y .dtype )) # type: ignore
649+ return sge .Cast (this = x_label , to = sqlglot_types . from_bigframes_dtype (y .dtype ))
0 commit comments