Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 100 additions & 2 deletions bigframes/core/compile/sqlglot/expressions/generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,54 @@

import sqlglot.expressions as sge

from bigframes import dtypes
from bigframes import operations as ops
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
from bigframes.core.compile.sqlglot.sqlglot_types import SQLGlotType

register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op


@register_unary_op(ops.AsTypeOp, pass_op=True)
def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
# TODO: Support more types for casting, such as JSON, etc.
return sge.Cast(this=expr.expr, to=op.to_type)
from_type = expr.dtype
to_type = op.to_type
sg_to_type = SQLGlotType.from_bigframes_dtype(to_type)
sg_expr = expr.expr

if to_type == dtypes.JSON_DTYPE:
return _cast_to_json(expr, op)

if from_type == dtypes.JSON_DTYPE:
return _cast_from_json(expr, op)

if to_type == dtypes.INT_DTYPE:
result = _cast_to_int(expr, op)
if result is not None:
return result

if to_type == dtypes.FLOAT_DTYPE and from_type == dtypes.BOOL_DTYPE:
sg_expr = _cast(sg_expr, "INT64", op.safe)
return _cast(sg_expr, sg_to_type, op.safe)

if to_type == dtypes.BOOL_DTYPE:
if from_type == dtypes.BOOL_DTYPE:
return sg_expr
else:
return sge.NEQ(this=sg_expr, expression=sge.convert(0))

if to_type == dtypes.STRING_DTYPE:
sg_expr = _cast(sg_expr, sg_to_type, op.safe)
if from_type == dtypes.BOOL_DTYPE:
sg_expr = sge.func("INITCAP", sg_expr)
return sg_expr

if dtypes.is_time_like(to_type) and from_type == dtypes.INT_DTYPE:
sg_expr = sge.func("TIMESTAMP_MICROS", sg_expr)
return _cast(sg_expr, sg_to_type, op.safe)

return _cast(sg_expr, sg_to_type, op.safe)


@register_unary_op(ops.hash_op)
Expand All @@ -53,3 +90,64 @@ def _(expr: TypedExpr, op: ops.MapOp) -> sge.Expression:
@register_unary_op(ops.notnull_op)
def _(expr: TypedExpr) -> sge.Expression:
return sge.Not(this=sge.Is(this=expr.expr, expression=sge.Null()))


# Helper functions
def _cast_to_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
from_type = expr.dtype
sg_expr = expr.expr

if from_type == dtypes.STRING_DTYPE:
func_name = "PARSE_JSON_IN_SAFE" if op.safe else "PARSE_JSON"
return sge.func(func_name, sg_expr)
if from_type in (dtypes.INT_DTYPE, dtypes.BOOL_DTYPE, dtypes.FLOAT_DTYPE):
sg_expr = sge.Cast(this=sg_expr, to="STRING")
return sge.func("PARSE_JSON", sg_expr)
raise TypeError(f"Cannot cast from {from_type} to {dtypes.JSON_DTYPE}")


def _cast_from_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
to_type = op.to_type
sg_expr = expr.expr
func_name = ""
if to_type == dtypes.INT_DTYPE:
func_name = "INT64"
elif to_type == dtypes.FLOAT_DTYPE:
func_name = "FLOAT64"
elif to_type == dtypes.BOOL_DTYPE:
func_name = "BOOL"
elif to_type == dtypes.STRING_DTYPE:
func_name = "STRING"
if func_name:
func_name = "SAFE." + func_name if op.safe else func_name
return sge.func(func_name, sg_expr)
raise TypeError(f"Cannot cast from {dtypes.JSON_DTYPE} to {to_type}")


