@@ -37,26 +37,23 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
3737 return sge .Concat (expressions = [left .expr , right .expr ])
3838
3939 if dtypes .is_numeric (left .dtype ) and dtypes .is_numeric (right .dtype ):
40- left_expr , right_expr = _coerce_bools (left , right )
40+ left_expr = left .coerce_bool_to_int ()
41+ right_expr = right .coerce_bool_to_int ()
4142 return sge .Add (this = left_expr , expression = right_expr )
4243
4344 if (
4445 dtypes .is_time_or_date_like (left .dtype )
4546 and right .dtype == dtypes .TIMEDELTA_DTYPE
4647 ):
47- left_expr = left .expr
48- if left .dtype == dtypes .DATE_DTYPE :
49- left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
48+ left_expr = left .coerce_date_to_datetime ()
5049 return sge .TimestampAdd (
5150 this = left_expr , expression = right .expr , unit = sge .Var (this = "MICROSECOND" )
5251 )
5352 if (
5453 dtypes .is_time_or_date_like (right .dtype )
5554 and left .dtype == dtypes .TIMEDELTA_DTYPE
5655 ):
57- right_expr = right .expr
58- if right .dtype == dtypes .DATE_DTYPE :
59- right_expr = sge .Cast (this = right_expr , to = "DATETIME" )
56+ right_expr = right .coerce_date_to_datetime ()
6057 return sge .TimestampAdd (
6158 this = right_expr , expression = left .expr , unit = sge .Var (this = "MICROSECOND" )
6259 )
@@ -70,19 +67,20 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
7067
7168@BINARY_OP_REGISTRATION .register (ops .eq_op )
7269def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
73- left_expr , right_expr = _coerce_bools (left , right )
70+ left_expr = left .coerce_bool_to_int ()
71+ right_expr = right .coerce_bool_to_int ()
7472 return sge .EQ (this = left_expr , expression = right_expr )
7573
7674
7775@BINARY_OP_REGISTRATION .register (ops .eq_null_match_op )
7876def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
7977 left_expr = left .expr
80- if left . dtype == dtypes . BOOL_DTYPE and right .dtype != dtypes .BOOL_DTYPE :
81- left_expr = sge . Cast ( this = left_expr , to = "INT64" )
78+ if right .dtype != dtypes .BOOL_DTYPE :
79+ left_expr = left . coerce_bool_to_int ( )
8280
8381 right_expr = right .expr
84- if right . dtype == dtypes . BOOL_DTYPE and left .dtype != dtypes .BOOL_DTYPE :
85- right_expr = sge . Cast ( this = right_expr , to = "INT64" )
82+ if left .dtype != dtypes .BOOL_DTYPE :
83+ right_expr = right . coerce_bool_to_int ( )
8684
8785 sentinel = sge .convert ("$NULL_SENTINEL$" )
8886 left_coalesce = sge .Coalesce (
@@ -96,7 +94,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
9694
9795@BINARY_OP_REGISTRATION .register (ops .div_op )
9896def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
99- left_expr , right_expr = _coerce_bools (left , right )
97+ left_expr = left .coerce_bool_to_int ()
98+ right_expr = right .coerce_bool_to_int ()
10099
101100 result = sge .func ("IEEE_DIVIDE" , left_expr , right_expr )
102101 if left .dtype == dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right .dtype ):
@@ -117,7 +116,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
117116
118117@BINARY_OP_REGISTRATION .register (ops .mul_op )
119118def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
120- left_expr , right_expr = _coerce_bools (left , right )
119+ left_expr = left .coerce_bool_to_int ()
120+ right_expr = right .coerce_bool_to_int ()
121121
122122 result = sge .Mul (this = left_expr , expression = right_expr )
123123
@@ -131,35 +131,31 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
131131
132132@BINARY_OP_REGISTRATION .register (ops .ne_op )
133133def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
134- left_expr , right_expr = _coerce_bools (left , right )
134+ left_expr = left .coerce_bool_to_int ()
135+ right_expr = right .coerce_bool_to_int ()
135136 return sge .NEQ (this = left_expr , expression = right_expr )
136137
137138
138139@BINARY_OP_REGISTRATION .register (ops .sub_op )
139140def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
140141 if dtypes .is_numeric (left .dtype ) and dtypes .is_numeric (right .dtype ):
141- left_expr , right_expr = _coerce_bools (left , right )
142+ left_expr = left .coerce_bool_to_int ()
143+ right_expr = right .coerce_bool_to_int ()
142144 return sge .Sub (this = left_expr , expression = right_expr )
143145
144146 if (
145147 dtypes .is_time_or_date_like (left .dtype )
146148 and right .dtype == dtypes .TIMEDELTA_DTYPE
147149 ):
148- left_expr = left .expr
149- if left .dtype == dtypes .DATE_DTYPE :
150- left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
150+ left_expr = left .coerce_date_to_datetime ()
151151 return sge .TimestampSub (
152152 this = left_expr , expression = right .expr , unit = sge .Var (this = "MICROSECOND" )
153153 )
154154 if dtypes .is_time_or_date_like (left .dtype ) and dtypes .is_time_or_date_like (
155155 right .dtype
156156 ):
157- left_expr = left .expr
158- if left .dtype == dtypes .DATE_DTYPE :
159- left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
160- right_expr = right .expr
161- if right .dtype == dtypes .DATE_DTYPE :
162- right_expr = sge .Cast (this = right_expr , to = "DATETIME" )
157+ left_expr = left .coerce_date_to_datetime ()
158+ right_expr = right .coerce_date_to_datetime ()
163159 return sge .TimestampDiff (
164160 this = left_expr , expression = right_expr , unit = sge .Var (this = "MICROSECOND" )
165161 )
@@ -175,16 +171,3 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
175171@BINARY_OP_REGISTRATION .register (ops .obj_make_ref_op )
176172def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
177173 return sge .func ("OBJ.MAKE_REF" , left .expr , right .expr )
178-
179-
180- def _coerce_bools (
181- left : TypedExpr , right : TypedExpr
182- ) -> tuple [sge .Expression , sge .Expression ]:
183- """Coerce boolean expressions to INT64 for binary operations."""
184- left_expr = left .expr
185- if left .dtype == dtypes .BOOL_DTYPE :
186- left_expr = sge .Cast (this = left_expr , to = "INT64" )
187- right_expr = right .expr
188- if right .dtype == dtypes .BOOL_DTYPE :
189- right_expr = sge .Cast (this = right_expr , to = "INT64" )
190- return left_expr , right_expr
0 commit comments