Skip to content

Commit fa4e46f

Browse files
authored
refactor: make sqlglot.from_bf_dtype() a top-level function (#2144)
1 parent 8fc051f commit fa4e46f

File tree

4 files changed

+73
-78
lines changed

4 files changed

+73
-78
lines changed

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
from bigframes import dtypes
2020
from bigframes import operations as ops
21+
from bigframes.core.compile.sqlglot import sqlglot_types
2122
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2223
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
23-
from bigframes.core.compile.sqlglot.sqlglot_types import SQLGlotType
2424

2525
register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
2626

@@ -29,7 +29,7 @@
2929
def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
3030
from_type = expr.dtype
3131
to_type = op.to_type
32-
sg_to_type = SQLGlotType.from_bigframes_dtype(to_type)
32+
sg_to_type = sqlglot_types.from_bigframes_dtype(to_type)
3333
sg_expr = expr.expr
3434

3535
if to_type == dtypes.JSON_DTYPE:

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def from_pyarrow(
7979
expressions=[
8080
sge.ColumnDef(
8181
this=sge.to_identifier(field.column, quoted=True),
82-
kind=sgt.SQLGlotType.from_bigframes_dtype(field.dtype),
82+
kind=sgt.from_bigframes_dtype(field.dtype),
8383
)
8484
for field in schema.items
8585
],
@@ -620,7 +620,7 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select:
620620

621621

622622
def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
623-
sqlglot_type = sgt.SQLGlotType.from_bigframes_dtype(dtype)
623+
sqlglot_type = sgt.from_bigframes_dtype(dtype)
624624
if value is None:
625625
return _cast(sge.Null(), sqlglot_type)
626626
elif dtype == dtypes.BYTES_DTYPE:

bigframes/core/compile/sqlglot/sqlglot_types.py

Lines changed: 52 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -25,62 +25,57 @@
2525
import bigframes.dtypes
2626

2727

28-
class SQLGlotType:
29-
@classmethod
30-
def from_bigframes_dtype(
31-
cls,
32-
bigframes_dtype: typing.Union[
33-
bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype, np.dtype[typing.Any]
34-
],
35-
) -> str:
36-
if bigframes_dtype == bigframes.dtypes.INT_DTYPE:
37-
return "INT64"
38-
elif bigframes_dtype == bigframes.dtypes.FLOAT_DTYPE:
39-
return "FLOAT64"
40-
elif bigframes_dtype == bigframes.dtypes.STRING_DTYPE:
41-
return "STRING"
42-
elif bigframes_dtype == bigframes.dtypes.BOOL_DTYPE:
43-
return "BOOLEAN"
44-
elif bigframes_dtype == bigframes.dtypes.DATE_DTYPE:
45-
return "DATE"
46-
elif bigframes_dtype == bigframes.dtypes.TIME_DTYPE:
47-
return "TIME"
48-
elif bigframes_dtype == bigframes.dtypes.DATETIME_DTYPE:
49-
return "DATETIME"
50-
elif bigframes_dtype == bigframes.dtypes.TIMESTAMP_DTYPE:
51-
return "TIMESTAMP"
52-
elif bigframes_dtype == bigframes.dtypes.BYTES_DTYPE:
53-
return "BYTES"
54-
elif bigframes_dtype == bigframes.dtypes.NUMERIC_DTYPE:
55-
return "NUMERIC"
56-
elif bigframes_dtype == bigframes.dtypes.BIGNUMERIC_DTYPE:
57-
return "BIGNUMERIC"
58-
elif bigframes_dtype == bigframes.dtypes.JSON_DTYPE:
59-
return "JSON"
60-
elif bigframes_dtype == bigframes.dtypes.GEO_DTYPE:
61-
return "GEOGRAPHY"
62-
elif bigframes_dtype == bigframes.dtypes.TIMEDELTA_DTYPE:
63-
return "INT64"
64-
elif isinstance(bigframes_dtype, pd.ArrowDtype):
65-
if pa.types.is_list(bigframes_dtype.pyarrow_dtype):
66-
inner_bigframes_dtype = bigframes.dtypes.arrow_dtype_to_bigframes_dtype(
67-
bigframes_dtype.pyarrow_dtype.value_type
28+
def from_bigframes_dtype(
29+
bigframes_dtype: typing.Union[
30+
bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype, np.dtype[typing.Any]
31+
],
32+
) -> str:
33+
if bigframes_dtype == bigframes.dtypes.INT_DTYPE:
34+
return "INT64"
35+
elif bigframes_dtype == bigframes.dtypes.FLOAT_DTYPE:
36+
return "FLOAT64"
37+
elif bigframes_dtype == bigframes.dtypes.STRING_DTYPE:
38+
return "STRING"
39+
elif bigframes_dtype == bigframes.dtypes.BOOL_DTYPE:
40+
return "BOOLEAN"
41+
elif bigframes_dtype == bigframes.dtypes.DATE_DTYPE:
42+
return "DATE"
43+
elif bigframes_dtype == bigframes.dtypes.TIME_DTYPE:
44+
return "TIME"
45+
elif bigframes_dtype == bigframes.dtypes.DATETIME_DTYPE:
46+
return "DATETIME"
47+
elif bigframes_dtype == bigframes.dtypes.TIMESTAMP_DTYPE:
48+
return "TIMESTAMP"
49+
elif bigframes_dtype == bigframes.dtypes.BYTES_DTYPE:
50+
return "BYTES"
51+
elif bigframes_dtype == bigframes.dtypes.NUMERIC_DTYPE:
52+
return "NUMERIC"
53+
elif bigframes_dtype == bigframes.dtypes.BIGNUMERIC_DTYPE:
54+
return "BIGNUMERIC"
55+
elif bigframes_dtype == bigframes.dtypes.JSON_DTYPE:
56+
return "JSON"
57+
elif bigframes_dtype == bigframes.dtypes.GEO_DTYPE:
58+
return "GEOGRAPHY"
59+
elif bigframes_dtype == bigframes.dtypes.TIMEDELTA_DTYPE:
60+
return "INT64"
61+
elif isinstance(bigframes_dtype, pd.ArrowDtype):
62+
if pa.types.is_list(bigframes_dtype.pyarrow_dtype):
63+
inner_bigframes_dtype = bigframes.dtypes.arrow_dtype_to_bigframes_dtype(
64+
bigframes_dtype.pyarrow_dtype.value_type
65+
)
66+
return f"ARRAY<{from_bigframes_dtype(inner_bigframes_dtype)}>"
67+
elif pa.types.is_struct(bigframes_dtype.pyarrow_dtype):
68+
struct_type = typing.cast(pa.StructType, bigframes_dtype.pyarrow_dtype)
69+
inner_fields: list[str] = []
70+
for i in range(struct_type.num_fields):
71+
field = struct_type.field(i)
72+
key = sg.to_identifier(field.name).sql("bigquery")
73+
dtype = from_bigframes_dtype(
74+
bigframes.dtypes.arrow_dtype_to_bigframes_dtype(field.type)
6875
)
69-
return (
70-
f"ARRAY<{SQLGlotType.from_bigframes_dtype(inner_bigframes_dtype)}>"
71-
)
72-
elif pa.types.is_struct(bigframes_dtype.pyarrow_dtype):
73-
struct_type = typing.cast(pa.StructType, bigframes_dtype.pyarrow_dtype)
74-
inner_fields: list[str] = []
75-
for i in range(struct_type.num_fields):
76-
field = struct_type.field(i)
77-
key = sg.to_identifier(field.name).sql("bigquery")
78-
dtype = SQLGlotType.from_bigframes_dtype(
79-
bigframes.dtypes.arrow_dtype_to_bigframes_dtype(field.type)
80-
)
81-
inner_fields.append(f"{key} {dtype}")
82-
return "STRUCT<{}>".format(", ".join(inner_fields))
76+
inner_fields.append(f"{key} {dtype}")
77+
return "STRUCT<{}>".format(", ".join(inner_fields))
8378

84-
raise ValueError(
85-
f"Unsupported type for {bigframes_dtype}. {constants.FEEDBACK_LINK}"
86-
)
79+
raise ValueError(
80+
f"Unsupported type for {bigframes_dtype}. {constants.FEEDBACK_LINK}"
81+
)

tests/unit/core/compile/sqlglot/test_sqlglot_types.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,34 +20,34 @@
2020

2121

2222
def test_from_bigframes_simple_dtypes():
23-
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.INT_DTYPE) == "INT64"
24-
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.FLOAT_DTYPE) == "FLOAT64"
25-
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.STRING_DTYPE) == "STRING"
26-
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.BOOL_DTYPE) == "BOOLEAN"
27-
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.DATE_DTYPE) == "DATE"
28-
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.TIME_DTYPE) == "TIME"
29-
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.DATETIME_DTYPE) == "DATETIME"
30-
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.TIMESTAMP_DTYPE) == "TIMESTAMP"
31-
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.BYTES_DTYPE) == "BYTES"
32-
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.NUMERIC_DTYPE) == "NUMERIC"
33-
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.BIGNUMERIC_DTYPE) == "BIGNUMERIC"
34-
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.JSON_DTYPE) == "JSON"
35-
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.GEO_DTYPE) == "GEOGRAPHY"
23+
assert sgt.from_bigframes_dtype(dtypes.INT_DTYPE) == "INT64"
24+
assert sgt.from_bigframes_dtype(dtypes.FLOAT_DTYPE) == "FLOAT64"
25+
assert sgt.from_bigframes_dtype(dtypes.STRING_DTYPE) == "STRING"
26+
assert sgt.from_bigframes_dtype(dtypes.BOOL_DTYPE) == "BOOLEAN"
27+
assert sgt.from_bigframes_dtype(dtypes.DATE_DTYPE) == "DATE"
28+
assert sgt.from_bigframes_dtype(dtypes.TIME_DTYPE) == "TIME"
29+
assert sgt.from_bigframes_dtype(dtypes.DATETIME_DTYPE) == "DATETIME"
30+
assert sgt.from_bigframes_dtype(dtypes.TIMESTAMP_DTYPE) == "TIMESTAMP"
31+
assert sgt.from_bigframes_dtype(dtypes.BYTES_DTYPE) == "BYTES"
32+
assert sgt.from_bigframes_dtype(dtypes.NUMERIC_DTYPE) == "NUMERIC"
33+
assert sgt.from_bigframes_dtype(dtypes.BIGNUMERIC_DTYPE) == "BIGNUMERIC"
34+
assert sgt.from_bigframes_dtype(dtypes.JSON_DTYPE) == "JSON"
35+
assert sgt.from_bigframes_dtype(dtypes.GEO_DTYPE) == "GEOGRAPHY"
3636