def _cast_to_int(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression | None:
from_type = expr.dtype
sg_expr = expr.expr
# Cannot cast DATETIME to INT directly so need to convert to TIMESTAMP first.
if from_type == dtypes.DATETIME_DTYPE:
sg_expr = _cast(sg_expr, "TIMESTAMP", op.safe)
return sge.func("UNIX_MICROS", sg_expr)
if from_type == dtypes.TIMESTAMP_DTYPE:
return sge.func("UNIX_MICROS", sg_expr)
if from_type == dtypes.TIME_DTYPE:
return sge.func(
"TIME_DIFF",
_cast(sg_expr, "TIME", op.safe),
sge.convert("00:00:00"),
"MICROSECOND",
)
if from_type == dtypes.NUMERIC_DTYPE or from_type == dtypes.FLOAT_DTYPE:
sg_expr = sge.func("TRUNC", sg_expr)
return _cast(sg_expr, "INT64", op.safe)
return None


def _cast(expr: sge.Expression, to: str, safe: bool):
if safe:
return sge.TryCast(this=expr, to=to)
else:
return sge.Cast(this=expr, to=to)
36 changes: 18 additions & 18 deletions tests/system/small/engines/test_generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def apply_op(
return new_arr


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_int(scalars_array_value: array_value.ArrayValue, engine):
arr = apply_op(
scalars_array_value,
Expand All @@ -63,7 +63,7 @@ def test_engines_astype_int(scalars_array_value: array_value.ArrayValue, engine)
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_string_int(scalars_array_value: array_value.ArrayValue, engine):
vals = ["1", "100", "-3"]
arr, _ = scalars_array_value.compute_values(
Expand All @@ -78,7 +78,7 @@ def test_engines_astype_string_int(scalars_array_value: array_value.ArrayValue,
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_float(scalars_array_value: array_value.ArrayValue, engine):
arr = apply_op(
scalars_array_value,
Expand All @@ -89,7 +89,7 @@ def test_engines_astype_float(scalars_array_value: array_value.ArrayValue, engin
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_string_float(
scalars_array_value: array_value.ArrayValue, engine
):
Expand All @@ -106,7 +106,7 @@ def test_engines_astype_string_float(
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_bool(scalars_array_value: array_value.ArrayValue, engine):
arr = apply_op(
scalars_array_value, ops.AsTypeOp(to_type=bigframes.dtypes.BOOL_DTYPE)
Expand All @@ -115,7 +115,7 @@ def test_engines_astype_bool(scalars_array_value: array_value.ArrayValue, engine
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_string(scalars_array_value: array_value.ArrayValue, engine):
# floats work slightly different with trailing zeroes rn
arr = apply_op(
Expand All @@ -127,7 +127,7 @@ def test_engines_astype_string(scalars_array_value: array_value.ArrayValue, engi
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_numeric(scalars_array_value: array_value.ArrayValue, engine):
arr = apply_op(
scalars_array_value,
Expand All @@ -138,7 +138,7 @@ def test_engines_astype_numeric(scalars_array_value: array_value.ArrayValue, eng
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_string_numeric(
scalars_array_value: array_value.ArrayValue, engine
):
Expand All @@ -155,7 +155,7 @@ def test_engines_astype_string_numeric(
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_date(scalars_array_value: array_value.ArrayValue, engine):
arr = apply_op(
scalars_array_value,
Expand All @@ -166,7 +166,7 @@ def test_engines_astype_date(scalars_array_value: array_value.ArrayValue, engine
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_string_date(
scalars_array_value: array_value.ArrayValue, engine
):
Expand All @@ -183,7 +183,7 @@ def test_engines_astype_string_date(
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_datetime(scalars_array_value: array_value.ArrayValue, engine):
arr = apply_op(
scalars_array_value,
Expand All @@ -194,7 +194,7 @@ def test_engines_astype_datetime(scalars_array_value: array_value.ArrayValue, en
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_string_datetime(
scalars_array_value: array_value.ArrayValue, engine
):
Expand All @@ -211,7 +211,7 @@ def test_engines_astype_string_datetime(
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_timestamp(scalars_array_value: array_value.ArrayValue, engine):
arr = apply_op(
scalars_array_value,
Expand All @@ -222,7 +222,7 @@ def test_engines_astype_timestamp(scalars_array_value: array_value.ArrayValue, e
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_string_timestamp(
scalars_array_value: array_value.ArrayValue, engine
):
Expand All @@ -243,7 +243,7 @@ def test_engines_astype_string_timestamp(
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_time(scalars_array_value: array_value.ArrayValue, engine):
arr = apply_op(
scalars_array_value,
Expand All @@ -254,7 +254,7 @@ def test_engines_astype_time(scalars_array_value: array_value.ArrayValue, engine
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_from_json(scalars_array_value: array_value.ArrayValue, engine):
exprs = [
ops.AsTypeOp(to_type=bigframes.dtypes.INT_DTYPE).as_expr(
Expand All @@ -275,7 +275,7 @@ def test_engines_astype_from_json(scalars_array_value: array_value.ArrayValue, e
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_to_json(scalars_array_value: array_value.ArrayValue, engine):
exprs = [
ops.AsTypeOp(to_type=bigframes.dtypes.JSON_DTYPE).as_expr(
Expand All @@ -298,7 +298,7 @@ def test_engines_astype_to_json(scalars_array_value: array_value.ArrayValue, eng
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_timedelta(scalars_array_value: array_value.ArrayValue, engine):
arr = apply_op(
scalars_array_value,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
WITH `bfcte_0` AS (
SELECT
`bool_col` AS `bfcol_0`,
`float64_col` AS `bfcol_1`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
`bfcol_0` AS `bfcol_2`,
`bfcol_1` <> 0 AS `bfcol_3`,
`bfcol_1` <> 0 AS `bfcol_4`
FROM `bfcte_0`
)
SELECT
`bfcol_2` AS `bool_col`,
`bfcol_3` AS `float64_col`,
`bfcol_4` AS `float64_w_safe`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
WITH `bfcte_0` AS (
SELECT
`bool_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
CAST(CAST(`bfcol_0` AS INT64) AS FLOAT64) AS `bfcol_1`,
CAST('1.34235e4' AS FLOAT64) AS `bfcol_2`,
SAFE_CAST(SAFE_CAST(`bfcol_0` AS INT64) AS FLOAT64) AS `bfcol_3`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `bool_col`,
`bfcol_2` AS `str_const`,
`bfcol_3` AS `bool_w_safe`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
WITH `bfcte_0` AS (
SELECT
`json_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`json_types`
), `bfcte_1` AS (
SELECT
*,
INT64(`bfcol_0`) AS `bfcol_1`,
FLOAT64(`bfcol_0`) AS `bfcol_2`,
BOOL(`bfcol_0`) AS `bfcol_3`,
STRING(`bfcol_0`) AS `bfcol_4`,
SAFE.INT64(`bfcol_0`) AS `bfcol_5`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `int64_col`,
`bfcol_2` AS `float64_col`,
`bfcol_3` AS `bool_col`,
`bfcol_4` AS `string_col`,
`bfcol_5` AS `int64_w_safe`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
WITH `bfcte_0` AS (
SELECT
`datetime_col` AS `bfcol_0`,
`numeric_col` AS `bfcol_1`,
`float64_col` AS `bfcol_2`,
`time_col` AS `bfcol_3`,
`timestamp_col` AS `bfcol_4`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
UNIX_MICROS(CAST(`bfcol_0` AS TIMESTAMP)) AS `bfcol_5`,
UNIX_MICROS(SAFE_CAST(`bfcol_0` AS TIMESTAMP)) AS `bfcol_6`,
TIME_DIFF(CAST(`bfcol_3` AS TIME), '00:00:00', MICROSECOND) AS `bfcol_7`,
TIME_DIFF(SAFE_CAST(`bfcol_3` AS TIME), '00:00:00', MICROSECOND) AS `bfcol_8`,
UNIX_MICROS(`bfcol_4`) AS `bfcol_9`,
CAST(TRUNC(`bfcol_1`) AS INT64) AS `bfcol_10`,
CAST(TRUNC(`bfcol_2`) AS INT64) AS `bfcol_11`,
SAFE_CAST(TRUNC(`bfcol_2`) AS INT64) AS `bfcol_12`,
CAST('100' AS INT64) AS `bfcol_13`
FROM `bfcte_0`
)
SELECT
`bfcol_5` AS `datetime_col`,
`bfcol_6` AS `datetime_w_safe`,
`bfcol_7` AS `time_col`,
`bfcol_8` AS `time_w_safe`,
`bfcol_9` AS `timestamp_col`,
`bfcol_10` AS `numeric_col`,
`bfcol_11` AS `float64_col`,
`bfcol_12` AS `float64_w_safe`,
`bfcol_13` AS `str_const`
FROM `bfcte_1`
Loading