diff --git a/bigframes/core/compile/sqlglot/expressions/generic_ops.py b/bigframes/core/compile/sqlglot/expressions/generic_ops.py index 2bd19e1967..7ff09ab3f6 100644 --- a/bigframes/core/compile/sqlglot/expressions/generic_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/generic_ops.py @@ -134,6 +134,11 @@ def _( ) +@register_binary_op(ops.fillna_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + return sge.Coalesce(this=left.expr, expressions=[right.expr]) + + @register_nary_op(ops.case_when_op) def _(*cases_and_outputs: TypedExpr) -> sge.Expression: # Need to upcast BOOL to INT if any output is numeric diff --git a/tests/system/small/engines/test_generic_ops.py b/tests/system/small/engines/test_generic_ops.py index 01d4dad849..c0469ed97a 100644 --- a/tests/system/small/engines/test_generic_ops.py +++ b/tests/system/small/engines/test_generic_ops.py @@ -343,7 +343,7 @@ def test_engines_coalesce_op(scalars_array_value: array_value.ArrayValue, engine assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_fillna_op(scalars_array_value: array_value.ArrayValue, engine): arr, _ = scalars_array_value.compute_values( [ diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_fillna/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_fillna/out.sql new file mode 100644 index 0000000000..07f2877e74 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_fillna/out.sql @@ -0,0 +1,14 @@ +WITH `bfcte_0` AS ( + SELECT + `float64_col`, + `int64_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + COALESCE(`int64_col`, `float64_col`) AS `bfcol_2` + FROM `bfcte_0` +) +SELECT + `bfcol_2` AS `int64_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py index 693f8dc34c..11daf6813a 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py @@ -239,6 +239,12 @@ def test_clip(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_fillna(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "float64_col"]] + sql = utils._apply_binary_op(bf_df, ops.fillna_op, "int64_col", "float64_col") + snapshot.assert_match(sql, "out.sql") + + def test_hash(scalar_types_df: bpd.DataFrame, snapshot): col_name = "string_col" bf_df = scalar_types_df[[col_name]]