Skip to content

Commit 8035e01

Browse files
authored
refactor: support agg_ops.AllOp and AnyValueOp in sqlglot compiler (#2127)
1 parent 9130a61 commit 8035e01

File tree

4 files changed

+62
-0
lines changed

4 files changed

+62
-0
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,17 @@ def compile(
3838
return UNARY_OP_REGISTRATION[op](op, column, window=window)
3939

4040

41+
@UNARY_OP_REGISTRATION.register(agg_ops.AllOp)
42+
def _(
43+
op: agg_ops.AllOp,
44+
column: typed_expr.TypedExpr,
45+
window: typing.Optional[window_spec.WindowSpec] = None,
46+
) -> sge.Expression:
47+
# BQ will return null for empty column, result would be false in pandas.
48+
result = apply_window_if_present(sge.func("LOGICAL_AND", column.expr), window)
49+
return sge.func("IFNULL", result, sge.true())
50+
51+
4152
@UNARY_OP_REGISTRATION.register(agg_ops.ApproxQuartilesOp)
4253
def _(
4354
op: agg_ops.ApproxQuartilesOp,
@@ -69,6 +80,15 @@ def _(
6980
return sge.func("APPROX_TOP_COUNT", column.expr, sge.convert(op.number))
7081

7182

83+
@UNARY_OP_REGISTRATION.register(agg_ops.AnyValueOp)
84+
def _(
85+
op: agg_ops.AnyValueOp,
86+
column: typed_expr.TypedExpr,
87+
window: typing.Optional[window_spec.WindowSpec] = None,
88+
) -> sge.Expression:
89+
return apply_window_if_present(sge.func("ANY_VALUE", column.expr), window)
90+
91+
7292
@UNARY_OP_REGISTRATION.register(agg_ops.CountOp)
7393
def _(
7494
op: agg_ops.CountOp,
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
COALESCE(LOGICAL_AND(`bfcol_0`), TRUE) AS `bfcol_1`
8+
FROM `bfcte_0`
9+
)
10+
SELECT
11+
`bfcol_1` AS `bool_col`
12+
FROM `bfcte_1`
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
ANY_VALUE(`bfcol_0`) AS `bfcol_1`
8+
FROM `bfcte_0`
9+
)
10+
SELECT
11+
`bfcol_1` AS `int64_col`
12+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ def _apply_unary_window_op(
6363
return sql
6464

6565

66+
def test_all(scalar_types_df: bpd.DataFrame, snapshot):
67+
col_name = "bool_col"
68+
bf_df = scalar_types_df[[col_name]]
69+
agg_expr = agg_ops.AllOp().as_expr(col_name)
70+
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])
71+
72+
snapshot.assert_match(sql, "out.sql")
73+
74+
6675
def test_approx_quartiles(scalar_types_df: bpd.DataFrame, snapshot):
6776
col_name = "int64_col"
6877
bf_df = scalar_types_df[[col_name]]
@@ -87,6 +96,15 @@ def test_approx_top_count(scalar_types_df: bpd.DataFrame, snapshot):
8796
snapshot.assert_match(sql, "out.sql")
8897

8998

99+
def test_any_value(scalar_types_df: bpd.DataFrame, snapshot):
100+
col_name = "int64_col"
101+
bf_df = scalar_types_df[[col_name]]
102+
agg_expr = agg_ops.AnyValueOp().as_expr(col_name)
103+
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])
104+
105+
snapshot.assert_match(sql, "out.sql")
106+
107+
90108
def test_count(scalar_types_df: bpd.DataFrame, snapshot):
91109
col_name = "int64_col"
92110
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)