@@ -81,6 +81,51 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
8181 )
8282
8383
84+ @BINARY_OP_REGISTRATION .register (ops .sub_op )
85+ def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
86+ if dtypes .is_numeric (left .dtype ) and dtypes .is_numeric (right .dtype ):
87+ left_expr = left .expr
88+ if left .dtype == dtypes .BOOL_DTYPE :
89+ left_expr = sge .Cast (this = left_expr , to = "INT64" )
90+ right_expr = right .expr
91+ if right .dtype == dtypes .BOOL_DTYPE :
92+ right_expr = sge .Cast (this = right_expr , to = "INT64" )
93+ return sge .Sub (this = left_expr , expression = right_expr )
94+
95+ if (
96+ dtypes .is_datetime_like (left .dtype ) or dtypes .is_date_like (left .dtype )
97+ ) and right .dtype == dtypes .TIMEDELTA_DTYPE :
98+ left_expr = left .expr
99+ if left .dtype == dtypes .DATE_DTYPE :
100+ left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
101+ return sge .TimestampSub (
102+ this = left_expr , expression = right .expr , unit = sge .Var (this = "MICROSECOND" )
103+ )
104+ if (dtypes .is_datetime_like (left .dtype ) or dtypes .is_date_like (left .dtype )) and (
105+ dtypes .is_datetime_like (right .dtype ) or dtypes .is_date_like (right .dtype )
106+ ):
107+ left_expr = left .expr
108+ if left .dtype == dtypes .DATE_DTYPE :
109+ left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
110+ right_expr = right .expr
111+ if right .dtype == dtypes .DATE_DTYPE :
112+ right_expr = sge .Cast (this = right_expr , to = "DATETIME" )
113+ return sge .TimestampDiff (
114+ this = left_expr , expression = right_expr , unit = sge .Var (this = "MICROSECOND" )
115+ )
116+
117+ if left .dtype == dtypes .DATE_DTYPE and right .dtype == dtypes .INT_DTYPE :
118+ return sge .DateSub (
119+ this = left .expr , expression = right .expr , unit = sge .Var (this = "DAY" )
120+ )
121+ if left .dtype == dtypes .TIMEDELTA_DTYPE and right .dtype == dtypes .TIMEDELTA_DTYPE :
122+ return sge .Sub (this = left .expr , expression = right .expr )
123+
124+ raise TypeError (
125+ f"Cannot subtract type { left .dtype } and { right .dtype } . { constants .FEEDBACK_LINK } "
126+ )
127+
128+
84129@BINARY_OP_REGISTRATION .register (ops .ge_op )
85130def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
86131 return sge .GTE (this = left .expr , expression = right .expr )
0 commit comments