Skip to content

Commit bcf5479

Browse files
committed
use sqlglot for type parsing
1 parent d3193e8 commit bcf5479

File tree

2 files changed

+65
-115
lines changed

2 files changed

+65
-115
lines changed

bigframes/operations/output_schemas.py

Lines changed: 32 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -12,79 +12,52 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
"""This file is specifically for parsing type strings for the output_schema parameter of AI functions.
16+
Do not use it from general SQL -> PyArrow type parsing as it does not handle all the types.
17+
"""
18+
19+
from typing import cast
20+
1521
import pyarrow as pa
22+
import sqlglot
23+
from sqlglot import exp as sgexp
1624

1725

18-
def parse_sql_type(sql: str) -> pa.DataType:
26+
def parse_sql_fields(sql: str) -> tuple[pa.Field]:
1927
"""
20-
Parses a SQL type string to its PyArrow equivalence:
28+
Parses a sequence of SQL struct fields into their PyArrow equivalents
2129
22-
For example:
23-
"STRING" -> pa.string()
24-
"ARRAY<INT64>" -> pa.list_(pa.int64())
25-
"STRUCT<x ARRAY<FLOAT64>, y BOOL>" -> pa.struct(
26-
(
27-
pa.field("x", pa.list_(pa.float64())),
28-
pa.field("y", pa.bool_()),
29-
)
30-
)
30+
Examples:
31+
"x INT64, y FLOAT64" => (pa.field("x", pa.int64()), pa.field("y", pa.float64()))
3132
"""
32-
sql = sql.strip()
33+
sg_type = sqlglot.parse_one(f"STRUCT<{sql}>", read="bigquery")
34+
pa_struct = _sg_to_pyarrow_dtype(cast(sgexp.DataType, sg_type))
35+
return tuple(cast(pa.StructType, pa_struct).fields)
3336

34-
if sql.upper() == "STRING":
35-
return pa.string()
3637

37-
if sql.upper() == "INT64":
38+
def _sg_to_pyarrow_dtype(sg_type: sgexp.DataType) -> pa.DataType:
39+
if sg_type.is_type(sgexp.DataType.Type.BIGINT):
3840
return pa.int64()
3941

40-
if sql.upper() == "FLOAT64":
42+
if sg_type.is_type(sgexp.DataType.Type.DOUBLE):
4143
return pa.float64()
4244

43-
if sql.upper() == "BOOL":
45+
if sg_type.is_type(sgexp.DataType.Type.BOOLEAN):
4446
return pa.bool_()
4547

46-
if sql.upper().startswith("ARRAY<") and sql.endswith(">"):
47-
inner_type = sql[len("ARRAY<") : -1]
48-
return pa.list_(parse_sql_type(inner_type))
49-
50-
if sql.upper().startswith("STRUCT<") and sql.endswith(">"):
51-
inner_fields = parse_sql_fields(sql[len("STRUCT<") : -1])
52-
return pa.struct(inner_fields)
53-
54-
raise ValueError(f"Unsupported SQL type: {sql}")
55-
56-
57-
def parse_sql_fields(sql: str) -> tuple[pa.Field]:
58-
sql = sql.strip()
59-
60-
start_idx = 0
61-
nested_depth = 0
62-
fields: list[pa.field] = []
63-
64-
for end_idx in range(len(sql)):
65-
c = sql[end_idx]
66-
67-
if c == "<":
68-
nested_depth += 1
69-
elif c == ">":
70-
nested_depth -= 1
71-
elif c == "," and nested_depth == 0:
72-
field = sql[start_idx:end_idx]
73-
fields.append(parse_sql_field(field))
74-
start_idx = end_idx + 1
75-
76-
# Append the last field
77-
fields.append(parse_sql_field(sql[start_idx:]))
78-
79-
return tuple(sorted(fields, key=lambda f: f.name))
80-
81-
82-
def parse_sql_field(sql: str) -> pa.Field:
83-
sql = sql.strip()
48+
if sg_type.is_type(sgexp.DataType.Type.TEXT):
49+
return pa.string()
8450

85-
space_idx = sql.find(" ")
51+
if sg_type.is_type(sgexp.DataType.Type.ARRAY):
52+
assert len(sg_type.expressions) == 1
53+
return pa.list_(_sg_to_pyarrow_dtype(sg_type.expressions[0]))
8654

