Skip to content

Commit 2bb5fdd

Browse files
committed
refactor: enable "astype" engine tests for the sqlglot compiler
1 parent ca1e44c commit 2bb5fdd

File tree

10 files changed

+395
-20
lines changed

10 files changed

+395
-20
lines changed

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,85 @@
1616

1717
import sqlglot.expressions as sge
1818

19+
from bigframes import dtypes
1920
from bigframes import operations as ops
2021
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2122
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
23+
from bigframes.core.compile.sqlglot.sqlglot_types import SQLGlotType
2224

2325
register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
2426

2527

2628
@register_unary_op(ops.AsTypeOp, pass_op=True)
2729
def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
28-
# TODO: Support more types for casting, such as JSON, etc.
29-
return sge.Cast(this=expr.expr, to=op.to_type)
30+
from_type = expr.dtype
31+
to_type = op.to_type
32+
sg_to_type = SQLGlotType.from_bigframes_dtype(to_type)
33+
sg_expr = expr.expr
34+
35+
if to_type == dtypes.JSON_DTYPE:
36+
if from_type == dtypes.STRING_DTYPE:
37+
func_name = "PARSE_JSON_IN_SAFE" if op.safe else "PARSE_JSON"
38+
return sge.func(func_name, sg_expr)
39+
if from_type in (dtypes.INT_DTYPE, dtypes.BOOL_DTYPE, dtypes.FLOAT_DTYPE):
40+
sg_expr = sge.Cast(this=sg_expr, to="STRING")
41+
return sge.func("PARSE_JSON", sg_expr)
42+
raise TypeError(f"Cannot cast from {from_type} to {to_type}")
43+
44+
if from_type == dtypes.JSON_DTYPE:
45+
func_name = ""
46+
if to_type == dtypes.INT_DTYPE:
47+
func_name = "INT64"
48+
elif to_type == dtypes.FLOAT_DTYPE:
49+
func_name = "FLOAT64"
50+
elif to_type == dtypes.BOOL_DTYPE:
51+
func_name = "BOOL"
52+
elif to_type == dtypes.STRING_DTYPE:
53+
func_name = "STRING"
54+
if func_name:
55+
func_name = "SAFE." + func_name if op.safe else func_name
56+
return sge.func(func_name, sg_expr)
57+
raise TypeError(f"Cannot cast from {from_type} to {to_type}")
58+
59+
if to_type == dtypes.INT_DTYPE:
60+
# Cannot cast DATETIME to INT directly so need to convert to TIMESTAMP first.
61+
if from_type == dtypes.DATETIME_DTYPE:
62+
sg_expr = _cast(sg_expr, "TIMESTAMP", op.safe)
63+
return sge.func("UNIX_MICROS", sg_expr)
64+
if from_type == dtypes.TIMESTAMP_DTYPE:
65+
return sge.func("UNIX_MICROS", sg_expr)
66+
if from_type == dtypes.TIME_DTYPE:
67+
return sge.func(
68+
"TIME_DIFF",
69+
_cast(sg_expr, "TIME", op.safe),
70+
sge.convert("00:00:00"),
71+
"MICROSECOND",
72+
)
73+
if from_type == dtypes.NUMERIC_DTYPE or from_type == dtypes.FLOAT_DTYPE:
74+
sg_expr = sge.func("TRUNC", sg_expr)
75+
return _cast(sg_expr, sg_to_type, op.safe)
76+
77+
if to_type == dtypes.FLOAT_DTYPE and from_type == dtypes.BOOL_DTYPE:
78+
sg_expr = _cast(sg_expr, "INT64", op.safe)
79+
return _cast(sg_expr, sg_to_type, op.safe)
80+
81+
if to_type == dtypes.BOOL_DTYPE:
82+
if from_type == dtypes.BOOL_DTYPE:
83+
return sg_expr
84+
else:
85+
return sge.NEQ(this=sg_expr, expression=sge.convert(0))
86+
87+
if to_type == dtypes.STRING_DTYPE:
88+
sg_expr = _cast(sg_expr, sg_to_type, op.safe)
89+
if from_type == dtypes.BOOL_DTYPE:
90+
sg_expr = sge.func("INITCAP", sg_expr)
91+
return sg_expr
92+
93+
if dtypes.is_time_like(to_type) and from_type == dtypes.INT_DTYPE:
94+
sg_expr = sge.func("TIMESTAMP_MICROS", sg_expr)
95+
return _cast(sg_expr, sg_to_type, op.safe)
96+
97+
return _cast(sg_expr, sg_to_type, op.safe)
3098

