Skip to content

Commit 4285b24

Browse files
committed
chore: Migrate up to 5 scalar operators to SQLGlot
Migrated cosh_op, tanh_op, arcsinh_op, arccosh_op, and arctanh_op scalar operators to SQLGlot.
1 parent 4da309e commit 4285b24

File tree

7 files changed

+144
-0
lines changed

7 files changed

+144
-0
lines changed

bigframes/core/compile/sqlglot/expressions/unary_compiler.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,19 @@ def compile(op: ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
3838
return UNARY_OP_REGISTRATION[op](op, expr)
3939

4040

41+
@UNARY_OP_REGISTRATION.register(ops.arccosh_op)
42+
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
43+
return sge.Case(
44+
ifs=[
45+
sge.If(
46+
this=expr.expr < sge.convert(1),
47+
true=sge.func("IEEE_DIVIDE", sge.convert(0), sge.convert(0)),
48+
)
49+
],
50+
default=sge.func("ACOSH", expr.expr),
51+
)
52+
53+
4154
@UNARY_OP_REGISTRATION.register(ops.arccos_op)
4255
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
4356
return sge.Case(
@@ -64,11 +77,29 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
6477
)
6578

6679

80+
@UNARY_OP_REGISTRATION.register(ops.arcsinh_op)
81+
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
82+
return sge.func("ASINH", expr.expr)
83+
84+
6785
@UNARY_OP_REGISTRATION.register(ops.arctan_op)
6886
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
6987
return sge.func("ATAN", expr.expr)
7088

7189

90+
@UNARY_OP_REGISTRATION.register(ops.arctanh_op)
91+
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
92+
return sge.Case(
93+
ifs=[
94+
sge.If(
95+
this=sge.func("ABS", expr.expr) > sge.convert(1),
96+
true=sge.func("IEEE_DIVIDE", sge.convert(0), sge.convert(0)),
97+
)
98+
],
99+
default=sge.func("ATANH", expr.expr),
100+
)
101+
102+
72103
@UNARY_OP_REGISTRATION.register(ops.ArrayToStringOp)
73104
def _(op: ops.ArrayToStringOp, expr: TypedExpr) -> sge.Expression:
74105
return sge.ArrayToString(this=expr.expr, expression=f"'{op.delimiter}'")
@@ -116,6 +147,19 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
116147
return sge.func("COS", expr.expr)
117148

118149

150+
@UNARY_OP_REGISTRATION.register(ops.cosh_op)
151+
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
152+
return sge.Case(
153+
ifs=[
154+
sge.If(
155+
this=sge.func("ABS", expr.expr) > sge.convert(709.78),
156+
true=sge.func("IEEE_DIVIDE", sge.convert(1), sge.convert(0)),
157+
)
158+
],
159+
default=sge.func("COSH", expr.expr),
160+
)
161+
162+
119163
@UNARY_OP_REGISTRATION.register(ops.hash_op)
120164
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
121165
return sge.func("FARM_FINGERPRINT", expr.expr)
@@ -154,6 +198,11 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
154198
return sge.func("TAN", expr.expr)
155199

156200

201+
@UNARY_OP_REGISTRATION.register(ops.tanh_op)
202+
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
203+
return sge.func("TANH", expr.expr)
204+
205+
157206
# JSON Ops
158207
@UNARY_OP_REGISTRATION.register(ops.JSONExtract)
159208
def _(op: ops.JSONExtract, expr: TypedExpr) -> sge.Expression:
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`float64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE WHEN `bfcol_0` < 1 THEN IEEE_DIVIDE(0, 0) ELSE ACOSH(`bfcol_0`) END AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `float64_col`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`float64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
ASINH(`bfcol_0`) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `float64_col`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`float64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE WHEN ABS(`bfcol_0`) > 1 THEN IEEE_DIVIDE(0, 0) ELSE ATANH(`bfcol_0`) END AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `float64_col`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`float64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE WHEN ABS(`bfcol_0`) > 709.78 THEN IEEE_DIVIDE(1, 0) ELSE COSH(`bfcol_0`) END AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `float64_col`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`float64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
TANH(`bfcol_0`) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `float64_col`
13+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ def _apply_unary_op(obj: bpd.DataFrame, op: ops.UnaryOp, arg: str) -> str:
3434
return sql
3535

3636

37+
def test_arccosh(scalar_types_df: bpd.DataFrame, snapshot):
38+
bf_df = scalar_types_df[["float64_col"]]
39+
sql = _apply_unary_op(bf_df, ops.arccosh_op, "float64_col")
40+
snapshot.assert_match(sql, "out.sql")
41+
42+
3743
def test_arccos(scalar_types_df: bpd.DataFrame, snapshot):
3844
bf_df = scalar_types_df[["float64_col"]]
3945
sql = _apply_unary_op(bf_df, ops.arccos_op, "float64_col")
@@ -48,13 +54,25 @@ def test_arcsin(scalar_types_df: bpd.DataFrame, snapshot):
4854
snapshot.assert_match(sql, "out.sql")
4955

5056

57+
def test_arcsinh(scalar_types_df: bpd.DataFrame, snapshot):
58+
bf_df = scalar_types_df[["float64_col"]]
59+
sql = _apply_unary_op(bf_df, ops.arcsinh_op, "float64_col")
60+
snapshot.assert_match(sql, "out.sql")
61+
62+
5163
def test_arctan(scalar_types_df: bpd.DataFrame, snapshot):
5264
bf_df = scalar_types_df[["float64_col"]]
5365
sql = _apply_unary_op(bf_df, ops.arctan_op, "float64_col")
5466

5567
snapshot.assert_match(sql, "out.sql")
5668

5769

70+
def test_arctanh(scalar_types_df: bpd.DataFrame, snapshot):
71+
bf_df = scalar_types_df[["float64_col"]]
72+
sql = _apply_unary_op(bf_df, ops.arctanh_op, "float64_col")
73+
snapshot.assert_match(sql, "out.sql")
74+
75+
5876
def test_array_to_string(repeated_types_df: bpd.DataFrame, snapshot):
5977
bf_df = repeated_types_df[["string_list_col"]]
6078
sql = _apply_unary_op(bf_df, ops.ArrayToStringOp(delimiter="."), "string_list_col")
@@ -90,6 +108,12 @@ def test_cos(scalar_types_df: bpd.DataFrame, snapshot):
90108
snapshot.assert_match(sql, "out.sql")
91109

92110

111+
def test_cosh(scalar_types_df: bpd.DataFrame, snapshot):
112+
bf_df = scalar_types_df[["float64_col"]]
113+
sql = _apply_unary_op(bf_df, ops.cosh_op, "float64_col")
114+
snapshot.assert_match(sql, "out.sql")
115+
116+
93117
def test_hash(scalar_types_df: bpd.DataFrame, snapshot):
94118
bf_df = scalar_types_df[["string_col"]]
95119
sql = _apply_unary_op(bf_df, ops.hash_op, "string_col")
@@ -132,6 +156,12 @@ def test_tan(scalar_types_df: bpd.DataFrame, snapshot):
132156
snapshot.assert_match(sql, "out.sql")
133157

134158

159+
def test_tanh(scalar_types_df: bpd.DataFrame, snapshot):
160+
bf_df = scalar_types_df[["float64_col"]]
161+
sql = _apply_unary_op(bf_df, ops.tanh_op, "float64_col")
162+
snapshot.assert_match(sql, "out.sql")
163+
164+
135165
def test_json_extract(json_types_df: bpd.DataFrame, snapshot):
136166
bf_df = json_types_df[["json_col"]]
137167
sql = _apply_unary_op(bf_df, ops.JSONExtract(json_path="$"), "json_col")

0 commit comments

Comments
 (0)