Skip to content

Commit 0776b87

Browse files
committed
refactor: add agg_ops.MinOp and MaxOp for sqlglot compiler
1 parent 81dbf9a commit 0776b87

File tree

9 files changed

+116
-30
lines changed

9 files changed

+116
-30
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ repos:
2020
hooks:
2121
- id: trailing-whitespace
2222
- id: end-of-file-fixer
23-
exclude: "^tests/unit/core/compile/sqlglot/snapshots"
23+
exclude: "^tests/unit/core/compile/sqlglot/.*snapshots"
2424
- id: check-yaml
2525
- repo: https://github.com/pycqa/isort
2626
rev: 5.12.0

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,22 @@ def _(
4646
return apply_window_if_present(sge.func("COUNT", column.expr), window)
4747

4848

49-
@UNARY_OP_REGISTRATION.register(agg_ops.SumOp)
49+
@UNARY_OP_REGISTRATION.register(agg_ops.MaxOp)
5050
def _(
51-
op: agg_ops.SumOp,
51+
op: agg_ops.MaxOp,
5252
column: typed_expr.TypedExpr,
5353
window: typing.Optional[window_spec.WindowSpec] = None,
5454
) -> sge.Expression:
55-
expr = column.expr
56-
if column.dtype == dtypes.BOOL_DTYPE:
57-
expr = sge.Cast(this=column.expr, to="INT64")
58-
# Will be null if all inputs are null. Pandas defaults to zero sum though.
59-
expr = apply_window_if_present(sge.func("SUM", expr), window)
60-
return sge.func("IFNULL", expr, ir._literal(0, column.dtype))
55+
return apply_window_if_present(sge.func("MAX", column.expr), window)
56+
57+
58+
@UNARY_OP_REGISTRATION.register(agg_ops.MinOp)
59+
def _(
60+
op: agg_ops.MinOp,
61+
column: typed_expr.TypedExpr,
62+
window: typing.Optional[window_spec.WindowSpec] = None,
63+
) -> sge.Expression:
64+
return apply_window_if_present(sge.func("MIN", column.expr), window)
6165

6266

6367
@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp)
@@ -67,3 +71,17 @@ def _(
6771
window: typing.Optional[window_spec.WindowSpec] = None,
6872
) -> sge.Expression:
6973
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)
74+
75+
76+
@UNARY_OP_REGISTRATION.register(agg_ops.SumOp)
77+
def _(
78+
op: agg_ops.SumOp,
79+
column: typed_expr.TypedExpr,
80+
window: typing.Optional[window_spec.WindowSpec] = None,
81+
) -> sge.Expression:
82+
expr = column.expr
83+
if column.dtype == dtypes.BOOL_DTYPE:
84+
expr = sge.Cast(this=column.expr, to="INT64")
85+
# Will be null if all inputs are null. Pandas defaults to zero sum though.
86+
expr = apply_window_if_present(sge.func("SUM", expr), window)
87+
return sge.func("IFNULL", expr, ir._literal(0, column.dtype))

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,8 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
633633
elif dtype == dtypes.JSON_DTYPE:
634634
return sge.ParseJSON(this=sge.convert(str(value)))
635635
elif dtype == dtypes.TIMEDELTA_DTYPE:
636+
if isinstance(value, int):
637+
return sge.convert(value)
636638
return sge.convert(utils.timedelta_to_micros(value))
637639
elif dtypes.is_struct_like(dtype):
638640
items = [
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
COUNT(`bfcol_0`) AS `bfcol_1`
8+
FROM `bfcte_0`
9+
)
10+
SELECT
11+
`bfcol_1` AS `int64_col`
12+
FROM `bfcte_1`
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
MAX(`bfcol_0`) AS `bfcol_1`
8+
FROM `bfcte_0`
9+
)
10+
SELECT
11+
`bfcol_1` AS `int64_col`
12+
FROM `bfcte_1`
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
MIN(`bfcol_0`) AS `bfcol_1`
8+
FROM `bfcte_0`
9+
)
10+
SELECT
11+
`bfcol_1` AS `int64_col`
12+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_size/out.sql renamed to tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_size_unary/out.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
WITH `bfcte_0` AS (
22
SELECT
3-
`string_col` AS `bfcol_0`
3+
`float64_col` AS `bfcol_0`
44
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
55
), `bfcte_1` AS (
66
SELECT
77
COUNT(1) AS `bfcol_1`
88
FROM `bfcte_0`
99
)
1010
SELECT
11-
`bfcol_1` AS `string_col_agg`
11+
`bfcol_1` AS `float64_col`
1212
FROM `bfcte_1`
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
WITH `bfcte_0` AS (
22
SELECT
3-
`int64_col` AS `bfcol_0`
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`
45
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
56
), `bfcte_1` AS (
67
SELECT
7-
COALESCE(SUM(`bfcol_0`), 0) AS `bfcol_1`
8+
COALESCE(SUM(`bfcol_1`), 0) AS `bfcol_4`,
9+
COALESCE(SUM(CAST(`bfcol_0` AS INT64)), 0) AS `bfcol_5`
810
FROM `bfcte_0`
911
)
1012
SELECT
11-
`bfcol_1` AS `int64_col_agg`
13+
`bfcol_4` AS `int64_col`,
14+
`bfcol_5` AS `bool_col`
1215
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,40 +12,67 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import typing
16+
1517
import pytest
1618

17-
from bigframes.core import agg_expressions, array_value, expression, identifiers, nodes
19+
from bigframes.core import agg_expressions as agg_exprs
20+
from bigframes.core import array_value, identifiers, nodes
1821
from bigframes.operations import aggregations as agg_ops
1922
import bigframes.pandas as bpd
2023

2124
pytest.importorskip("pytest_snapshot")
2225

2326

24-
def _apply_unary_op(obj: bpd.DataFrame, op: agg_ops.UnaryWindowOp, arg: str) -> str:
25-
agg_node = nodes.AggregateNode(
26-
obj._block.expr.node,
27-
aggregations=(
28-
(
29-
agg_expressions.UnaryAggregation(op, expression.deref(arg)),
30-
identifiers.ColumnId(arg + "_agg"),
31-
),
32-
),
33-
)
27+
def _apply_unary_agg_ops(
28+
obj: bpd.DataFrame,
29+
ops_list: typing.Sequence[agg_exprs.UnaryAggregation],
30+
new_names: typing.Sequence[str],
31+
) -> str:
32+
aggs = [(op, identifiers.ColumnId(name)) for op, name in zip(ops_list, new_names)]
33+
34+
agg_node = nodes.AggregateNode(obj._block.expr.node, aggregations=tuple(aggs))
3435
result = array_value.ArrayValue(agg_node)
3536

3637
sql = result.session._executor.to_sql(result, enable_cache=False)
3738
return sql
3839

3940

40-
def test_size(scalar_types_df: bpd.DataFrame, snapshot):
41-
bf_df = scalar_types_df[["string_col"]]
42-
sql = _apply_unary_op(bf_df, agg_ops.SizeUnaryOp(), "string_col")
41+
def test_count(scalar_types_df: bpd.DataFrame, snapshot):
42+
col_name = "int64_col"
43+
bf_df = scalar_types_df[[col_name]]
44+
agg_expr = agg_ops.CountOp().as_expr(col_name)
45+
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])
46+
47+
snapshot.assert_match(sql, "out.sql")
48+
49+
50+
def test_max(scalar_types_df: bpd.DataFrame, snapshot):
51+
col_name = "int64_col"
52+
bf_df = scalar_types_df[[col_name]]
53+
agg_expr = agg_ops.MaxOp().as_expr(col_name)
54+
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])
55+
56+
snapshot.assert_match(sql, "out.sql")
57+
58+
59+
def test_min(scalar_types_df: bpd.DataFrame, snapshot):
60+
col_name = "int64_col"
61+
bf_df = scalar_types_df[[col_name]]
62+
agg_expr = agg_ops.MinOp().as_expr(col_name)
63+
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])
4364

4465
snapshot.assert_match(sql, "out.sql")
4566

4667

4768
def test_sum(scalar_types_df: bpd.DataFrame, snapshot):
48-
bf_df = scalar_types_df[["int64_col"]]
49-
sql = _apply_unary_op(bf_df, agg_ops.SumOp(), "int64_col")
69+
bf_df = scalar_types_df[["int64_col", "bool_col"]]
70+
agg_ops_map = {
71+
"int64_col": agg_ops.SumOp().as_expr("int64_col"),
72+
"bool_col": agg_ops.SumOp().as_expr("bool_col"),
73+
}
74+
sql = _apply_unary_agg_ops(
75+
bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys())
76+
)
5077

5178
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)