From 38bcf9ab27b35f063a590ff0fca9bc3a58895af3 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 8 Oct 2025 23:28:15 +0000 Subject: [PATCH 1/3] refactor: support agg_ops.ShiftOp for the sqlglot compiler --- .../sqlglot/aggregations/unary_compiler.py | 17 ++++++++++++ .../test_unary_compiler/test_shift/lag.sql | 13 +++++++++ .../test_unary_compiler/test_shift/lead.sql | 13 +++++++++ .../test_unary_compiler/test_shift/noop.sql | 13 +++++++++ .../aggregations/test_unary_compiler.py | 27 +++++++++++++++++++ 5 files changed, 83 insertions(+) create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_shift/lag.sql create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_shift/lead.sql create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_shift/noop.sql diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 16bd3ef099..73e31c1d44 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -240,6 +240,23 @@ def _( return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window) +@UNARY_OP_REGISTRATION.register(agg_ops.ShiftOp) +def _( + op: agg_ops.ShiftOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + if op.periods == 0: # No-op + return column.expr + if op.periods > 0: + return apply_window_if_present( + sge.func("LAG", column.expr, sge.convert(op.periods)), window + ) + return apply_window_if_present( + sge.func("LEAD", column.expr, sge.convert(-op.periods)), window + ) + + @UNARY_OP_REGISTRATION.register(agg_ops.SumOp) def _( op: agg_ops.SumOp, diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_shift/lag.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_shift/lag.sql new file mode 100644 index 0000000000..59e2c47edf --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_shift/lag.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + LAG(`bfcol_0`, 1) OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `lag` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_shift/lead.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_shift/lead.sql new file mode 100644 index 0000000000..5c82b5db39 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_shift/lead.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + LEAD(`bfcol_0`, 1) OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `lead` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_shift/noop.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_shift/noop.sql new file mode 100644 index 0000000000..fef4a2bde8 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_shift/noop.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `bfcol_0` AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `noop` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index ea15f155ad..0803d569d7 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -271,6 +271,33 @@ def test_rank(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_shift(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),)) + + # Test lag + lag_op = agg_exprs.UnaryAggregation( + agg_ops.ShiftOp(periods=1), expression.deref(col_name) + ) + lag_sql = _apply_unary_window_op(bf_df, lag_op, window, "lag") + snapshot.assert_match(lag_sql, "lag.sql") + + # Test lead + lead_op = agg_exprs.UnaryAggregation( + agg_ops.ShiftOp(periods=-1), expression.deref(col_name) + ) + lead_sql = _apply_unary_window_op(bf_df, lead_op, window, "lead") + snapshot.assert_match(lead_sql, "lead.sql") + + # Test no-op + noop_op = agg_exprs.UnaryAggregation( + agg_ops.ShiftOp(periods=0), expression.deref(col_name) + ) + noop_sql = _apply_unary_window_op(bf_df, noop_op, window, "noop") + snapshot.assert_match(noop_sql, "noop.sql") + + def test_sum(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["int64_col", "bool_col"]] agg_ops_map = { From 149ebc9f0d6aa0f701b299a65f73b0c6d7197e88 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 8 Oct 2025 23:47:25 +0000 Subject: [PATCH 2/3] refactor: add agg_ops.DiffOp --- .../sqlglot/aggregations/unary_compiler.py | 17 ++++++++++++++ .../test_diff/diff_bool.sql | 13 +++++++++++ .../test_diff/diff_int.sql | 13 +++++++++++ .../aggregations/test_unary_compiler.py | 22 +++++++++++++++++++ 4 files changed, 65 insertions(+) create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_bool.sql create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_int.sql diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 73e31c1d44..8cb6d0092b 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -151,6 +151,23 @@ def _( ) +@UNARY_OP_REGISTRATION.register(agg_ops.DiffOp) +def _( + op: agg_ops.DiffOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + shift_op_impl = UNARY_OP_REGISTRATION[agg_ops.ShiftOp(0)] + shifted = shift_op_impl(agg_ops.ShiftOp(op.periods), column, window) + if column.dtype in (dtypes.BOOL_DTYPE, dtypes.INT_DTYPE, dtypes.FLOAT_DTYPE): + if column.dtype == dtypes.BOOL_DTYPE: + return sge.NEQ(this=column.expr, expression=shifted) + else: + return sge.Sub(this=column.expr, expression=shifted) + else: + raise TypeError(f"Cannot perform diff on type {column.dtype}") + + @UNARY_OP_REGISTRATION.register(agg_ops.MaxOp) def _( op: agg_ops.MaxOp, diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_bool.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_bool.sql new file mode 100644 index 0000000000..6c7d37c037 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_bool.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `bfcol_0` <> LAG(`bfcol_0`, 1) OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `diff_bool` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_int.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_int.sql new file mode 100644 index 0000000000..1ce4953d87 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff/diff_int.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `bfcol_0` - LAG(`bfcol_0`, 1) OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `diff_int` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index 0803d569d7..a83a494e55 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -127,6 +127,28 @@ def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_diff(scalar_types_df: bpd.DataFrame, snapshot): + # Test integer + int_col = "int64_col" + bf_df_int = scalar_types_df[[int_col]] + window = window_spec.WindowSpec(ordering=(ordering.ascending_over(int_col),)) + int_op = agg_exprs.UnaryAggregation( + agg_ops.DiffOp(periods=1), expression.deref(int_col) + ) + int_sql = _apply_unary_window_op(bf_df_int, int_op, window, "diff_int") + snapshot.assert_match(int_sql, "diff_int.sql") + + # Test boolean + bool_col = "bool_col" + bf_df_bool = scalar_types_df[[bool_col]] + window = window_spec.WindowSpec(ordering=(ordering.ascending_over(bool_col),)) + bool_op = agg_exprs.UnaryAggregation( + agg_ops.DiffOp(periods=1), expression.deref(bool_col) + ) + bool_sql = _apply_unary_window_op(bf_df_bool, bool_op, window, "diff_bool") + snapshot.assert_match(bool_sql, "diff_bool.sql") + + def test_first(scalar_types_df: bpd.DataFrame, snapshot): if sys.version_info < (3, 12): pytest.skip( From 4f3c92006edf58834cb3d00eb1fb67c6a75c7ee2 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 9 Oct 2025 20:23:57 +0000 Subject: [PATCH 3/3] exclude_framing_clause --- .../core/compile/sqlglot/aggregations/unary_compiler.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 8cb6d0092b..cfa27909c6 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -267,10 +267,14 @@ def _( return column.expr if op.periods > 0: return apply_window_if_present( - sge.func("LAG", column.expr, sge.convert(op.periods)), window + sge.func("LAG", column.expr, sge.convert(op.periods)), + window, + include_framing_clauses=False, ) return apply_window_if_present( - sge.func("LEAD", column.expr, sge.convert(-op.periods)), window + sge.func("LEAD", column.expr, sge.convert(-op.periods)), + window, + include_framing_clauses=False, )