Skip to content

Commit 1e5f698

Browse files
committed
refactor: add ArrayReduceOp to the sqlglot compiler
1 parent a21361b commit 1e5f698

File tree

8 files changed

+140
-4
lines changed

8 files changed

+140
-4
lines changed

bigframes/core/compile/sqlglot/aggregations/ordered_unary_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def compile(
2727
op: agg_ops.WindowOp,
2828
column: typed_expr.TypedExpr,
2929
*,
30-
order_by: tuple[sge.Expression, ...],
30+
order_by: tuple[sge.Expression, ...] = (),
3131
) -> sge.Expression:
3232
return ORDERED_UNARY_OP_REGISTRATION[op](op, column, order_by=order_by)
3333

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,19 @@ def _(
4949
return sge.func("IFNULL", result, sge.true())
5050

5151

52+
@UNARY_OP_REGISTRATION.register(agg_ops.AnyOp)
53+
def _(
54+
op: agg_ops.AnyOp,
55+
column: typed_expr.TypedExpr,
56+
window: typing.Optional[window_spec.WindowSpec] = None,
57+
) -> sge.Expression:
58+
expr = column.expr
59+
expr = apply_window_if_present(sge.func("LOGICAL_OR", expr), window)
60+
61+
# BQ will return null for empty column, result would be false in pandas.
62+
return sge.func("COALESCE", expr, sge.convert(False))
63+
64+
5265
@UNARY_OP_REGISTRATION.register(agg_ops.ApproxQuartilesOp)
5366
def _(
5467
op: agg_ops.ApproxQuartilesOp,

bigframes/core/compile/sqlglot/expressions/array_ops.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import typing
1818

19-
import sqlglot
19+
import sqlglot as sg
2020
import sqlglot.expressions as sge
2121

2222
from bigframes import operations as ops
@@ -38,17 +38,45 @@ def _(expr: TypedExpr, op: ops.ArrayIndexOp) -> sge.Expression:
3838
)
3939

4040

41+
@register_unary_op(ops.ArrayReduceOp, pass_op=True)
42+
def _(expr: TypedExpr, op: ops.ArrayReduceOp) -> sge.Expression:
43+
sub_expr = sg.to_identifier("bf_arr_reduce_uid")
44+
sub_type = dtypes.get_array_inner_type(expr.dtype)
45+
46+
if op.aggregation.order_independent:
47+
from bigframes.core.compile.sqlglot.aggregations import unary_compiler
48+
49+
agg_expr = unary_compiler.compile(op.aggregation, TypedExpr(sub_expr, sub_type))
50+
else:
51+
from bigframes.core.compile.sqlglot.aggregations import ordered_unary_compiler
52+
53+
agg_expr = ordered_unary_compiler.compile(
54+
op.aggregation, TypedExpr(sub_expr, sub_type)
55+
)
56+
57+
return (
58+
sge.select(agg_expr)
59+
.from_(
60+
sge.Unnest(
61+
expressions=[expr.expr],
62+
alias=sge.TableAlias(columns=[sub_expr]),
63+
)
64+
)
65+
.subquery()
66+
)
67+
68+
4169
@register_unary_op(ops.ArraySliceOp, pass_op=True)
4270
def _(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression:
43-
slice_idx = sqlglot.to_identifier("slice_idx")
71+
slice_idx = sg.to_identifier("slice_idx")
4472

4573
conditions: typing.List[sge.Predicate] = [slice_idx >= op.start]
4674

4775
if op.stop is not None:
4876
conditions.append(slice_idx < op.stop)
4977

5078
# local name for each element in the array
51-
el = sqlglot.to_identifier("el")
79+
el = sg.to_identifier("el")
5280

5381
selected_elements = (
5482
sge.select(el)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
COALESCE(LOGICAL_OR(`bool_col`), FALSE) AS `bfcol_1`
8+
FROM `bfcte_0`
9+
)
10+
SELECT
11+
`bfcol_1` AS `bool_col`
12+
FROM `bfcte_1`
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE
9+
WHEN `bool_col` IS NULL
10+
THEN NULL
11+
ELSE COALESCE(LOGICAL_OR(`bool_col`) OVER (), FALSE)
12+
END AS `bfcol_1`
13+
FROM `bfcte_0`
14+
)
15+
SELECT
16+
`bfcol_1` AS `agg_bool`
17+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,20 @@ def test_all(scalar_types_df: bpd.DataFrame, snapshot):
8888
snapshot.assert_match(sql_window_partition, "window_partition_out.sql")
8989

9090

91+
def test_any(scalar_types_df: bpd.DataFrame, snapshot):
92+
col_name = "bool_col"
93+
bf_df = scalar_types_df[[col_name]]
94+
agg_expr = agg_ops.AnyOp().as_expr(col_name)
95+
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])
96+
97+
snapshot.assert_match(sql, "out.sql")
98+
99+
# Window tests
100+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
101+
sql_window = _apply_unary_window_op(bf_df, agg_expr, window, "agg_bool")
102+
snapshot.assert_match(sql_window, "window_out.sql")
103+
104+
91105
def test_approx_quartiles(scalar_types_df: bpd.DataFrame, snapshot):
92106
col_name = "int64_col"
93107
bf_df = scalar_types_df[[col_name]]
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_list_col`,
4+
`float_list_col`,
5+
`string_list_col`
6+
FROM `bigframes-dev`.`sqlglot_test`.`repeated_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
(
11+
SELECT
12+
COALESCE(SUM(bf_arr_reduce_uid), 0)
13+
FROM UNNEST(`float_list_col`) AS bf_arr_reduce_uid
14+
) AS `bfcol_3`,
15+
(
16+
SELECT
17+
STDDEV(bf_arr_reduce_uid)
18+
FROM UNNEST(`float_list_col`) AS bf_arr_reduce_uid
19+
) AS `bfcol_4`,
20+
(
21+
SELECT
22+
COUNT(bf_arr_reduce_uid)
23+
FROM UNNEST(`string_list_col`) AS bf_arr_reduce_uid
24+
) AS `bfcol_5`,
25+
(
26+
SELECT
27+
COALESCE(LOGICAL_OR(bf_arr_reduce_uid), FALSE)
28+
FROM UNNEST(`bool_list_col`) AS bf_arr_reduce_uid
29+
) AS `bfcol_6`
30+
FROM `bfcte_0`
31+
)
32+
SELECT
33+
`bfcol_3` AS `sum_float`,
34+
`bfcol_4` AS `std_float`,
35+
`bfcol_5` AS `count_str`,
36+
`bfcol_6` AS `any_bool`
37+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_array_ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from bigframes import operations as ops
1818
from bigframes.core import expression
1919
from bigframes.operations._op_converters import convert_index, convert_slice
20+
import bigframes.operations.aggregations as agg_ops
2021
import bigframes.pandas as bpd
2122
from bigframes.testing import utils
2223

@@ -43,6 +44,20 @@ def test_array_index(repeated_types_df: bpd.DataFrame, snapshot):
4344
snapshot.assert_match(sql, "out.sql")
4445

4546

47+
def test_array_reduce_op(repeated_types_df: bpd.DataFrame, snapshot):
48+
ops_map = {
49+
"sum_float": ops.ArrayReduceOp(agg_ops.SumOp()).as_expr("float_list_col"),
50+
"std_float": ops.ArrayReduceOp(agg_ops.StdOp()).as_expr("float_list_col"),
51+
"count_str": ops.ArrayReduceOp(agg_ops.CountOp()).as_expr("string_list_col"),
52+
"any_bool": ops.ArrayReduceOp(agg_ops.AnyOp()).as_expr("bool_list_col"),
53+
}
54+
55+
sql = utils._apply_ops_to_sql(
56+
repeated_types_df, list(ops_map.values()), list(ops_map.keys())
57+
)
58+
snapshot.assert_match(sql, "out.sql")
59+
60+
4661
def test_array_slice_with_only_start(repeated_types_df: bpd.DataFrame, snapshot):
4762
col_name = "string_list_col"
4863
bf_df = repeated_types_df[[col_name]]

0 commit comments

Comments
 (0)