From 920f673b86c85803a161f0e36a36955c295e6f32 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 13 Nov 2025 23:58:19 +0000 Subject: [PATCH 1/2] refactor: add agg_ops.QcutOp to the sqlglot compiler --- .../sqlglot/aggregations/unary_compiler.py | 37 ++++++++++++++ .../compile/sqlglot/aggregations/windows.py | 7 ++- bigframes/core/compile/sqlglot/sqlglot_ir.py | 7 ++- bigframes/core/reshape/tile.py | 6 --- .../test_unary_compiler/test_qcut/out.sql | 51 +++++++++++++++++++ .../aggregations/test_unary_compiler.py | 21 ++++++++ 6 files changed, 121 insertions(+), 8 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_qcut/out.sql diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index e2bd6b8382..14da8dd555 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -400,6 +400,43 @@ def _( return apply_window_if_present(expr, window) +@UNARY_OP_REGISTRATION.register(agg_ops.QcutOp) +def _( + op: agg_ops.QcutOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + percent_ranks_order_by = sge.Ordered(this=column.expr, desc=False) + percent_ranks = apply_window_if_present( + sge.func("PERCENT_RANK"), + window, + include_framing_clauses=False, + order_by_override=[percent_ranks_order_by], + ) + if isinstance(op.quantiles, int): + scaled_rank = percent_ranks * sge.convert(op.quantiles) + # Calculate the 0-based bucket index. + bucket_index = sge.func("CEIL", scaled_rank) - sge.convert(1) + safe_bucket_index = sge.func("GREATEST", bucket_index, 0) + + return sge.If( + this=sge.Is(this=column.expr, expression=sge.Null()), + true=sge.Null(), + false=sge.Cast(this=safe_bucket_index, to="INT64"), + ) + else: + case = sge.Case() + first_quantile = sge.convert(op.quantiles[0]) + case = case.when( + sge.LT(this=percent_ranks, expression=first_quantile), sge.Null() + ) + for bucket_n in range(len(op.quantiles) - 1): + quantile = sge.convert(op.quantiles[bucket_n + 1]) + bucket = sge.convert(bucket_n) + case = case.when(sge.LTE(this=percent_ranks, expression=quantile), bucket) + return case.else_(sge.Null()) + + @UNARY_OP_REGISTRATION.register(agg_ops.QuantileOp) def _( op: agg_ops.QuantileOp, diff --git a/bigframes/core/compile/sqlglot/aggregations/windows.py b/bigframes/core/compile/sqlglot/aggregations/windows.py index 099f5832da..b775d6666a 100644 --- a/bigframes/core/compile/sqlglot/aggregations/windows.py +++ b/bigframes/core/compile/sqlglot/aggregations/windows.py @@ -26,6 +26,7 @@ def apply_window_if_present( value: sge.Expression, window: typing.Optional[window_spec.WindowSpec] = None, include_framing_clauses: bool = True, + order_by_override: typing.Optional[typing.List[sge.Ordered]] = None, ) -> sge.Expression: if window is None: return value @@ -44,7 +45,11 @@ def apply_window_if_present( else: order_by = get_window_order_by(window.ordering) - order = sge.Order(expressions=order_by) if order_by else None + order = None + if order_by_override is not None and len(order_by_override) > 0: + order = sge.Order(expressions=order_by_override) + elif order_by: + order = sge.Order(expressions=order_by) group_by = ( [ diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index 3473968450..fd3bdd532f 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -637,7 +637,12 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select: def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: - sqlglot_type = sgt.from_bigframes_dtype(dtype) + sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None + if sqlglot_type is None: + if value is not None: + raise ValueError("Cannot infer SQLGlot type from None dtype.") + return sge.Null() + if value is None: return _cast(sge.Null(), sqlglot_type) elif dtype == dtypes.BYTES_DTYPE: diff --git a/bigframes/core/reshape/tile.py b/bigframes/core/reshape/tile.py index a2efa8f927..961870616c 100644 --- a/bigframes/core/reshape/tile.py +++ b/bigframes/core/reshape/tile.py @@ -22,8 +22,6 @@ import pandas as pd import bigframes -import bigframes.constants -import bigframes.core.expression as ex import bigframes.core.ordering as order import bigframes.core.utils as utils import bigframes.core.window_spec as window_specs @@ -165,7 +163,6 @@ def qcut( f"Only duplicates='drop' is supported in BigQuery DataFrames so far. {constants.FEEDBACK_LINK}" ) block = x._block - label = block.col_id_to_label[x._value_column] block, nullity_id = block.apply_unary_op(x._value_column, ops.notnull_op) block, result = block.apply_window_op( x._value_column, @@ -175,9 +172,6 @@ def qcut( ordering=(order.ascending_over(x._value_column),), ), ) - block, result = block.project_expr( - ops.where_op.as_expr(result, nullity_id, ex.const(None)), label=label - ) return bigframes.series.Series(block.select_column(result)) diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_qcut/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_qcut/out.sql new file mode 100644 index 0000000000..79f847f92f --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_qcut/out.sql @@ -0,0 +1,51 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col`, + `rowindex` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + NOT `int64_col` IS NULL AS `bfcol_4` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + *, + IF( + `int64_col` IS NULL, + NULL, + CAST(GREATEST( + CEIL(PERCENT_RANK() OVER (PARTITION BY `bfcol_4` ORDER BY `int64_col` ASC) * 4) - 1, + 0 + ) AS INT64) + ) AS `bfcol_5` + FROM `bfcte_1` +), `bfcte_3` AS ( + SELECT + *, + NOT `int64_col` IS NULL AS `bfcol_9` + FROM `bfcte_2` +), `bfcte_4` AS ( + SELECT + *, + CASE + WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_9` ORDER BY `int64_col` ASC) < 0 + THEN NULL + WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_9` ORDER BY `int64_col` ASC) <= 0.25 + THEN 0 + WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_9` ORDER BY `int64_col` ASC) <= 0.5 + THEN 1 + WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_9` ORDER BY `int64_col` ASC) <= 0.75 + THEN 2 + WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_9` ORDER BY `int64_col` ASC) <= 1 + THEN 3 + ELSE NULL + END AS `bfcol_10` + FROM `bfcte_3` +) +SELECT + `rowindex`, + `int64_col`, + `bfcol_5` AS `qcut_w_int`, + `bfcol_10` AS `qcut_w_list` +FROM `bfcte_4` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index ab9f7febbf..184cb3925f 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -435,6 +435,27 @@ def test_pop_var(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql_window, "window_out.sql") +def test_qcut(scalar_types_df: bpd.DataFrame, snapshot): + if sys.version_info < (3, 12): + pytest.skip( + "Skipping test due to inconsistent SQL formatting on Python < 3.12.", + ) + + col_name = "int64_col" + bf = scalar_types_df[[col_name]] + bf["qcut_w_int"] = bpd.qcut(bf[col_name], q=4, labels=False, duplicates="drop") + + q_list = tuple([0, 0.25, 0.5, 0.75, 1]) + bf["qcut_w_list"] = bpd.qcut( + scalar_types_df[col_name], + q=q_list, + labels=False, + duplicates="drop", + ) + + snapshot.assert_match(bf.sql, "out.sql") + + def test_quantile(scalar_types_df: bpd.DataFrame, snapshot): col_name = "int64_col" bf_df = scalar_types_df[[col_name]] From 8588e725eabe966d0bfce49ac06bb143e0616abb Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Fri, 21 Nov 2025 00:50:36 +0000 Subject: [PATCH 2/2] undo changes in tile.py --- bigframes/core/reshape/tile.py | 6 ++++ .../test_unary_compiler/test_qcut/out.sql | 32 ++++++++++++------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/bigframes/core/reshape/tile.py b/bigframes/core/reshape/tile.py index 961870616c..a2efa8f927 100644 --- a/bigframes/core/reshape/tile.py +++ b/bigframes/core/reshape/tile.py @@ -22,6 +22,8 @@ import pandas as pd import bigframes +import bigframes.constants +import bigframes.core.expression as ex import bigframes.core.ordering as order import bigframes.core.utils as utils import bigframes.core.window_spec as window_specs @@ -163,6 +165,7 @@ def qcut( f"Only duplicates='drop' is supported in BigQuery DataFrames so far. {constants.FEEDBACK_LINK}" ) block = x._block + label = block.col_id_to_label[x._value_column] block, nullity_id = block.apply_unary_op(x._value_column, ops.notnull_op) block, result = block.apply_window_op( x._value_column, @@ -172,6 +175,9 @@ def qcut( ordering=(order.ascending_over(x._value_column),), ), ) + block, result = block.project_expr( + ops.where_op.as_expr(result, nullity_id, ex.const(None)), label=label + ) return bigframes.series.Series(block.select_column(result)) diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_qcut/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_qcut/out.sql index 79f847f92f..1aa2e436ca 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_qcut/out.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_qcut/out.sql @@ -23,29 +23,39 @@ WITH `bfcte_0` AS ( ), `bfcte_3` AS ( SELECT *, - NOT `int64_col` IS NULL AS `bfcol_9` + IF(`bfcol_4`, `bfcol_5`, NULL) AS `bfcol_6` FROM `bfcte_2` ), `bfcte_4` AS ( + SELECT + *, + NOT `int64_col` IS NULL AS `bfcol_10` + FROM `bfcte_3` +), `bfcte_5` AS ( SELECT *, CASE - WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_9` ORDER BY `int64_col` ASC) < 0 + WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_10` ORDER BY `int64_col` ASC) < 0 THEN NULL - WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_9` ORDER BY `int64_col` ASC) <= 0.25 + WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_10` ORDER BY `int64_col` ASC) <= 0.25 THEN 0 - WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_9` ORDER BY `int64_col` ASC) <= 0.5 + WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_10` ORDER BY `int64_col` ASC) <= 0.5 THEN 1 - WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_9` ORDER BY `int64_col` ASC) <= 0.75 + WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_10` ORDER BY `int64_col` ASC) <= 0.75 THEN 2 - WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_9` ORDER BY `int64_col` ASC) <= 1 + WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_10` ORDER BY `int64_col` ASC) <= 1 THEN 3 ELSE NULL - END AS `bfcol_10` - FROM `bfcte_3` + END AS `bfcol_11` + FROM `bfcte_4` +), `bfcte_6` AS ( + SELECT + *, + IF(`bfcol_10`, `bfcol_11`, NULL) AS `bfcol_12` + FROM `bfcte_5` ) SELECT `rowindex`, `int64_col`, - `bfcol_5` AS `qcut_w_int`, - `bfcol_10` AS `qcut_w_list` -FROM `bfcte_4` \ No newline at end of file + `bfcol_6` AS `qcut_w_int`, + `bfcol_12` AS `qcut_w_list` +FROM `bfcte_6` \ No newline at end of file