Skip to content

Commit 74c3051

Browse files
feat: Allow local arithmetic execution in hybrid engine
1 parent 07222bf commit 74c3051

File tree

8 files changed

+519
-48
lines changed

8 files changed

+519
-48
lines changed

bigframes/core/compile/polars/compiler.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import bigframes.operations.comparison_ops as comp_ops
3636
import bigframes.operations.generic_ops as gen_ops
3737
import bigframes.operations.numeric_ops as num_ops
38+
import bigframes.operations.string_ops as string_ops
3839

3940
polars_installed = True
4041
if TYPE_CHECKING:
@@ -146,6 +147,14 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
146147
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
147148
return input.abs()
148149

150+
@compile_op.register(num_ops.FloorOp)
151+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
152+
return input.floor()
153+
154+
@compile_op.register(num_ops.CeilOp)
155+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
156+
return input.ceil()
157+
149158
@compile_op.register(num_ops.PosOp)
150159
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
151160
return input.__pos__()
@@ -182,10 +191,6 @@ def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
182191
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
183192
return l_input // r_input
184193

185-
@compile_op.register(num_ops.FloorDivOp)
186-
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
187-
return l_input // r_input
188-
189194
@compile_op.register(num_ops.ModOp)
190195
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
191196
return l_input % r_input
@@ -270,6 +275,11 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
270275
# eg. We want "True" instead of "true" for bool to strin
271276
return input.cast(_DTYPE_MAPPING[op.to_type], strict=not op.safe)
272277

278+
@compile_op.register(string_ops.StrConcatOp)
279+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
280+
assert isinstance(op, string_ops.StrConcatOp)
281+
return pl.concat_str(l_input, r_input)
282+
273283
@dataclasses.dataclass(frozen=True)
274284
class PolarsAggregateCompiler:
275285
scalar_compiler = PolarsExpressionCompiler()

bigframes/core/compile/polars/lowering.py

Lines changed: 247 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,26 +37,258 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
3737
return expr.op.as_expr(larg, rarg)
3838

3939

