diff --git a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py index c022356fd3..83b29f67df 100644 --- a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py @@ -443,6 +443,14 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: ) +@register_binary_op(ops.unsafe_pow_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + """For internal use only - where domain and overflow checks are not needed.""" + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) + return sge.Pow(this=left_expr, expression=right_expr) + + @register_unary_op(numeric_ops.isnan_op) def isnan(arg: TypedExpr) -> sge.Expression: return sge.IsNan(this=arg.expr) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_unsafe_pow_op/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_unsafe_pow_op/out.sql new file mode 100644 index 0000000000..9957a34665 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_unsafe_pow_op/out.sql @@ -0,0 +1,43 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col`, + `float64_col`, + `int64_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `bool_col` AS `bfcol_3`, + `int64_col` AS `bfcol_4`, + `float64_col` AS `bfcol_5`, + ( + `int64_col` >= 0 + ) AND ( + `int64_col` <= 10 + ) AS `bfcol_6` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + * + FROM `bfcte_1` + WHERE + `bfcol_6` +), `bfcte_3` AS ( + SELECT + *, + POWER(`bfcol_4`, `bfcol_4`) AS `bfcol_14`, + POWER(`bfcol_4`, `bfcol_5`) AS `bfcol_15`, + POWER(`bfcol_5`, `bfcol_4`) AS `bfcol_16`, + POWER(`bfcol_5`, `bfcol_5`) AS `bfcol_17`, + POWER(`bfcol_4`, CAST(`bfcol_3` AS INT64)) AS `bfcol_18`, + POWER(CAST(`bfcol_3` AS INT64), `bfcol_4`) AS `bfcol_19` + FROM `bfcte_2` +) +SELECT + `bfcol_14` AS `int_pow_int`, + `bfcol_15` AS `int_pow_float`, + `bfcol_16` AS `float_pow_int`, + `bfcol_17` AS `float_pow_float`, + `bfcol_18` AS `int_pow_bool`, + `bfcol_19` AS `bool_pow_int` +FROM `bfcte_3` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py index 5d3b23ebb7..0b4f8fbe70 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py @@ -438,3 +438,36 @@ def test_sub_unsupported_raises(scalar_types_df: bpd.DataFrame): with pytest.raises(TypeError): utils._apply_binary_op(scalar_types_df, ops.sub_op, "int64_col", "string_col") + + +def test_unsafe_pow_op(scalar_types_df: bpd.DataFrame, snapshot): + # Choose certain row so the sql execution won't fail even with unsafe_pow_op. + bf_df = scalar_types_df[ + (scalar_types_df["int64_col"] >= 0) & (scalar_types_df["int64_col"] <= 10) + ] + bf_df = bf_df[["int64_col", "float64_col", "bool_col"]] + + int64_col_id = bf_df["int64_col"]._value_column + float64_col_id = bf_df["float64_col"]._value_column + bool_col_id = bf_df["bool_col"]._value_column + + sql = utils._apply_ops_to_sql( + bf_df, + [ + ops.unsafe_pow_op.as_expr(int64_col_id, int64_col_id), + ops.unsafe_pow_op.as_expr(int64_col_id, float64_col_id), + ops.unsafe_pow_op.as_expr(float64_col_id, int64_col_id), + ops.unsafe_pow_op.as_expr(float64_col_id, float64_col_id), + ops.unsafe_pow_op.as_expr(int64_col_id, bool_col_id), + ops.unsafe_pow_op.as_expr(bool_col_id, int64_col_id), + ], + [ + "int_pow_int", + "int_pow_float", + "float_pow_int", + "float_pow_float", + "int_pow_bool", + "bool_pow_int", + ], + ) + snapshot.assert_match(sql, "out.sql")