diff --git a/bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py index c6418591ba..95dad4ff3b 100644 --- a/bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py @@ -50,4 +50,4 @@ def _( if window is None: # ROW_NUMBER always needs an OVER clause. return sge.Window(this=result) - return apply_window_if_present(result, window) + return apply_window_if_present(result, window, include_framing_clauses=False) diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 1e87fd1fc5..16bd3ef099 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -104,7 +104,51 @@ def _( column: typed_expr.TypedExpr, window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: - return apply_window_if_present(sge.func("DENSE_RANK"), window) + return apply_window_if_present( + sge.func("DENSE_RANK"), window, include_framing_clauses=False + ) + + +@UNARY_OP_REGISTRATION.register(agg_ops.FirstOp) +def _( + op: agg_ops.FirstOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + # FIRST_VALUE in BQ respects nulls by default. + return apply_window_if_present(sge.FirstValue(this=column.expr), window) + + +@UNARY_OP_REGISTRATION.register(agg_ops.FirstNonNullOp) +def _( + op: agg_ops.FirstNonNullOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + return apply_window_if_present( + sge.IgnoreNulls(this=sge.FirstValue(this=column.expr)), window + ) + + +@UNARY_OP_REGISTRATION.register(agg_ops.LastOp) +def _( + op: agg_ops.LastOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + # LAST_VALUE in BQ respects nulls by default. + return apply_window_if_present(sge.LastValue(this=column.expr), window) + + +@UNARY_OP_REGISTRATION.register(agg_ops.LastNonNullOp) +def _( + op: agg_ops.LastNonNullOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + return apply_window_if_present( + sge.IgnoreNulls(this=sge.LastValue(this=column.expr)), window + ) @UNARY_OP_REGISTRATION.register(agg_ops.MaxOp) @@ -182,7 +226,9 @@ def _( column: typed_expr.TypedExpr, window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: - return apply_window_if_present(sge.func("RANK"), window) + return apply_window_if_present( + sge.func("RANK"), window, include_framing_clauses=False + ) @UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp) diff --git a/bigframes/core/compile/sqlglot/aggregations/windows.py b/bigframes/core/compile/sqlglot/aggregations/windows.py index 5e38bf120e..41b4c674f9 100644 --- a/bigframes/core/compile/sqlglot/aggregations/windows.py +++ b/bigframes/core/compile/sqlglot/aggregations/windows.py @@ -25,6 +25,7 @@ def apply_window_if_present( value: sge.Expression, window: typing.Optional[window_spec.WindowSpec] = None, + include_framing_clauses: bool = True, ) -> sge.Expression: if window is None: return value @@ -64,11 +65,11 @@ def apply_window_if_present( if not window.bounds and not order: return sge.Window(this=value, partition_by=group_by) - if not window.bounds: + if not window.bounds and not include_framing_clauses: return sge.Window(this=value, partition_by=group_by, order=order) kind = ( - "ROWS" if isinstance(window.bounds, window_spec.RowsWindowBounds) else "RANGE" + "RANGE" if isinstance(window.bounds, window_spec.RangeWindowBounds) else "ROWS" ) start: typing.Union[int, float, None] = None @@ -125,7 +126,7 @@ def get_window_order_by( nulls_first=nulls_first, ) ) - elif not nulls_first and not desc: + elif (not nulls_first) and (not desc): order_by.append( sge.Ordered( this=is_null_expr, diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_first/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_first/out.sql new file mode 100644 index 0000000000..6c7d39c24a --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_first/out.sql @@ -0,0 +1,20 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE + WHEN `bfcol_0` IS NULL + THEN NULL + ELSE FIRST_VALUE(`bfcol_0`) OVER ( + ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) + END AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `agg_int64` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_first_non_null/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_first_non_null/out.sql new file mode 100644 index 0000000000..ff90c6fcd9 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_first_non_null/out.sql @@ -0,0 +1,16 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + FIRST_VALUE(`bfcol_0` IGNORE NULLS) OVER ( + ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `agg_int64` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_last/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_last/out.sql new file mode 100644 index 0000000000..788c5ba466 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_last/out.sql @@ -0,0 +1,20 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE + WHEN `bfcol_0` IS NULL + THEN NULL + ELSE LAST_VALUE(`bfcol_0`) OVER ( + ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) + END AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `agg_int64` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_last_non_null/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_last_non_null/out.sql new file mode 100644 index 0000000000..17e7dbd446 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_last_non_null/out.sql @@ -0,0 +1,16 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + LAST_VALUE(`bfcol_0` IGNORE NULLS) OVER ( + ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `agg_int64` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index ea7faca7fb..ea15f155ad 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import typing import pytest @@ -126,6 +127,66 @@ def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_first(scalar_types_df: bpd.DataFrame, snapshot): + if sys.version_info < (3, 12): + pytest.skip( + "Skipping test due to inconsistent SQL formatting on Python < 3.12.", + ) + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_expr = agg_exprs.UnaryAggregation(agg_ops.FirstOp(), expression.deref(col_name)) + window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),)) + sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64") + + snapshot.assert_match(sql, "out.sql") + + +def test_first_non_null(scalar_types_df: bpd.DataFrame, snapshot): + if sys.version_info < (3, 12): + pytest.skip( + "Skipping test due to inconsistent SQL formatting on Python < 3.12.", + ) + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_expr = agg_exprs.UnaryAggregation( + agg_ops.FirstNonNullOp(), expression.deref(col_name) + ) + window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),)) + sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64") + + snapshot.assert_match(sql, "out.sql") + + +def test_last(scalar_types_df: bpd.DataFrame, snapshot): + if sys.version_info < (3, 12): + pytest.skip( + "Skipping test due to inconsistent SQL formatting on Python < 3.12.", + ) + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_expr = agg_exprs.UnaryAggregation(agg_ops.LastOp(), expression.deref(col_name)) + window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),)) + sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64") + + snapshot.assert_match(sql, "out.sql") + + +def test_last_non_null(scalar_types_df: bpd.DataFrame, snapshot): + if sys.version_info < (3, 12): + pytest.skip( + "Skipping test due to inconsistent SQL formatting on Python < 3.12.", + ) + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_expr = agg_exprs.UnaryAggregation( + agg_ops.LastNonNullOp(), expression.deref(col_name) + ) + window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),)) + sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64") + + snapshot.assert_match(sql, "out.sql") + + def test_max(scalar_types_df: bpd.DataFrame, snapshot): col_name = "int64_col" bf_df = scalar_types_df[[col_name]]