Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bigframes/core/compile/sqlglot/expressions/generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

from bigframes import dtypes
from bigframes import operations as ops
from bigframes.core.compile.sqlglot import sqlglot_types
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
from bigframes.core.compile.sqlglot.sqlglot_types import SQLGlotType

register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op

Expand All @@ -29,7 +29,7 @@
def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
from_type = expr.dtype
to_type = op.to_type
sg_to_type = SQLGlotType.from_bigframes_dtype(to_type)
sg_to_type = sqlglot_types.from_bigframes_dtype(to_type)
sg_expr = expr.expr

if to_type == dtypes.JSON_DTYPE:
Expand Down
4 changes: 2 additions & 2 deletions bigframes/core/compile/sqlglot/sqlglot_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def from_pyarrow(
expressions=[
sge.ColumnDef(
this=sge.to_identifier(field.column, quoted=True),
kind=sgt.SQLGlotType.from_bigframes_dtype(field.dtype),
kind=sgt.from_bigframes_dtype(field.dtype),
)
for field in schema.items
],
Expand Down Expand Up @@ -620,7 +620,7 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select:


def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
sqlglot_type = sgt.SQLGlotType.from_bigframes_dtype(dtype)
sqlglot_type = sgt.from_bigframes_dtype(dtype)
if value is None:
return _cast(sge.Null(), sqlglot_type)
elif dtype == dtypes.BYTES_DTYPE:
Expand Down
109 changes: 52 additions & 57 deletions bigframes/core/compile/sqlglot/sqlglot_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,62 +25,57 @@
import bigframes.dtypes


class SQLGlotType:
@classmethod
def from_bigframes_dtype(
cls,
bigframes_dtype: typing.Union[
bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype, np.dtype[typing.Any]
],
) -> str:
if bigframes_dtype == bigframes.dtypes.INT_DTYPE:
return "INT64"
elif bigframes_dtype == bigframes.dtypes.FLOAT_DTYPE:
return "FLOAT64"
elif bigframes_dtype == bigframes.dtypes.STRING_DTYPE:
return "STRING"
elif bigframes_dtype == bigframes.dtypes.BOOL_DTYPE:
return "BOOLEAN"
elif bigframes_dtype == bigframes.dtypes.DATE_DTYPE:
return "DATE"
elif bigframes_dtype == bigframes.dtypes.TIME_DTYPE:
return "TIME"
elif bigframes_dtype == bigframes.dtypes.DATETIME_DTYPE:
return "DATETIME"
elif bigframes_dtype == bigframes.dtypes.TIMESTAMP_DTYPE:
return "TIMESTAMP"
elif bigframes_dtype == bigframes.dtypes.BYTES_DTYPE:
return "BYTES"
elif bigframes_dtype == bigframes.dtypes.NUMERIC_DTYPE:
return "NUMERIC"
elif bigframes_dtype == bigframes.dtypes.BIGNUMERIC_DTYPE:
return "BIGNUMERIC"
elif bigframes_dtype == bigframes.dtypes.JSON_DTYPE:
return "JSON"
elif bigframes_dtype == bigframes.dtypes.GEO_DTYPE:
return "GEOGRAPHY"
elif bigframes_dtype == bigframes.dtypes.TIMEDELTA_DTYPE:
return "INT64"
elif isinstance(bigframes_dtype, pd.ArrowDtype):
if pa.types.is_list(bigframes_dtype.pyarrow_dtype):
inner_bigframes_dtype = bigframes.dtypes.arrow_dtype_to_bigframes_dtype(
bigframes_dtype.pyarrow_dtype.value_type
def from_bigframes_dtype(
bigframes_dtype: typing.Union[
bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype, np.dtype[typing.Any]
],
) -> str:
if bigframes_dtype == bigframes.dtypes.INT_DTYPE:
return "INT64"
elif bigframes_dtype == bigframes.dtypes.FLOAT_DTYPE:
return "FLOAT64"
elif bigframes_dtype == bigframes.dtypes.STRING_DTYPE:
return "STRING"
elif bigframes_dtype == bigframes.dtypes.BOOL_DTYPE:
return "BOOLEAN"
elif bigframes_dtype == bigframes.dtypes.DATE_DTYPE:
return "DATE"
elif bigframes_dtype == bigframes.dtypes.TIME_DTYPE:
return "TIME"
elif bigframes_dtype == bigframes.dtypes.DATETIME_DTYPE:
return "DATETIME"
elif bigframes_dtype == bigframes.dtypes.TIMESTAMP_DTYPE:
return "TIMESTAMP"
elif bigframes_dtype == bigframes.dtypes.BYTES_DTYPE:
return "BYTES"
elif bigframes_dtype == bigframes.dtypes.NUMERIC_DTYPE:
return "NUMERIC"
elif bigframes_dtype == bigframes.dtypes.BIGNUMERIC_DTYPE:
return "BIGNUMERIC"
elif bigframes_dtype == bigframes.dtypes.JSON_DTYPE:
return "JSON"
elif bigframes_dtype == bigframes.dtypes.GEO_DTYPE:
return "GEOGRAPHY"
elif bigframes_dtype == bigframes.dtypes.TIMEDELTA_DTYPE:
return "INT64"
elif isinstance(bigframes_dtype, pd.ArrowDtype):
if pa.types.is_list(bigframes_dtype.pyarrow_dtype):
inner_bigframes_dtype = bigframes.dtypes.arrow_dtype_to_bigframes_dtype(
bigframes_dtype.pyarrow_dtype.value_type
)
return f"ARRAY<{from_bigframes_dtype(inner_bigframes_dtype)}>"
elif pa.types.is_struct(bigframes_dtype.pyarrow_dtype):
struct_type = typing.cast(pa.StructType, bigframes_dtype.pyarrow_dtype)
inner_fields: list[str] = []
for i in range(struct_type.num_fields):
field = struct_type.field(i)
key = sg.to_identifier(field.name).sql("bigquery")
dtype = from_bigframes_dtype(
bigframes.dtypes.arrow_dtype_to_bigframes_dtype(field.type)
)
return (
f"ARRAY<{SQLGlotType.from_bigframes_dtype(inner_bigframes_dtype)}>"
)
elif pa.types.is_struct(bigframes_dtype.pyarrow_dtype):
struct_type = typing.cast(pa.StructType, bigframes_dtype.pyarrow_dtype)
inner_fields: list[str] = []
for i in range(struct_type.num_fields):
field = struct_type.field(i)
key = sg.to_identifier(field.name).sql("bigquery")
dtype = SQLGlotType.from_bigframes_dtype(
bigframes.dtypes.arrow_dtype_to_bigframes_dtype(field.type)
)
inner_fields.append(f"{key} {dtype}")
return "STRUCT<{}>".format(", ".join(inner_fields))
inner_fields.append(f"{key} {dtype}")
return "STRUCT<{}>".format(", ".join(inner_fields))

raise ValueError(
f"Unsupported type for {bigframes_dtype}. {constants.FEEDBACK_LINK}"
)
raise ValueError(
f"Unsupported type for {bigframes_dtype}. {constants.FEEDBACK_LINK}"
)
34 changes: 17 additions & 17 deletions tests/unit/core/compile/sqlglot/test_sqlglot_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,34 @@


def test_from_bigframes_simple_dtypes():
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.INT_DTYPE) == "INT64"
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.FLOAT_DTYPE) == "FLOAT64"
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.STRING_DTYPE) == "STRING"
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.BOOL_DTYPE) == "BOOLEAN"
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.DATE_DTYPE) == "DATE"
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.TIME_DTYPE) == "TIME"
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.DATETIME_DTYPE) == "DATETIME"
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.TIMESTAMP_DTYPE) == "TIMESTAMP"
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.BYTES_DTYPE) == "BYTES"
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.NUMERIC_DTYPE) == "NUMERIC"
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.BIGNUMERIC_DTYPE) == "BIGNUMERIC"
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.JSON_DTYPE) == "JSON"
assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.GEO_DTYPE) == "GEOGRAPHY"
assert sgt.from_bigframes_dtype(dtypes.INT_DTYPE) == "INT64"
assert sgt.from_bigframes_dtype(dtypes.FLOAT_DTYPE) == "FLOAT64"
assert sgt.from_bigframes_dtype(dtypes.STRING_DTYPE) == "STRING"
assert sgt.from_bigframes_dtype(dtypes.BOOL_DTYPE) == "BOOLEAN"
assert sgt.from_bigframes_dtype(dtypes.DATE_DTYPE) == "DATE"
assert sgt.from_bigframes_dtype(dtypes.TIME_DTYPE) == "TIME"
assert sgt.from_bigframes_dtype(dtypes.DATETIME_DTYPE) == "DATETIME"
assert sgt.from_bigframes_dtype(dtypes.TIMESTAMP_DTYPE) == "TIMESTAMP"
assert sgt.from_bigframes_dtype(dtypes.BYTES_DTYPE) == "BYTES"
assert sgt.from_bigframes_dtype(dtypes.NUMERIC_DTYPE) == "NUMERIC"
assert sgt.from_bigframes_dtype(dtypes.BIGNUMERIC_DTYPE) == "BIGNUMERIC"
assert sgt.from_bigframes_dtype(dtypes.JSON_DTYPE) == "JSON"
assert sgt.from_bigframes_dtype(dtypes.GEO_DTYPE) == "GEOGRAPHY"


