Skip to content

Commit 1e5918b

Browse files
authored
refactor: support agg_ops.CovOp and CorrOp in sqlglot compiler (#2116)
1 parent a3c2522 commit 1e5918b

File tree

4 files changed

+103
-0
lines changed

4 files changed

+103
-0
lines changed

bigframes/core/compile/sqlglot/aggregations/binary_compiler.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from bigframes.core import window_spec
2222
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
23+
from bigframes.core.compile.sqlglot.aggregations.windows import apply_window_if_present
2324
import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr
2425
from bigframes.operations import aggregations as agg_ops
2526

@@ -33,3 +34,25 @@ def compile(
3334
window: typing.Optional[window_spec.WindowSpec] = None,
3435
) -> sge.Expression:
3536
return BINARY_OP_REGISTRATION[op](op, left, right, window=window)
37+
38+
39+
@BINARY_OP_REGISTRATION.register(agg_ops.CorrOp)
40+
def _(
41+
op: agg_ops.CorrOp,
42+
left: typed_expr.TypedExpr,
43+
right: typed_expr.TypedExpr,
44+
window: typing.Optional[window_spec.WindowSpec] = None,
45+
) -> sge.Expression:
46+
result = sge.func("CORR", left.expr, right.expr)
47+
return apply_window_if_present(result, window)
48+
49+
50+
@BINARY_OP_REGISTRATION.register(agg_ops.CovOp)
51+
def _(
52+
op: agg_ops.CovOp,
53+
left: typed_expr.TypedExpr,
54+
right: typed_expr.TypedExpr,
55+
window: typing.Optional[window_spec.WindowSpec] = None,
56+
) -> sge.Expression:
57+
result = sge.func("COVAR_SAMP", left.expr, right.expr)
58+
return apply_window_if_present(result, window)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`,
4+
`float64_col` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
CORR(`bfcol_0`, `bfcol_1`) AS `bfcol_2`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_2` AS `corr_col`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`,
4+
`float64_col` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
COVAR_SAMP(`bfcol_0`, `bfcol_1`) AS `bfcol_2`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_2` AS `cov_col`
13+
FROM `bfcte_1`
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import typing
16+
17+
import pytest
18+
19+
from bigframes.core import agg_expressions as agg_exprs
20+
from bigframes.core import array_value, identifiers, nodes
21+
from bigframes.operations import aggregations as agg_ops
22+
import bigframes.pandas as bpd
23+
24+
pytest.importorskip("pytest_snapshot")
25+
26+
27+
def _apply_binary_agg_ops(
28+
obj: bpd.DataFrame,
29+
ops_list: typing.Sequence[agg_exprs.BinaryAggregation],
30+
new_names: typing.Sequence[str],
31+
) -> str:
32+
aggs = [(op, identifiers.ColumnId(name)) for op, name in zip(ops_list, new_names)]
33+
34+
agg_node = nodes.AggregateNode(obj._block.expr.node, aggregations=tuple(aggs))
35+
result = array_value.ArrayValue(agg_node)
36+
37+
sql = result.session._executor.to_sql(result, enable_cache=False)
38+
return sql
39+
40+
41+
def test_corr(scalar_types_df: bpd.DataFrame, snapshot):
42+
bf_df = scalar_types_df[["int64_col", "float64_col"]]
43+
agg_expr = agg_ops.CorrOp().as_expr("int64_col", "float64_col")
44+
sql = _apply_binary_agg_ops(bf_df, [agg_expr], ["corr_col"])
45+
46+
snapshot.assert_match(sql, "out.sql")
47+
48+
49+
def test_cov(scalar_types_df: bpd.DataFrame, snapshot):
50+
bf_df = scalar_types_df[["int64_col", "float64_col"]]
51+
agg_expr = agg_ops.CovOp().as_expr("int64_col", "float64_col")
52+
sql = _apply_binary_agg_ops(bf_df, [agg_expr], ["cov_col"])
53+
54+
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)