Skip to content

Commit 999dd99

Browse files
authored
refactor: add agg_ops.MinOp and MaxOp for sqlglot compiler (#2097)
* refactor: add agg_ops.MinOp and MaxOp for sqlglot compiler * allow int timedelta to micro * address comments
1 parent 9dc9695 commit 999dd99

File tree

8 files changed

+118
-30
lines changed

8 files changed

+118
-30
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ repos:
2020
hooks:
2121
- id: trailing-whitespace
2222
- id: end-of-file-fixer
23-
exclude: "^tests/unit/core/compile/sqlglot/snapshots"
23+
exclude: "^tests/unit/core/compile/sqlglot/.*snapshots"
2424
- id: check-yaml
2525
- repo: https://github.com/pycqa/isort
2626
rev: 5.12.0

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import typing
1818

19+
import pandas as pd
1920
import sqlglot.expressions as sge
2021

2122
from bigframes import dtypes
@@ -46,18 +47,22 @@ def _(
4647
return apply_window_if_present(sge.func("COUNT", column.expr), window)
4748

4849

49-
@UNARY_OP_REGISTRATION.register(agg_ops.SumOp)
50+
@UNARY_OP_REGISTRATION.register(agg_ops.MaxOp)
5051
def _(
51-
op: agg_ops.SumOp,
52+
op: agg_ops.MaxOp,
5253
column: typed_expr.TypedExpr,
5354
window: typing.Optional[window_spec.WindowSpec] = None,
5455
) -> sge.Expression:
55-
expr = column.expr
56-
if column.dtype == dtypes.BOOL_DTYPE:
57-
expr = sge.Cast(this=column.expr, to="INT64")
58-
# Will be null if all inputs are null. Pandas defaults to zero sum though.
59-
expr = apply_window_if_present(sge.func("SUM", expr), window)
60-
return sge.func("IFNULL", expr, ir._literal(0, column.dtype))
56+
return apply_window_if_present(sge.func("MAX", column.expr), window)
57+
58+
59+
@UNARY_OP_REGISTRATION.register(agg_ops.MinOp)
60+
def _(
61+
op: agg_ops.MinOp,
62+
column: typed_expr.TypedExpr,
63+
window: typing.Optional[window_spec.WindowSpec] = None,
64+
) -> sge.Expression:
65+
return apply_window_if_present(sge.func("MIN", column.expr), window)
6166

6267

6368
@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp)
@@ -67,3 +72,20 @@ def _(
6772
window: typing.Optional[window_spec.WindowSpec] = None,
6873
) -> sge.Expression:
6974
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)
75+
76+
77+
@UNARY_OP_REGISTRATION.register(agg_ops.SumOp)
78+
def _(
79+
op: agg_ops.SumOp,
80+
column: typed_expr.TypedExpr,
81+
window: typing.Optional[window_spec.WindowSpec] = None,
82+
) -> sge.Expression:
83+
expr = column.expr
84+
if column.dtype == dtypes.BOOL_DTYPE:
85+
expr = sge.Cast(this=column.expr, to="INT64")
86+
87+
expr = apply_window_if_present(sge.func("SUM", expr), window)
88+
89+
# Will be null if all inputs are null. Pandas defaults to zero sum though.
90+
zero = pd.to_timedelta(0) if column.dtype == dtypes.TIMEDELTA_DTYPE else 0
91+
return sge.func("IFNULL", expr, ir._literal(zero, column.dtype))
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
COUNT(`bfcol_0`) AS `bfcol_1`
8+
FROM `bfcte_0`
9+
)
10+
SELECT
11+
`bfcol_1` AS `int64_col`
12+
FROM `bfcte_1`
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
MAX(`bfcol_0`) AS `bfcol_1`
8+
FROM `bfcte_0`
9+
)
10+
SELECT
11+
`bfcol_1` AS `int64_col`
12+
FROM `bfcte_1`
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
MIN(`bfcol_0`) AS `bfcol_1`
8+
FROM `bfcte_0`
9+
)
10+
SELECT
11+
`bfcol_1` AS `int64_col`
12+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_size/out.sql renamed to tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_size_unary/out.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
WITH `bfcte_0` AS (
22
SELECT
3-
`string_col` AS `bfcol_0`
3+
`float64_col` AS `bfcol_0`
44
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
55
), `bfcte_1` AS (
66
SELECT
77
COUNT(1) AS `bfcol_1`
88
FROM `bfcte_0`
99
)
1010
SELECT
11-
`bfcol_1` AS `string_col_agg`
11+
`bfcol_1` AS `float64_col`
1212
FROM `bfcte_1`
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
WITH `bfcte_0` AS (
22
SELECT
3-
`int64_col` AS `bfcol_0`
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`
45
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
56
), `bfcte_1` AS (
67
SELECT
7-
COALESCE(SUM(`bfcol_0`), 0) AS `bfcol_1`
8+
COALESCE(SUM(`bfcol_1`), 0) AS `bfcol_4`,
9+
COALESCE(SUM(CAST(`bfcol_0` AS INT64)), 0) AS `bfcol_5`
810
FROM `bfcte_0`
911
)
1012
SELECT
11-
`bfcol_1` AS `int64_col_agg`
13+
`bfcol_4` AS `int64_col`,
14+
`bfcol_5` AS `bool_col`
1215
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,40 +12,67 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import typing
16+
1517
import pytest
1618

17-
from bigframes.core import agg_expressions, array_value, expression, identifiers, nodes
19+
from bigframes.core import agg_expressions as agg_exprs
20+
from bigframes.core import array_value, identifiers, nodes
1821
from bigframes.operations import aggregations as agg_ops
1922
import bigframes.pandas as bpd
2023

2124
pytest.importorskip("pytest_snapshot")
2225

2326

24-
def _apply_unary_op(obj: bpd.DataFrame, op: agg_ops.UnaryWindowOp, arg: str) -> str:
25-
agg_node = nodes.AggregateNode(
26-
obj._block.expr.node,
27-
aggregations=(
28-
(
29-
agg_expressions.UnaryAggregation(op, expression.deref(arg)),
30-
identifiers.ColumnId(arg + "_agg"),
31-
),
32-
),
33-
)
27+
def _apply_unary_agg_ops(
28+
obj: bpd.DataFrame,
29+
ops_list: typing.Sequence[agg_exprs.UnaryAggregation],
30+
new_names: typing.Sequence[str],
31+
) -> str:
32+
aggs = [(op, identifiers.ColumnId(name)) for op, name in zip(ops_list, new_names)]
33+
34+
agg_node = nodes.AggregateNode(obj._block.expr.node, aggregations=tuple(aggs))
3435
result = array_value.ArrayValue(agg_node)
3536

3637
sql = result.session._executor.to_sql(result, enable_cache=False)
3738
return sql
3839

3940

40-
def test_size(scalar_types_df: bpd.DataFrame, snapshot):
41-
bf_df = scalar_types_df[["string_col"]]
42-
sql = _apply_unary_op(bf_df, agg_ops.SizeUnaryOp(), "string_col")
41+
def test_count(scalar_types_df: bpd.DataFrame, snapshot):
42+
col_name = "int64_col"
43+
bf_df = scalar_types_df[[col_name]]
44+
agg_expr = agg_ops.CountOp().as_expr(col_name)
45+
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])
46+
47+
snapshot.assert_match(sql, "out.sql")
48+
49+
50+
def test_max(scalar_types_df: bpd.DataFrame, snapshot):
51+
col_name = "int64_col"
52+
bf_df = scalar_types_df[[col_name]]
53+
agg_expr = agg_ops.MaxOp().as_expr(col_name)
54+
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])
55+
56+
snapshot.assert_match(sql, "out.sql")
57+
58+
59+
def test_min(scalar_types_df: bpd.DataFrame, snapshot):
60+
col_name = "int64_col"
61+
bf_df = scalar_types_df[[col_name]]
62+
agg_expr = agg_ops.MinOp().as_expr(col_name)
63+
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])
4364

4465
snapshot.assert_match(sql, "out.sql")
4566

4667

4768
def test_sum(scalar_types_df: bpd.DataFrame, snapshot):
48-
bf_df = scalar_types_df[["int64_col"]]
49-
sql = _apply_unary_op(bf_df, agg_ops.SumOp(), "int64_col")
69+
bf_df = scalar_types_df[["int64_col", "bool_col"]]
70+
agg_ops_map = {
71+
"int64_col": agg_ops.SumOp().as_expr("int64_col"),
72+
"bool_col": agg_ops.SumOp().as_expr("bool_col"),
73+
}
74+
sql = _apply_unary_agg_ops(
75+
bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys())
76+
)
5077

5178
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)