From cc2da36c02ed9a28e9ac99925c52e0c9313339a2 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Tue, 7 Oct 2025 18:15:45 +0000 Subject: [PATCH] refactor: make sqlglot.from_bf_dtype() a top-level function --- .../sqlglot/expressions/generic_ops.py | 4 +- bigframes/core/compile/sqlglot/sqlglot_ir.py | 4 +- .../core/compile/sqlglot/sqlglot_types.py | 109 +++++++++--------- .../compile/sqlglot/test_sqlglot_types.py | 34 +++--- 4 files changed, 73 insertions(+), 78 deletions(-) diff --git a/bigframes/core/compile/sqlglot/expressions/generic_ops.py b/bigframes/core/compile/sqlglot/expressions/generic_ops.py index 6a3825309c..af3b57f77b 100644 --- a/bigframes/core/compile/sqlglot/expressions/generic_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/generic_ops.py @@ -18,9 +18,9 @@ from bigframes import dtypes from bigframes import operations as ops +from bigframes.core.compile.sqlglot import sqlglot_types from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler -from bigframes.core.compile.sqlglot.sqlglot_types import SQLGlotType register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op @@ -29,7 +29,7 @@ def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression: from_type = expr.dtype to_type = op.to_type - sg_to_type = SQLGlotType.from_bigframes_dtype(to_type) + sg_to_type = sqlglot_types.from_bigframes_dtype(to_type) sg_expr = expr.expr if to_type == dtypes.JSON_DTYPE: diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index 98dbed4cdd..c7ee13f4e8 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -79,7 +79,7 @@ def from_pyarrow( expressions=[ sge.ColumnDef( this=sge.to_identifier(field.column, quoted=True), - kind=sgt.SQLGlotType.from_bigframes_dtype(field.dtype), + kind=sgt.from_bigframes_dtype(field.dtype), ) for field in schema.items ], @@ -620,7 +620,7 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select: def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: - sqlglot_type = sgt.SQLGlotType.from_bigframes_dtype(dtype) + sqlglot_type = sgt.from_bigframes_dtype(dtype) if value is None: return _cast(sge.Null(), sqlglot_type) elif dtype == dtypes.BYTES_DTYPE: diff --git a/bigframes/core/compile/sqlglot/sqlglot_types.py b/bigframes/core/compile/sqlglot/sqlglot_types.py index 5b0f70077d..64e4363ddf 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_types.py +++ b/bigframes/core/compile/sqlglot/sqlglot_types.py @@ -25,62 +25,57 @@ import bigframes.dtypes -class SQLGlotType: - @classmethod - def from_bigframes_dtype( - cls, - bigframes_dtype: typing.Union[ - bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype, np.dtype[typing.Any] - ], - ) -> str: - if bigframes_dtype == bigframes.dtypes.INT_DTYPE: - return "INT64" - elif bigframes_dtype == bigframes.dtypes.FLOAT_DTYPE: - return "FLOAT64" - elif bigframes_dtype == bigframes.dtypes.STRING_DTYPE: - return "STRING" - elif bigframes_dtype == bigframes.dtypes.BOOL_DTYPE: - return "BOOLEAN" - elif bigframes_dtype == bigframes.dtypes.DATE_DTYPE: - return "DATE" - elif bigframes_dtype == bigframes.dtypes.TIME_DTYPE: - return "TIME" - elif bigframes_dtype == bigframes.dtypes.DATETIME_DTYPE: - return "DATETIME" - elif bigframes_dtype == bigframes.dtypes.TIMESTAMP_DTYPE: - return "TIMESTAMP" - elif bigframes_dtype == bigframes.dtypes.BYTES_DTYPE: - return "BYTES" - elif bigframes_dtype == bigframes.dtypes.NUMERIC_DTYPE: - return "NUMERIC" - elif bigframes_dtype == bigframes.dtypes.BIGNUMERIC_DTYPE: - return "BIGNUMERIC" - elif bigframes_dtype == bigframes.dtypes.JSON_DTYPE: - return "JSON" - elif bigframes_dtype == bigframes.dtypes.GEO_DTYPE: - return "GEOGRAPHY" - elif bigframes_dtype == bigframes.dtypes.TIMEDELTA_DTYPE: - return "INT64" - elif isinstance(bigframes_dtype, pd.ArrowDtype): - if pa.types.is_list(bigframes_dtype.pyarrow_dtype): - inner_bigframes_dtype = bigframes.dtypes.arrow_dtype_to_bigframes_dtype( - bigframes_dtype.pyarrow_dtype.value_type +def from_bigframes_dtype( + bigframes_dtype: typing.Union[ + bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype, np.dtype[typing.Any] + ], +) -> str: + if bigframes_dtype == bigframes.dtypes.INT_DTYPE: + return "INT64" + elif bigframes_dtype == bigframes.dtypes.FLOAT_DTYPE: + return "FLOAT64" + elif bigframes_dtype == bigframes.dtypes.STRING_DTYPE: + return "STRING" + elif bigframes_dtype == bigframes.dtypes.BOOL_DTYPE: + return "BOOLEAN" + elif bigframes_dtype == bigframes.dtypes.DATE_DTYPE: + return "DATE" + elif bigframes_dtype == bigframes.dtypes.TIME_DTYPE: + return "TIME" + elif bigframes_dtype == bigframes.dtypes.DATETIME_DTYPE: + return "DATETIME" + elif bigframes_dtype == bigframes.dtypes.TIMESTAMP_DTYPE: + return "TIMESTAMP" + elif bigframes_dtype == bigframes.dtypes.BYTES_DTYPE: + return "BYTES" + elif bigframes_dtype == bigframes.dtypes.NUMERIC_DTYPE: + return "NUMERIC" + elif bigframes_dtype == bigframes.dtypes.BIGNUMERIC_DTYPE: + return "BIGNUMERIC" + elif bigframes_dtype == bigframes.dtypes.JSON_DTYPE: + return "JSON" + elif bigframes_dtype == bigframes.dtypes.GEO_DTYPE: + return "GEOGRAPHY" + elif bigframes_dtype == bigframes.dtypes.TIMEDELTA_DTYPE: + return "INT64" + elif isinstance(bigframes_dtype, pd.ArrowDtype): + if pa.types.is_list(bigframes_dtype.pyarrow_dtype): + inner_bigframes_dtype = bigframes.dtypes.arrow_dtype_to_bigframes_dtype( + bigframes_dtype.pyarrow_dtype.value_type + ) + return f"ARRAY<{from_bigframes_dtype(inner_bigframes_dtype)}>" + elif pa.types.is_struct(bigframes_dtype.pyarrow_dtype): + struct_type = typing.cast(pa.StructType, bigframes_dtype.pyarrow_dtype) + inner_fields: list[str] = [] + for i in range(struct_type.num_fields): + field = struct_type.field(i) + key = sg.to_identifier(field.name).sql("bigquery") + dtype = from_bigframes_dtype( + bigframes.dtypes.arrow_dtype_to_bigframes_dtype(field.type) ) - return ( - f"ARRAY<{SQLGlotType.from_bigframes_dtype(inner_bigframes_dtype)}>" - ) - elif pa.types.is_struct(bigframes_dtype.pyarrow_dtype): - struct_type = typing.cast(pa.StructType, bigframes_dtype.pyarrow_dtype) - inner_fields: list[str] = [] - for i in range(struct_type.num_fields): - field = struct_type.field(i) - key = sg.to_identifier(field.name).sql("bigquery") - dtype = SQLGlotType.from_bigframes_dtype( - bigframes.dtypes.arrow_dtype_to_bigframes_dtype(field.type) - ) - inner_fields.append(f"{key} {dtype}") - return "STRUCT<{}>".format(", ".join(inner_fields)) + inner_fields.append(f"{key} {dtype}") + return "STRUCT<{}>".format(", ".join(inner_fields)) - raise ValueError( - f"Unsupported type for {bigframes_dtype}. {constants.FEEDBACK_LINK}" - ) + raise ValueError( + f"Unsupported type for {bigframes_dtype}. {constants.FEEDBACK_LINK}" + ) diff --git a/tests/unit/core/compile/sqlglot/test_sqlglot_types.py b/tests/unit/core/compile/sqlglot/test_sqlglot_types.py index a9108e5daf..5c2d84383d 100644 --- a/tests/unit/core/compile/sqlglot/test_sqlglot_types.py +++ b/tests/unit/core/compile/sqlglot/test_sqlglot_types.py @@ -20,34 +20,34 @@ def test_from_bigframes_simple_dtypes(): - assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.INT_DTYPE) == "INT64" - assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.FLOAT_DTYPE) == "FLOAT64" - assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.STRING_DTYPE) == "STRING" - assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.BOOL_DTYPE) == "BOOLEAN" - assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.DATE_DTYPE) == "DATE" - assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.TIME_DTYPE) == "TIME" - assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.DATETIME_DTYPE) == "DATETIME" - assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.TIMESTAMP_DTYPE) == "TIMESTAMP" - assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.BYTES_DTYPE) == "BYTES" - assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.NUMERIC_DTYPE) == "NUMERIC" - assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.BIGNUMERIC_DTYPE) == "BIGNUMERIC" - assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.JSON_DTYPE) == "JSON" - assert sgt.SQLGlotType.from_bigframes_dtype(dtypes.GEO_DTYPE) == "GEOGRAPHY" + assert sgt.from_bigframes_dtype(dtypes.INT_DTYPE) == "INT64" + assert sgt.from_bigframes_dtype(dtypes.FLOAT_DTYPE) == "FLOAT64" + assert sgt.from_bigframes_dtype(dtypes.STRING_DTYPE) == "STRING" + assert sgt.from_bigframes_dtype(dtypes.BOOL_DTYPE) == "BOOLEAN" + assert sgt.from_bigframes_dtype(dtypes.DATE_DTYPE) == "DATE" + assert sgt.from_bigframes_dtype(dtypes.TIME_DTYPE) == "TIME" + assert sgt.from_bigframes_dtype(dtypes.DATETIME_DTYPE) == "DATETIME" + assert sgt.from_bigframes_dtype(dtypes.TIMESTAMP_DTYPE) == "TIMESTAMP" + assert sgt.from_bigframes_dtype(dtypes.BYTES_DTYPE) == "BYTES" + assert sgt.from_bigframes_dtype(dtypes.NUMERIC_DTYPE) == "NUMERIC" + assert sgt.from_bigframes_dtype(dtypes.BIGNUMERIC_DTYPE) == "BIGNUMERIC" + assert sgt.from_bigframes_dtype(dtypes.JSON_DTYPE) == "JSON" + assert sgt.from_bigframes_dtype(dtypes.GEO_DTYPE) == "GEOGRAPHY" def test_from_bigframes_struct_dtypes(): fields = [pa.field("int_col", pa.int64()), pa.field("bool_col", pa.bool_())] struct_type = pd.ArrowDtype(pa.struct(fields)) expected = "STRUCT" - assert sgt.SQLGlotType.from_bigframes_dtype(struct_type) == expected + assert sgt.from_bigframes_dtype(struct_type) == expected def test_from_bigframes_array_dtypes(): int_array_type = pd.ArrowDtype(pa.list_(pa.int64())) - assert sgt.SQLGlotType.from_bigframes_dtype(int_array_type) == "ARRAY" + assert sgt.from_bigframes_dtype(int_array_type) == "ARRAY" string_array_type = pd.ArrowDtype(pa.list_(pa.string())) - assert sgt.SQLGlotType.from_bigframes_dtype(string_array_type) == "ARRAY" + assert sgt.from_bigframes_dtype(string_array_type) == "ARRAY" def test_from_bigframes_multi_nested_dtypes(): @@ -61,4 +61,4 @@ def test_from_bigframes_multi_nested_dtypes(): expected = ( "ARRAY>>" ) - assert sgt.SQLGlotType.from_bigframes_dtype(array_type) == expected + assert sgt.from_bigframes_dtype(array_type) == expected