diff --git a/bigframes/core/compile/sqlglot/expressions/generic_ops.py b/bigframes/core/compile/sqlglot/expressions/generic_ops.py index af3b57f77b..1ed49b89eb 100644 --- a/bigframes/core/compile/sqlglot/expressions/generic_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/generic_ops.py @@ -23,6 +23,7 @@ import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op +register_ternary_op = scalar_compiler.scalar_op_compiler.register_ternary_op @register_unary_op(ops.AsTypeOp, pass_op=True) @@ -66,6 +67,18 @@ def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression: return _cast(sg_expr, sg_to_type, op.safe) +@register_ternary_op(ops.clip_op) +def _( + original: TypedExpr, + lower: TypedExpr, + upper: TypedExpr, +) -> sge.Expression: + return sge.Greatest( + this=sge.Least(this=original.expr, expressions=[upper.expr]), + expressions=[lower.expr], + ) + + @register_unary_op(ops.hash_op) def _(expr: TypedExpr) -> sge.Expression: return sge.func("FARM_FINGERPRINT", expr.expr) @@ -94,6 +107,13 @@ def _(expr: TypedExpr) -> sge.Expression: return sge.Not(this=sge.Is(this=expr.expr, expression=sge.Null())) +@register_ternary_op(ops.where_op) +def _( + original: TypedExpr, condition: TypedExpr, replacement: TypedExpr +) -> sge.Expression: + return sge.If(this=condition.expr, true=original.expr, false=replacement.expr) + + # Helper functions def _cast_to_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression: from_type = expr.dtype diff --git a/tests/system/small/engines/test_generic_ops.py b/tests/system/small/engines/test_generic_ops.py index f252782dbd..ae7eafd347 100644 --- a/tests/system/small/engines/test_generic_ops.py +++ b/tests/system/small/engines/test_generic_ops.py @@ -314,7 +314,7 @@ def test_engines_astype_timedelta(scalars_array_value: array_value.ArrayValue, e assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_where_op(scalars_array_value: array_value.ArrayValue, engine): arr, _ = scalars_array_value.compute_values( [ diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_clip/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_clip/out.sql new file mode 100644 index 0000000000..172e1f53e7 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_clip/out.sql @@ -0,0 +1,15 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0`, + `int64_too` AS `bfcol_1`, + `rowindex` AS `bfcol_2` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + GREATEST(LEAST(`bfcol_2`, `bfcol_1`), `bfcol_0`) AS `bfcol_3` + FROM `bfcte_0` +) +SELECT + `bfcol_3` AS `result_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_where/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_where/out.sql new file mode 100644 index 0000000000..678208e9ba --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_where/out.sql @@ -0,0 +1,15 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1`, + `float64_col` AS `bfcol_2` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + IF(`bfcol_0`, `bfcol_1`, `bfcol_2`) AS `bfcol_3` + FROM `bfcte_0` +) +SELECT + `bfcol_3` AS `result_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py index d9ae6ab539..261a630d3a 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py @@ -168,6 +168,22 @@ def test_astype_json_invalid( ) +def test_clip(scalar_types_df: bpd.DataFrame, snapshot): + op_expr = ops.clip_op.as_expr("rowindex", "int64_col", "int64_too") + + array_value = scalar_types_df._block.expr + result, col_ids = array_value.compute_values([op_expr]) + + # Rename columns for deterministic golden SQL results. + assert len(col_ids) == 1 + result = result.rename_columns({col_ids[0]: "result_col"}).select_columns( + ["result_col"] + ) + + sql = result.session._executor.to_sql(result, enable_cache=False) + snapshot.assert_match(sql, "out.sql") + + def test_hash(scalar_types_df: bpd.DataFrame, snapshot): col_name = "string_col" bf_df = scalar_types_df[[col_name]] @@ -202,3 +218,19 @@ def test_map(scalar_types_df: bpd.DataFrame, snapshot): ) snapshot.assert_match(sql, "out.sql") + + +def test_where(scalar_types_df: bpd.DataFrame, snapshot): + op_expr = ops.where_op.as_expr("int64_col", "bool_col", "float64_col") + + array_value = scalar_types_df._block.expr + result, col_ids = array_value.compute_values([op_expr]) + + # Rename columns for deterministic golden SQL results. + assert len(col_ids) == 1 + result = result.rename_columns({col_ids[0]: "result_col"}).select_columns( + ["result_col"] + ) + + sql = result.session._executor.to_sql(result, enable_cache=False) + snapshot.assert_match(sql, "out.sql")