From bab9576fb09ffcaf732754c6c81386b5d1f67f0a Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 8 Oct 2025 21:44:14 +0000 Subject: [PATCH 1/3] refactor: support agg_ops.FirstOp, FirstNonNullOp in the sqlglot compiler --- .../sqlglot/aggregations/nullary_compiler.py | 2 +- .../sqlglot/aggregations/unary_compiler.py | 25 +++++++++++++++++-- .../compile/sqlglot/aggregations/windows.py | 7 +++--- .../test_unary_compiler/test_first/out.sql | 20 +++++++++++++++ .../test_first_non_null/out.sql | 16 ++++++++++++ .../aggregations/test_unary_compiler.py | 22 ++++++++++++++++ 6 files changed, 86 insertions(+), 6 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_first/out.sql create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_first_non_null/out.sql diff --git a/bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py index c6418591ba..95dad4ff3b 100644 --- a/bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py @@ -50,4 +50,4 @@ def _( if window is None: # ROW_NUMBER always needs an OVER clause. return sge.Window(this=result) - return apply_window_if_present(result, window) + return apply_window_if_present(result, window, include_framing_clauses=False) diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 1e87fd1fc5..04bc29def8 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -104,7 +104,28 @@ def _( column: typed_expr.TypedExpr, window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: - return apply_window_if_present(sge.func("DENSE_RANK"), window) + return apply_window_if_present(sge.func("DENSE_RANK"), window, include_framing_clauses=False) + + +@UNARY_OP_REGISTRATION.register(agg_ops.FirstOp) +def _( + op: agg_ops.FirstOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + # FIRST_VALUE in BQ respects nulls by default. + return apply_window_if_present(sge.FirstValue(this=column.expr), window) + + +@UNARY_OP_REGISTRATION.register(agg_ops.FirstNonNullOp) +def _( + op: agg_ops.FirstNonNullOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + return apply_window_if_present( + sge.IgnoreNulls(this=sge.FirstValue(this=column.expr)), window + ) @UNARY_OP_REGISTRATION.register(agg_ops.MaxOp) @@ -182,7 +203,7 @@ def _( column: typed_expr.TypedExpr, window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: - return apply_window_if_present(sge.func("RANK"), window) + return apply_window_if_present(sge.func("RANK"), window, include_framing_clauses=False) @UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp) diff --git a/bigframes/core/compile/sqlglot/aggregations/windows.py b/bigframes/core/compile/sqlglot/aggregations/windows.py index 5e38bf120e..41b4c674f9 100644 --- a/bigframes/core/compile/sqlglot/aggregations/windows.py +++ b/bigframes/core/compile/sqlglot/aggregations/windows.py @@ -25,6 +25,7 @@ def apply_window_if_present( value: sge.Expression, window: typing.Optional[window_spec.WindowSpec] = None, + include_framing_clauses: bool = True, ) -> sge.Expression: if window is None: return value @@ -64,11 +65,11 @@ def apply_window_if_present( if not window.bounds and not order: return sge.Window(this=value, partition_by=group_by) - if not window.bounds: + if not window.bounds and not include_framing_clauses: return sge.Window(this=value, partition_by=group_by, order=order) kind = ( - "ROWS" if isinstance(window.bounds, window_spec.RowsWindowBounds) else "RANGE" + "RANGE" if isinstance(window.bounds, window_spec.RangeWindowBounds) else "ROWS" ) start: typing.Union[int, float, None] = None @@ -125,7 +126,7 @@ def get_window_order_by( nulls_first=nulls_first, ) ) - elif not nulls_first and not desc: + elif (not nulls_first) and (not desc): order_by.append( sge.Ordered( this=is_null_expr, diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_first/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_first/out.sql new file mode 100644 index 0000000000..6c7d39c24a --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_first/out.sql @@ -0,0 +1,20 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE + WHEN `bfcol_0` IS NULL + THEN NULL + ELSE FIRST_VALUE(`bfcol_0`) OVER ( + ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) + END AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `agg_int64` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_first_non_null/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_first_non_null/out.sql new file mode 100644 index 0000000000..ff90c6fcd9 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_first_non_null/out.sql @@ -0,0 +1,16 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + FIRST_VALUE(`bfcol_0` IGNORE NULLS) OVER ( + ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `agg_int64` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index ea7faca7fb..8a9bf08a14 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -126,6 +126,28 @@ def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_first(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_expr = agg_exprs.UnaryAggregation(agg_ops.FirstOp(), expression.deref(col_name)) + window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),)) + sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64") + + snapshot.assert_match(sql, "out.sql") + + +def test_first_non_null(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_expr = agg_exprs.UnaryAggregation( + agg_ops.FirstNonNullOp(), expression.deref(col_name) + ) + window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),)) + sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64") + + snapshot.assert_match(sql, "out.sql") + + def test_max(scalar_types_df: bpd.DataFrame, snapshot): col_name = "int64_col" bf_df = scalar_types_df[[col_name]] From 4cab482cc334767704420c3935f60a03eb56ad9d Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 8 Oct 2025 21:51:48 +0000 Subject: [PATCH 2/3] refactor: support agg_ops.LastOp, LastNonNullOp in the sqlglot compiler --- .../sqlglot/aggregations/unary_compiler.py | 29 +++++++++++++++++-- .../test_unary_compiler/test_last/out.sql | 20 +++++++++++++ .../test_last_non_null/out.sql | 16 ++++++++++ .../aggregations/test_unary_compiler.py | 22 ++++++++++++++ 4 files changed, 85 insertions(+), 2 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_last/out.sql create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_last_non_null/out.sql diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 04bc29def8..16bd3ef099 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -104,7 +104,9 @@ def _( column: typed_expr.TypedExpr, window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: - return apply_window_if_present(sge.func("DENSE_RANK"), window, include_framing_clauses=False) + return apply_window_if_present( + sge.func("DENSE_RANK"), window, include_framing_clauses=False + ) @UNARY_OP_REGISTRATION.register(agg_ops.FirstOp) @@ -128,6 +130,27 @@ def _( ) +@UNARY_OP_REGISTRATION.register(agg_ops.LastOp) +def _( + op: agg_ops.LastOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + # LAST_VALUE in BQ respects nulls by default. + return apply_window_if_present(sge.LastValue(this=column.expr), window) + + +@UNARY_OP_REGISTRATION.register(agg_ops.LastNonNullOp) +def _( + op: agg_ops.LastNonNullOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + return apply_window_if_present( + sge.IgnoreNulls(this=sge.LastValue(this=column.expr)), window + ) + + @UNARY_OP_REGISTRATION.register(agg_ops.MaxOp) def _( op: agg_ops.MaxOp, @@ -203,7 +226,9 @@ def _( column: typed_expr.TypedExpr, window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: - return apply_window_if_present(sge.func("RANK"), window, include_framing_clauses=False) + return apply_window_if_present( + sge.func("RANK"), window, include_framing_clauses=False + ) @UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp) diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_last/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_last/out.sql new file mode 100644 index 0000000000..788c5ba466 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_last/out.sql @@ -0,0 +1,20 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE + WHEN `bfcol_0` IS NULL + THEN NULL + ELSE LAST_VALUE(`bfcol_0`) OVER ( + ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) + END AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `agg_int64` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_last_non_null/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_last_non_null/out.sql new file mode 100644 index 0000000000..17e7dbd446 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_last_non_null/out.sql @@ -0,0 +1,16 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + LAST_VALUE(`bfcol_0` IGNORE NULLS) OVER ( + ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `agg_int64` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index 8a9bf08a14..2b563c2463 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -148,6 +148,28 @@ def test_first_non_null(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_last(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_expr = agg_exprs.UnaryAggregation(agg_ops.LastOp(), expression.deref(col_name)) + window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),)) + sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64") + + snapshot.assert_match(sql, "out.sql") + + +def test_last_non_null(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_expr = agg_exprs.UnaryAggregation( + agg_ops.LastNonNullOp(), expression.deref(col_name) + ) + window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),)) + sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64") + + snapshot.assert_match(sql, "out.sql") + + def test_max(scalar_types_df: bpd.DataFrame, snapshot): col_name = "int64_col" bf_df = scalar_types_df[[col_name]] From 53337198e6e481b1a4e7dab300d7e77870312a7d Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 8 Oct 2025 23:22:51 +0000 Subject: [PATCH 3/3] exclude tests for lower version --- .../sqlglot/aggregations/test_unary_compiler.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index 2b563c2463..ea15f155ad 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import typing import pytest @@ -127,6 +128,10 @@ def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot): def test_first(scalar_types_df: bpd.DataFrame, snapshot): + if sys.version_info < (3, 12): + pytest.skip( + "Skipping test due to inconsistent SQL formatting on Python < 3.12.", + ) col_name = "int64_col" bf_df = scalar_types_df[[col_name]] agg_expr = agg_exprs.UnaryAggregation(agg_ops.FirstOp(), expression.deref(col_name)) @@ -137,6 +142,10 @@ def test_first(scalar_types_df: bpd.DataFrame, snapshot): def test_first_non_null(scalar_types_df: bpd.DataFrame, snapshot): + if sys.version_info < (3, 12): + pytest.skip( + "Skipping test due to inconsistent SQL formatting on Python < 3.12.", + ) col_name = "int64_col" bf_df = scalar_types_df[[col_name]] agg_expr = agg_exprs.UnaryAggregation( @@ -149,6 +158,10 @@ def test_first_non_null(scalar_types_df: bpd.DataFrame, snapshot): def test_last(scalar_types_df: bpd.DataFrame, snapshot): + if sys.version_info < (3, 12): + pytest.skip( + "Skipping test due to inconsistent SQL formatting on Python < 3.12.", + ) col_name = "int64_col" bf_df = scalar_types_df[[col_name]] agg_expr = agg_exprs.UnaryAggregation(agg_ops.LastOp(), expression.deref(col_name)) @@ -159,6 +172,10 @@ def test_last(scalar_types_df: bpd.DataFrame, snapshot): def test_last_non_null(scalar_types_df: bpd.DataFrame, snapshot): + if sys.version_info < (3, 12): + pytest.skip( + "Skipping test due to inconsistent SQL formatting on Python < 3.12.", + ) col_name = "int64_col" bf_df = scalar_types_df[[col_name]] agg_expr = agg_exprs.UnaryAggregation(