Skip to content

Commit dd94c9e

Browse files
committed
chore: Migrate up to 5 scalar operators to SQLGlot
Migrated operators: - dayofweek_op - dayofyear_op - exp_op - expm1_op - floor_op
1 parent 7bb8d1e commit dd94c9e

File tree

7 files changed

+136
-0
lines changed

7 files changed

+136
-0
lines changed

bigframes/core/compile/sqlglot/expressions/unary_compiler.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,47 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
185185
return sge.Extract(this=sge.Identifier(this="DAY"), expression=expr.expr)
186186

187187

188+
@UNARY_OP_REGISTRATION.register(ops.dayofweek_op)
189+
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
190+
return sge.Extract(this=sge.Identifier(this="DAYOFWEEK"), expression=expr.expr)
191+
192+
193+
@UNARY_OP_REGISTRATION.register(ops.dayofyear_op)
194+
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
195+
return sge.Extract(this=sge.Identifier(this="DAYOFYEAR"), expression=expr.expr)
196+
197+
198+
@UNARY_OP_REGISTRATION.register(ops.exp_op)
199+
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
200+
return sge.Case(
201+
ifs=[
202+
sge.If(
203+
this=expr.expr > sge.convert(709.78),
204+
true=sge.func("IEEE_DIVIDE", sge.convert(1), sge.convert(0)),
205+
)
206+
],
207+
default=sge.func("EXP", expr.expr),
208+
)
209+
210+
211+
@UNARY_OP_REGISTRATION.register(ops.expm1_op)
212+
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
213+
return sge.Case(
214+
ifs=[
215+
sge.If(
216+
this=expr.expr > sge.convert(709.78),
217+
true=sge.func("IEEE_DIVIDE", sge.convert(1), sge.convert(0)),
218+
)
219+
],
220+
default=sge.func("EXP", expr.expr),
221+
) - sge.convert(1)
222+
223+
224+
@UNARY_OP_REGISTRATION.register(ops.floor_op)
225+
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
226+
return sge.Floor(this=expr.expr)
227+
228+
188229
@UNARY_OP_REGISTRATION.register(ops.hash_op)
189230
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
190231
return sge.func("FARM_FINGERPRINT", expr.expr)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`timestamp_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
EXTRACT(DAYOFWEEK FROM `bfcol_0`) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `timestamp_col`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`timestamp_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
EXTRACT(DAYOFYEAR FROM `bfcol_0`) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `timestamp_col`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`float64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE WHEN `bfcol_0` > 709.78 THEN IEEE_DIVIDE(1, 0) ELSE EXP(`bfcol_0`) END AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `float64_col`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`float64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE WHEN `bfcol_0` > 709.78 THEN IEEE_DIVIDE(1, 0) ELSE EXP(`bfcol_0`) END - 1 AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `float64_col`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`float64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
FLOOR(`bfcol_0`) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `float64_col`
13+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,36 @@ def test_day(scalar_types_df: bpd.DataFrame, snapshot):
103103
snapshot.assert_match(sql, "out.sql")
104104

105105

106+
def test_dayofweek(scalar_types_df: bpd.DataFrame, snapshot):
107+
bf_df = scalar_types_df[["timestamp_col"]]
108+
sql = _apply_unary_op(bf_df, ops.dayofweek_op, "timestamp_col")
109+
snapshot.assert_match(sql, "out.sql")
110+
111+
112+
def test_dayofyear(scalar_types_df: bpd.DataFrame, snapshot):
113+
bf_df = scalar_types_df[["timestamp_col"]]
114+
sql = _apply_unary_op(bf_df, ops.dayofyear_op, "timestamp_col")
115+
snapshot.assert_match(sql, "out.sql")
116+
117+
118+
def test_exp(scalar_types_df: bpd.DataFrame, snapshot):
119+
bf_df = scalar_types_df[["float64_col"]]
120+
sql = _apply_unary_op(bf_df, ops.exp_op, "float64_col")
121+
snapshot.assert_match(sql, "out.sql")
122+
123+
124+
def test_expm1(scalar_types_df: bpd.DataFrame, snapshot):
125+
bf_df = scalar_types_df[["float64_col"]]
126+
sql = _apply_unary_op(bf_df, ops.expm1_op, "float64_col")
127+
snapshot.assert_match(sql, "out.sql")
128+
129+
130+
def test_floor(scalar_types_df: bpd.DataFrame, snapshot):
131+
bf_df = scalar_types_df[["float64_col"]]
132+
sql = _apply_unary_op(bf_df, ops.floor_op, "float64_col")
133+
snapshot.assert_match(sql, "out.sql")
134+
135+
106136
def test_array_to_string(repeated_types_df: bpd.DataFrame, snapshot):
107137
bf_df = repeated_types_df[["string_list_col"]]
108138
sql = _apply_unary_op(bf_df, ops.ArrayToStringOp(delimiter="."), "string_list_col")

0 commit comments

Comments
 (0)