From 552871d3b9c7b707c6aa6225f293e046e057ade7 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 30 Jul 2025 22:31:19 +0000 Subject: [PATCH 1/3] refactor: add compile_window to the sqlglot compiler --- .../compile/sqlglot/aggregate_compiler.py | 20 +++- .../sqlglot/aggregations/unary_compiler.py | 6 +- bigframes/core/compile/sqlglot/compiler.py | 66 +++++++++++ .../sqlglot/expressions/unary_compiler.py | 5 + .../core/compile/sqlglot/scalar_compiler.py | 2 +- bigframes/core/compile/sqlglot/sqlglot_ir.py | 7 ++ bigframes/core/rewrite/schema_binding.py | 103 ++++++++++++------ bigframes/operations/aggregations.py | 1 + tests/system/small/engines/test_windowing.py | 31 +++++- .../out.sql | 76 +++++++++++++ .../test_compile_window_w_min_periods/out.sql | 30 +++++ .../out.sql | 30 +++++ .../test_compile_window_w_rolling/out.sql | 30 +++++ .../compile/sqlglot/test_compile_window.py | 58 ++++++++++ 14 files changed, 424 insertions(+), 41 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_groupby_rolling/out.sql create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_min_periods/out.sql create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_range_rolling/out.sql create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_rolling/out.sql create mode 100644 tests/unit/core/compile/sqlglot/test_compile_window.py diff --git a/bigframes/core/compile/sqlglot/aggregate_compiler.py b/bigframes/core/compile/sqlglot/aggregate_compiler.py index f7abd7dc7a..52ef4cc26c 100644 --- a/bigframes/core/compile/sqlglot/aggregate_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregate_compiler.py @@ -15,7 +15,7 @@ import sqlglot.expressions as sge -from bigframes.core import expression +from bigframes.core import expression, window_spec from bigframes.core.compile.sqlglot.aggregations import ( binary_compiler, nullary_compiler, @@ -56,3 +56,21 @@ def compile_aggregate( return binary_compiler.compile(aggregate.op, left, right) else: raise ValueError(f"Unexpected aggregation: {aggregate}") + + +def compile_analytic( + aggregate: expression.Aggregation, + window: window_spec.WindowSpec, +) -> sge.Expression: + if isinstance(aggregate, expression.NullaryAggregation): + return nullary_compiler.compile(aggregate.op) + if isinstance(aggregate, expression.UnaryAggregation): + column = typed_expr.TypedExpr( + scalar_compiler.compile_scalar_expression(aggregate.arg), + aggregate.arg.output_type, + ) + return unary_compiler.compile(aggregate.op, column, window) + elif isinstance(aggregate, expression.BinaryAggregation): + raise NotImplementedError("binary analytic operations not yet supported") + else: + raise ValueError(f"Unexpected analytic operation: {aggregate}") diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index eddf7f56d2..c65c971bfa 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -18,6 +18,7 @@ import sqlglot.expressions as sge +from bigframes import dtypes from bigframes.core import window_spec import bigframes.core.compile.sqlglot.aggregations.op_registration as reg from bigframes.core.compile.sqlglot.aggregations.windows import apply_window_if_present @@ -42,8 +43,11 @@ def _( column: typed_expr.TypedExpr, window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: + expr = column.expr + if column.dtype == dtypes.BOOL_DTYPE: + expr = sge.Cast(this=column.expr, to="INT64") # Will be null if all inputs are null. Pandas defaults to zero sum though. - expr = apply_window_if_present(sge.func("SUM", column.expr), window) + expr = apply_window_if_present(sge.func("SUM", expr), window) return sge.func("IFNULL", expr, ir._literal(0, column.dtype)) diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index 2ae6b4bb9c..1a8455176a 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -298,6 +298,72 @@ def compile_aggregate( return child.aggregate(aggregations, by_cols, tuple(dropna_cols)) + @_compile_node.register + def compile_window( + self, node: nodes.WindowOpNode, child: ir.SQLGlotIR + ) -> ir.SQLGlotIR: + window_spec = node.window_spec + if node.expression.op.order_independent and window_spec.is_unbounded: + # notably percentile_cont does not support ordering clause + window_spec = window_spec.without_order() + + window_op = aggregate_compiler.compile_analytic(node.expression, window_spec) + + inputs: tuple[sge.Expression, ...] = tuple( + scalar_compiler.compile_scalar_expression(expression.DerefOp(column)) + for column in node.expression.column_references + ) + + clauses: list[tuple[sge.Expression, sge.Expression]] = [] + if node.expression.op.skips_nulls and not node.never_skip_nulls: + for column in inputs: + clauses.append((sge.Is(this=column, expression=sge.Null()), sge.Null())) + + if window_spec.min_periods and len(inputs) > 0: + if node.expression.op.skips_nulls: + # Most operations do not count NULL values towards min_periods + not_null_columns = [ + sge.Not(this=sge.Is(this=column, expression=sge.Null())) + for column in inputs + ] + # All inputs must be non-null for observation to count + if not not_null_columns: + is_observation_expr: sge.Expression = sge.convert(True) + else: + is_observation_expr = not_null_columns[0] + for expr in not_null_columns[1:]: + is_observation_expr = sge.And( + this=is_observation_expr, expression=expr + ) + is_observation = ir._cast(is_observation_expr, "INT64") + else: + # Operations like count treat even NULLs as valid observations + # for the sake of min_periods notnull is just used to convert + # null values to non-null (FALSE) values to be counted. + is_observation = ir._cast( + sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())), + "INT64", + ) + + observation_count = windows.apply_window_if_present( + sge.func("SUM", is_observation), window_spec + ) + clauses.append( + ( + observation_count < sge.convert(window_spec.min_periods), + sge.Null(), + ) + ) + if clauses: + when_expressions = [sge.When(this=cond, true=res) for cond, res in clauses] + window_op = sge.Case(ifs=when_expressions, default=window_op) + + # TODO: check if we can directly window the expression. + return child.window( + window_op=window_op, + output_column_id=node.output_name.sql, + ) + def _replace_unsupported_ops(node: nodes.BigFrameNode): node = nodes.bottom_up(node, rewrite.rewrite_slice) diff --git a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py index 125c60bbf4..273eb08421 100644 --- a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py @@ -680,3 +680,8 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: @UNARY_OP_REGISTRATION.register(ops.year_op) def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="YEAR"), expression=expr.expr) + + +@UNARY_OP_REGISTRATION.register(ops.UnixMicros) +def _(op: ops.UnixMicros, expr: TypedExpr) -> sge.Expression: + return sge.func("UNIX_MICROS", expr.expr) diff --git a/bigframes/core/compile/sqlglot/scalar_compiler.py b/bigframes/core/compile/sqlglot/scalar_compiler.py index 683dd38c9a..65c2501b71 100644 --- a/bigframes/core/compile/sqlglot/scalar_compiler.py +++ b/bigframes/core/compile/sqlglot/scalar_compiler.py @@ -31,7 +31,7 @@ @functools.singledispatch def compile_scalar_expression( - expression: expression.Expression, + expr: expression.Expression, ) -> sge.Expression: """Compiles BigFrames scalar expression into SQLGlot expression.""" raise ValueError(f"Can't compile unrecognized node: {expression}") diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index b194fe9e5d..1a00cd0a93 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -409,6 +409,13 @@ def aggregate( new_expr = new_expr.where(condition, append=False) return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + def window( + self, + window_op: sge.Expression, + output_column_id: str, + ) -> SQLGlotIR: + return self.project(((output_column_id, window_op),)) + def insert( self, destination: bigquery.TableReference, diff --git a/bigframes/core/rewrite/schema_binding.py b/bigframes/core/rewrite/schema_binding.py index f7f2ca8c59..cbecf83035 100644 --- a/bigframes/core/rewrite/schema_binding.py +++ b/bigframes/core/rewrite/schema_binding.py @@ -17,7 +17,7 @@ from bigframes.core import bigframe_node from bigframes.core import expression as ex -from bigframes.core import nodes +from bigframes.core import nodes, ordering def bind_schema_to_tree( @@ -79,46 +79,77 @@ def bind_schema_to_node( if isinstance(node, nodes.AggregateNode): aggregations = [] for aggregation, id in node.aggregations: - if isinstance(aggregation, ex.UnaryAggregation): - replaced = typing.cast( - ex.Aggregation, - dataclasses.replace( - aggregation, - arg=typing.cast( - ex.RefOrConstant, - ex.bind_schema_fields( - aggregation.arg, node.child.field_by_id - ), - ), - ), + aggregations.append( + (_bind_schema_to_aggregation_expr(aggregation, node.child), id) + ) + + return dataclasses.replace( + node, + aggregations=tuple(aggregations), + ) + + if isinstance(node, nodes.WindowOpNode): + window_spec = dataclasses.replace( + node.window_spec, + grouping_keys=tuple( + typing.cast( + ex.DerefOp, ex.bind_schema_fields(expr, node.child.field_by_id) ) - aggregations.append((replaced, id)) - elif isinstance(aggregation, ex.BinaryAggregation): - replaced = typing.cast( - ex.Aggregation, - dataclasses.replace( - aggregation, - left=typing.cast( - ex.RefOrConstant, - ex.bind_schema_fields( - aggregation.left, node.child.field_by_id - ), - ), - right=typing.cast( - ex.RefOrConstant, - ex.bind_schema_fields( - aggregation.right, node.child.field_by_id - ), - ), + for expr in node.window_spec.grouping_keys + ), + ordering=tuple( + ordering.OrderingExpression( + scalar_expression=ex.bind_schema_fields( + expr.scalar_expression, node.child.field_by_id ), + direction=expr.direction, + na_last=expr.na_last, ) - aggregations.append((replaced, id)) - else: - aggregations.append((aggregation, id)) - + for expr in node.window_spec.ordering + ), + ) return dataclasses.replace( node, - aggregations=tuple(aggregations), + expression=_bind_schema_to_aggregation_expr(node.expression, node.child), + window_spec=window_spec, ) return node + + +def _bind_schema_to_aggregation_expr( + aggregation: ex.Aggregation, + child: bigframe_node.BigFrameNode, +) -> ex.Aggregation: + assert isinstance( + aggregation, ex.Aggregation + ), f"Expected Aggregation, got {type(aggregation)}" + + if isinstance(aggregation, ex.UnaryAggregation): + return typing.cast( + ex.Aggregation, + dataclasses.replace( + aggregation, + arg=typing.cast( + ex.RefOrConstant, + ex.bind_schema_fields(aggregation.arg, child.field_by_id), + ), + ), + ) + elif isinstance(aggregation, ex.BinaryAggregation): + return typing.cast( + ex.Aggregation, + dataclasses.replace( + aggregation, + left=typing.cast( + ex.RefOrConstant, + ex.bind_schema_fields(aggregation.left, child.field_by_id), + ), + right=typing.cast( + ex.RefOrConstant, + ex.bind_schema_fields(aggregation.right, child.field_by_id), + ), + ), + ) + else: + return aggregation diff --git a/bigframes/operations/aggregations.py b/bigframes/operations/aggregations.py index 984f7d3798..6889997a10 100644 --- a/bigframes/operations/aggregations.py +++ b/bigframes/operations/aggregations.py @@ -517,6 +517,7 @@ def skips_nulls(self): @dataclasses.dataclass(frozen=True) class DiffOp(UnaryWindowOp): + name: ClassVar[str] = "diff" periods: int @property diff --git a/tests/system/small/engines/test_windowing.py b/tests/system/small/engines/test_windowing.py index f4c2b61e6f..3712e4c047 100644 --- a/tests/system/small/engines/test_windowing.py +++ b/tests/system/small/engines/test_windowing.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from google.cloud import bigquery import pytest -from bigframes.core import array_value -from bigframes.session import polars_executor +from bigframes.core import array_value, expression, identifiers, nodes, window_spec +import bigframes.operations.aggregations as agg_ops +from bigframes.session import direct_gbq_execution, polars_executor from bigframes.testing.engine_utils import assert_equivalence_execution pytest.importorskip("polars") @@ -31,3 +33,28 @@ def test_engines_with_offsets( ): result, _ = scalars_array_value.promote_offsets() assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine) + + +def test_engines_with_rows_window( + scalars_array_value: array_value.ArrayValue, + bigquery_client: bigquery.Client, +): + window = window_spec.WindowSpec( + bounds=window_spec.RowsWindowBounds.from_window_size(3, "left"), + ) + window_node = nodes.WindowOpNode( + child=scalars_array_value.node, + expression=expression.UnaryAggregation( + agg_ops.sum_op, expression.deref("int64_too") + ), + window_spec=window, + output_name=identifiers.ColumnId("sum_int64"), + never_skip_nulls=False, + skip_reproject_unsafe=False, + ) + + bq_executor = direct_gbq_execution.DirectGbqExecutor(bigquery_client) + bq_sqlgot_executor = direct_gbq_execution.DirectGbqExecutor( + bigquery_client, compiler="sqlglot" + ) + assert_equivalence_execution(window_node, bq_executor, bq_sqlgot_executor) diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_groupby_rolling/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_groupby_rolling/out.sql new file mode 100644 index 0000000000..beb3caa073 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_groupby_rolling/out.sql @@ -0,0 +1,76 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1`, + `rowindex` AS `bfcol_2` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `bfcol_2` AS `bfcol_6`, + `bfcol_0` AS `bfcol_7`, + `bfcol_1` AS `bfcol_8`, + `bfcol_0` AS `bfcol_9` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + * + FROM `bfcte_1` + WHERE + NOT `bfcol_9` IS NULL +), `bfcte_3` AS ( + SELECT + *, + CASE + WHEN SUM(CAST(NOT `bfcol_7` IS NULL AS INT64)) OVER ( + PARTITION BY `bfcol_9` + ORDER BY `bfcol_9` IS NULL ASC NULLS LAST, `bfcol_9` ASC NULLS LAST, `bfcol_2` IS NULL ASC NULLS LAST, `bfcol_2` ASC NULLS LAST + ROWS BETWEEN 3 PRECEDING AND CURRENT ROW + ) < 3 + THEN NULL + ELSE COALESCE( + SUM(CAST(`bfcol_7` AS INT64)) OVER ( + PARTITION BY `bfcol_9` + ORDER BY `bfcol_9` IS NULL ASC NULLS LAST, `bfcol_9` ASC NULLS LAST, `bfcol_2` IS NULL ASC NULLS LAST, `bfcol_2` ASC NULLS LAST + ROWS BETWEEN 3 PRECEDING AND CURRENT ROW + ), + 0 + ) + END AS `bfcol_15` + FROM `bfcte_2` +), `bfcte_4` AS ( + SELECT + * + FROM `bfcte_3` + WHERE + NOT `bfcol_9` IS NULL +), `bfcte_5` AS ( + SELECT + *, + CASE + WHEN SUM(CAST(NOT `bfcol_8` IS NULL AS INT64)) OVER ( + PARTITION BY `bfcol_9` + ORDER BY `bfcol_9` IS NULL ASC NULLS LAST, `bfcol_9` ASC NULLS LAST, `bfcol_2` IS NULL ASC NULLS LAST, `bfcol_2` ASC NULLS LAST + ROWS BETWEEN 3 PRECEDING AND CURRENT ROW + ) < 3 + THEN NULL + ELSE COALESCE( + SUM(`bfcol_8`) OVER ( + PARTITION BY `bfcol_9` + ORDER BY `bfcol_9` IS NULL ASC NULLS LAST, `bfcol_9` ASC NULLS LAST, `bfcol_2` IS NULL ASC NULLS LAST, `bfcol_2` ASC NULLS LAST + ROWS BETWEEN 3 PRECEDING AND CURRENT ROW + ), + 0 + ) + END AS `bfcol_21` + FROM `bfcte_4` +) +SELECT + `bfcol_9` AS `bool_col`, + `bfcol_6` AS `rowindex`, + `bfcol_15` AS `bool_col_1`, + `bfcol_21` AS `int64_col` +FROM `bfcte_5` +ORDER BY + `bfcol_9` ASC NULLS LAST, + `bfcol_2` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_min_periods/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_min_periods/out.sql new file mode 100644 index 0000000000..5885f5ab3c --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_min_periods/out.sql @@ -0,0 +1,30 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0`, + `rowindex` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE + WHEN SUM(CAST(NOT `bfcol_0` IS NULL AS INT64)) OVER ( + ORDER BY `bfcol_1` IS NULL ASC NULLS LAST, `bfcol_1` ASC NULLS LAST + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + ) < 3 + THEN NULL + ELSE COALESCE( + SUM(`bfcol_0`) OVER ( + ORDER BY `bfcol_1` IS NULL ASC NULLS LAST, `bfcol_1` ASC NULLS LAST + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + ), + 0 + ) + END AS `bfcol_4` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `rowindex`, + `bfcol_4` AS `int64_col` +FROM `bfcte_1` +ORDER BY + `bfcol_1` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_range_rolling/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_range_rolling/out.sql new file mode 100644 index 0000000000..581c81c6b4 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_range_rolling/out.sql @@ -0,0 +1,30 @@ +WITH `bfcte_0` AS ( + SELECT + * + FROM UNNEST(ARRAY>[STRUCT(CAST('2025-01-01T00:00:00+00:00' AS TIMESTAMP), 0, 0), STRUCT(CAST('2025-01-01T00:00:01+00:00' AS TIMESTAMP), 1, 1), STRUCT(CAST('2025-01-01T00:00:02+00:00' AS TIMESTAMP), 2, 2), STRUCT(CAST('2025-01-01T00:00:03+00:00' AS TIMESTAMP), 3, 3), STRUCT(CAST('2025-01-01T00:00:04+00:00' AS TIMESTAMP), 0, 4), STRUCT(CAST('2025-01-01T00:00:05+00:00' AS TIMESTAMP), 1, 5), STRUCT(CAST('2025-01-01T00:00:06+00:00' AS TIMESTAMP), 2, 6), STRUCT(CAST('2025-01-01T00:00:07+00:00' AS TIMESTAMP), 3, 7), STRUCT(CAST('2025-01-01T00:00:08+00:00' AS TIMESTAMP), 0, 8), STRUCT(CAST('2025-01-01T00:00:09+00:00' AS TIMESTAMP), 1, 9), STRUCT(CAST('2025-01-01T00:00:10+00:00' AS TIMESTAMP), 2, 10), STRUCT(CAST('2025-01-01T00:00:11+00:00' AS TIMESTAMP), 3, 11), STRUCT(CAST('2025-01-01T00:00:12+00:00' AS TIMESTAMP), 0, 12), STRUCT(CAST('2025-01-01T00:00:13+00:00' AS TIMESTAMP), 1, 13), STRUCT(CAST('2025-01-01T00:00:14+00:00' AS TIMESTAMP), 2, 14), STRUCT(CAST('2025-01-01T00:00:15+00:00' AS TIMESTAMP), 3, 15), STRUCT(CAST('2025-01-01T00:00:16+00:00' AS TIMESTAMP), 0, 16), STRUCT(CAST('2025-01-01T00:00:17+00:00' AS TIMESTAMP), 1, 17), STRUCT(CAST('2025-01-01T00:00:18+00:00' AS TIMESTAMP), 2, 18), STRUCT(CAST('2025-01-01T00:00:19+00:00' AS TIMESTAMP), 3, 19)]) +), `bfcte_1` AS ( + SELECT + *, + CASE + WHEN SUM(CAST(NOT `bfcol_1` IS NULL AS INT64)) OVER ( + ORDER BY UNIX_MICROS(`bfcol_0`) ASC NULLS LAST + RANGE BETWEEN 2999999 PRECEDING AND CURRENT ROW + ) < 1 + THEN NULL + ELSE COALESCE( + SUM(`bfcol_1`) OVER ( + ORDER BY UNIX_MICROS(`bfcol_0`) ASC NULLS LAST + RANGE BETWEEN 2999999 PRECEDING AND CURRENT ROW + ), + 0 + ) + END AS `bfcol_6` + FROM `bfcte_0` +) +SELECT + `bfcol_0` AS `ts_col`, + `bfcol_6` AS `int_col` +FROM `bfcte_1` +ORDER BY + `bfcol_0` ASC NULLS LAST, + `bfcol_2` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_rolling/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_rolling/out.sql new file mode 100644 index 0000000000..6d779a40ac --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_rolling/out.sql @@ -0,0 +1,30 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0`, + `rowindex` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE + WHEN SUM(CAST(NOT `bfcol_0` IS NULL AS INT64)) OVER ( + ORDER BY `bfcol_1` IS NULL ASC NULLS LAST, `bfcol_1` ASC NULLS LAST + ROWS BETWEEN 2 PRECEDING AND CURRENT ROW + ) < 3 + THEN NULL + ELSE COALESCE( + SUM(`bfcol_0`) OVER ( + ORDER BY `bfcol_1` IS NULL ASC NULLS LAST, `bfcol_1` ASC NULLS LAST + ROWS BETWEEN 2 PRECEDING AND CURRENT ROW + ), + 0 + ) + END AS `bfcol_4` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `rowindex`, + `bfcol_4` AS `int64_col` +FROM `bfcte_1` +ORDER BY + `bfcol_1` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/test_compile_window.py b/tests/unit/core/compile/sqlglot/test_compile_window.py new file mode 100644 index 0000000000..718fb7fbfc --- /dev/null +++ b/tests/unit/core/compile/sqlglot/test_compile_window.py @@ -0,0 +1,58 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pandas as pd +import pytest + +import bigframes.pandas as bpd + +pytest.importorskip("pytest_snapshot") + + +def test_compile_window_w_rolling(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col"]].sort_index() + result = bf_df.rolling(window=3).sum() + snapshot.assert_match(result.sql, "out.sql") + + +def test_compile_window_w_groupby_rolling(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["bool_col", "int64_col"]].sort_index() + result = ( + bf_df.groupby(scalar_types_df["bool_col"]) + .rolling(window=3, closed="both") + .sum() + ) + snapshot.assert_match(result.sql, "out.sql") + + +def test_compile_window_w_min_periods(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col"]].sort_index() + result = bf_df.expanding(min_periods=3).sum() + snapshot.assert_match(result.sql, "out.sql") + + +def test_compile_window_w_range_rolling(compiler_session, snapshot): + values = np.arange(20) + pd_df = pd.DataFrame( + { + "ts_col": pd.Timestamp("20250101", tz="UTC") + pd.to_timedelta(values, "s"), + "int_col": values % 4, + "float_col": values / 2, + } + ) + bf_df = compiler_session.read_pandas(pd_df) + bf_series = bf_df.set_index("ts_col")["int_col"].sort_index() + result = bf_series.rolling(window="3s").sum() + snapshot.assert_match(result.to_frame().sql, "out.sql") From f90371caaefcc30914feb12a3a1ca48a28abb741 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 30 Jul 2025 23:58:57 +0000 Subject: [PATCH 2/3] skip tests for sqlglot sql formatting issues --- .../core/compile/sqlglot/expressions/unary_compiler.py | 5 ----- tests/unit/core/compile/sqlglot/test_compile_window.py | 9 +++++++++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py index 273eb08421..125c60bbf4 100644 --- a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py @@ -680,8 +680,3 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: @UNARY_OP_REGISTRATION.register(ops.year_op) def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="YEAR"), expression=expr.expr) - - -@UNARY_OP_REGISTRATION.register(ops.UnixMicros) -def _(op: ops.UnixMicros, expr: TypedExpr) -> sge.Expression: - return sge.func("UNIX_MICROS", expr.expr) diff --git a/tests/unit/core/compile/sqlglot/test_compile_window.py b/tests/unit/core/compile/sqlglot/test_compile_window.py index 718fb7fbfc..20258d9972 100644 --- a/tests/unit/core/compile/sqlglot/test_compile_window.py +++ b/tests/unit/core/compile/sqlglot/test_compile_window.py @@ -13,14 +13,23 @@ # limitations under the License. import numpy as np +from packaging import version import pandas as pd import pytest +import sqlglot import bigframes.pandas as bpd pytest.importorskip("pytest_snapshot") +if version.Version(sqlglot.__version__) < version.Version("25.0.0"): + pytest.skip( + "Skip tests for sqlglot < 25.0.0 due to SQL formatting issues", + allow_module_level=True, + ) + + def test_compile_window_w_rolling(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["int64_col"]].sort_index() result = bf_df.rolling(window=3).sum() From 4e4b15ab5e24013926ca3c273a96a2750153b2d7 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 13 Aug 2025 18:00:01 +0000 Subject: [PATCH 3/3] skip tests on Python < 3.11 instead --- tests/unit/core/compile/sqlglot/test_compile_window.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/core/compile/sqlglot/test_compile_window.py b/tests/unit/core/compile/sqlglot/test_compile_window.py index 20258d9972..5a6e3e5322 100644 --- a/tests/unit/core/compile/sqlglot/test_compile_window.py +++ b/tests/unit/core/compile/sqlglot/test_compile_window.py @@ -12,20 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys + import numpy as np -from packaging import version import pandas as pd import pytest -import sqlglot import bigframes.pandas as bpd pytest.importorskip("pytest_snapshot") -if version.Version(sqlglot.__version__) < version.Version("25.0.0"): +if sys.version_info < (3, 12): pytest.skip( - "Skip tests for sqlglot < 25.0.0 due to SQL formatting issues", + "Skipping test due to inconsistent SQL formatting on Python < 3.12.", allow_module_level=True, )