@@ -38,31 +38,23 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
3838 return sge .Concat (expressions = [left .expr , right .expr ])
3939
4040 if dtypes .is_numeric (left .dtype ) and dtypes .is_numeric (right .dtype ):
41- left_expr = left .expr
42- if left .dtype == dtypes .BOOL_DTYPE :
43- left_expr = sge .Cast (this = left_expr , to = "INT64" )
44- right_expr = right .expr
45- if right .dtype == dtypes .BOOL_DTYPE :
46- right_expr = sge .Cast (this = right_expr , to = "INT64" )
41+ left_expr = _coerce_bool_to_int (left )
42+ right_expr = _coerce_bool_to_int (right )
4743 return sge .Add (this = left_expr , expression = right_expr )
4844
4945 if (
5046 dtypes .is_time_or_date_like (left .dtype )
5147 and right .dtype == dtypes .TIMEDELTA_DTYPE
5248 ):
53- left_expr = left .expr
54- if left .dtype == dtypes .DATE_DTYPE :
55- left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
49+ left_expr = _coerce_date_to_datetime (left )
5650 return sge .TimestampAdd (
5751 this = left_expr , expression = right .expr , unit = sge .Var (this = "MICROSECOND" )
5852 )
5953 if (
6054 dtypes .is_time_or_date_like (right .dtype )
6155 and left .dtype == dtypes .TIMEDELTA_DTYPE
6256 ):
63- right_expr = right .expr
64- if right .dtype == dtypes .DATE_DTYPE :
65- right_expr = sge .Cast (this = right_expr , to = "DATETIME" )
57+ right_expr = _coerce_date_to_datetime (right )
6658 return sge .TimestampAdd (
6759 this = right_expr , expression = left .expr , unit = sge .Var (this = "MICROSECOND" )
6860 )
@@ -74,14 +66,37 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
7466 )
7567
7668
77- @BINARY_OP_REGISTRATION .register (ops .div_op )
69+ @BINARY_OP_REGISTRATION .register (ops .eq_op )
70+ def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
71+ left_expr = _coerce_bool_to_int (left )
72+ right_expr = _coerce_bool_to_int (right )
73+ return sge .EQ (this = left_expr , expression = right_expr )
74+
75+
76+ @BINARY_OP_REGISTRATION .register (ops .eq_null_match_op )
7877def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
7978 left_expr = left .expr
80- if left .dtype == dtypes .BOOL_DTYPE :
81- left_expr = sge .Cast (this = left_expr , to = "INT64" )
79+ if right .dtype != dtypes .BOOL_DTYPE :
80+ left_expr = _coerce_bool_to_int (left )
81+
8282 right_expr = right .expr
83- if right .dtype == dtypes .BOOL_DTYPE :
84- right_expr = sge .Cast (this = right_expr , to = "INT64" )
83+ if left .dtype != dtypes .BOOL_DTYPE :
84+ right_expr = _coerce_bool_to_int (right )
85+
86+ sentinel = sge .convert ("$NULL_SENTINEL$" )
87+ left_coalesce = sge .Coalesce (
88+ this = sge .Cast (this = left_expr , to = "STRING" ), expressions = [sentinel ]
89+ )
90+ right_coalesce = sge .Coalesce (
91+ this = sge .Cast (this = right_expr , to = "STRING" ), expressions = [sentinel ]
92+ )
93+ return sge .EQ (this = left_coalesce , expression = right_coalesce )
94+
95+
96+ @BINARY_OP_REGISTRATION .register (ops .div_op )
97+ def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
98+ left_expr = _coerce_bool_to_int (left )
99+ right_expr = _coerce_bool_to_int (right )
85100
86101 result = sge .func ("IEEE_DIVIDE" , left_expr , right_expr )
87102 if left .dtype == dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right .dtype ):
@@ -92,12 +107,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
92107
93108@BINARY_OP_REGISTRATION .register (ops .floordiv_op )
94109def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
95- left_expr = left .expr
96- if left .dtype == dtypes .BOOL_DTYPE :
97- left_expr = sge .Cast (this = left_expr , to = "INT64" )
98- right_expr = right .expr
99- if right .dtype == dtypes .BOOL_DTYPE :
100- right_expr = sge .Cast (this = right_expr , to = "INT64" )
110+ left_expr = _coerce_bool_to_int (left )
111+ right_expr = _coerce_bool_to_int (right )
101112
102113 result : sge .Expression = sge .Cast (
103114 this = sge .Floor (this = sge .func ("IEEE_DIVIDE" , left_expr , right_expr )), to = "INT64"
@@ -139,12 +150,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
139150
140151@BINARY_OP_REGISTRATION .register (ops .mul_op )
141152def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
142- left_expr = left .expr
143- if left .dtype == dtypes .BOOL_DTYPE :
144- left_expr = sge .Cast (this = left_expr , to = "INT64" )
145- right_expr = right .expr
146- if right .dtype == dtypes .BOOL_DTYPE :
147- right_expr = sge .Cast (this = right_expr , to = "INT64" )
153+ left_expr = _coerce_bool_to_int (left )
154+ right_expr = _coerce_bool_to_int (right )
148155
149156 result = sge .Mul (this = left_expr , expression = right_expr )
150157
@@ -156,36 +163,33 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
156163 return result
157164
158165
166+ @BINARY_OP_REGISTRATION .register (ops .ne_op )
167+ def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
168+ left_expr = _coerce_bool_to_int (left )
169+ right_expr = _coerce_bool_to_int (right )
170+ return sge .NEQ (this = left_expr , expression = right_expr )
171+
172+
159173@BINARY_OP_REGISTRATION .register (ops .sub_op )
160174def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
161175 if dtypes .is_numeric (left .dtype ) and dtypes .is_numeric (right .dtype ):
162- left_expr = left .expr
163- if left .dtype == dtypes .BOOL_DTYPE :
164- left_expr = sge .Cast (this = left_expr , to = "INT64" )
165- right_expr = right .expr
166- if right .dtype == dtypes .BOOL_DTYPE :
167- right_expr = sge .Cast (this = right_expr , to = "INT64" )
176+ left_expr = _coerce_bool_to_int (left )
177+ right_expr = _coerce_bool_to_int (right )
168178 return sge .Sub (this = left_expr , expression = right_expr )
169179
170180 if (
171181 dtypes .is_time_or_date_like (left .dtype )
172182 and right .dtype == dtypes .TIMEDELTA_DTYPE
173183 ):
174- left_expr = left .expr
175- if left .dtype == dtypes .DATE_DTYPE :
176- left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
184+ left_expr = _coerce_date_to_datetime (left )
177185 return sge .TimestampSub (
178186 this = left_expr , expression = right .expr , unit = sge .Var (this = "MICROSECOND" )
179187 )
180188 if dtypes .is_time_or_date_like (left .dtype ) and dtypes .is_time_or_date_like (
181189 right .dtype
182190 ):
183- left_expr = left .expr
184- if left .dtype == dtypes .DATE_DTYPE :
185- left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
186- right_expr = right .expr
187- if right .dtype == dtypes .DATE_DTYPE :
188- right_expr = sge .Cast (this = right_expr , to = "DATETIME" )
191+ left_expr = _coerce_date_to_datetime (left )
192+ right_expr = _coerce_date_to_datetime (right )
189193 return sge .TimestampDiff (
190194 this = left_expr , expression = right_expr , unit = sge .Var (this = "MICROSECOND" )
191195 )
@@ -201,3 +205,17 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
201205@BINARY_OP_REGISTRATION .register (ops .obj_make_ref_op )
202206def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
203207 return sge .func ("OBJ.MAKE_REF" , left .expr , right .expr )
208+
209+
210+ def _coerce_bool_to_int (typed_expr : TypedExpr ) -> sge .Expression :
211+ """Coerce boolean expression to integer."""
212+ if typed_expr .dtype == dtypes .BOOL_DTYPE :
213+ return sge .Cast (this = typed_expr .expr , to = "INT64" )
214+ return typed_expr .expr
215+
216+
217+ def _coerce_date_to_datetime (typed_expr : TypedExpr ) -> sge .Expression :
218+ """Coerce date expression to datetime."""
219+ if typed_expr .dtype == dtypes .DATE_DTYPE :
220+ return sge .Cast (this = typed_expr .expr , to = "DATETIME" )
221+ return typed_expr .expr
0 commit comments