1414
1515from __future__ import annotations
1616
17+ import bigframes_vendored .constants as constants
1718import sqlglot .expressions as sge
1819
1920from bigframes import dtypes
@@ -35,8 +36,83 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
3536 # String addition
3637 return sge .Concat (expressions = [left .expr , right .expr ])
3738
38- # Numerical addition
39- return sge .Add (this = left .expr , expression = right .expr )
39+ if dtypes .is_numeric (left .dtype ) and dtypes .is_numeric (right .dtype ):
40+ left_expr = left .expr
41+ if left .dtype == dtypes .BOOL_DTYPE :
42+ left_expr = sge .Cast (this = left_expr , to = "INT64" )
43+ right_expr = right .expr
44+ if right .dtype == dtypes .BOOL_DTYPE :
45+ right_expr = sge .Cast (this = right_expr , to = "INT64" )
46+ return sge .Add (this = left_expr , expression = right_expr )
47+
48+ if (
49+ dtypes .is_time_or_date_like (left .dtype )
50+ and right .dtype == dtypes .TIMEDELTA_DTYPE
51+ ):
52+ left_expr = left .expr
53+ if left .dtype == dtypes .DATE_DTYPE :
54+ left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
55+ return sge .TimestampAdd (
56+ this = left_expr , expression = right .expr , unit = sge .Var (this = "MICROSECOND" )
57+ )
58+ if (
59+ dtypes .is_time_or_date_like (right .dtype )
60+ and left .dtype == dtypes .TIMEDELTA_DTYPE
61+ ):
62+ right_expr = right .expr
63+ if right .dtype == dtypes .DATE_DTYPE :
64+ right_expr = sge .Cast (this = right_expr , to = "DATETIME" )
65+ return sge .TimestampAdd (
66+ this = right_expr , expression = left .expr , unit = sge .Var (this = "MICROSECOND" )
67+ )
68+ if left .dtype == dtypes .TIMEDELTA_DTYPE and right .dtype == dtypes .TIMEDELTA_DTYPE :
69+ return sge .Add (this = left .expr , expression = right .expr )
70+
71+ raise TypeError (
72+ f"Cannot add type { left .dtype } and { right .dtype } . { constants .FEEDBACK_LINK } "
73+ )
74+
75+
76+ @BINARY_OP_REGISTRATION .register (ops .sub_op )
77+ def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
78+ if dtypes .is_numeric (left .dtype ) and dtypes .is_numeric (right .dtype ):
79+ left_expr = left .expr
80+ if left .dtype == dtypes .BOOL_DTYPE :
81+ left_expr = sge .Cast (this = left_expr , to = "INT64" )
82+ right_expr = right .expr
83+ if right .dtype == dtypes .BOOL_DTYPE :
84+ right_expr = sge .Cast (this = right_expr , to = "INT64" )
85+ return sge .Sub (this = left_expr , expression = right_expr )
86+
87+ if (
88+ dtypes .is_time_or_date_like (left .dtype )
89+ and right .dtype == dtypes .TIMEDELTA_DTYPE
90+ ):
91+ left_expr = left .expr
92+ if left .dtype == dtypes .DATE_DTYPE :
93+ left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
94+ return sge .TimestampSub (
95+ this = left_expr , expression = right .expr , unit = sge .Var (this = "MICROSECOND" )
96+ )
97+ if dtypes .is_time_or_date_like (left .dtype ) and dtypes .is_time_or_date_like (
98+ right .dtype
99+ ):
100+ left_expr = left .expr
101+ if left .dtype == dtypes .DATE_DTYPE :
102+ left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
103+ right_expr = right .expr
104+ if right .dtype == dtypes .DATE_DTYPE :
105+ right_expr = sge .Cast (this = right_expr , to = "DATETIME" )
106+ return sge .TimestampDiff (
107+ this = left_expr , expression = right_expr , unit = sge .Var (this = "MICROSECOND" )
108+ )
109+
110+ if left .dtype == dtypes .TIMEDELTA_DTYPE and right .dtype == dtypes .TIMEDELTA_DTYPE :
111+ return sge .Sub (this = left .expr , expression = right .expr )
112+
113+ raise TypeError (
114+ f"Cannot subtract type { left .dtype } and { right .dtype } . { constants .FEEDBACK_LINK } "
115+ )
40116
41117
42118@BINARY_OP_REGISTRATION .register (ops .ge_op )
0 commit comments