Skip to content

Commit c6c3330

Browse files
committed
refactor: fix test_list_apply_callable on agg_ops.AllOp and agg_ops.AnyOp
1 parent 17b8c86 commit c6c3330

File tree

7 files changed

+45
-41
lines changed

7 files changed

+45
-41
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,13 @@ def _(
4545
column: typed_expr.TypedExpr,
4646
window: typing.Optional[window_spec.WindowSpec] = None,
4747
) -> sge.Expression:
48-
# BQ will return null for empty column, result would be false in pandas.
49-
result = apply_window_if_present(sge.func("LOGICAL_AND", column.expr), window)
50-
return sge.func("IFNULL", result, sge.true())
48+
expr = column.expr
49+
if column.dtype != dtypes.BOOL_DTYPE:
50+
expr = sge.NEQ(this=expr, expression=sge.convert(0))
51+
expr = apply_window_if_present(sge.func("LOGICAL_AND", expr), window)
52+
53+
# BQ will return null for empty column, result would be true in pandas.
54+
return sge.func("COALESCE", expr, sge.convert(True))
5155

5256

5357
@UNARY_OP_REGISTRATION.register(agg_ops.AnyOp)
@@ -57,6 +61,8 @@ def _(
5761
window: typing.Optional[window_spec.WindowSpec] = None,
5862
) -> sge.Expression:
5963
expr = column.expr
64+
if column.dtype != dtypes.BOOL_DTYPE:
65+
expr = sge.NEQ(this=expr, expression=sge.convert(0))
6066
expr = apply_window_if_present(sge.func("LOGICAL_OR", expr), window)
6167

6268
# BQ will return null for empty column, result would be false in pandas.
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
WITH `bfcte_0` AS (
22
SELECT
3-
`bool_col`
3+
`bool_col`,
4+
`int64_col`
45
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
56
), `bfcte_1` AS (
67
SELECT
7-
COALESCE(LOGICAL_AND(`bool_col`), TRUE) AS `bfcol_1`
8+
COALESCE(LOGICAL_AND(`bool_col`), TRUE) AS `bfcol_2`,
9+
COALESCE(LOGICAL_AND(`int64_col` <> 0), TRUE) AS `bfcol_3`
810
FROM `bfcte_0`
911
)
1012
SELECT
11-
`bfcol_1` AS `bool_col`
13+
`bfcol_2` AS `bool_col`,
14+
`bfcol_3` AS `int64_col`
1215
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all/window_partition_out.sql

Lines changed: 0 additions & 14 deletions
This file was deleted.

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all/window_out.sql renamed to tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all_w_window/out.sql

File renamed without changes.
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
WITH `bfcte_0` AS (
22
SELECT
3-
`bool_col`
3+
`bool_col`,
4+
`int64_col`
45
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
56
), `bfcte_1` AS (
67
SELECT
7-
COALESCE(LOGICAL_OR(`bool_col`), FALSE) AS `bfcol_1`
8+
COALESCE(LOGICAL_OR(`bool_col`), FALSE) AS `bfcol_2`,
9+
COALESCE(LOGICAL_OR(`int64_col` <> 0), FALSE) AS `bfcol_3`
810
FROM `bfcte_0`
911
)
1012
SELECT
11-
`bfcol_1` AS `bool_col`
13+
`bfcol_2` AS `bool_col`,
14+
`bfcol_3` AS `int64_col`
1215
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_any/window_out.sql renamed to tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_any_w_window/out.sql

File renamed without changes.

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -63,41 +63,47 @@ def _apply_unary_window_op(
6363

6464

6565
def test_all(scalar_types_df: bpd.DataFrame, snapshot):
66+
bf_df = scalar_types_df[["bool_col", "int64_col"]]
67+
ops_map = {
68+
"bool_col": agg_ops.AllOp().as_expr("bool_col"),
69+
"int64_col": agg_ops.AllOp().as_expr("int64_col"),
70+
}
71+
sql = _apply_unary_agg_ops(bf_df, list(ops_map.values()), list(ops_map.keys()))
72+
73+
snapshot.assert_match(sql, "out.sql")
74+
75+
76+
def test_all_w_window(scalar_types_df: bpd.DataFrame, snapshot):
6677
col_name = "bool_col"
6778
bf_df = scalar_types_df[[col_name]]
6879
agg_expr = agg_ops.AllOp().as_expr(col_name)
69-
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])
70-
71-
snapshot.assert_match(sql, "out.sql")
7280

7381
# Window tests
7482
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
7583
sql_window = _apply_unary_window_op(bf_df, agg_expr, window, "agg_bool")
76-
snapshot.assert_match(sql_window, "window_out.sql")
77-
78-
bf_df_str = scalar_types_df[[col_name, "string_col"]]
79-
window_partition = window_spec.WindowSpec(
80-
grouping_keys=(expression.deref("string_col"),),
81-
ordering=(ordering.descending_over(col_name),),
82-
)
83-
sql_window_partition = _apply_unary_window_op(
84-
bf_df_str, agg_expr, window_partition, "agg_bool"
85-
)
86-
snapshot.assert_match(sql_window_partition, "window_partition_out.sql")
84+
snapshot.assert_match(sql_window, "out.sql")
8785

8886

8987
def test_any(scalar_types_df: bpd.DataFrame, snapshot):
88+
bf_df = scalar_types_df[["bool_col", "int64_col"]]
89+
ops_map = {
90+
"bool_col": agg_ops.AnyOp().as_expr("bool_col"),
91+
"int64_col": agg_ops.AnyOp().as_expr("int64_col"),
92+
}
93+
sql = _apply_unary_agg_ops(bf_df, list(ops_map.values()), list(ops_map.keys()))
94+
95+
snapshot.assert_match(sql, "out.sql")
96+
97+
98+
def test_any_w_window(scalar_types_df: bpd.DataFrame, snapshot):
9099
col_name = "bool_col"
91100
bf_df = scalar_types_df[[col_name]]
92101
agg_expr = agg_ops.AnyOp().as_expr(col_name)
93-
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])
94-
95-
snapshot.assert_match(sql, "out.sql")
96102

97103
# Window tests
98104
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
99105
sql_window = _apply_unary_window_op(bf_df, agg_expr, window, "agg_bool")
100-
snapshot.assert_match(sql_window, "window_out.sql")
106+
snapshot.assert_match(sql_window, "out.sql")
101107

102108

103109
def test_approx_quartiles(scalar_types_df: bpd.DataFrame, snapshot):

0 commit comments

Comments
 (0)