3737

3838
def test_from_bigframes_struct_dtypes():
3939
fields = [pa.field("int_col", pa.int64()), pa.field("bool_col", pa.bool_())]
4040
struct_type = pd.ArrowDtype(pa.struct(fields))
4141
expected = "STRUCT<int_col INT64, bool_col BOOLEAN>"
42-
assert sgt.SQLGlotType.from_bigframes_dtype(struct_type) == expected
42+
assert sgt.from_bigframes_dtype(struct_type) == expected
4343

4444

4545
def test_from_bigframes_array_dtypes():
4646
int_array_type = pd.ArrowDtype(pa.list_(pa.int64()))
47-
assert sgt.SQLGlotType.from_bigframes_dtype(int_array_type) == "ARRAY<INT64>"
47+
assert sgt.from_bigframes_dtype(int_array_type) == "ARRAY<INT64>"
4848

4949
string_array_type = pd.ArrowDtype(pa.list_(pa.string()))
50-
assert sgt.SQLGlotType.from_bigframes_dtype(string_array_type) == "ARRAY<STRING>"
50+
assert sgt.from_bigframes_dtype(string_array_type) == "ARRAY<STRING>"
5151

5252

5353
def test_from_bigframes_multi_nested_dtypes():
@@ -61,4 +61,4 @@ def test_from_bigframes_multi_nested_dtypes():
6161
expected = (
6262
"ARRAY<STRUCT<string_col STRING, date_col DATE, array_col ARRAY<DATETIME>>>"
6363
)
64-
assert sgt.SQLGlotType.from_bigframes_dtype(array_type) == expected
64+
assert sgt.from_bigframes_dtype(array_type) == expected

0 commit comments

Comments
 (0)