Skip to content

Commit 36b5dfb

Browse files
committed
support case-insensitive type parsing
1 parent acd11f9 commit 36b5dfb

File tree

3 files changed

+14
-10
lines changed

3 files changed

+14
-10
lines changed

bigframes/operations/output_schemas.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,23 @@ def parse_sql_type(sql: str) -> pa.DataType:
3131
"""
3232
sql = sql.strip()
3333

34-
if sql == "STRING":
34+
if sql.upper() == "STRING":
3535
return pa.string()
3636

37-
if sql == "INT64":
37+
if sql.upper() == "INT64":
3838
return pa.int64()
3939

40-
if sql == "FLOAT64":
40+
if sql.upper() == "FLOAT64":
4141
return pa.float64()
4242

43-
if sql == "BOOL":
43+
if sql.upper() == "BOOL":
4444
return pa.bool_()
4545

46-
if sql.startswith("ARRAY<") and sql.endswith(">"):
46+
if sql.upper().startswith("ARRAY<") and sql.endswith(">"):
4747
inner_type = sql[len("ARRAY<") : -1]
4848
return pa.list_(parse_sql_type(inner_type))
4949

50-
if sql.startswith("STRUCT<") and sql.endswith(">"):
50+
if sql.upper().startswith("STRUCT<") and sql.endswith(">"):
5151
inner_fields = parse_sql_fields(sql[len("STRUCT<") : -1])
5252
return pa.struct(inner_fields)
5353

tests/system/small/bigquery/test_ai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_ai_generate_with_output_schema(session):
9494
result = bbq.ai.generate(
9595
prompt,
9696
endpoint="gemini-2.5-flash",
97-
output_schema={"population": "INT64", "is_in_north_america": "BOOL"},
97+
output_schema={"population": "INT64", "is_in_north_america": "bool"},
9898
)
9999

100100
assert _contains_no_nulls(result)

tests/unit/operations/test_output_schemas.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
[
2424
("INT64", pa.int64()),
2525
(" INT64 ", pa.int64()),
26+
("int64", pa.int64()),
2627
("FLOAT64", pa.float64()),
2728
("STRING", pa.string()),
2829
("BOOL", pa.bool_()),
@@ -44,7 +45,7 @@
4445
pa.list_(pa.struct((pa.field("x", pa.int64()), pa.field("y", pa.int64())))),
4546
),
4647
(
47-
"STRUCT<x ARRAY<FLOAT64>, y STRUCT<a BOOL, b STRING>>",
48+
"STRUCT<y STRUCT<b STRING, a BOOL>, x ARRAY<FLOAT64>>",
4849
pa.struct(
4950
(
5051
pa.field("x", pa.list_(pa.float64())),
@@ -66,14 +67,13 @@ def test_parse_sql_to_pyarrow_dtype(sql, expected):
6667
@pytest.mark.parametrize(
6768
"sql",
6869
[
69-
"int64",
7070
"a INT64",
7171
"ARRAY<>",
7272
"ARRAY<INT64",
7373
"ARRAY<x INT64>" "ARRAY<int64>" "STRUCT<>",
7474
"DATE",
7575
"STRUCT<INT64, FLOAT64>",
76-
"STRUCT<x int64>",
76+
"ARRAY<ARRAY<>>",
7777
],
7878
)
7979
def test_parse_sql_to_pyarrow_dtype_invalid_input_raies_error(sql):
@@ -89,6 +89,10 @@ def test_parse_sql_to_pyarrow_dtype_invalid_input_raies_error(sql):
8989
"x INT64, y FLOAT64",
9090
(pa.field("x", pa.int64()), pa.field("y", pa.float64())),
9191
),
92+
(
93+
"y FLOAT64, x INT64",
94+
(pa.field("x", pa.int64()), pa.field("y", pa.float64())),
95+
),
9296
],
9397
)
9498
def test_parse_sql_fields(sql, expected):

0 commit comments

Comments
 (0)