40+
class LowerAddRule(op_lowering.OpLoweringRule):
41+
@property
42+
def op(self) -> type[ops.ScalarOp]:
43+
return numeric_ops.AddOp
44+
45+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
46+
assert isinstance(expr.op, numeric_ops.AddOp)
47+
larg, rarg = expr.children[0], expr.children[1]
48+
49+
if (
50+
larg.output_type == dtypes.BOOL_DTYPE
51+
and rarg.output_type == dtypes.BOOL_DTYPE
52+
):
53+
int_result = expr.op.as_expr(
54+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg),
55+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg),
56+
)
57+
return ops.AsTypeOp(to_type=dtypes.BOOL_DTYPE).as_expr(int_result)
58+
59+
if dtypes.is_string_like(larg.output_type) and dtypes.is_string_like(
60+
rarg.output_type
61+
):
62+
return ops.strconcat_op.as_expr(larg, rarg)
63+
64+
if larg.output_type == dtypes.BOOL_DTYPE:
65+
larg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg)
66+
if rarg.output_type == dtypes.BOOL_DTYPE:
67+
rarg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg)
68+
69+
if (
70+
larg.output_type == dtypes.DATE_DTYPE
71+
and rarg.output_type == dtypes.TIMEDELTA_DTYPE
72+
):
73+
larg = ops.AsTypeOp(to_type=dtypes.DATETIME_DTYPE).as_expr(larg)
74+
75+
if (
76+
larg.output_type == dtypes.TIMEDELTA_DTYPE
77+
and rarg.output_type == dtypes.DATE_DTYPE
78+
):
79+
rarg = ops.AsTypeOp(to_type=dtypes.DATETIME_DTYPE).as_expr(rarg)
80+
81+
return expr.op.as_expr(larg, rarg)
82+
83+
84+
class LowerSubRule(op_lowering.OpLoweringRule):
85+
@property
86+
def op(self) -> type[ops.ScalarOp]:
87+
return numeric_ops.SubOp
88+
89+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
90+
assert isinstance(expr.op, numeric_ops.SubOp)
91+
larg, rarg = expr.children[0], expr.children[1]
92+
93+
if (
94+
larg.output_type == dtypes.BOOL_DTYPE
95+
and rarg.output_type == dtypes.BOOL_DTYPE
96+
):
97+
int_result = expr.op.as_expr(
98+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg),
99+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg),
100+
)
101+
return ops.AsTypeOp(to_type=dtypes.BOOL_DTYPE).as_expr(int_result)
102+
103+
if larg.output_type == dtypes.BOOL_DTYPE:
104+
larg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg)
105+
if rarg.output_type == dtypes.BOOL_DTYPE:
106+
rarg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg)
107+
108+
if (
109+
larg.output_type == dtypes.DATE_DTYPE
110+
and rarg.output_type == dtypes.TIMEDELTA_DTYPE
111+
):
112+
larg = ops.AsTypeOp(to_type=dtypes.DATETIME_DTYPE).as_expr(larg)
113+
114+
return expr.op.as_expr(larg, rarg)
115+
116+
117+
@dataclasses.dataclass
118+
class LowerMulRule(op_lowering.OpLoweringRule):
119+
@property
120+
def op(self) -> type[ops.ScalarOp]:
121+
return numeric_ops.MulOp
122+
123+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
124+
assert isinstance(expr.op, numeric_ops.MulOp)
125+
larg, rarg = expr.children[0], expr.children[1]
126+
127+
if (
128+
larg.output_type == dtypes.BOOL_DTYPE
129+
and rarg.output_type == dtypes.BOOL_DTYPE
130+
):
131+
int_result = expr.op.as_expr(
132+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg),
133+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg),
134+
)
135+
return ops.AsTypeOp(to_type=dtypes.BOOL_DTYPE).as_expr(int_result)
136+
137+
if (
138+
larg.output_type == dtypes.BOOL_DTYPE
139+
and rarg.output_type != dtypes.BOOL_DTYPE
140+
):
141+
larg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg)
142+
if (
143+
rarg.output_type == dtypes.BOOL_DTYPE
144+
and larg.output_type != dtypes.BOOL_DTYPE
145+
):
146+
rarg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg)
147+
148+
return expr.op.as_expr(larg, rarg)
149+
150+
151+
class LowerDivRule(op_lowering.OpLoweringRule):
152+
@property
153+
def op(self) -> type[ops.ScalarOp]:
154+
return numeric_ops.DivOp
155+
156+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
157+
assert isinstance(expr.op, numeric_ops.DivOp)
158+
159+
dividend = expr.children[0]
160+
divisor = expr.children[1]
161+
162+
if (
163+
dividend.output_type == dtypes.TIMEDELTA_DTYPE
164+
and divisor.output_type == dtypes.INT_DTYPE
165+
):
166+
int_result = expr.op.as_expr(
167+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend), divisor
168+
)
169+
return ops.AsTypeOp(to_type=dtypes.TIMEDELTA_DTYPE).as_expr(int_result)
170+
171+
if (
172+
dividend.output_type == dtypes.BOOL_DTYPE
173+
and divisor.output_type == dtypes.BOOL_DTYPE
174+
):
175+
int_result = expr.op.as_expr(
176+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend),
177+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(divisor),
178+
)
179+
return ops.AsTypeOp(to_type=dtypes.BOOL_DTYPE).as_expr(int_result)
180+
181+
# polars divide doesn't like bools, convert to int always
182+
# convert numerics to float always
183+
if dividend.output_type == dtypes.BOOL_DTYPE:
184+
dividend = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend)
185+
elif dividend.output_type in (dtypes.BIGNUMERIC_DTYPE, dtypes.NUMERIC_DTYPE):
186+
dividend = ops.AsTypeOp(to_type=dtypes.FLOAT_DTYPE).as_expr(dividend)
187+
if divisor.output_type == dtypes.BOOL_DTYPE:
188+
divisor = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(divisor)
189+
190+
return numeric_ops.div_op.as_expr(dividend, divisor)
191+
192+
40193
class LowerFloorDivRule(op_lowering.OpLoweringRule):
41194
@property
42195
def op(self) -> type[ops.ScalarOp]:
43196
return numeric_ops.FloorDivOp
44197

