From c3113becd89c039835ad48f5d43f85e223af12bb Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Mon, 24 Nov 2025 22:55:05 +0000 Subject: [PATCH 1/2] refactor: add agg_ops.ProductOp to the sqlglot compiler --- .../compile/sqlglot/aggregations/unary_compiler.py | 11 +++++++++++ .../test_unary_compiler/test_nunique/out.sql | 12 ++++++++++++ .../sqlglot/aggregations/test_unary_compiler.py | 9 +++++++++ 3 files changed, 32 insertions(+) create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_nunique/out.sql diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 14da8dd555..457149c04c 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -386,6 +386,17 @@ def _( return apply_window_if_present(sge.func("MIN", column.expr), window) +@UNARY_OP_REGISTRATION.register(agg_ops.NuniqueOp) +def _( + op: agg_ops.NuniqueOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + return apply_window_if_present( + sge.func("COUNT", sge.Distinct(expressions=[column.expr])), window + ) + + @UNARY_OP_REGISTRATION.register(agg_ops.PopVarOp) def _( op: agg_ops.PopVarOp, diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_nunique/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_nunique/out.sql new file mode 100644 index 0000000000..f0b54934b4 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_nunique/out.sql @@ -0,0 +1,12 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + COUNT(DISTINCT `int64_col`) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `int64_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index 428c76cbb4..a293f0e3f8 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -412,6 +412,15 @@ def test_min(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql_window_partition, "window_partition_out.sql") +def test_nunique(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_expr = agg_ops.NuniqueOp().as_expr(col_name) + sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name]) + + snapshot.assert_match(sql, "out.sql") + + def test_pop_var(scalar_types_df: bpd.DataFrame, snapshot): col_names = ["int64_col", "bool_col"] bf_df = scalar_types_df[col_names] From 7c1a859d66622f35c8211e8c771f3585bd223beb Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Mon, 24 Nov 2025 23:05:29 +0000 Subject: [PATCH 2/2] refactor: add agg_ops.ProductOp to the sqlglot compiler --- .../sqlglot/aggregations/unary_compiler.py | 52 +++++++++++++++++++ .../test_unary_compiler/test_product/out.sql | 16 ++++++ .../test_product/window_partition_out.sql | 27 ++++++++++ .../aggregations/test_unary_compiler.py | 19 +++++++ 4 files changed, 114 insertions(+) create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_product/out.sql create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_product/window_partition_out.sql diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 457149c04c..171c3cc239 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -411,6 +411,58 @@ def _( return apply_window_if_present(expr, window) +@UNARY_OP_REGISTRATION.register(agg_ops.ProductOp) +def _( + op: agg_ops.ProductOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + # Need to short-circuit as log with zeroes is illegal sql + is_zero = sge.EQ(this=column.expr, expression=sge.convert(0)) + + # There is no product sql aggregate function, so must implement as a sum of logs, and then + # apply power after. Note, log and power base must be equal! This impl uses natural log. + logs = ( + sge.Case() + .when(is_zero, sge.convert(0)) + .else_(sge.func("LN", sge.func("ABS", column.expr))) + ) + logs_sum = apply_window_if_present(sge.func("SUM", logs), window) + magnitude = sge.func("EXP", logs_sum) + + # Can't determine sign from logs, so have to determine parity of count of negative inputs + is_negative = ( + sge.Case() + .when( + sge.LT(this=sge.func("SIGN", column.expr), expression=sge.convert(0)), + sge.convert(1), + ) + .else_(sge.convert(0)) + ) + negative_count = apply_window_if_present(sge.func("SUM", is_negative), window) + negative_count_parity = sge.Mod( + this=negative_count, expression=sge.convert(2) + ) # 1 if result should be negative, otherwise 0 + + any_zeroes = apply_window_if_present(sge.func("LOGICAL_OR", is_zero), window) + + float_result = ( + sge.Case() + .when(any_zeroes, sge.convert(0)) + .else_( + sge.Mul( + this=magnitude, + expression=sge.If( + this=sge.EQ(this=negative_count_parity, expression=sge.convert(1)), + true=sge.convert(-1), + false=sge.convert(1), + ), + ) + ) + ) + return float_result + + @UNARY_OP_REGISTRATION.register(agg_ops.QcutOp) def _( op: agg_ops.QcutOp, diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_product/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_product/out.sql new file mode 100644 index 0000000000..bec1527137 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_product/out.sql @@ -0,0 +1,16 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + CASE + WHEN LOGICAL_OR(`int64_col` = 0) + THEN 0 + ELSE EXP(SUM(CASE WHEN `int64_col` = 0 THEN 0 ELSE LN(ABS(`int64_col`)) END)) * IF(MOD(SUM(CASE WHEN SIGN(`int64_col`) < 0 THEN 1 ELSE 0 END), 2) = 1, -1, 1) + END AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `int64_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_product/window_partition_out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_product/window_partition_out.sql new file mode 100644 index 0000000000..9c1650222a --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_product/window_partition_out.sql @@ -0,0 +1,27 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col`, + `string_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE + WHEN LOGICAL_OR(`int64_col` = 0) OVER (PARTITION BY `string_col`) + THEN 0 + ELSE EXP( + SUM(CASE WHEN `int64_col` = 0 THEN 0 ELSE LN(ABS(`int64_col`)) END) OVER (PARTITION BY `string_col`) + ) * IF( + MOD( + SUM(CASE WHEN SIGN(`int64_col`) < 0 THEN 1 ELSE 0 END) OVER (PARTITION BY `string_col`), + 2 + ) = 1, + -1, + 1 + ) + END AS `bfcol_2` + FROM `bfcte_0` +) +SELECT + `bfcol_2` AS `agg_int64` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index a293f0e3f8..5f7d0d7653 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -443,6 +443,25 @@ def test_pop_var(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql_window, "window_out.sql") +def test_product(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_expr = agg_ops.ProductOp().as_expr(col_name) + sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name]) + + snapshot.assert_match(sql, "out.sql") + + bf_df_str = scalar_types_df[[col_name, "string_col"]] + window_partition = window_spec.WindowSpec( + grouping_keys=(expression.deref("string_col"),), + ) + sql_window_partition = _apply_unary_window_op( + bf_df_str, agg_expr, window_partition, "agg_int64" + ) + + snapshot.assert_match(sql_window_partition, "window_partition_out.sql") + + def test_qcut(scalar_types_df: bpd.DataFrame, snapshot): if sys.version_info < (3, 12): pytest.skip(