diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 90335cb8b9..b697d2324b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,7 @@ repos: hooks: - id: trailing-whitespace - id: end-of-file-fixer - exclude: "^tests/unit/core/compile/sqlglot/snapshots" + exclude: "^tests/unit/core/compile/sqlglot/.*snapshots" - id: check-yaml - repo: https://github.com/pycqa/isort rev: 5.12.0 diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index c7eb84cba6..542bb10670 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -16,6 +16,7 @@ import typing +import pandas as pd import sqlglot.expressions as sge from bigframes import dtypes @@ -46,18 +47,22 @@ def _( return apply_window_if_present(sge.func("COUNT", column.expr), window) -@UNARY_OP_REGISTRATION.register(agg_ops.SumOp) +@UNARY_OP_REGISTRATION.register(agg_ops.MaxOp) def _( - op: agg_ops.SumOp, + op: agg_ops.MaxOp, column: typed_expr.TypedExpr, window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: - expr = column.expr - if column.dtype == dtypes.BOOL_DTYPE: - expr = sge.Cast(this=column.expr, to="INT64") - # Will be null if all inputs are null. Pandas defaults to zero sum though. - expr = apply_window_if_present(sge.func("SUM", expr), window) - return sge.func("IFNULL", expr, ir._literal(0, column.dtype)) + return apply_window_if_present(sge.func("MAX", column.expr), window) + + +@UNARY_OP_REGISTRATION.register(agg_ops.MinOp) +def _( + op: agg_ops.MinOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + return apply_window_if_present(sge.func("MIN", column.expr), window) @UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp) @@ -67,3 +72,20 @@ def _( window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window) + + +@UNARY_OP_REGISTRATION.register(agg_ops.SumOp) +def _( + op: agg_ops.SumOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + expr = column.expr + if column.dtype == dtypes.BOOL_DTYPE: + expr = sge.Cast(this=column.expr, to="INT64") + + expr = apply_window_if_present(sge.func("SUM", expr), window) + + # Will be null if all inputs are null. Pandas defaults to zero sum though. + zero = pd.to_timedelta(0) if column.dtype == dtypes.TIMEDELTA_DTYPE else 0 + return sge.func("IFNULL", expr, ir._literal(zero, column.dtype)) diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_count/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_count/out.sql new file mode 100644 index 0000000000..01684b4af6 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_count/out.sql @@ -0,0 +1,12 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + COUNT(`bfcol_0`) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `int64_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_max/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_max/out.sql new file mode 100644 index 0000000000..c88fa58d0f --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_max/out.sql @@ -0,0 +1,12 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + MAX(`bfcol_0`) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `int64_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_min/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_min/out.sql new file mode 100644 index 0000000000..b067817218 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_min/out.sql @@ -0,0 +1,12 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + MIN(`bfcol_0`) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `int64_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_size/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_size_unary/out.sql similarity index 73% rename from tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_size/out.sql rename to tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_size_unary/out.sql index 78104eb578..fffb4831b9 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_size/out.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_size_unary/out.sql @@ -1,6 +1,6 @@ WITH `bfcte_0` AS ( SELECT - `string_col` AS `bfcol_0` + `float64_col` AS `bfcol_0` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` ), `bfcte_1` AS ( SELECT @@ -8,5 +8,5 @@ WITH `bfcte_0` AS ( FROM `bfcte_0` ) SELECT - `bfcol_1` AS `string_col_agg` + `bfcol_1` AS `float64_col` FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_sum/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_sum/out.sql index e748f71278..be684f6768 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_sum/out.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_sum/out.sql @@ -1,12 +1,15 @@ WITH `bfcte_0` AS ( SELECT - `int64_col` AS `bfcol_0` + `bool_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` ), `bfcte_1` AS ( SELECT - COALESCE(SUM(`bfcol_0`), 0) AS `bfcol_1` + COALESCE(SUM(`bfcol_1`), 0) AS `bfcol_4`, + COALESCE(SUM(CAST(`bfcol_0` AS INT64)), 0) AS `bfcol_5` FROM `bfcte_0` ) SELECT - `bfcol_1` AS `int64_col_agg` + `bfcol_4` AS `int64_col`, + `bfcol_5` AS `bool_col` FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index d12b4dda17..311c039e11 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -12,40 +12,67 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + import pytest -from bigframes.core import agg_expressions, array_value, expression, identifiers, nodes +from bigframes.core import agg_expressions as agg_exprs +from bigframes.core import array_value, identifiers, nodes from bigframes.operations import aggregations as agg_ops import bigframes.pandas as bpd pytest.importorskip("pytest_snapshot") -def _apply_unary_op(obj: bpd.DataFrame, op: agg_ops.UnaryWindowOp, arg: str) -> str: - agg_node = nodes.AggregateNode( - obj._block.expr.node, - aggregations=( - ( - agg_expressions.UnaryAggregation(op, expression.deref(arg)), - identifiers.ColumnId(arg + "_agg"), - ), - ), - ) +def _apply_unary_agg_ops( + obj: bpd.DataFrame, + ops_list: typing.Sequence[agg_exprs.UnaryAggregation], + new_names: typing.Sequence[str], +) -> str: + aggs = [(op, identifiers.ColumnId(name)) for op, name in zip(ops_list, new_names)] + + agg_node = nodes.AggregateNode(obj._block.expr.node, aggregations=tuple(aggs)) result = array_value.ArrayValue(agg_node) sql = result.session._executor.to_sql(result, enable_cache=False) return sql -def test_size(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, agg_ops.SizeUnaryOp(), "string_col") +def test_count(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_expr = agg_ops.CountOp().as_expr(col_name) + sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name]) + + snapshot.assert_match(sql, "out.sql") + + +def test_max(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_expr = agg_ops.MaxOp().as_expr(col_name) + sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name]) + + snapshot.assert_match(sql, "out.sql") + + +def test_min(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_expr = agg_ops.MinOp().as_expr(col_name) + sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name]) snapshot.assert_match(sql, "out.sql") def test_sum(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col"]] - sql = _apply_unary_op(bf_df, agg_ops.SumOp(), "int64_col") + bf_df = scalar_types_df[["int64_col", "bool_col"]] + agg_ops_map = { + "int64_col": agg_ops.SumOp().as_expr("int64_col"), + "bool_col": agg_ops.SumOp().as_expr("bool_col"), + } + sql = _apply_unary_agg_ops( + bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys()) + ) snapshot.assert_match(sql, "out.sql")