From f1daafc2337302ebd8d2b2011ab72731effc9180 Mon Sep 17 00:00:00 2001 From: jialuo Date: Tue, 25 Nov 2025 00:03:17 +0000 Subject: [PATCH 1/2] chore: Migrate pow_op operator to SQLGlot --- .../sqlglot/expressions/numeric_ops.py | 135 +++++++ .../test_numeric_ops/test_pow/out.sql | 329 ++++++++++++++++++ .../sqlglot/expressions/test_numeric_ops.py | 16 + 3 files changed, 480 insertions(+) create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_pow/out.sql diff --git a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py index c022356fd3..55fe6424f9 100644 --- a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py @@ -210,6 +210,141 @@ def _(expr: TypedExpr) -> sge.Expression: return expr.expr +@register_binary_op(ops.pow_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) + if left.dtype == dtypes.INT_DTYPE and right.dtype == dtypes.INT_DTYPE: + return _int_pow_op(left_expr, right_expr) + else: + return _float_pow_op(left_expr, right_expr) + + +def _int_pow_op( + left_expr: sge.Expression, right_expr: sge.Expression +) -> sge.Expression: + import math + + overflow_value = math.log(2**63 - 1) + overflow_cond = sge.and_( + sge.NEQ(this=left_expr, expression=sge.convert(0)), + sge.GT( + this=sge.Mul( + this=right_expr, expression=sge.Ln(this=sge.Abs(this=left_expr)) + ), + expression=sge.convert(overflow_value), + ), + ) + + return sge.Case( + ifs=[ + sge.If( + this=overflow_cond, + true=sge.Null(), + ) + ], + default=sge.Cast( + this=sge.Pow( + this=sge.Cast( + this=left_expr, to=sge.DataType(this=sge.DataType.Type.DECIMAL) + ), + expression=right_expr, + ), + to="INT64", + ), + ) + + +def _float_pow_op( + left_expr: sge.Expression, right_expr: sge.Expression +) -> sge.Expression: + # Most conditions here seek to prevent calling BQ POW with inputs that would generate errors. + # See: https://cloud.google.com/bigquery/docs/reference/standard-sql/mathematical_functions#pow + overflow_cond = sge.and_( + sge.NEQ(this=left_expr, expression=constants._ZERO), + sge.GT( + this=sge.Mul( + this=right_expr, expression=sge.Ln(this=sge.Abs(this=left_expr)) + ), + expression=constants._FLOAT64_EXP_BOUND, + ), + ) + + # Float64 lose integer precision beyond 2**53, beyond this insufficient precision to get parity + exp_too_big = sge.GT(this=sge.Abs(this=right_expr), expression=sge.convert(2**53)) + # Treat very large exponents as +=INF + norm_exp = sge.Case( + ifs=[ + sge.If( + this=exp_too_big, + true=sge.Mul(this=constants._INF, expression=sge.Sign(this=right_expr)), + ) + ], + default=right_expr, + ) + + pow_result = sge.Pow(this=left_expr, expression=norm_exp) + + # This cast is dangerous, need to only excuted where y_val has been bounds-checked + # Ibis needs try_cast binding to bq safe_cast + exponent_is_whole = sge.EQ( + this=sge.Cast(this=right_expr, to="INT64"), expression=right_expr + ) + odd_exponent = sge.and_( + sge.LT(this=left_expr, expression=constants._ZERO), + sge.EQ( + this=sge.Mod( + this=sge.Cast(this=right_expr, to="INT64"), expression=sge.convert(2) + ), + expression=sge.convert(1), + ), + ) + infinite_base = sge.EQ(this=sge.Abs(this=left_expr), expression=constants._INF) + + return sge.Case( + ifs=[ + # Might be able to do something more clever with x_val==0 case + sge.If( + this=sge.EQ(this=right_expr, expression=constants._ZERO), + true=sge.convert(1), + ), + sge.If( + this=sge.EQ(this=left_expr, expression=sge.convert(1)), + true=sge.convert(1), + ), # Need to ignore exponent, even if it is NA + sge.If( + this=sge.and_( + sge.EQ(this=left_expr, expression=constants._ZERO), + sge.LT(this=right_expr, expression=constants._ZERO), + ), + true=constants._INF, + ), # This case would error POW function in BQ + sge.If(this=infinite_base, true=pow_result), + sge.If( + this=exp_too_big, true=pow_result + ), # Bigquery can actually handle the +-inf cases gracefully + sge.If( + this=sge.and_( + sge.LT(this=left_expr, expression=constants._ZERO), + sge.Not(this=exponent_is_whole), + ), + true=constants._NAN, + ), + sge.If( + this=overflow_cond, + true=sge.Mul( + this=constants._INF, + expression=sge.Case( + ifs=[sge.If(this=odd_exponent, true=sge.convert(-1))], + default=sge.convert(1), + ), + ), + ), # finite overflows would cause bq to error + ], + default=pow_result, + ) + + @register_unary_op(ops.sqrt_op) def _(expr: TypedExpr) -> sge.Expression: return sge.Case( diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_pow/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_pow/out.sql new file mode 100644 index 0000000000..05fbaa12c9 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_pow/out.sql @@ -0,0 +1,329 @@ +WITH `bfcte_0` AS ( + SELECT + `float64_col`, + `int64_col`, + `rowindex` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `rowindex` AS `bfcol_6`, + `int64_col` AS `bfcol_7`, + `float64_col` AS `bfcol_8`, + CASE + WHEN `int64_col` <> 0 AND `int64_col` * LN(ABS(`int64_col`)) > 43.66827237527655 + THEN NULL + ELSE CAST(POWER(CAST(`int64_col` AS NUMERIC), `int64_col`) AS INT64) + END AS `bfcol_9` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + *, + `bfcol_6` AS `bfcol_14`, + `bfcol_7` AS `bfcol_15`, + `bfcol_8` AS `bfcol_16`, + `bfcol_9` AS `bfcol_17`, + CASE + WHEN `bfcol_8` = CAST(0 AS INT64) + THEN 1 + WHEN `bfcol_7` = 1 + THEN 1 + WHEN `bfcol_7` = CAST(0 AS INT64) AND `bfcol_8` < CAST(0 AS INT64) + THEN CAST('Infinity' AS FLOAT64) + WHEN ABS(`bfcol_7`) = CAST('Infinity' AS FLOAT64) + THEN POWER( + `bfcol_7`, + CASE + WHEN ABS(`bfcol_8`) > 9007199254740992 + THEN CAST('Infinity' AS FLOAT64) * SIGN(`bfcol_8`) + ELSE `bfcol_8` + END + ) + WHEN ABS(`bfcol_8`) > 9007199254740992 + THEN POWER( + `bfcol_7`, + CASE + WHEN ABS(`bfcol_8`) > 9007199254740992 + THEN CAST('Infinity' AS FLOAT64) * SIGN(`bfcol_8`) + ELSE `bfcol_8` + END + ) + WHEN `bfcol_7` < CAST(0 AS INT64) AND NOT CAST(`bfcol_8` AS INT64) = `bfcol_8` + THEN CAST('NaN' AS FLOAT64) + WHEN `bfcol_7` <> CAST(0 AS INT64) AND `bfcol_8` * LN(ABS(`bfcol_7`)) > 709.78 + THEN CAST('Infinity' AS FLOAT64) * CASE + WHEN `bfcol_7` < CAST(0 AS INT64) AND MOD(CAST(`bfcol_8` AS INT64), 2) = 1 + THEN -1 + ELSE 1 + END + ELSE POWER( + `bfcol_7`, + CASE + WHEN ABS(`bfcol_8`) > 9007199254740992 + THEN CAST('Infinity' AS FLOAT64) * SIGN(`bfcol_8`) + ELSE `bfcol_8` + END + ) + END AS `bfcol_18` + FROM `bfcte_1` +), `bfcte_3` AS ( + SELECT + *, + `bfcol_14` AS `bfcol_24`, + `bfcol_15` AS `bfcol_25`, + `bfcol_16` AS `bfcol_26`, + `bfcol_17` AS `bfcol_27`, + `bfcol_18` AS `bfcol_28`, + CASE + WHEN `bfcol_15` = CAST(0 AS INT64) + THEN 1 + WHEN `bfcol_16` = 1 + THEN 1 + WHEN `bfcol_16` = CAST(0 AS INT64) AND `bfcol_15` < CAST(0 AS INT64) + THEN CAST('Infinity' AS FLOAT64) + WHEN ABS(`bfcol_16`) = CAST('Infinity' AS FLOAT64) + THEN POWER( + `bfcol_16`, + CASE + WHEN ABS(`bfcol_15`) > 9007199254740992 + THEN CAST('Infinity' AS FLOAT64) * SIGN(`bfcol_15`) + ELSE `bfcol_15` + END + ) + WHEN ABS(`bfcol_15`) > 9007199254740992 + THEN POWER( + `bfcol_16`, + CASE + WHEN ABS(`bfcol_15`) > 9007199254740992 + THEN CAST('Infinity' AS FLOAT64) * SIGN(`bfcol_15`) + ELSE `bfcol_15` + END + ) + WHEN `bfcol_16` < CAST(0 AS INT64) AND NOT CAST(`bfcol_15` AS INT64) = `bfcol_15` + THEN CAST('NaN' AS FLOAT64) + WHEN `bfcol_16` <> CAST(0 AS INT64) AND `bfcol_15` * LN(ABS(`bfcol_16`)) > 709.78 + THEN CAST('Infinity' AS FLOAT64) * CASE + WHEN `bfcol_16` < CAST(0 AS INT64) AND MOD(CAST(`bfcol_15` AS INT64), 2) = 1 + THEN -1 + ELSE 1 + END + ELSE POWER( + `bfcol_16`, + CASE + WHEN ABS(`bfcol_15`) > 9007199254740992 + THEN CAST('Infinity' AS FLOAT64) * SIGN(`bfcol_15`) + ELSE `bfcol_15` + END + ) + END AS `bfcol_29` + FROM `bfcte_2` +), `bfcte_4` AS ( + SELECT + *, + `bfcol_24` AS `bfcol_36`, + `bfcol_25` AS `bfcol_37`, + `bfcol_26` AS `bfcol_38`, + `bfcol_27` AS `bfcol_39`, + `bfcol_28` AS `bfcol_40`, + `bfcol_29` AS `bfcol_41`, + CASE + WHEN `bfcol_26` = CAST(0 AS INT64) + THEN 1 + WHEN `bfcol_26` = 1 + THEN 1 + WHEN `bfcol_26` = CAST(0 AS INT64) AND `bfcol_26` < CAST(0 AS INT64) + THEN CAST('Infinity' AS FLOAT64) + WHEN ABS(`bfcol_26`) = CAST('Infinity' AS FLOAT64) + THEN POWER( + `bfcol_26`, + CASE + WHEN ABS(`bfcol_26`) > 9007199254740992 + THEN CAST('Infinity' AS FLOAT64) * SIGN(`bfcol_26`) + ELSE `bfcol_26` + END + ) + WHEN ABS(`bfcol_26`) > 9007199254740992 + THEN POWER( + `bfcol_26`, + CASE + WHEN ABS(`bfcol_26`) > 9007199254740992 + THEN CAST('Infinity' AS FLOAT64) * SIGN(`bfcol_26`) + ELSE `bfcol_26` + END + ) + WHEN `bfcol_26` < CAST(0 AS INT64) AND NOT CAST(`bfcol_26` AS INT64) = `bfcol_26` + THEN CAST('NaN' AS FLOAT64) + WHEN `bfcol_26` <> CAST(0 AS INT64) AND `bfcol_26` * LN(ABS(`bfcol_26`)) > 709.78 + THEN CAST('Infinity' AS FLOAT64) * CASE + WHEN `bfcol_26` < CAST(0 AS INT64) AND MOD(CAST(`bfcol_26` AS INT64), 2) = 1 + THEN -1 + ELSE 1 + END + ELSE POWER( + `bfcol_26`, + CASE + WHEN ABS(`bfcol_26`) > 9007199254740992 + THEN CAST('Infinity' AS FLOAT64) * SIGN(`bfcol_26`) + ELSE `bfcol_26` + END + ) + END AS `bfcol_42` + FROM `bfcte_3` +), `bfcte_5` AS ( + SELECT + *, + `bfcol_36` AS `bfcol_50`, + `bfcol_37` AS `bfcol_51`, + `bfcol_38` AS `bfcol_52`, + `bfcol_39` AS `bfcol_53`, + `bfcol_40` AS `bfcol_54`, + `bfcol_41` AS `bfcol_55`, + `bfcol_42` AS `bfcol_56`, + CASE + WHEN `bfcol_37` <> 0 AND 0 * LN(ABS(`bfcol_37`)) > 43.66827237527655 + THEN NULL + ELSE CAST(POWER(CAST(`bfcol_37` AS NUMERIC), 0) AS INT64) + END AS `bfcol_57` + FROM `bfcte_4` +), `bfcte_6` AS ( + SELECT + *, + `bfcol_50` AS `bfcol_66`, + `bfcol_51` AS `bfcol_67`, + `bfcol_52` AS `bfcol_68`, + `bfcol_53` AS `bfcol_69`, + `bfcol_54` AS `bfcol_70`, + `bfcol_55` AS `bfcol_71`, + `bfcol_56` AS `bfcol_72`, + `bfcol_57` AS `bfcol_73`, + CASE + WHEN 0 = CAST(0 AS INT64) + THEN 1 + WHEN `bfcol_52` = 1 + THEN 1 + WHEN `bfcol_52` = CAST(0 AS INT64) AND 0 < CAST(0 AS INT64) + THEN CAST('Infinity' AS FLOAT64) + WHEN ABS(`bfcol_52`) = CAST('Infinity' AS FLOAT64) + THEN POWER( + `bfcol_52`, + CASE + WHEN ABS(0) > 9007199254740992 + THEN CAST('Infinity' AS FLOAT64) * SIGN(0) + ELSE 0 + END + ) + WHEN ABS(0) > 9007199254740992 + THEN POWER( + `bfcol_52`, + CASE + WHEN ABS(0) > 9007199254740992 + THEN CAST('Infinity' AS FLOAT64) * SIGN(0) + ELSE 0 + END + ) + WHEN `bfcol_52` < CAST(0 AS INT64) AND NOT CAST(0 AS INT64) = 0 + THEN CAST('NaN' AS FLOAT64) + WHEN `bfcol_52` <> CAST(0 AS INT64) AND 0 * LN(ABS(`bfcol_52`)) > 709.78 + THEN CAST('Infinity' AS FLOAT64) * CASE + WHEN `bfcol_52` < CAST(0 AS INT64) AND MOD(CAST(0 AS INT64), 2) = 1 + THEN -1 + ELSE 1 + END + ELSE POWER( + `bfcol_52`, + CASE + WHEN ABS(0) > 9007199254740992 + THEN CAST('Infinity' AS FLOAT64) * SIGN(0) + ELSE 0 + END + ) + END AS `bfcol_74` + FROM `bfcte_5` +), `bfcte_7` AS ( + SELECT + *, + `bfcol_66` AS `bfcol_84`, + `bfcol_67` AS `bfcol_85`, + `bfcol_68` AS `bfcol_86`, + `bfcol_69` AS `bfcol_87`, + `bfcol_70` AS `bfcol_88`, + `bfcol_71` AS `bfcol_89`, + `bfcol_72` AS `bfcol_90`, + `bfcol_73` AS `bfcol_91`, + `bfcol_74` AS `bfcol_92`, + CASE + WHEN `bfcol_67` <> 0 AND 1 * LN(ABS(`bfcol_67`)) > 43.66827237527655 + THEN NULL + ELSE CAST(POWER(CAST(`bfcol_67` AS NUMERIC), 1) AS INT64) + END AS `bfcol_93` + FROM `bfcte_6` +), `bfcte_8` AS ( + SELECT + *, + `bfcol_84` AS `bfcol_104`, + `bfcol_85` AS `bfcol_105`, + `bfcol_86` AS `bfcol_106`, + `bfcol_87` AS `bfcol_107`, + `bfcol_88` AS `bfcol_108`, + `bfcol_89` AS `bfcol_109`, + `bfcol_90` AS `bfcol_110`, + `bfcol_91` AS `bfcol_111`, + `bfcol_92` AS `bfcol_112`, + `bfcol_93` AS `bfcol_113`, + CASE + WHEN 1 = CAST(0 AS INT64) + THEN 1 + WHEN `bfcol_86` = 1 + THEN 1 + WHEN `bfcol_86` = CAST(0 AS INT64) AND 1 < CAST(0 AS INT64) + THEN CAST('Infinity' AS FLOAT64) + WHEN ABS(`bfcol_86`) = CAST('Infinity' AS FLOAT64) + THEN POWER( + `bfcol_86`, + CASE + WHEN ABS(1) > 9007199254740992 + THEN CAST('Infinity' AS FLOAT64) * SIGN(1) + ELSE 1 + END + ) + WHEN ABS(1) > 9007199254740992 + THEN POWER( + `bfcol_86`, + CASE + WHEN ABS(1) > 9007199254740992 + THEN CAST('Infinity' AS FLOAT64) * SIGN(1) + ELSE 1 + END + ) + WHEN `bfcol_86` < CAST(0 AS INT64) AND NOT CAST(1 AS INT64) = 1 + THEN CAST('NaN' AS FLOAT64) + WHEN `bfcol_86` <> CAST(0 AS INT64) AND 1 * LN(ABS(`bfcol_86`)) > 709.78 + THEN CAST('Infinity' AS FLOAT64) * CASE + WHEN `bfcol_86` < CAST(0 AS INT64) AND MOD(CAST(1 AS INT64), 2) = 1 + THEN -1 + ELSE 1 + END + ELSE POWER( + `bfcol_86`, + CASE + WHEN ABS(1) > 9007199254740992 + THEN CAST('Infinity' AS FLOAT64) * SIGN(1) + ELSE 1 + END + ) + END AS `bfcol_114` + FROM `bfcte_7` +) +SELECT + `bfcol_104` AS `rowindex`, + `bfcol_105` AS `int64_col`, + `bfcol_106` AS `float64_col`, + `bfcol_107` AS `int_pow_int`, + `bfcol_108` AS `int_pow_float`, + `bfcol_109` AS `float_pow_int`, + `bfcol_110` AS `float_pow_float`, + `bfcol_111` AS `int_pow_0`, + `bfcol_112` AS `float_pow_0`, + `bfcol_113` AS `int_pow_1`, + `bfcol_114` AS `float_pow_1` +FROM `bfcte_8` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py index 5d3b23ebb7..3bedb25507 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py @@ -196,6 +196,22 @@ def test_pos(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_pow(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "float64_col"]] + + bf_df["int_pow_int"] = bf_df["int64_col"] ** bf_df["int64_col"] + bf_df["int_pow_float"] = bf_df["int64_col"] ** bf_df["float64_col"] + bf_df["float_pow_int"] = bf_df["float64_col"] ** bf_df["int64_col"] + bf_df["float_pow_float"] = bf_df["float64_col"] ** bf_df["float64_col"] + + bf_df["int_pow_0"] = bf_df["int64_col"] ** 0 + bf_df["float_pow_0"] = bf_df["float64_col"] ** 0 + bf_df["int_pow_1"] = bf_df["int64_col"] ** 1 + bf_df["float_pow_1"] = bf_df["float64_col"] ** 1 + + snapshot.assert_match(bf_df.sql, "out.sql") + + def test_round(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["int64_col", "float64_col"]] From f7679b8e978ea29356b3731eafc9c8995b145e4c Mon Sep 17 00:00:00 2001 From: jialuo Date: Wed, 26 Nov 2025 01:09:38 +0000 Subject: [PATCH 2/2] resolve the comments --- .../core/compile/sqlglot/expressions/constants.py | 12 ++++++++++++ .../core/compile/sqlglot/expressions/numeric_ops.py | 10 +++++----- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/bigframes/core/compile/sqlglot/expressions/constants.py b/bigframes/core/compile/sqlglot/expressions/constants.py index 20857f6291..e005a1ed78 100644 --- a/bigframes/core/compile/sqlglot/expressions/constants.py +++ b/bigframes/core/compile/sqlglot/expressions/constants.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math + import sqlglot.expressions as sge _ZERO = sge.Cast(this=sge.convert(0), to="INT64") @@ -23,3 +25,13 @@ # FLOAT64 has 11 exponent bits, so max values is about 2**(2**10) # ln(2**(2**10)) == (2**10)*ln(2) ~= 709.78, so EXP(x) for x>709.78 will overflow. _FLOAT64_EXP_BOUND = sge.convert(709.78) + +# The natural logarithm of the maximum value for a signed 64-bit integer. +# This is used to check for potential overflows in power operations involving integers +# by checking if `exponent * log(base)` exceeds this value. +_INT64_LOG_BOUND = math.log(2**63 - 1) + +# Represents the largest integer N where all integers from -N to N can be +# represented exactly as a float64. Float64 types have a 53-bit significand precision, +# so integers beyond this value may lose precision. +_FLOAT64_MAX_INT_PRECISION = 2**53 diff --git a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py index 55fe6424f9..4d56fbc236 100644 --- a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py @@ -223,16 +223,13 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: def _int_pow_op( left_expr: sge.Expression, right_expr: sge.Expression ) -> sge.Expression: - import math - - overflow_value = math.log(2**63 - 1) overflow_cond = sge.and_( sge.NEQ(this=left_expr, expression=sge.convert(0)), sge.GT( this=sge.Mul( this=right_expr, expression=sge.Ln(this=sge.Abs(this=left_expr)) ), - expression=sge.convert(overflow_value), + expression=sge.convert(constants._INT64_LOG_BOUND), ), ) @@ -271,7 +268,10 @@ def _float_pow_op( ) # Float64 lose integer precision beyond 2**53, beyond this insufficient precision to get parity - exp_too_big = sge.GT(this=sge.Abs(this=right_expr), expression=sge.convert(2**53)) + exp_too_big = sge.GT( + this=sge.Abs(this=right_expr), + expression=sge.convert(constants._FLOAT64_MAX_INT_PRECISION), + ) # Treat very large exponents as +=INF norm_exp = sge.Case( ifs=[