From 2bb5fddb9afa76e38b966e66f873119a9e59f7db Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Mon, 22 Sep 2025 21:19:42 +0000 Subject: [PATCH 1/2] refactor: enable "astype" engine tests for the sqlglot compiler --- .../sqlglot/expressions/generic_ops.py | 80 +++++++++- .../system/small/engines/test_generic_ops.py | 36 ++--- .../test_generic_ops/test_astype_bool/out.sql | 18 +++ .../test_astype_float/out.sql | 17 ++ .../test_astype_from_json/out.sql | 21 +++ .../test_generic_ops/test_astype_int/out.sql | 33 ++++ .../test_generic_ops/test_astype_json/out.sql | 26 ++++ .../test_astype_string/out.sql | 18 +++ .../test_astype_time_like/out.sql | 19 +++ .../sqlglot/expressions/test_generic_ops.py | 147 ++++++++++++++++++ 10 files changed, 395 insertions(+), 20 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_bool/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_float/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_from_json/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_int/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_json/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_string/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_time_like/out.sql diff --git a/bigframes/core/compile/sqlglot/expressions/generic_ops.py b/bigframes/core/compile/sqlglot/expressions/generic_ops.py index 5ee4ede94a..9337e9b333 100644 --- a/bigframes/core/compile/sqlglot/expressions/generic_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/generic_ops.py @@ -16,17 +16,85 @@ import sqlglot.expressions as sge +from bigframes import dtypes from bigframes import operations as ops from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler +from bigframes.core.compile.sqlglot.sqlglot_types import SQLGlotType register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op @register_unary_op(ops.AsTypeOp, pass_op=True) def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression: - # TODO: Support more types for casting, such as JSON, etc. - return sge.Cast(this=expr.expr, to=op.to_type) + from_type = expr.dtype + to_type = op.to_type + sg_to_type = SQLGlotType.from_bigframes_dtype(to_type) + sg_expr = expr.expr + + if to_type == dtypes.JSON_DTYPE: + if from_type == dtypes.STRING_DTYPE: + func_name = "PARSE_JSON_IN_SAFE" if op.safe else "PARSE_JSON" + return sge.func(func_name, sg_expr) + if from_type in (dtypes.INT_DTYPE, dtypes.BOOL_DTYPE, dtypes.FLOAT_DTYPE): + sg_expr = sge.Cast(this=sg_expr, to="STRING") + return sge.func("PARSE_JSON", sg_expr) + raise TypeError(f"Cannot cast from {from_type} to {to_type}") + + if from_type == dtypes.JSON_DTYPE: + func_name = "" + if to_type == dtypes.INT_DTYPE: + func_name = "INT64" + elif to_type == dtypes.FLOAT_DTYPE: + func_name = "FLOAT64" + elif to_type == dtypes.BOOL_DTYPE: + func_name = "BOOL" + elif to_type == dtypes.STRING_DTYPE: + func_name = "STRING" + if func_name: + func_name = "SAFE." + func_name if op.safe else func_name + return sge.func(func_name, sg_expr) + raise TypeError(f"Cannot cast from {from_type} to {to_type}") + + if to_type == dtypes.INT_DTYPE: + # Cannot cast DATETIME to INT directly so need to convert to TIMESTAMP first. + if from_type == dtypes.DATETIME_DTYPE: + sg_expr = _cast(sg_expr, "TIMESTAMP", op.safe) + return sge.func("UNIX_MICROS", sg_expr) + if from_type == dtypes.TIMESTAMP_DTYPE: + return sge.func("UNIX_MICROS", sg_expr) + if from_type == dtypes.TIME_DTYPE: + return sge.func( + "TIME_DIFF", + _cast(sg_expr, "TIME", op.safe), + sge.convert("00:00:00"), + "MICROSECOND", + ) + if from_type == dtypes.NUMERIC_DTYPE or from_type == dtypes.FLOAT_DTYPE: + sg_expr = sge.func("TRUNC", sg_expr) + return _cast(sg_expr, sg_to_type, op.safe) + + if to_type == dtypes.FLOAT_DTYPE and from_type == dtypes.BOOL_DTYPE: + sg_expr = _cast(sg_expr, "INT64", op.safe) + return _cast(sg_expr, sg_to_type, op.safe) + + if to_type == dtypes.BOOL_DTYPE: + if from_type == dtypes.BOOL_DTYPE: + return sg_expr + else: + return sge.NEQ(this=sg_expr, expression=sge.convert(0)) + + if to_type == dtypes.STRING_DTYPE: + sg_expr = _cast(sg_expr, sg_to_type, op.safe) + if from_type == dtypes.BOOL_DTYPE: + sg_expr = sge.func("INITCAP", sg_expr) + return sg_expr + + if dtypes.is_time_like(to_type) and from_type == dtypes.INT_DTYPE: + sg_expr = sge.func("TIMESTAMP_MICROS", sg_expr) + return _cast(sg_expr, sg_to_type, op.safe) + + return _cast(sg_expr, sg_to_type, op.safe) @register_unary_op(ops.hash_op) @@ -53,3 +121,11 @@ def _(expr: TypedExpr, op: ops.MapOp) -> sge.Expression: @register_unary_op(ops.notnull_op) def _(expr: TypedExpr) -> sge.Expression: return sge.Not(this=sge.Is(this=expr.expr, expression=sge.Null())) + + +# Helper functions +def _cast(expr: sge.Expression, to: str, safe: bool): + if safe: + return sge.TryCast(this=expr, to=to) + else: + return sge.Cast(this=expr, to=to) diff --git a/tests/system/small/engines/test_generic_ops.py b/tests/system/small/engines/test_generic_ops.py index fc40b7e59d..fc491d358b 100644 --- a/tests/system/small/engines/test_generic_ops.py +++ b/tests/system/small/engines/test_generic_ops.py @@ -52,7 +52,7 @@ def apply_op( return new_arr -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_int(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, @@ -63,7 +63,7 @@ def test_engines_astype_int(scalars_array_value: array_value.ArrayValue, engine) assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_string_int(scalars_array_value: array_value.ArrayValue, engine): vals = ["1", "100", "-3"] arr, _ = scalars_array_value.compute_values( @@ -78,7 +78,7 @@ def test_engines_astype_string_int(scalars_array_value: array_value.ArrayValue, assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_float(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, @@ -89,7 +89,7 @@ def test_engines_astype_float(scalars_array_value: array_value.ArrayValue, engin assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_string_float( scalars_array_value: array_value.ArrayValue, engine ): @@ -106,7 +106,7 @@ def test_engines_astype_string_float( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_bool(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, ops.AsTypeOp(to_type=bigframes.dtypes.BOOL_DTYPE) @@ -115,7 +115,7 @@ def test_engines_astype_bool(scalars_array_value: array_value.ArrayValue, engine assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_string(scalars_array_value: array_value.ArrayValue, engine): # floats work slightly different with trailing zeroes rn arr = apply_op( @@ -127,7 +127,7 @@ def test_engines_astype_string(scalars_array_value: array_value.ArrayValue, engi assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_numeric(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, @@ -138,7 +138,7 @@ def test_engines_astype_numeric(scalars_array_value: array_value.ArrayValue, eng assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_string_numeric( scalars_array_value: array_value.ArrayValue, engine ): @@ -155,7 +155,7 @@ def test_engines_astype_string_numeric( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_date(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, @@ -166,7 +166,7 @@ def test_engines_astype_date(scalars_array_value: array_value.ArrayValue, engine assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_string_date( scalars_array_value: array_value.ArrayValue, engine ): @@ -183,7 +183,7 @@ def test_engines_astype_string_date( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_datetime(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, @@ -194,7 +194,7 @@ def test_engines_astype_datetime(scalars_array_value: array_value.ArrayValue, en assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_string_datetime( scalars_array_value: array_value.ArrayValue, engine ): @@ -211,7 +211,7 @@ def test_engines_astype_string_datetime( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_timestamp(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, @@ -222,7 +222,7 @@ def test_engines_astype_timestamp(scalars_array_value: array_value.ArrayValue, e assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_string_timestamp( scalars_array_value: array_value.ArrayValue, engine ): @@ -243,7 +243,7 @@ def test_engines_astype_string_timestamp( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_time(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, @@ -254,7 +254,7 @@ def test_engines_astype_time(scalars_array_value: array_value.ArrayValue, engine assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_from_json(scalars_array_value: array_value.ArrayValue, engine): exprs = [ ops.AsTypeOp(to_type=bigframes.dtypes.INT_DTYPE).as_expr( @@ -275,7 +275,7 @@ def test_engines_astype_from_json(scalars_array_value: array_value.ArrayValue, e assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_to_json(scalars_array_value: array_value.ArrayValue, engine): exprs = [ ops.AsTypeOp(to_type=bigframes.dtypes.JSON_DTYPE).as_expr( @@ -298,7 +298,7 @@ def test_engines_astype_to_json(scalars_array_value: array_value.ArrayValue, eng assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_timedelta(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_bool/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_bool/out.sql new file mode 100644 index 0000000000..440aea9161 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_bool/out.sql @@ -0,0 +1,18 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `float64_col` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `bfcol_0` AS `bfcol_2`, + `bfcol_1` <> 0 AS `bfcol_3`, + `bfcol_1` <> 0 AS `bfcol_4` + FROM `bfcte_0` +) +SELECT + `bfcol_2` AS `bool_col`, + `bfcol_3` AS `float64_col`, + `bfcol_4` AS `float64_w_safe` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_float/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_float/out.sql new file mode 100644 index 0000000000..81a8805f47 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_float/out.sql @@ -0,0 +1,17 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CAST(CAST(`bfcol_0` AS INT64) AS FLOAT64) AS `bfcol_1`, + CAST('1.34235e4' AS FLOAT64) AS `bfcol_2`, + SAFE_CAST(SAFE_CAST(`bfcol_0` AS INT64) AS FLOAT64) AS `bfcol_3` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `bool_col`, + `bfcol_2` AS `str_const`, + `bfcol_3` AS `bool_w_safe` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_from_json/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_from_json/out.sql new file mode 100644 index 0000000000..25d51b26b3 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_from_json/out.sql @@ -0,0 +1,21 @@ +WITH `bfcte_0` AS ( + SELECT + `json_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`json_types` +), `bfcte_1` AS ( + SELECT + *, + INT64(`bfcol_0`) AS `bfcol_1`, + FLOAT64(`bfcol_0`) AS `bfcol_2`, + BOOL(`bfcol_0`) AS `bfcol_3`, + STRING(`bfcol_0`) AS `bfcol_4`, + SAFE.INT64(`bfcol_0`) AS `bfcol_5` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `int64_col`, + `bfcol_2` AS `float64_col`, + `bfcol_3` AS `bool_col`, + `bfcol_4` AS `string_col`, + `bfcol_5` AS `int64_w_safe` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_int/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_int/out.sql new file mode 100644 index 0000000000..22aa2cf91a --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_int/out.sql @@ -0,0 +1,33 @@ +WITH `bfcte_0` AS ( + SELECT + `datetime_col` AS `bfcol_0`, + `numeric_col` AS `bfcol_1`, + `float64_col` AS `bfcol_2`, + `time_col` AS `bfcol_3`, + `timestamp_col` AS `bfcol_4` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + UNIX_MICROS(CAST(`bfcol_0` AS TIMESTAMP)) AS `bfcol_5`, + UNIX_MICROS(SAFE_CAST(`bfcol_0` AS TIMESTAMP)) AS `bfcol_6`, + TIME_DIFF(CAST(`bfcol_3` AS TIME), '00:00:00', MICROSECOND) AS `bfcol_7`, + TIME_DIFF(SAFE_CAST(`bfcol_3` AS TIME), '00:00:00', MICROSECOND) AS `bfcol_8`, + UNIX_MICROS(`bfcol_4`) AS `bfcol_9`, + CAST(TRUNC(`bfcol_1`) AS INT64) AS `bfcol_10`, + CAST(TRUNC(`bfcol_2`) AS INT64) AS `bfcol_11`, + SAFE_CAST(TRUNC(`bfcol_2`) AS INT64) AS `bfcol_12`, + CAST('100' AS INT64) AS `bfcol_13` + FROM `bfcte_0` +) +SELECT + `bfcol_5` AS `datetime_col`, + `bfcol_6` AS `datetime_w_safe`, + `bfcol_7` AS `time_col`, + `bfcol_8` AS `time_w_safe`, + `bfcol_9` AS `timestamp_col`, + `bfcol_10` AS `numeric_col`, + `bfcol_11` AS `float64_col`, + `bfcol_12` AS `float64_w_safe`, + `bfcol_13` AS `str_const` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_json/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_json/out.sql new file mode 100644 index 0000000000..8230b4a60b --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_json/out.sql @@ -0,0 +1,26 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1`, + `float64_col` AS `bfcol_2`, + `string_col` AS `bfcol_3` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + PARSE_JSON(CAST(`bfcol_1` AS STRING)) AS `bfcol_4`, + PARSE_JSON(CAST(`bfcol_2` AS STRING)) AS `bfcol_5`, + PARSE_JSON(CAST(`bfcol_0` AS STRING)) AS `bfcol_6`, + PARSE_JSON(`bfcol_3`) AS `bfcol_7`, + PARSE_JSON(CAST(`bfcol_0` AS STRING)) AS `bfcol_8`, + PARSE_JSON_IN_SAFE(`bfcol_3`) AS `bfcol_9` + FROM `bfcte_0` +) +SELECT + `bfcol_4` AS `int64_col`, + `bfcol_5` AS `float64_col`, + `bfcol_6` AS `bool_col`, + `bfcol_7` AS `string_col`, + `bfcol_8` AS `bool_w_safe`, + `bfcol_9` AS `string_w_safe` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_string/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_string/out.sql new file mode 100644 index 0000000000..f230a3799e --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_string/out.sql @@ -0,0 +1,18 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CAST(`bfcol_1` AS STRING) AS `bfcol_2`, + INITCAP(CAST(`bfcol_0` AS STRING)) AS `bfcol_3`, + INITCAP(SAFE_CAST(`bfcol_0` AS STRING)) AS `bfcol_4` + FROM `bfcte_0` +) +SELECT + `bfcol_2` AS `int64_col`, + `bfcol_3` AS `bool_col`, + `bfcol_4` AS `bool_w_safe` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_time_like/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_time_like/out.sql new file mode 100644 index 0000000000..141b7ffa9a --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_astype_time_like/out.sql @@ -0,0 +1,19 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CAST(TIMESTAMP_MICROS(`bfcol_0`) AS DATETIME) AS `bfcol_1`, + CAST(TIMESTAMP_MICROS(`bfcol_0`) AS TIME) AS `bfcol_2`, + CAST(TIMESTAMP_MICROS(`bfcol_0`) AS TIMESTAMP) AS `bfcol_3`, + SAFE_CAST(TIMESTAMP_MICROS(`bfcol_0`) AS TIME) AS `bfcol_4` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `int64_to_datetime`, + `bfcol_2` AS `int64_to_time`, + `bfcol_3` AS `int64_to_timestamp`, + `bfcol_4` AS `int64_to_time_safe` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py index 130d34a2fa..d9ae6ab539 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py @@ -14,13 +14,160 @@ import pytest +from bigframes import dtypes from bigframes import operations as ops +from bigframes.core import expression as ex import bigframes.pandas as bpd from bigframes.testing import utils pytest.importorskip("pytest_snapshot") +def test_astype_int(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df + to_type = dtypes.INT_DTYPE + + ops_map = { + "datetime_col": ops.AsTypeOp(to_type=to_type).as_expr("datetime_col"), + "datetime_w_safe": ops.AsTypeOp(to_type=to_type, safe=True).as_expr( + "datetime_col" + ), + "time_col": ops.AsTypeOp(to_type=to_type).as_expr("time_col"), + "time_w_safe": ops.AsTypeOp(to_type=to_type, safe=True).as_expr("time_col"), + "timestamp_col": ops.AsTypeOp(to_type=to_type).as_expr("timestamp_col"), + "numeric_col": ops.AsTypeOp(to_type=to_type).as_expr("numeric_col"), + "float64_col": ops.AsTypeOp(to_type=to_type).as_expr("float64_col"), + "float64_w_safe": ops.AsTypeOp(to_type=to_type, safe=True).as_expr( + "float64_col" + ), + "str_const": ops.AsTypeOp(to_type=to_type).as_expr(ex.const("100")), + } + + sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) + snapshot.assert_match(sql, "out.sql") + + +def test_astype_float(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df + to_type = dtypes.FLOAT_DTYPE + + ops_map = { + "bool_col": ops.AsTypeOp(to_type=to_type).as_expr("bool_col"), + "str_const": ops.AsTypeOp(to_type=to_type).as_expr(ex.const("1.34235e4")), + "bool_w_safe": ops.AsTypeOp(to_type=to_type, safe=True).as_expr("bool_col"), + } + sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) + snapshot.assert_match(sql, "out.sql") + + +def test_astype_bool(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df + to_type = dtypes.BOOL_DTYPE + + ops_map = { + "bool_col": ops.AsTypeOp(to_type=to_type).as_expr("bool_col"), + "float64_col": ops.AsTypeOp(to_type=to_type).as_expr("float64_col"), + "float64_w_safe": ops.AsTypeOp(to_type=to_type, safe=True).as_expr( + "float64_col" + ), + } + sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) + snapshot.assert_match(sql, "out.sql") + + +def test_astype_time_like(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df + + ops_map = { + "int64_to_datetime": ops.AsTypeOp(to_type=dtypes.DATETIME_DTYPE).as_expr( + "int64_col" + ), + "int64_to_time": ops.AsTypeOp(to_type=dtypes.TIME_DTYPE).as_expr("int64_col"), + "int64_to_timestamp": ops.AsTypeOp(to_type=dtypes.TIMESTAMP_DTYPE).as_expr( + "int64_col" + ), + "int64_to_time_safe": ops.AsTypeOp( + to_type=dtypes.TIME_DTYPE, safe=True + ).as_expr("int64_col"), + } + sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) + snapshot.assert_match(sql, "out.sql") + + +def test_astype_string(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df + to_type = dtypes.STRING_DTYPE + + ops_map = { + "int64_col": ops.AsTypeOp(to_type=to_type).as_expr("int64_col"), + "bool_col": ops.AsTypeOp(to_type=to_type).as_expr("bool_col"), + "bool_w_safe": ops.AsTypeOp(to_type=to_type, safe=True).as_expr("bool_col"), + } + sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) + snapshot.assert_match(sql, "out.sql") + + +def test_astype_json(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df + + ops_map = { + "int64_col": ops.AsTypeOp(to_type=dtypes.JSON_DTYPE).as_expr("int64_col"), + "float64_col": ops.AsTypeOp(to_type=dtypes.JSON_DTYPE).as_expr("float64_col"), + "bool_col": ops.AsTypeOp(to_type=dtypes.JSON_DTYPE).as_expr("bool_col"), + "string_col": ops.AsTypeOp(to_type=dtypes.JSON_DTYPE).as_expr("string_col"), + "bool_w_safe": ops.AsTypeOp(to_type=dtypes.JSON_DTYPE, safe=True).as_expr( + "bool_col" + ), + "string_w_safe": ops.AsTypeOp(to_type=dtypes.JSON_DTYPE, safe=True).as_expr( + "string_col" + ), + } + sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) + snapshot.assert_match(sql, "out.sql") + + +def test_astype_from_json(json_types_df: bpd.DataFrame, snapshot): + bf_df = json_types_df + + ops_map = { + "int64_col": ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr("json_col"), + "float64_col": ops.AsTypeOp(to_type=dtypes.FLOAT_DTYPE).as_expr("json_col"), + "bool_col": ops.AsTypeOp(to_type=dtypes.BOOL_DTYPE).as_expr("json_col"), + "string_col": ops.AsTypeOp(to_type=dtypes.STRING_DTYPE).as_expr("json_col"), + "int64_w_safe": ops.AsTypeOp(to_type=dtypes.INT_DTYPE, safe=True).as_expr( + "json_col" + ), + } + sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) + snapshot.assert_match(sql, "out.sql") + + +def test_astype_json_invalid( + scalar_types_df: bpd.DataFrame, json_types_df: bpd.DataFrame +): + # Test invalid cast to JSON + with pytest.raises(TypeError, match="Cannot cast timestamp.* to .*json.*"): + ops_map_to = { + "datetime_to_json": ops.AsTypeOp(to_type=dtypes.JSON_DTYPE).as_expr( + "datetime_col" + ), + } + utils._apply_unary_ops( + scalar_types_df, list(ops_map_to.values()), list(ops_map_to.keys()) + ) + + # Test invalid cast from JSON + with pytest.raises(TypeError, match="Cannot cast .*json.* to timestamp.*"): + ops_map_from = { + "json_to_datetime": ops.AsTypeOp(to_type=dtypes.DATETIME_DTYPE).as_expr( + "json_col" + ), + } + utils._apply_unary_ops( + json_types_df, list(ops_map_from.values()), list(ops_map_from.keys()) + ) + + def test_hash(scalar_types_df: bpd.DataFrame, snapshot): col_name = "string_col" bf_df = scalar_types_df[[col_name]] From 26dddb2fc1ccf29ffc07aa6e4e8f4e68e8ff10f0 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 23 Sep 2025 23:06:21 +0000 Subject: [PATCH 2/2] address comments --- .../sqlglot/expressions/generic_ops.py | 94 ++++++++++++------- 1 file changed, 58 insertions(+), 36 deletions(-) diff --git a/bigframes/core/compile/sqlglot/expressions/generic_ops.py b/bigframes/core/compile/sqlglot/expressions/generic_ops.py index 9337e9b333..8a792c0753 100644 --- a/bigframes/core/compile/sqlglot/expressions/generic_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/generic_ops.py @@ -33,46 +33,15 @@ def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression: sg_expr = expr.expr if to_type == dtypes.JSON_DTYPE: - if from_type == dtypes.STRING_DTYPE: - func_name = "PARSE_JSON_IN_SAFE" if op.safe else "PARSE_JSON" - return sge.func(func_name, sg_expr) - if from_type in (dtypes.INT_DTYPE, dtypes.BOOL_DTYPE, dtypes.FLOAT_DTYPE): - sg_expr = sge.Cast(this=sg_expr, to="STRING") - return sge.func("PARSE_JSON", sg_expr) - raise TypeError(f"Cannot cast from {from_type} to {to_type}") + return _cast_to_json(expr, op) if from_type == dtypes.JSON_DTYPE: - func_name = "" - if to_type == dtypes.INT_DTYPE: - func_name = "INT64" - elif to_type == dtypes.FLOAT_DTYPE: - func_name = "FLOAT64" - elif to_type == dtypes.BOOL_DTYPE: - func_name = "BOOL" - elif to_type == dtypes.STRING_DTYPE: - func_name = "STRING" - if func_name: - func_name = "SAFE." + func_name if op.safe else func_name - return sge.func(func_name, sg_expr) - raise TypeError(f"Cannot cast from {from_type} to {to_type}") + return _cast_from_json(expr, op) if to_type == dtypes.INT_DTYPE: - # Cannot cast DATETIME to INT directly so need to convert to TIMESTAMP first. - if from_type == dtypes.DATETIME_DTYPE: - sg_expr = _cast(sg_expr, "TIMESTAMP", op.safe) - return sge.func("UNIX_MICROS", sg_expr) - if from_type == dtypes.TIMESTAMP_DTYPE: - return sge.func("UNIX_MICROS", sg_expr) - if from_type == dtypes.TIME_DTYPE: - return sge.func( - "TIME_DIFF", - _cast(sg_expr, "TIME", op.safe), - sge.convert("00:00:00"), - "MICROSECOND", - ) - if from_type == dtypes.NUMERIC_DTYPE or from_type == dtypes.FLOAT_DTYPE: - sg_expr = sge.func("TRUNC", sg_expr) - return _cast(sg_expr, sg_to_type, op.safe) + result = _cast_to_int(expr, op) + if result is not None: + return result if to_type == dtypes.FLOAT_DTYPE and from_type == dtypes.BOOL_DTYPE: sg_expr = _cast(sg_expr, "INT64", op.safe) @@ -124,6 +93,59 @@ def _(expr: TypedExpr) -> sge.Expression: # Helper functions +def _cast_to_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression: + from_type = expr.dtype + sg_expr = expr.expr + + if from_type == dtypes.STRING_DTYPE: + func_name = "PARSE_JSON_IN_SAFE" if op.safe else "PARSE_JSON" + return sge.func(func_name, sg_expr) + if from_type in (dtypes.INT_DTYPE, dtypes.BOOL_DTYPE, dtypes.FLOAT_DTYPE): + sg_expr = sge.Cast(this=sg_expr, to="STRING") + return sge.func("PARSE_JSON", sg_expr) + raise TypeError(f"Cannot cast from {from_type} to {dtypes.JSON_DTYPE}") + + +def _cast_from_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression: + to_type = op.to_type + sg_expr = expr.expr + func_name = "" + if to_type == dtypes.INT_DTYPE: + func_name = "INT64" + elif to_type == dtypes.FLOAT_DTYPE: + func_name = "FLOAT64" + elif to_type == dtypes.BOOL_DTYPE: + func_name = "BOOL" + elif to_type == dtypes.STRING_DTYPE: + func_name = "STRING" + if func_name: + func_name = "SAFE." + func_name if op.safe else func_name + return sge.func(func_name, sg_expr) + raise TypeError(f"Cannot cast from {dtypes.JSON_DTYPE} to {to_type}") + + +def _cast_to_int(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression | None: + from_type = expr.dtype + sg_expr = expr.expr + # Cannot cast DATETIME to INT directly so need to convert to TIMESTAMP first. + if from_type == dtypes.DATETIME_DTYPE: + sg_expr = _cast(sg_expr, "TIMESTAMP", op.safe) + return sge.func("UNIX_MICROS", sg_expr) + if from_type == dtypes.TIMESTAMP_DTYPE: + return sge.func("UNIX_MICROS", sg_expr) + if from_type == dtypes.TIME_DTYPE: + return sge.func( + "TIME_DIFF", + _cast(sg_expr, "TIME", op.safe), + sge.convert("00:00:00"), + "MICROSECOND", + ) + if from_type == dtypes.NUMERIC_DTYPE or from_type == dtypes.FLOAT_DTYPE: + sg_expr = sge.func("TRUNC", sg_expr) + return _cast(sg_expr, "INT64", op.safe) + return None + + def _cast(expr: sge.Expression, to: str, safe: bool): if safe: return sge.TryCast(this=expr, to=to)