From b771319ab1df90933a03313e32b3c5ad632b204a Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Mon, 21 Jul 2025 22:55:07 +0000 Subject: [PATCH 1/2] refactor: provide infrastructure for SQLGlot aggregations compiler --- .../compile/sqlglot/aggregate_compiler.py | 80 +++---------------- .../compile/sqlglot/aggregations/__init__.py | 13 +++ .../sqlglot/aggregations/binary_compiler.py | 35 ++++++++ .../sqlglot/aggregations/nullary_compiler.py | 41 ++++++++++ .../sqlglot/aggregations/op_registration.py | 62 ++++++++++++++ .../aggregations/ordered_unary_compiler.py | 37 +++++++++ .../sqlglot/aggregations/unary_compiler.py | 56 +++++++++++++ .../compile/sqlglot/aggregations/utils.py | 29 +++++++ .../sqlglot/expressions/op_registration.py | 8 +- .../system/small/engines/test_aggregation.py | 4 +- .../compile/sqlglot/aggregations/__init__.py | 13 +++ .../test_unary_compiler/test_size/out.sql | 12 +++ .../test_unary_compiler/test_sum/out.sql | 12 +++ .../aggregations/test_op_registration.py | 45 +++++++++++ .../aggregations/test_unary_compiler.py | 49 ++++++++++++ 15 files changed, 423 insertions(+), 73 deletions(-) create mode 100644 bigframes/core/compile/sqlglot/aggregations/__init__.py create mode 100644 bigframes/core/compile/sqlglot/aggregations/binary_compiler.py create mode 100644 bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py create mode 100644 bigframes/core/compile/sqlglot/aggregations/op_registration.py create mode 100644 bigframes/core/compile/sqlglot/aggregations/ordered_unary_compiler.py create mode 100644 bigframes/core/compile/sqlglot/aggregations/unary_compiler.py create mode 100644 bigframes/core/compile/sqlglot/aggregations/utils.py create mode 100644 tests/unit/core/compile/sqlglot/aggregations/__init__.py create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_size/out.sql create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_sum/out.sql create mode 100644 tests/unit/core/compile/sqlglot/aggregations/test_op_registration.py create mode 100644 tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py diff --git a/bigframes/core/compile/sqlglot/aggregate_compiler.py b/bigframes/core/compile/sqlglot/aggregate_compiler.py index 888b3756b5..f7abd7dc7a 100644 --- a/bigframes/core/compile/sqlglot/aggregate_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregate_compiler.py @@ -13,16 +13,17 @@ # limitations under the License. from __future__ import annotations -import functools -import typing - import sqlglot.expressions as sge -from bigframes.core import expression, window_spec +from bigframes.core import expression +from bigframes.core.compile.sqlglot.aggregations import ( + binary_compiler, + nullary_compiler, + ordered_unary_compiler, + unary_compiler, +) from bigframes.core.compile.sqlglot.expressions import typed_expr import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler -import bigframes.core.compile.sqlglot.sqlglot_ir as ir -import bigframes.operations as ops def compile_aggregate( @@ -31,16 +32,18 @@ def compile_aggregate( ) -> sge.Expression: """Compiles BigFrames aggregation expression into SQLGlot expression.""" if isinstance(aggregate, expression.NullaryAggregation): - return compile_nullary_agg(aggregate.op) + return nullary_compiler.compile(aggregate.op) if isinstance(aggregate, expression.UnaryAggregation): column = typed_expr.TypedExpr( scalar_compiler.compile_scalar_expression(aggregate.arg), aggregate.arg.output_type, ) if not aggregate.op.order_independent: - return compile_ordered_unary_agg(aggregate.op, column, order_by=order_by) + return ordered_unary_compiler.compile( + aggregate.op, column, order_by=order_by + ) else: - return compile_unary_agg(aggregate.op, column) + return unary_compiler.compile(aggregate.op, column) elif isinstance(aggregate, expression.BinaryAggregation): left = typed_expr.TypedExpr( scalar_compiler.compile_scalar_expression(aggregate.left), @@ -50,63 +53,6 @@ def compile_aggregate( scalar_compiler.compile_scalar_expression(aggregate.right), aggregate.right.output_type, ) - return compile_binary_agg(aggregate.op, left, right) + return binary_compiler.compile(aggregate.op, left, right) else: raise ValueError(f"Unexpected aggregation: {aggregate}") - - -@functools.singledispatch -def compile_nullary_agg( - op: ops.aggregations.WindowOp, - window: typing.Optional[window_spec.WindowSpec] = None, -) -> sge.Expression: - raise ValueError(f"Can't compile unrecognized operation: {op}") - - -@functools.singledispatch -def compile_binary_agg( - op: ops.aggregations.WindowOp, - left: typed_expr.TypedExpr, - right: typed_expr.TypedExpr, - window: typing.Optional[window_spec.WindowSpec] = None, -) -> sge.Expression: - raise ValueError(f"Can't compile unrecognized operation: {op}") - - -@functools.singledispatch -def compile_unary_agg( - op: ops.aggregations.WindowOp, - column: typed_expr.TypedExpr, - window: typing.Optional[window_spec.WindowSpec] = None, -) -> sge.Expression: - raise ValueError(f"Can't compile unrecognized operation: {op}") - - -@functools.singledispatch -def compile_ordered_unary_agg( - op: ops.aggregations.WindowOp, - column: typed_expr.TypedExpr, - window: typing.Optional[window_spec.WindowSpec] = None, - order_by: typing.Sequence[sge.Expression] = [], -) -> sge.Expression: - raise ValueError(f"Can't compile unrecognized operation: {op}") - - -@compile_unary_agg.register -def _( - op: ops.aggregations.SumOp, - column: typed_expr.TypedExpr, - window: typing.Optional[window_spec.WindowSpec] = None, -) -> sge.Expression: - # Will be null if all inputs are null. Pandas defaults to zero sum though. - expr = _apply_window_if_present(sge.func("SUM", column.expr), window) - return sge.func("IFNULL", expr, ir._literal(0, column.dtype)) - - -def _apply_window_if_present( - value: sge.Expression, - window: typing.Optional[window_spec.WindowSpec] = None, -) -> sge.Expression: - if window is not None: - raise NotImplementedError("Can't apply window to the expression.") - return value diff --git a/bigframes/core/compile/sqlglot/aggregations/__init__.py b/bigframes/core/compile/sqlglot/aggregations/__init__.py new file mode 100644 index 0000000000..0a2669d7a2 --- /dev/null +++ b/bigframes/core/compile/sqlglot/aggregations/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/bigframes/core/compile/sqlglot/aggregations/binary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/binary_compiler.py new file mode 100644 index 0000000000..a162a9c18a --- /dev/null +++ b/bigframes/core/compile/sqlglot/aggregations/binary_compiler.py @@ -0,0 +1,35 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import typing + +import sqlglot.expressions as sge + +from bigframes.core import window_spec +import bigframes.core.compile.sqlglot.aggregations.op_registration as reg +import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr +from bigframes.operations import aggregations as agg_ops + +BINARY_OP_REGISTRATION = reg.OpRegistration() + + +def compile( + op: agg_ops.WindowOp, + left: typed_expr.TypedExpr, + right: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + return BINARY_OP_REGISTRATION[op](op, left, right, window=window) diff --git a/bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py new file mode 100644 index 0000000000..720ce743a6 --- /dev/null +++ b/bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py @@ -0,0 +1,41 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import typing + +import sqlglot.expressions as sge + +from bigframes.core import window_spec +import bigframes.core.compile.sqlglot.aggregations.op_registration as reg +from bigframes.core.compile.sqlglot.aggregations.utils import apply_window_if_present +from bigframes.operations import aggregations as agg_ops + +NULLARY_OP_REGISTRATION = reg.OpRegistration() + + +def compile( + op: agg_ops.WindowOp, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + return NULLARY_OP_REGISTRATION[op](op, window=window) + + +@NULLARY_OP_REGISTRATION.register(agg_ops.SizeOp) +def _( + op: agg_ops.SizeOp, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window) diff --git a/bigframes/core/compile/sqlglot/aggregations/op_registration.py b/bigframes/core/compile/sqlglot/aggregations/op_registration.py new file mode 100644 index 0000000000..996bf5b362 --- /dev/null +++ b/bigframes/core/compile/sqlglot/aggregations/op_registration.py @@ -0,0 +1,62 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import typing + +from sqlglot import expressions as sge + +from bigframes.operations import aggregations as agg_ops + +# We should've been more specific about input types. Unfortunately, +# MyPy doesn't support more rigorous checks. +CompilationFunc = typing.Callable[..., sge.Expression] + + +class OpRegistration: + def __init__(self) -> None: + self._registered_ops: dict[str, CompilationFunc] = {} + + def register( + self, op: agg_ops.WindowOp | type[agg_ops.WindowOp] + ) -> typing.Callable[[CompilationFunc], CompilationFunc]: + def decorator(item: CompilationFunc): + def arg_checker(*args, **kwargs): + if not isinstance(args[0], agg_ops.WindowOp): + raise ValueError( + "The first parameter must be a window operator. " + f"Got {type(args[0])}" + ) + return item(*args, **kwargs) + + if hasattr(op, "name"): + key = typing.cast(str, op.name) + if key in self._registered_ops: + raise ValueError(f"{key} is already registered") + else: + raise ValueError(f"The operator must have a 'name' attribute. Got {op}") + self._registered_ops[key] = item + return arg_checker + + return decorator + + def __getitem__(self, op: str | agg_ops.WindowOp) -> CompilationFunc: + if isinstance(op, agg_ops.WindowOp): + if not hasattr(op, "name"): + raise ValueError(f"The operator must have a 'name' attribute. Got {op}") + else: + key = typing.cast(str, op.name) + return self._registered_ops[key] + return self._registered_ops[op] diff --git a/bigframes/core/compile/sqlglot/aggregations/ordered_unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/ordered_unary_compiler.py new file mode 100644 index 0000000000..dea30ec206 --- /dev/null +++ b/bigframes/core/compile/sqlglot/aggregations/ordered_unary_compiler.py @@ -0,0 +1,37 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import typing + +import sqlglot.expressions as sge + +from bigframes.core import window_spec +import bigframes.core.compile.sqlglot.aggregations.op_registration as reg +import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr +from bigframes.operations import aggregations as agg_ops + +ORDERED_UNARY_OP_REGISTRATION = reg.OpRegistration() + + +def compile( + op: agg_ops.WindowOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, + order_by: typing.Sequence[sge.Expression] = [], +) -> sge.Expression: + return ORDERED_UNARY_OP_REGISTRATION[op]( + op, column, window=window, order_by=order_by + ) diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py new file mode 100644 index 0000000000..75ba090bc4 --- /dev/null +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -0,0 +1,56 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import typing + +import sqlglot.expressions as sge + +from bigframes.core import window_spec +import bigframes.core.compile.sqlglot.aggregations.op_registration as reg +from bigframes.core.compile.sqlglot.aggregations.utils import apply_window_if_present +import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr +import bigframes.core.compile.sqlglot.sqlglot_ir as ir +from bigframes.operations import aggregations as agg_ops + +UNARY_OP_REGISTRATION = reg.OpRegistration() + + +def compile( + op: agg_ops.WindowOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + return UNARY_OP_REGISTRATION[op](op, column, window=window) + + +@UNARY_OP_REGISTRATION.register(agg_ops.SumOp) +def _( + op: agg_ops.SumOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + # Will be null if all inputs are null. Pandas defaults to zero sum though. + expr = apply_window_if_present(sge.func("SUM", column.expr), window) + return sge.func("IFNULL", expr, ir._literal(0, column.dtype)) + + +@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp) +def _( + op: agg_ops.SizeUnaryOp, + _, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window) diff --git a/bigframes/core/compile/sqlglot/aggregations/utils.py b/bigframes/core/compile/sqlglot/aggregations/utils.py new file mode 100644 index 0000000000..57470cde5b --- /dev/null +++ b/bigframes/core/compile/sqlglot/aggregations/utils.py @@ -0,0 +1,29 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import typing + +import sqlglot.expressions as sge + +from bigframes.core import window_spec + + +def apply_window_if_present( + value: sge.Expression, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + if window is not None: + raise NotImplementedError("Can't apply window to the expression.") + return value diff --git a/bigframes/core/compile/sqlglot/expressions/op_registration.py b/bigframes/core/compile/sqlglot/expressions/op_registration.py index e30b58a6d2..d5e4853a45 100644 --- a/bigframes/core/compile/sqlglot/expressions/op_registration.py +++ b/bigframes/core/compile/sqlglot/expressions/op_registration.py @@ -48,7 +48,7 @@ def arg_checker(*args, **kwargs): return decorator - def __getitem__(self, key: str | ops.ScalarOp) -> CompilationFunc: - if isinstance(key, ops.ScalarOp): - return self._registered_ops[key.name] - return self._registered_ops[key] + def __getitem__(self, op: str | ops.ScalarOp) -> CompilationFunc: + if isinstance(op, ops.ScalarOp): + return self._registered_ops[op.name] + return self._registered_ops[op] diff --git a/tests/system/small/engines/test_aggregation.py b/tests/system/small/engines/test_aggregation.py index 8530a6fefa..c2fc9ad706 100644 --- a/tests/system/small/engines/test_aggregation.py +++ b/tests/system/small/engines/test_aggregation.py @@ -47,7 +47,7 @@ def apply_agg_to_all_valid( return new_arr -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_aggregate_size( scalars_array_value: array_value.ArrayValue, engine, @@ -84,7 +84,7 @@ def test_engines_unary_aggregates( assert_equivalence_execution(node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) @pytest.mark.parametrize( "grouping_cols", [ diff --git a/tests/unit/core/compile/sqlglot/aggregations/__init__.py b/tests/unit/core/compile/sqlglot/aggregations/__init__.py new file mode 100644 index 0000000000..0a2669d7a2 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_size/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_size/out.sql new file mode 100644 index 0000000000..78104eb578 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_size/out.sql @@ -0,0 +1,12 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + COUNT(1) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col_agg` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_sum/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_sum/out.sql new file mode 100644 index 0000000000..e748f71278 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_sum/out.sql @@ -0,0 +1,12 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + COALESCE(SUM(`bfcol_0`), 0) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `int64_col_agg` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_op_registration.py b/tests/unit/core/compile/sqlglot/aggregations/test_op_registration.py new file mode 100644 index 0000000000..1d76876219 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/test_op_registration.py @@ -0,0 +1,45 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from sqlglot import expressions as sge + +from bigframes.core.compile.sqlglot.aggregations import op_registration +from bigframes.operations import aggregations as agg_ops + + +def test_register_then_get(): + reg = op_registration.OpRegistration() + input = sge.to_identifier("A") + op = agg_ops.SizeOp() + + @reg.register(agg_ops.SizeOp) + def test_func(op: agg_ops.SizeOp, input: sge.Expression) -> sge.Expression: + return input + + assert reg[agg_ops.SizeOp()](op, input) == test_func(op, input) + assert reg[agg_ops.SizeOp.name](op, input) == test_func(op, input) + + +def test_register_function_first_argument_is_not_scalar_op_raise_error(): + reg = op_registration.OpRegistration() + + @reg.register(agg_ops.SizeOp) + def test_func(input: sge.Expression) -> sge.Expression: + return input + + with pytest.raises( + ValueError, match=r".*first parameter must be a window operator.*" + ): + test_func(sge.to_identifier("A")) diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py new file mode 100644 index 0000000000..d30c0c2a4f --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -0,0 +1,49 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from bigframes.core import array_value, expression, identifiers, nodes +from bigframes.operations import aggregations as agg_ops +import bigframes.pandas as bpd + +pytest.importorskip("pytest_snapshot") + + +def _apply_unary_op(obj: bpd.DataFrame, op: agg_ops.UnaryWindowOp, arg: str) -> str: + agg_node = nodes.AggregateNode( + obj._block.expr.node, + aggregations=( + ( + expression.UnaryAggregation(op, expression.deref(arg)), + identifiers.ColumnId(arg + "_agg"), + ), + ), + ) + result = array_value.ArrayValue(agg_node) + + sql = result.session._executor.to_sql(result, enable_cache=False) + return sql + + +def test_size(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op(bf_df, agg_ops.SizeUnaryOp(), "string_col") + snapshot.assert_match(sql, "out.sql") + + +def test_sum(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col"]] + sql = _apply_unary_op(bf_df, agg_ops.SumOp(), "int64_col") + snapshot.assert_match(sql, "out.sql") From 9e40021f838398ebf9f2eab036929d566fcc167f Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 22 Jul 2025 03:36:15 +0000 Subject: [PATCH 2/2] address comments --- .../core/compile/sqlglot/aggregations/test_op_registration.py | 2 +- .../core/compile/sqlglot/aggregations/test_unary_compiler.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_op_registration.py b/tests/unit/core/compile/sqlglot/aggregations/test_op_registration.py index 1d76876219..e3688f19df 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_op_registration.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_op_registration.py @@ -32,7 +32,7 @@ def test_func(op: agg_ops.SizeOp, input: sge.Expression) -> sge.Expression: assert reg[agg_ops.SizeOp.name](op, input) == test_func(op, input) -def test_register_function_first_argument_is_not_scalar_op_raise_error(): +def test_register_function_first_argument_is_not_agg_op_raise_error(): reg = op_registration.OpRegistration() @reg.register(agg_ops.SizeOp) diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index d30c0c2a4f..96cdceb3c6 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -40,10 +40,12 @@ def _apply_unary_op(obj: bpd.DataFrame, op: agg_ops.UnaryWindowOp, arg: str) -> def test_size(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["string_col"]] sql = _apply_unary_op(bf_df, agg_ops.SizeUnaryOp(), "string_col") + snapshot.assert_match(sql, "out.sql") def test_sum(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["int64_col"]] sql = _apply_unary_op(bf_df, agg_ops.SumOp(), "int64_col") + snapshot.assert_match(sql, "out.sql")