@@ -37,26 +37,258 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
3737 return expr .op .as_expr (larg , rarg )
3838
3939
40+ class LowerAddRule (op_lowering .OpLoweringRule ):
41+ @property
42+ def op (self ) -> type [ops .ScalarOp ]:
43+ return numeric_ops .AddOp
44+
45+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
46+ assert isinstance (expr .op , numeric_ops .AddOp )
47+ larg , rarg = expr .children [0 ], expr .children [1 ]
48+
49+ if (
50+ larg .output_type == dtypes .BOOL_DTYPE
51+ and rarg .output_type == dtypes .BOOL_DTYPE
52+ ):
53+ int_result = expr .op .as_expr (
54+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg ),
55+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg ),
56+ )
57+ return ops .AsTypeOp (to_type = dtypes .BOOL_DTYPE ).as_expr (int_result )
58+
59+ if dtypes .is_string_like (larg .output_type ) and dtypes .is_string_like (
60+ rarg .output_type
61+ ):
62+ return ops .strconcat_op .as_expr (larg , rarg )
63+
64+ if larg .output_type == dtypes .BOOL_DTYPE :
65+ larg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg )
66+ if rarg .output_type == dtypes .BOOL_DTYPE :
67+ rarg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg )
68+
69+ if (
70+ larg .output_type == dtypes .DATE_DTYPE
71+ and rarg .output_type == dtypes .TIMEDELTA_DTYPE
72+ ):
73+ larg = ops .AsTypeOp (to_type = dtypes .DATETIME_DTYPE ).as_expr (larg )
74+
75+ if (
76+ larg .output_type == dtypes .TIMEDELTA_DTYPE
77+ and rarg .output_type == dtypes .DATE_DTYPE
78+ ):
79+ rarg = ops .AsTypeOp (to_type = dtypes .DATETIME_DTYPE ).as_expr (rarg )
80+
81+ return expr .op .as_expr (larg , rarg )
82+
83+
84+ class LowerSubRule (op_lowering .OpLoweringRule ):
85+ @property
86+ def op (self ) -> type [ops .ScalarOp ]:
87+ return numeric_ops .SubOp
88+
89+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
90+ assert isinstance (expr .op , numeric_ops .SubOp )
91+ larg , rarg = expr .children [0 ], expr .children [1 ]
92+
93+ if (
94+ larg .output_type == dtypes .BOOL_DTYPE
95+ and rarg .output_type == dtypes .BOOL_DTYPE
96+ ):
97+ int_result = expr .op .as_expr (
98+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg ),
99+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg ),
100+ )
101+ return ops .AsTypeOp (to_type = dtypes .BOOL_DTYPE ).as_expr (int_result )
102+
103+ if larg .output_type == dtypes .BOOL_DTYPE :
104+ larg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg )
105+ if rarg .output_type == dtypes .BOOL_DTYPE :
106+ rarg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg )
107+
108+ if (
109+ larg .output_type == dtypes .DATE_DTYPE
110+ and rarg .output_type == dtypes .TIMEDELTA_DTYPE
111+ ):
112+ larg = ops .AsTypeOp (to_type = dtypes .DATETIME_DTYPE ).as_expr (larg )
113+
114+ return expr .op .as_expr (larg , rarg )
115+
116+
117+ @dataclasses .dataclass
118+ class LowerMulRule (op_lowering .OpLoweringRule ):
119+ @property
120+ def op (self ) -> type [ops .ScalarOp ]:
121+ return numeric_ops .MulOp
122+
123+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
124+ assert isinstance (expr .op , numeric_ops .MulOp )
125+ larg , rarg = expr .children [0 ], expr .children [1 ]
126+
127+ if (
128+ larg .output_type == dtypes .BOOL_DTYPE
129+ and rarg .output_type == dtypes .BOOL_DTYPE
130+ ):
131+ int_result = expr .op .as_expr (
132+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg ),
133+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg ),
134+ )
135+ return ops .AsTypeOp (to_type = dtypes .BOOL_DTYPE ).as_expr (int_result )
136+
137+ if (
138+ larg .output_type == dtypes .BOOL_DTYPE
139+ and rarg .output_type != dtypes .BOOL_DTYPE
140+ ):
141+ larg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg )
142+ if (
143+ rarg .output_type == dtypes .BOOL_DTYPE
144+ and larg .output_type != dtypes .BOOL_DTYPE
145+ ):
146+ rarg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg )
147+
148+ return expr .op .as_expr (larg , rarg )
149+
150+
151+ class LowerDivRule (op_lowering .OpLoweringRule ):
152+ @property
153+ def op (self ) -> type [ops .ScalarOp ]:
154+ return numeric_ops .DivOp
155+
156+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
157+ assert isinstance (expr .op , numeric_ops .DivOp )
158+
159+ dividend = expr .children [0 ]
160+ divisor = expr .children [1 ]
161+
162+ if (
163+ dividend .output_type == dtypes .TIMEDELTA_DTYPE
164+ and divisor .output_type == dtypes .INT_DTYPE
165+ ):
166+ int_result = expr .op .as_expr (
167+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend ), divisor
168+ )
169+ return ops .AsTypeOp (to_type = dtypes .TIMEDELTA_DTYPE ).as_expr (int_result )
170+
171+ if (
172+ dividend .output_type == dtypes .BOOL_DTYPE
173+ and divisor .output_type == dtypes .BOOL_DTYPE
174+ ):
175+ int_result = expr .op .as_expr (
176+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend ),
177+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (divisor ),
178+ )
179+ return ops .AsTypeOp (to_type = dtypes .BOOL_DTYPE ).as_expr (int_result )
180+
181+ # polars divide doesn't like bools, convert to int always
182+ # convert numerics to float always
183+ if dividend .output_type == dtypes .BOOL_DTYPE :
184+ dividend = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend )
185+ elif dividend .output_type in (dtypes .BIGNUMERIC_DTYPE , dtypes .NUMERIC_DTYPE ):
186+ dividend = ops .AsTypeOp (to_type = dtypes .FLOAT_DTYPE ).as_expr (dividend )
187+ if divisor .output_type == dtypes .BOOL_DTYPE :
188+ divisor = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (divisor )
189+
190+ return numeric_ops .div_op .as_expr (dividend , divisor )
191+
192+
40193class LowerFloorDivRule (op_lowering .OpLoweringRule ):
41194 @property
42195 def op (self ) -> type [ops .ScalarOp ]:
43196 return numeric_ops .FloorDivOp
44197
45198 def lower (self , expr : expression .OpExpression ) -> expression .Expression :
199+ assert isinstance (expr .op , numeric_ops .FloorDivOp )
200+
46201 dividend = expr .children [0 ]
47202 divisor = expr .children [1 ]
48- using_floats = (dividend .output_type == dtypes .FLOAT_DTYPE ) or (
49- divisor .output_type == dtypes .FLOAT_DTYPE
50- )
51- inf_or_zero = (
52- expression .const (float ("INF" )) if using_floats else expression .const (0 )
53- )
54- zero_result = ops .mul_op .as_expr (inf_or_zero , dividend )
55- divisor_is_zero = ops .eq_op .as_expr (divisor , expression .const (0 ))
56- return ops .where_op .as_expr (zero_result , divisor_is_zero , expr )
203+
204+ if (
205+ dividend .output_type == dtypes .TIMEDELTA_DTYPE
206+ and divisor .output_type == dtypes .TIMEDELTA_DTYPE
207+ ):
208+ int_result = expr .op .as_expr (
209+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend ),
210+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (divisor ),
211+ )
212+ return int_result
213+ if dividend .output_type == dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (
214+ divisor .output_type
215+ ):
216+ # this is pretty fragile as zero will break it, and must fit back into int
217+ numeric_result = expr .op .as_expr (
218+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend ), divisor
219+ )
220+ int_result = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (numeric_result )
221+ return ops .AsTypeOp (to_type = dtypes .TIMEDELTA_DTYPE ).as_expr (int_result )
222+
223+ if dividend .output_type == dtypes .BOOL_DTYPE :
224+ dividend = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend )
225+ if divisor .output_type == dtypes .BOOL_DTYPE :
226+ divisor = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (divisor )
227+
228+ if expr .output_type != dtypes .FLOAT_DTYPE :
229+ # need to guard against zero divisor
230+ # multiply dividend in this case to propagate nulls
231+ return ops .where_op .as_expr (
232+ ops .mul_op .as_expr (dividend , expression .const (0 )),
233+ ops .eq_op .as_expr (divisor , expression .const (0 )),
234+ numeric_ops .floordiv_op .as_expr (dividend , divisor ),
235+ )
236+ else :
237+ return expr .op .as_expr (dividend , divisor )
238+
239+
240+ class LowerModRule (op_lowering .OpLoweringRule ):
241+ @property
242+ def op (self ) -> type [ops .ScalarOp ]:
243+ return numeric_ops .ModOp
244+
245+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
246+ og_expr = expr
247+ assert isinstance (expr .op , numeric_ops .ModOp )
248+ larg , rarg = expr .children [0 ], expr .children [1 ]
249+
250+ if (
251+ larg .output_type == dtypes .TIMEDELTA_DTYPE
252+ and rarg .output_type == dtypes .TIMEDELTA_DTYPE
253+ ):
254+ larg_int = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg )
255+ rarg_int = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg )
256+ int_result = expr .op .as_expr (larg_int , rarg_int )
257+ w_zero_handling = ops .where_op .as_expr (
258+ int_result ,
259+ ops .ne_op .as_expr (rarg_int , expression .const (0 )),
260+ ops .mul_op .as_expr (rarg_int , expression .const (0 )),
261+ )
262+ return ops .AsTypeOp (to_type = dtypes .TIMEDELTA_DTYPE ).as_expr (w_zero_handling )
263+
264+ if larg .output_type == dtypes .BOOL_DTYPE :
265+ larg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg )
266+ if rarg .output_type == dtypes .BOOL_DTYPE :
267+ rarg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg )
268+
269+ wo_bools = expr .op .as_expr (larg , rarg )
270+
271+ if og_expr .output_type == dtypes .INT_DTYPE :
272+ return ops .where_op .as_expr (
273+ wo_bools ,
274+ ops .ne_op .as_expr (rarg , expression .const (0 )),
275+ ops .mul_op .as_expr (rarg , expression .const (0 )),
276+ )
277+ return wo_bools
57278
58279
59- def _coerce_comparables (expr1 : expression .Expression , expr2 : expression .Expression ):
280+ def _coerce_comparables (
281+ expr1 : expression .Expression ,
282+ expr2 : expression .Expression ,
283+ * ,
284+ bools_only : bool = False
285+ ):
286+ if bools_only :
287+ if (
288+ expr1 .output_type != dtypes .BOOL_DTYPE
289+ and expr2 .output_type != dtypes .BOOL_DTYPE
290+ ):
291+ return expr1 , expr2
60292
61293 target_type = dtypes .coerce_to_common (expr1 .output_type , expr2 .output_type )
62294 if expr1 .output_type != target_type :
@@ -90,7 +322,12 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
90322
91323POLARS_LOWERING_RULES = (
92324 * LOWER_COMPARISONS ,
325+ LowerAddRule (),
326+ LowerSubRule (),
327+ LowerMulRule (),
328+ LowerDivRule (),
93329 LowerFloorDivRule (),
330+ LowerModRule (),
94331)
95332
96333
0 commit comments