45198
def lower(self, expr: expression.OpExpression) -> expression.Expression:
199+
assert isinstance(expr.op, numeric_ops.FloorDivOp)
200+
46201
dividend = expr.children[0]
47202
divisor = expr.children[1]
48-
using_floats = (dividend.output_type == dtypes.FLOAT_DTYPE) or (
49-
divisor.output_type == dtypes.FLOAT_DTYPE
50-
)
51-
inf_or_zero = (
52-
expression.const(float("INF")) if using_floats else expression.const(0)
53-
)
54-
zero_result = ops.mul_op.as_expr(inf_or_zero, dividend)
55-
divisor_is_zero = ops.eq_op.as_expr(divisor, expression.const(0))
56-
return ops.where_op.as_expr(zero_result, divisor_is_zero, expr)
203+
204+
if (
205+
dividend.output_type == dtypes.TIMEDELTA_DTYPE
206+
and divisor.output_type == dtypes.TIMEDELTA_DTYPE
207+
):
208+
int_result = expr.op.as_expr(
209+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend),
210+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(divisor),
211+
)
212+
return int_result
213+
if dividend.output_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(
214+
divisor.output_type
215+
):
216+
# this is pretty fragile as zero will break it, and must fit back into int
217+
numeric_result = expr.op.as_expr(
218+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend), divisor
219+
)
220+
int_result = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(numeric_result)
221+
return ops.AsTypeOp(to_type=dtypes.TIMEDELTA_DTYPE).as_expr(int_result)
222+
223+
if dividend.output_type == dtypes.BOOL_DTYPE:
224+
dividend = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend)
225+
if divisor.output_type == dtypes.BOOL_DTYPE:
226+
divisor = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(divisor)
227+
228+
if expr.output_type != dtypes.FLOAT_DTYPE:
229+
# need to guard against zero divisor
230+
# multiply dividend in this case to propagate nulls
231+
return ops.where_op.as_expr(
232+
ops.mul_op.as_expr(dividend, expression.const(0)),
233+
ops.eq_op.as_expr(divisor, expression.const(0)),
234+
numeric_ops.floordiv_op.as_expr(dividend, divisor),
235+
)
236+
else:
237+
return expr.op.as_expr(dividend, divisor)
238+
239+
240+
class LowerModRule(op_lowering.OpLoweringRule):
241+
@property
242+
def op(self) -> type[ops.ScalarOp]:
243+
return numeric_ops.ModOp
244+
245+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
246+
og_expr = expr
247+
assert isinstance(expr.op, numeric_ops.ModOp)
248+
larg, rarg = expr.children[0], expr.children[1]
249+
250+
if (
251+
larg.output_type == dtypes.TIMEDELTA_DTYPE
252+
and rarg.output_type == dtypes.TIMEDELTA_DTYPE
253+
):
254+
larg_int = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg)
255+
rarg_int = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg)
256+
int_result = expr.op.as_expr(larg_int, rarg_int)
257+
w_zero_handling = ops.where_op.as_expr(
258+
int_result,
259+
ops.ne_op.as_expr(rarg_int, expression.const(0)),
260+
ops.mul_op.as_expr(rarg_int, expression.const(0)),
261+
)
262+
return ops.AsTypeOp(to_type=dtypes.TIMEDELTA_DTYPE).as_expr(w_zero_handling)
263+
264+
if larg.output_type == dtypes.BOOL_DTYPE:
265+
larg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg)
266+
if rarg.output_type == dtypes.BOOL_DTYPE:
267+
rarg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg)
268+
269+
wo_bools = expr.op.as_expr(larg, rarg)
270+
271+
if og_expr.output_type == dtypes.INT_DTYPE:
272+
return ops.where_op.as_expr(
273+
wo_bools,
274+
ops.ne_op.as_expr(rarg, expression.const(0)),
275+
ops.mul_op.as_expr(rarg, expression.const(0)),
276+
)
277+
return wo_bools
57278

58279

59-
def _coerce_comparables(expr1: expression.Expression, expr2: expression.Expression):
280+
def _coerce_comparables(
281+
expr1: expression.Expression,
282+
expr2: expression.Expression,
283+
*,
284+
bools_only: bool = False
285+
):
286+
if bools_only:
287+
if (
288+
expr1.output_type != dtypes.BOOL_DTYPE
289+
and expr2.output_type != dtypes.BOOL_DTYPE
290+
):
291+
return expr1, expr2
60292

61293
target_type = dtypes.coerce_to_common(expr1.output_type, expr2.output_type)
62294
if expr1.output_type != target_type:
@@ -90,7 +322,12 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
90322

91323
POLARS_LOWERING_RULES = (
92324
*LOWER_COMPARISONS,
325+
LowerAddRule(),
326+
LowerSubRule(),
327+
LowerMulRule(),
328+
LowerDivRule(),
93329
LowerFloorDivRule(),
330+
LowerModRule(),
94331
)
95332

96333

0 commit comments

Comments
 (0)