3199

32100
@register_unary_op(ops.hash_op)
@@ -53,3 +121,11 @@ def _(expr: TypedExpr, op: ops.MapOp) -> sge.Expression:
53121
@register_unary_op(ops.notnull_op)
54122
def _(expr: TypedExpr) -> sge.Expression:
55123
return sge.Not(this=sge.Is(this=expr.expr, expression=sge.Null()))
124+
125+
126+
# Helper functions
127+
def _cast(expr: sge.Expression, to: str, safe: bool):
128+
if safe:
129+
return sge.TryCast(this=expr, to=to)
130+
else:
131+
return sge.Cast(this=expr, to=to)

tests/system/small/engines/test_generic_ops.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def apply_op(
5252
return new_arr
5353

5454

55-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
55+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
5656
def test_engines_astype_int(scalars_array_value: array_value.ArrayValue, engine):
5757
arr = apply_op(
5858
scalars_array_value,
@@ -63,7 +63,7 @@ def test_engines_astype_int(scalars_array_value: array_value.ArrayValue, engine)
6363
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
6464

6565

66-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
66+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
6767
def test_engines_astype_string_int(scalars_array_value: array_value.ArrayValue, engine):
6868
vals = ["1", "100", "-3"]
6969
arr, _ = scalars_array_value.compute_values(
@@ -78,7 +78,7 @@ def test_engines_astype_string_int(scalars_array_value: array_value.ArrayValue,
7878
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
7979

8080

81-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
81+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
8282
def test_engines_astype_float(scalars_array_value: array_value.ArrayValue, engine):
8383
arr = apply_op(
8484
scalars_array_value,
@@ -89,7 +89,7 @@ def test_engines_astype_float(scalars_array_value: array_value.ArrayValue, engin
8989
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
9090

9191

92-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
92+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
9393
def test_engines_astype_string_float(
9494
scalars_array_value: array_value.ArrayValue, engine
9595
):
@@ -106,7 +106,7 @@ def test_engines_astype_string_float(
106106
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
107107

108108

109-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
109+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
110110
def test_engines_astype_bool(scalars_array_value: array_value.ArrayValue, engine):
111111
arr = apply_op(
112112
scalars_array_value, ops.AsTypeOp(to_type=bigframes.dtypes.BOOL_DTYPE)
@@ -115,7 +115,7 @@ def test_engines_astype_bool(scalars_array_value: array_value.ArrayValue, engine
115115
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
116116

117117

118-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
118+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
119119
def test_engines_astype_string(scalars_array_value: array_value.ArrayValue, engine):
120120
# floats work slightly different with trailing zeroes rn
121121
arr = apply_op(
@@ -127,7 +127,7 @@ def test_engines_astype_string(scalars_array_value: array_value.ArrayValue, engi
127127
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
128128

129129

130-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
130+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
131131
def test_engines_astype_numeric(scalars_array_value: array_value.ArrayValue, engine):
132132
arr = apply_op(
133133
scalars_array_value,
@@ -138,7 +138,7 @@ def test_engines_astype_numeric(scalars_array_value: array_value.ArrayValue, eng
138138
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
139139

140140

141-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
141+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
142142
def test_engines_astype_string_numeric(
143143
scalars_array_value: array_value.ArrayValue, engine
144144
):
@@ -155,7 +155,7 @@ def test_engines_astype_string_numeric(
155155
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
156156

157157

158-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
158+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
159159
def test_engines_astype_date(scalars_array_value: array_value.ArrayValue, engine):
160160
arr = apply_op(
161161
scalars_array_value,
@@ -166,7 +166,7 @@ def test_engines_astype_date(scalars_array_value: array_value.ArrayValue, engine
166166
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
167167

168168

169-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
169+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
170170
def test_engines_astype_string_date(
171171
scalars_array_value: array_value.ArrayValue, engine
172172
):
@@ -183,7 +183,7 @@ def test_engines_astype_string_date(
183183
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
184184

185185

186-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
186+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
187187
def test_engines_astype_datetime(scalars_array_value: array_value.ArrayValue, engine):
188188
arr = apply_op(
189189
scalars_array_value,
@@ -194,7 +194,7 @@ def test_engines_astype_datetime(scalars_array_value: array_value.ArrayValue, en
194194
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
195195

196196

197-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
197+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
198198
def test_engines_astype_string_datetime(
199199
scalars_array_value: array_value.ArrayValue, engine
200200
):
@@ -211,7 +211,7 @@ def test_engines_astype_string_datetime(
211211
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
212212

213213

214-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
214+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
215215
def test_engines_astype_timestamp(scalars_array_value: array_value.ArrayValue, engine):
216216
arr = apply_op(
217217
scalars_array_value,
@@ -222,7 +222,7 @@ def test_engines_astype_timestamp(scalars_array_value: array_value.ArrayValue, e
222222
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
223223

224224

225-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
225+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
226226
def test_engines_astype_string_timestamp(
227227
scalars_array_value: array_value.ArrayValue, engine
228228
):
@@ -243,7 +243,7 @@ def test_engines_astype_string_timestamp(
243243
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
244244

245245

246-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
246+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
247247
def test_engines_astype_time(scalars_array_value: array_value.ArrayValue, engine):
248248
arr = apply_op(
249249
scalars_array_value,
@@ -254,7 +254,7 @@ def test_engines_astype_time(scalars_array_value: array_value.ArrayValue, engine
254254
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
255255

256256

257-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
257+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
258258
def test_engines_astype_from_json(scalars_array_value: array_value.ArrayValue, engine):
259259
exprs = [
260260
ops.AsTypeOp(to_type=bigframes.dtypes.INT_DTYPE).as_expr(
@@ -275,7 +275,7 @@ def test_engines_astype_from_json(scalars_array_value: array_value.ArrayValue, e
275275
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
276276

277277

278-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
278+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
279279
def test_engines_astype_to_json(scalars_array_value: array_value.ArrayValue, engine):
280280
exprs = [
281281
ops.AsTypeOp(to_type=bigframes.dtypes.JSON_DTYPE).as_expr(
@@ -298,7 +298,7 @@ def test_engines_astype_to_json(scalars_array_value: array_value.ArrayValue, eng
298298
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
299299

300300

301-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
301+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
302302
def test_engines_astype_timedelta(scalars_array_value: array_value.ArrayValue, engine):
303303
arr = apply_op(
304304
scalars_array_value,
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`float64_col` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
`bfcol_0` AS `bfcol_2`,
10+
`bfcol_1` <> 0 AS `bfcol_3`,
11+
`bfcol_1` <> 0 AS `bfcol_4`
12+
FROM `bfcte_0`
13+
)
14+
SELECT
15+
`bfcol_2` AS `bool_col`,
16+
`bfcol_3` AS `float64_col`,
17+
`bfcol_4` AS `float64_w_safe`
18+
FROM `bfcte_1`
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CAST(CAST(`bfcol_0` AS INT64) AS FLOAT64) AS `bfcol_1`,
9+
CAST('1.34235e4' AS FLOAT64) AS `bfcol_2`,
10+
SAFE_CAST(SAFE_CAST(`bfcol_0` AS INT64) AS FLOAT64) AS `bfcol_3`
11+
FROM `bfcte_0`
12+
)
13+
SELECT
14+
`bfcol_1` AS `bool_col`,
15+
`bfcol_2` AS `str_const`,
16+
`bfcol_3` AS `bool_w_safe`
17+
FROM `bfcte_1`
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`json_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`json_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
INT64(`bfcol_0`) AS `bfcol_1`,
9+
FLOAT64(`bfcol_0`) AS `bfcol_2`,
10+
BOOL(`bfcol_0`) AS `bfcol_3`,
11+
STRING(`bfcol_0`) AS `bfcol_4`,
12+
SAFE.INT64(`bfcol_0`) AS `bfcol_5`
13+
FROM `bfcte_0`
14+
)
15+
SELECT
16+
`bfcol_1` AS `int64_col`,
17+
`bfcol_2` AS `float64_col`,
18+
`bfcol_3` AS `bool_col`,
19+
`bfcol_4` AS `string_col`,
20+
`bfcol_5` AS `int64_w_safe`
21+
FROM `bfcte_1`
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`datetime_col` AS `bfcol_0`,
4+
`numeric_col` AS `bfcol_1`,
5+
`float64_col` AS `bfcol_2`,
6+
`time_col` AS `bfcol_3`,
7+
`timestamp_col` AS `bfcol_4`
8+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
9+
), `bfcte_1` AS (
10+
SELECT
11+
*,
12+
UNIX_MICROS(CAST(`bfcol_0` AS TIMESTAMP)) AS `bfcol_5`,
13+
UNIX_MICROS(SAFE_CAST(`bfcol_0` AS TIMESTAMP)) AS `bfcol_6`,
14+
TIME_DIFF(CAST(`bfcol_3` AS TIME), '00:00:00', MICROSECOND) AS `bfcol_7`,
15+
TIME_DIFF(SAFE_CAST(`bfcol_3` AS TIME), '00:00:00', MICROSECOND) AS `bfcol_8`,
16+
UNIX_MICROS(`bfcol_4`) AS `bfcol_9`,
17+
CAST(TRUNC(`bfcol_1`) AS INT64) AS `bfcol_10`,
18+
CAST(TRUNC(`bfcol_2`) AS INT64) AS `bfcol_11`,
19+
SAFE_CAST(TRUNC(`bfcol_2`) AS INT64) AS `bfcol_12`,
20+
CAST('100' AS INT64) AS `bfcol_13`
21+
FROM `bfcte_0`
22+
)
23+
SELECT
24+
`bfcol_5` AS `datetime_col`,
25+
`bfcol_6` AS `datetime_w_safe`,
26+
`bfcol_7` AS `time_col`,
27+
`bfcol_8` AS `time_w_safe`,
28+
`bfcol_9` AS `timestamp_col`,
29+
`bfcol_10` AS `numeric_col`,
30+
`bfcol_11` AS `float64_col`,
31+
`bfcol_12` AS `float64_w_safe`,
32+
`bfcol_13` AS `str_const`
33+
FROM `bfcte_1`
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`float64_col` AS `bfcol_2`,
6+
`string_col` AS `bfcol_3`
7+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
8+
), `bfcte_1` AS (
9+
SELECT
10+
*,
11+
PARSE_JSON(CAST(`bfcol_1` AS STRING)) AS `bfcol_4`,
12+
PARSE_JSON(CAST(`bfcol_2` AS STRING)) AS `bfcol_5`,
13+
PARSE_JSON(CAST(`bfcol_0` AS STRING)) AS `bfcol_6`,
14+
PARSE_JSON(`bfcol_3`) AS `bfcol_7`,
15+
PARSE_JSON(CAST(`bfcol_0` AS STRING)) AS `bfcol_8`,
16+
PARSE_JSON_IN_SAFE(`bfcol_3`) AS `bfcol_9`
17+
FROM `bfcte_0`
18+
)
19+
SELECT
20+
`bfcol_4` AS `int64_col`,
21+
`bfcol_5` AS `float64_col`,
22+
`bfcol_6` AS `bool_col`,
23+
`bfcol_7` AS `string_col`,
24+
`bfcol_8` AS `bool_w_safe`,
25+
`bfcol_9` AS `string_w_safe`
26+
FROM `bfcte_1`

0 commit comments

Comments
 (0)