Skip to content

Commit 9163cb6

Browse files
committed
Chore: Migrate unsafe_pow_op operator to SQLGlot
1 parent 8f490e6 commit 9163cb6

File tree

3 files changed

+84
-0
lines changed

3 files changed

+84
-0
lines changed

bigframes/core/compile/sqlglot/expressions/numeric_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,14 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
448448
)
449449

450450

451+
@register_binary_op(ops.unsafe_pow_op)
452+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
453+
"""For internal use only - where domain and overflow checks are not needed."""
454+
left_expr = _coerce_bool_to_int(left)
455+
right_expr = _coerce_bool_to_int(right)
456+
return sge.Pow(this=left_expr, expression=right_expr)
457+
458+
451459
@register_unary_op(numeric_ops.isnan_op)
452460
def isnan(arg: TypedExpr) -> sge.Expression:
453461
return sge.IsNan(this=arg.expr)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col`,
4+
`float64_col`,
5+
`int64_col`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
`bool_col` AS `bfcol_3`,
11+
`int64_col` AS `bfcol_4`,
12+
`float64_col` AS `bfcol_5`,
13+
(
14+
`int64_col` >= 0
15+
) AND (
16+
`int64_col` <= 10
17+
) AS `bfcol_6`
18+
FROM `bfcte_0`
19+
), `bfcte_2` AS (
20+
SELECT
21+
*
22+
FROM `bfcte_1`
23+
WHERE
24+
`bfcol_6`
25+
), `bfcte_3` AS (
26+
SELECT
27+
*,
28+
POWER(`bfcol_4`, `bfcol_4`) AS `bfcol_14`,
29+
POWER(`bfcol_4`, `bfcol_5`) AS `bfcol_15`,
30+
POWER(`bfcol_5`, `bfcol_4`) AS `bfcol_16`,
31+
POWER(`bfcol_5`, `bfcol_5`) AS `bfcol_17`,
32+
POWER(`bfcol_4`, CAST(`bfcol_3` AS INT64)) AS `bfcol_18`,
33+
POWER(CAST(`bfcol_3` AS INT64), `bfcol_4`) AS `bfcol_19`
34+
FROM `bfcte_2`
35+
)
36+
SELECT
37+
`bfcol_14` AS `int_pow_int`,
38+
`bfcol_15` AS `int_pow_float`,
39+
`bfcol_16` AS `float_pow_int`,
40+
`bfcol_17` AS `float_pow_float`,
41+
`bfcol_18` AS `int_pow_bool`,
42+
`bfcol_19` AS `bool_pow_int`
43+
FROM `bfcte_3`

tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,3 +423,36 @@ def test_sub_unsupported_raises(scalar_types_df: bpd.DataFrame):
423423

424424
with pytest.raises(TypeError):
425425
utils._apply_binary_op(scalar_types_df, ops.sub_op, "int64_col", "string_col")
426+
427+
428+
def test_unsafe_pow_op(scalar_types_df: bpd.DataFrame, snapshot):
429+
# Choose certain row so the sql execution won't fail even with unsafe_pow_op.
430+
bf_df = scalar_types_df[
431+
(scalar_types_df["int64_col"] >= 0) & (scalar_types_df["int64_col"] <= 10)
432+
]
433+
bf_df = bf_df[["int64_col", "float64_col", "bool_col"]]
434+
435+
int64_col_id = bf_df["int64_col"]._value_column
436+
float64_col_id = bf_df["float64_col"]._value_column
437+
bool_col_id = bf_df["bool_col"]._value_column
438+
439+
sql = utils._apply_ops_to_sql(
440+
bf_df,
441+
[
442+
ops.unsafe_pow_op.as_expr(int64_col_id, int64_col_id),
443+
ops.unsafe_pow_op.as_expr(int64_col_id, float64_col_id),
444+
ops.unsafe_pow_op.as_expr(float64_col_id, int64_col_id),
445+
ops.unsafe_pow_op.as_expr(float64_col_id, float64_col_id),
446+
ops.unsafe_pow_op.as_expr(int64_col_id, bool_col_id),
447+
ops.unsafe_pow_op.as_expr(bool_col_id, int64_col_id),
448+
],
449+
[
450+
"int_pow_int",
451+
"int_pow_float",
452+
"float_pow_int",
453+
"float_pow_float",
454+
"int_pow_bool",
455+
"bool_pow_int",
456+
],
457+
)
458+
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)