diff --git a/bigframes/core/compile/sqlglot/aggregations/binary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/binary_compiler.py index a162a9c18a..856b5e2f3a 100644 --- a/bigframes/core/compile/sqlglot/aggregations/binary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/binary_compiler.py @@ -20,6 +20,7 @@ from bigframes.core import window_spec import bigframes.core.compile.sqlglot.aggregations.op_registration as reg +from bigframes.core.compile.sqlglot.aggregations.windows import apply_window_if_present import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr from bigframes.operations import aggregations as agg_ops @@ -33,3 +34,25 @@ def compile( window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: return BINARY_OP_REGISTRATION[op](op, left, right, window=window) + + +@BINARY_OP_REGISTRATION.register(agg_ops.CorrOp) +def _( + op: agg_ops.CorrOp, + left: typed_expr.TypedExpr, + right: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + result = sge.func("CORR", left.expr, right.expr) + return apply_window_if_present(result, window) + + +@BINARY_OP_REGISTRATION.register(agg_ops.CovOp) +def _( + op: agg_ops.CovOp, + left: typed_expr.TypedExpr, + right: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + result = sge.func("COVAR_SAMP", left.expr, right.expr) + return apply_window_if_present(result, window) diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_binary_compiler/test_corr/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_binary_compiler/test_corr/out.sql new file mode 100644 index 0000000000..8922a71de4 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_binary_compiler/test_corr/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0`, + `float64_col` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + CORR(`bfcol_0`, `bfcol_1`) AS `bfcol_2` + FROM `bfcte_0` +) +SELECT + `bfcol_2` AS `corr_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_binary_compiler/test_cov/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_binary_compiler/test_cov/out.sql new file mode 100644 index 0000000000..6cf189da31 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_binary_compiler/test_cov/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0`, + `float64_col` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + COVAR_SAMP(`bfcol_0`, `bfcol_1`) AS `bfcol_2` + FROM `bfcte_0` +) +SELECT + `bfcol_2` AS `cov_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_binary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_binary_compiler.py new file mode 100644 index 0000000000..0897b535be --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/test_binary_compiler.py @@ -0,0 +1,54 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +import pytest + +from bigframes.core import agg_expressions as agg_exprs +from bigframes.core import array_value, identifiers, nodes +from bigframes.operations import aggregations as agg_ops +import bigframes.pandas as bpd + +pytest.importorskip("pytest_snapshot") + + +def _apply_binary_agg_ops( + obj: bpd.DataFrame, + ops_list: typing.Sequence[agg_exprs.BinaryAggregation], + new_names: typing.Sequence[str], +) -> str: + aggs = [(op, identifiers.ColumnId(name)) for op, name in zip(ops_list, new_names)] + + agg_node = nodes.AggregateNode(obj._block.expr.node, aggregations=tuple(aggs)) + result = array_value.ArrayValue(agg_node) + + sql = result.session._executor.to_sql(result, enable_cache=False) + return sql + + +def test_corr(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "float64_col"]] + agg_expr = agg_ops.CorrOp().as_expr("int64_col", "float64_col") + sql = _apply_binary_agg_ops(bf_df, [agg_expr], ["corr_col"]) + + snapshot.assert_match(sql, "out.sql") + + +def test_cov(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "float64_col"]] + agg_expr = agg_ops.CovOp().as_expr("int64_col", "float64_col") + sql = _apply_binary_agg_ops(bf_df, [agg_expr], ["cov_col"]) + + snapshot.assert_match(sql, "out.sql")