@@ -38,26 +38,23 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
3838 return sge .Concat (expressions = [left .expr , right .expr ])
3939
4040 if dtypes .is_numeric (left .dtype ) and dtypes .is_numeric (right .dtype ):
41- left_expr , right_expr = _coerce_bools (left , right )
41+ left_expr = left .coerce_bool_to_int ()
42+ right_expr = right .coerce_bool_to_int ()
4243 return sge .Add (this = left_expr , expression = right_expr )
4344
4445 if (
4546 dtypes .is_time_or_date_like (left .dtype )
4647 and right .dtype == dtypes .TIMEDELTA_DTYPE
4748 ):
48- left_expr = left .expr
49- if left .dtype == dtypes .DATE_DTYPE :
50- left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
49+ left_expr = left .coerce_date_to_datetime ()
5150 return sge .TimestampAdd (
5251 this = left_expr , expression = right .expr , unit = sge .Var (this = "MICROSECOND" )
5352 )
5453 if (
5554 dtypes .is_time_or_date_like (right .dtype )
5655 and left .dtype == dtypes .TIMEDELTA_DTYPE
5756 ):
58- right_expr = right .expr
59- if right .dtype == dtypes .DATE_DTYPE :
60- right_expr = sge .Cast (this = right_expr , to = "DATETIME" )
57+ right_expr = right .coerce_date_to_datetime ()
6158 return sge .TimestampAdd (
6259 this = right_expr , expression = left .expr , unit = sge .Var (this = "MICROSECOND" )
6360 )
@@ -71,19 +68,20 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
7168
7269@BINARY_OP_REGISTRATION .register (ops .eq_op )
7370def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
74- left_expr , right_expr = _coerce_bools (left , right )
71+ left_expr = left .coerce_bool_to_int ()
72+ right_expr = right .coerce_bool_to_int ()
7573 return sge .EQ (this = left_expr , expression = right_expr )
7674
7775
7876@BINARY_OP_REGISTRATION .register (ops .eq_null_match_op )
7977def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
8078 left_expr = left .expr
81- if left . dtype == dtypes . BOOL_DTYPE and right .dtype != dtypes .BOOL_DTYPE :
82- left_expr = sge . Cast ( this = left_expr , to = "INT64" )
79+ if right .dtype != dtypes .BOOL_DTYPE :
80+ left_expr = left . coerce_bool_to_int ( )
8381
8482 right_expr = right .expr
85- if right . dtype == dtypes . BOOL_DTYPE and left .dtype != dtypes .BOOL_DTYPE :
86- right_expr = sge . Cast ( this = right_expr , to = "INT64" )
83+ if left .dtype != dtypes .BOOL_DTYPE :
84+ right_expr = right . coerce_bool_to_int ( )
8785
8886 sentinel = sge .convert ("$NULL_SENTINEL$" )
8987 left_coalesce = sge .Coalesce (
@@ -97,7 +95,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
9795
9896@BINARY_OP_REGISTRATION .register (ops .div_op )
9997def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
100- left_expr , right_expr = _coerce_bools (left , right )
98+ left_expr = left .coerce_bool_to_int ()
99+ right_expr = right .coerce_bool_to_int ()
101100
102101 result = sge .func ("IEEE_DIVIDE" , left_expr , right_expr )
103102 if left .dtype == dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right .dtype ):
@@ -108,12 +107,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
108107
109108@BINARY_OP_REGISTRATION .register (ops .floordiv_op )
110109def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
111- left_expr = left .expr
112- if left .dtype == dtypes .BOOL_DTYPE :
113- left_expr = sge .Cast (this = left_expr , to = "INT64" )
114- right_expr = right .expr
115- if right .dtype == dtypes .BOOL_DTYPE :
116- right_expr = sge .Cast (this = right_expr , to = "INT64" )
110+ left_expr = left .coerce_bool_to_int ()
111+ right_expr = right .coerce_bool_to_int ()
117112
118113 result : sge .Expression = sge .Cast (
119114 this = sge .Floor (this = sge .func ("IEEE_DIVIDE" , left_expr , right_expr )), to = "INT64"
@@ -155,7 +150,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
155150
156151@BINARY_OP_REGISTRATION .register (ops .mul_op )
157152def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
158- left_expr , right_expr = _coerce_bools (left , right )
153+ left_expr = left .coerce_bool_to_int ()
154+ right_expr = right .coerce_bool_to_int ()
159155
160156 result = sge .Mul (this = left_expr , expression = right_expr )
161157
@@ -169,35 +165,31 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
169165
170166@BINARY_OP_REGISTRATION .register (ops .ne_op )
171167def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
172- left_expr , right_expr = _coerce_bools (left , right )
168+ left_expr = left .coerce_bool_to_int ()
169+ right_expr = right .coerce_bool_to_int ()
173170 return sge .NEQ (this = left_expr , expression = right_expr )
174171
175172
176173@BINARY_OP_REGISTRATION .register (ops .sub_op )
177174def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
178175 if dtypes .is_numeric (left .dtype ) and dtypes .is_numeric (right .dtype ):
179- left_expr , right_expr = _coerce_bools (left , right )
176+ left_expr = left .coerce_bool_to_int ()
177+ right_expr = right .coerce_bool_to_int ()
180178 return sge .Sub (this = left_expr , expression = right_expr )
181179
182180 if (
183181 dtypes .is_time_or_date_like (left .dtype )
184182 and right .dtype == dtypes .TIMEDELTA_DTYPE
185183 ):
186- left_expr = left .expr
187- if left .dtype == dtypes .DATE_DTYPE :
188- left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
184+ left_expr = left .coerce_date_to_datetime ()
189185 return sge .TimestampSub (
190186 this = left_expr , expression = right .expr , unit = sge .Var (this = "MICROSECOND" )
191187 )
192188 if dtypes .is_time_or_date_like (left .dtype ) and dtypes .is_time_or_date_like (
193189 right .dtype
194190 ):
195- left_expr = left .expr
196- if left .dtype == dtypes .DATE_DTYPE :
197- left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
198- right_expr = right .expr
199- if right .dtype == dtypes .DATE_DTYPE :
200- right_expr = sge .Cast (this = right_expr , to = "DATETIME" )
191+ left_expr = left .coerce_date_to_datetime ()
192+ right_expr = right .coerce_date_to_datetime ()
201193 return sge .TimestampDiff (
202194 this = left_expr , expression = right_expr , unit = sge .Var (this = "MICROSECOND" )
203195 )
@@ -213,16 +205,3 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
213205@BINARY_OP_REGISTRATION .register (ops .obj_make_ref_op )
214206def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
215207 return sge .func ("OBJ.MAKE_REF" , left .expr , right .expr )
216-
217-
218- def _coerce_bools (
219- left : TypedExpr , right : TypedExpr
220- ) -> tuple [sge .Expression , sge .Expression ]:
221- """Coerce boolean expressions to INT64 for binary operations."""
222- left_expr = left .expr
223- if left .dtype == dtypes .BOOL_DTYPE :
224- left_expr = sge .Cast (this = left_expr , to = "INT64" )
225- right_expr = right .expr
226- if right .dtype == dtypes .BOOL_DTYPE :
227- right_expr = sge .Cast (this = right_expr , to = "INT64" )
228- return left_expr , right_expr
0 commit comments