87-
if space_idx == -1:
88-
raise ValueError(f"Invalid struct field: {sql}")
55+
if sg_type.is_type(sgexp.DataType.Type.STRUCT):
56+
fields = (
57+
pa.field(col.name, _sg_to_pyarrow_dtype(col.kind))
58+
for col in sg_type.expressions
59+
)
60+
# Sort the fields by name to align with server behavior.
61+
return pa.struct(sorted(fields, key=lambda f: f.name))
8962

90-
return pa.field(sql[:space_idx].strip(), parse_sql_type(sql[space_idx:]))
63+
raise ValueError(f"Unsupported type: {sg_type}")

tests/unit/operations/test_output_schemas.py

Lines changed: 33 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -18,67 +18,13 @@
1818
from bigframes.operations import output_schemas
1919

2020

21-
@pytest.mark.parametrize(
22-
("sql", "expected"),
23-
[
24-
("INT64", pa.int64()),
25-
(" INT64 ", pa.int64()),
26-
("int64", pa.int64()),
27-
("FLOAT64", pa.float64()),
28-
("STRING", pa.string()),
29-
("BOOL", pa.bool_()),
30-
("ARRAY<INT64>", pa.list_(pa.int64())),
31-
(
32-
"STRUCT<x INT64, y FLOAT64>",
33-
pa.struct((pa.field("x", pa.int64()), pa.field("y", pa.float64()))),
34-
),
35-
(
36-
"STRUCT< x INT64, y FLOAT64>",
37-
pa.struct((pa.field("x", pa.int64()), pa.field("y", pa.float64()))),
38-
),
39-
(
40-
"STRUCT<y INT64, x FLOAT64>",
41-
pa.struct((pa.field("x", pa.float64()), pa.field("y", pa.int64()))),
42-
),
43-
(
44-
"ARRAY<STRUCT<y INT64, x INT64>>",
45-
pa.list_(pa.struct((pa.field("x", pa.int64()), pa.field("y", pa.int64())))),
46-
),
47-
(
48-
"STRUCT<y STRUCT<b STRING, a BOOL>, x ARRAY<FLOAT64>>",
49-
pa.struct(
50-
(
51-
pa.field("x", pa.list_(pa.float64())),
52-
pa.field(
53-
"y",
54-
pa.struct(
55-
(pa.field("a", pa.bool_()), pa.field("b", pa.string()))
56-
),
57-
),
58-
)
59-
),
60-
),
61-
],
62-
)
63-
def test_parse_sql_to_pyarrow_dtype(sql, expected):
64-
assert output_schemas.parse_sql_type(sql) == expected
65-
66-
6721
@pytest.mark.parametrize(
6822
"sql",
69-
[
70-
"a INT64",
71-
"ARRAY<>",
72-
"ARRAY<INT64",
73-
"ARRAY<x INT64>" "ARRAY<int64>" "STRUCT<>",
74-
"DATE",
75-
"STRUCT<INT64, FLOAT64>",
76-
"ARRAY<ARRAY<>>",
77-
],
23+
["x TIMESTAMP", "x INT64, y DATETIME"],
7824
)
7925
def test_parse_sql_to_pyarrow_dtype_invalid_input_raies_error(sql):
8026
with pytest.raises(ValueError):
81-
output_schemas.parse_sql_type(sql)
27+
output_schemas.parse_sql_fields(sql)
8228

8329

8430
@pytest.mark.parametrize(
@@ -93,6 +39,37 @@ def test_parse_sql_to_pyarrow_dtype_invalid_input_raies_error(sql):
9339
"y FLOAT64, x INT64",
9440
(pa.field("x", pa.int64()), pa.field("y", pa.float64())),
9541
),
42+
(
43+
"a STRUCT<y FLOAT64, x INT64>",
44+
(
45+
pa.field(
46+
"a",
47+
pa.struct((pa.field("x", pa.int64()), pa.field("y", pa.float64()))),
48+
),
49+
),
50+
),
51+
(
52+
"a STRUCT<y STRUCT<b STRING, a BOOL>, x ARRAY<FLOAT64>>",
53+
(
54+
pa.field(
55+
"a",
56+
pa.struct(
57+
(
58+
pa.field("x", pa.list_(pa.float64())),
59+
pa.field(
60+
"y",
61+
pa.struct(
62+
(
63+
pa.field("a", pa.bool_()),
64+
pa.field("b", pa.string()),
65+
)
66+
),
67+
),
68+
)
69+
),
70+
),
71+
),
72+
),
9673
],
9774
)
9875
def test_parse_sql_fields(sql, expected):

0 commit comments

Comments
 (0)