def test_from_bigframes_struct_dtypes():
fields = [pa.field("int_col", pa.int64()), pa.field("bool_col", pa.bool_())]
struct_type = pd.ArrowDtype(pa.struct(fields))
expected = "STRUCT<int_col INT64, bool_col BOOLEAN>"
assert sgt.SQLGlotType.from_bigframes_dtype(struct_type) == expected
assert sgt.from_bigframes_dtype(struct_type) == expected


def test_from_bigframes_array_dtypes():
int_array_type = pd.ArrowDtype(pa.list_(pa.int64()))
assert sgt.SQLGlotType.from_bigframes_dtype(int_array_type) == "ARRAY<INT64>"
assert sgt.from_bigframes_dtype(int_array_type) == "ARRAY<INT64>"

string_array_type = pd.ArrowDtype(pa.list_(pa.string()))
assert sgt.SQLGlotType.from_bigframes_dtype(string_array_type) == "ARRAY<STRING>"
assert sgt.from_bigframes_dtype(string_array_type) == "ARRAY<STRING>"


def test_from_bigframes_multi_nested_dtypes():
Expand All @@ -61,4 +61,4 @@ def test_from_bigframes_multi_nested_dtypes():
expected = (
"ARRAY<STRUCT<string_col STRING, date_col DATE, array_col ARRAY<DATETIME>>>"
)
assert sgt.SQLGlotType.from_bigframes_dtype(array_type) == expected
assert sgt.from_bigframes_dtype(array_type) == expected