Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions bigframes/core/compile/sqlglot/expressions/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import sqlglot.expressions as sge

_ZERO = sge.Cast(this=sge.convert(0), to="INT64")
Expand All @@ -23,3 +25,13 @@
# FLOAT64 has 11 exponent bits, so max values is about 2**(2**10)
# ln(2**(2**10)) == (2**10)*ln(2) ~= 709.78, so EXP(x) for x>709.78 will overflow.
_FLOAT64_EXP_BOUND = sge.convert(709.78)

# The natural logarithm of the maximum value for a signed 64-bit integer.
# This is used to check for potential overflows in power operations involving integers
# by checking if `exponent * log(base)` exceeds this value.
_INT64_LOG_BOUND = math.log(2**63 - 1)

# Represents the largest integer N where all integers from -N to N can be
# represented exactly as a float64. Float64 types have a 53-bit significand precision,
# so integers beyond this value may lose precision.
_FLOAT64_MAX_INT_PRECISION = 2**53
135 changes: 135 additions & 0 deletions bigframes/core/compile/sqlglot/expressions/numeric_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,141 @@ def _(expr: TypedExpr) -> sge.Expression:
return expr.expr


@register_binary_op(ops.pow_op)
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
left_expr = _coerce_bool_to_int(left)
right_expr = _coerce_bool_to_int(right)
if left.dtype == dtypes.INT_DTYPE and right.dtype == dtypes.INT_DTYPE:
return _int_pow_op(left_expr, right_expr)
else:
return _float_pow_op(left_expr, right_expr)


def _int_pow_op(
left_expr: sge.Expression, right_expr: sge.Expression
) -> sge.Expression:
overflow_cond = sge.and_(
sge.NEQ(this=left_expr, expression=sge.convert(0)),
sge.GT(
this=sge.Mul(
this=right_expr, expression=sge.Ln(this=sge.Abs(this=left_expr))
),
expression=sge.convert(constants._INT64_LOG_BOUND),
),
)

return sge.Case(
ifs=[
sge.If(
this=overflow_cond,
true=sge.Null(),
)
],
default=sge.Cast(
this=sge.Pow(
this=sge.Cast(
this=left_expr, to=sge.DataType(this=sge.DataType.Type.DECIMAL)
),
expression=right_expr,
),
to="INT64",
),
)


def _float_pow_op(
left_expr: sge.Expression, right_expr: sge.Expression
) -> sge.Expression:
# Most conditions here seek to prevent calling BQ POW with inputs that would generate errors.
# See: https://cloud.google.com/bigquery/docs/reference/standard-sql/mathematical_functions#pow
overflow_cond = sge.and_(
sge.NEQ(this=left_expr, expression=constants._ZERO),
sge.GT(
this=sge.Mul(
this=right_expr, expression=sge.Ln(this=sge.Abs(this=left_expr))
),
expression=constants._FLOAT64_EXP_BOUND,
),
)

# Float64 lose integer precision beyond 2**53, beyond this insufficient precision to get parity
exp_too_big = sge.GT(
this=sge.Abs(this=right_expr),
expression=sge.convert(constants._FLOAT64_MAX_INT_PRECISION),
)
# Treat very large exponents as +=INF
norm_exp = sge.Case(
ifs=[
sge.If(
this=exp_too_big,
true=sge.Mul(this=constants._INF, expression=sge.Sign(this=right_expr)),
)
],
default=right_expr,
)

pow_result = sge.Pow(this=left_expr, expression=norm_exp)

# This cast is dangerous, need to only excuted where y_val has been bounds-checked
# Ibis needs try_cast binding to bq safe_cast
exponent_is_whole = sge.EQ(
this=sge.Cast(this=right_expr, to="INT64"), expression=right_expr
)
odd_exponent = sge.and_(
sge.LT(this=left_expr, expression=constants._ZERO),
sge.EQ(
this=sge.Mod(
this=sge.Cast(this=right_expr, to="INT64"), expression=sge.convert(2)
),
expression=sge.convert(1),
),
)
infinite_base = sge.EQ(this=sge.Abs(this=left_expr), expression=constants._INF)

return sge.Case(
ifs=[
# Might be able to do something more clever with x_val==0 case
sge.If(
this=sge.EQ(this=right_expr, expression=constants._ZERO),
true=sge.convert(1),
),
sge.If(
this=sge.EQ(this=left_expr, expression=sge.convert(1)),
true=sge.convert(1),
), # Need to ignore exponent, even if it is NA
sge.If(
this=sge.and_(
sge.EQ(this=left_expr, expression=constants._ZERO),
sge.LT(this=right_expr, expression=constants._ZERO),
),
true=constants._INF,
), # This case would error POW function in BQ
sge.If(this=infinite_base, true=pow_result),
sge.If(
this=exp_too_big, true=pow_result
), # Bigquery can actually handle the +-inf cases gracefully
sge.If(
this=sge.and_(
sge.LT(this=left_expr, expression=constants._ZERO),
sge.Not(this=exponent_is_whole),
),
true=constants._NAN,
),
sge.If(
this=overflow_cond,
true=sge.Mul(
this=constants._INF,
expression=sge.Case(
ifs=[sge.If(this=odd_exponent, true=sge.convert(-1))],
default=sge.convert(1),
),
),
), # finite overflows would cause bq to error
],
default=pow_result,
)


@register_unary_op(ops.sqrt_op)
def _(expr: TypedExpr) -> sge.Expression:
return sge.Case(
Expand Down
Loading