Skip to content

Commit f1daafc

Browse files
committed
chore: Migrate pow_op operator to SQLGlot
1 parent 6e73d77 commit f1daafc

File tree

3 files changed

+480
-0
lines changed

3 files changed

+480
-0
lines changed

bigframes/core/compile/sqlglot/expressions/numeric_ops.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,141 @@ def _(expr: TypedExpr) -> sge.Expression:
210210
return expr.expr
211211

212212

213+
@register_binary_op(ops.pow_op)
214+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
215+
left_expr = _coerce_bool_to_int(left)
216+
right_expr = _coerce_bool_to_int(right)
217+
if left.dtype == dtypes.INT_DTYPE and right.dtype == dtypes.INT_DTYPE:
218+
return _int_pow_op(left_expr, right_expr)
219+
else:
220+
return _float_pow_op(left_expr, right_expr)
221+
222+
223+
def _int_pow_op(
224+
left_expr: sge.Expression, right_expr: sge.Expression
225+
) -> sge.Expression:
226+
import math
227+
228+
overflow_value = math.log(2**63 - 1)
229+
overflow_cond = sge.and_(
230+
sge.NEQ(this=left_expr, expression=sge.convert(0)),
231+
sge.GT(
232+
this=sge.Mul(
233+
this=right_expr, expression=sge.Ln(this=sge.Abs(this=left_expr))
234+
),
235+
expression=sge.convert(overflow_value),
236+
),
237+
)
238+
239+
return sge.Case(
240+
ifs=[
241+
sge.If(
242+
this=overflow_cond,
243+
true=sge.Null(),
244+
)
245+
],
246+
default=sge.Cast(
247+
this=sge.Pow(
248+
this=sge.Cast(
249+
this=left_expr, to=sge.DataType(this=sge.DataType.Type.DECIMAL)
250+
),
251+
expression=right_expr,
252+
),
253+
to="INT64",
254+
),
255+
)
256+
257+
258+
def _float_pow_op(
259+
left_expr: sge.Expression, right_expr: sge.Expression
260+
) -> sge.Expression:
261+
# Most conditions here seek to prevent calling BQ POW with inputs that would generate errors.
262+
# See: https://cloud.google.com/bigquery/docs/reference/standard-sql/mathematical_functions#pow
263+
overflow_cond = sge.and_(
264+
sge.NEQ(this=left_expr, expression=constants._ZERO),
265+
sge.GT(
266+
this=sge.Mul(
267+
this=right_expr, expression=sge.Ln(this=sge.Abs(this=left_expr))
268+
),
269+
expression=constants._FLOAT64_EXP_BOUND,
270+
),
271+
)
272+
273+
# Float64 lose integer precision beyond 2**53, beyond this insufficient precision to get parity
274+
exp_too_big = sge.GT(this=sge.Abs(this=right_expr), expression=sge.convert(2**53))
275+
# Treat very large exponents as +=INF
276+
norm_exp = sge.Case(
277+
ifs=[
278+
sge.If(
279+
this=exp_too_big,
280+
true=sge.Mul(this=constants._INF, expression=sge.Sign(this=right_expr)),
281+
)
282+
],
283+
default=right_expr,
284+
)
285+
286+
pow_result = sge.Pow(this=left_expr, expression=norm_exp)
287+
288+
# This cast is dangerous, need to only excuted where y_val has been bounds-checked
289+
# Ibis needs try_cast binding to bq safe_cast
290+
exponent_is_whole = sge.EQ(
291+
this=sge.Cast(this=right_expr, to="INT64"), expression=right_expr
292+
)
293+
odd_exponent = sge.and_(
294+
sge.LT(this=left_expr, expression=constants._ZERO),
295+
sge.EQ(
296+
this=sge.Mod(
297+
this=sge.Cast(this=right_expr, to="INT64"), expression=sge.convert(2)
298+
),
299+
expression=sge.convert(1),
300+
),
301+
)
302+
infinite_base = sge.EQ(this=sge.Abs(this=left_expr), expression=constants._INF)
303+
304+
return sge.Case(
305+
ifs=[
306+
# Might be able to do something more clever with x_val==0 case
307+
sge.If(
308+
this=sge.EQ(this=right_expr, expression=constants._ZERO),
309+
true=sge.convert(1),
310+
),
311+
sge.If(
312+
this=sge.EQ(this=left_expr, expression=sge.convert(1)),
313+
true=sge.convert(1),
314+
), # Need to ignore exponent, even if it is NA
315+
sge.If(
316+
this=sge.and_(
317+
sge.EQ(this=left_expr, expression=constants._ZERO),
318+
sge.LT(this=right_expr, expression=constants._ZERO),
319+
),
320+
true=constants._INF,
321+
), # This case would error POW function in BQ
322+
sge.If(this=infinite_base, true=pow_result),
323+
sge.If(
324+
this=exp_too_big, true=pow_result
325+
), # Bigquery can actually handle the +-inf cases gracefully
326+
sge.If(
327+
this=sge.and_(
328+
sge.LT(this=left_expr, expression=constants._ZERO),
329+
sge.Not(this=exponent_is_whole),
330+
),
331+
true=constants._NAN,
332+
),
333+
sge.If(
334+
this=overflow_cond,
335+
true=sge.Mul(
336+
this=constants._INF,
337+
expression=sge.Case(
338+
ifs=[sge.If(this=odd_exponent, true=sge.convert(-1))],
339+
default=sge.convert(1),
340+
),
341+
),
342+
), # finite overflows would cause bq to error
343+
],
344+
default=pow_result,
345+
)
346+
347+
213348
@register_unary_op(ops.sqrt_op)
214349
def _(expr: TypedExpr) -> sge.Expression:
215350
return sge.Case(

0 commit comments

Comments
 (0)