From b8c9713e718fa166196ee771fc356a41d2a74160 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 23 Sep 2025 22:03:58 +0000 Subject: [PATCH 1/5] refactor: add agg_ops.QuantileOp to sqlglot compiler --- .../sqlglot/aggregations/op_registration.py | 4 ++- .../sqlglot/aggregations/unary_compiler.py | 27 ++++++++++++++++--- bigframes/operations/aggregations.py | 5 +--- .../test_unary_compiler/test_quantile/out.sql | 14 ++++++++++ .../aggregations/test_unary_compiler.py | 16 +++++++++++ 5 files changed, 57 insertions(+), 9 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_quantile/out.sql diff --git a/bigframes/core/compile/sqlglot/aggregations/op_registration.py b/bigframes/core/compile/sqlglot/aggregations/op_registration.py index 996bf5b362..10d2242999 100644 --- a/bigframes/core/compile/sqlglot/aggregations/op_registration.py +++ b/bigframes/core/compile/sqlglot/aggregations/op_registration.py @@ -58,5 +58,7 @@ def __getitem__(self, op: str | agg_ops.WindowOp) -> CompilationFunc: raise ValueError(f"The operator must have a 'name' attribute. Got {op}") else: key = typing.cast(str, op.name) - return self._registered_ops[key] + if key in self._registered_ops: + return self._registered_ops[key] + return self._registered_ops[type(op)] return self._registered_ops[op] diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 598a89e4eb..8e4bd0386f 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -109,13 +109,23 @@ def _( return apply_window_if_present(sge.func("MIN", column.expr), window) -@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp) +@UNARY_OP_REGISTRATION.register(agg_ops.QuantileOp) def _( - op: agg_ops.SizeUnaryOp, - _, + op: agg_ops.QuantileOp, + column: typed_expr.TypedExpr, window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: - return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window) + # TODO: Support interpolation argument + # TODO: Support percentile_disc + result = sge.func("PERCENTILE_CONT", column.expr, sge.convert(op.q)) + if window is None: + # PERCENTILE_CONT is a navigation function, not an aggregate function, so it always needs an OVER clause. + result = sge.Window(this=result) + else: + result = apply_window_if_present(result, window) + if op.should_floor_result: + result = sge.Cast(this=sge.func("FLOOR", result), to="INT64") + return result @UNARY_OP_REGISTRATION.register(agg_ops.RankOp) @@ -130,6 +140,15 @@ def _( ) +@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp) +def _( + op: agg_ops.SizeUnaryOp, + _, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window) + + @UNARY_OP_REGISTRATION.register(agg_ops.SumOp) def _( op: agg_ops.SumOp, diff --git a/bigframes/operations/aggregations.py b/bigframes/operations/aggregations.py index f6e8600d42..800d07324e 100644 --- a/bigframes/operations/aggregations.py +++ b/bigframes/operations/aggregations.py @@ -223,13 +223,10 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT @dataclasses.dataclass(frozen=True) class QuantileOp(UnaryAggregateOp): + name: typing.ClassVar[str] = "quantile" q: float should_floor_result: bool = False - @property - def name(self): - return f"{int(self.q * 100)}%" - @property def order_independent(self) -> bool: return True diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_quantile/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_quantile/out.sql new file mode 100644 index 0000000000..c1b3d1fffa --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_quantile/out.sql @@ -0,0 +1,14 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + PERCENTILE_CONT(`bfcol_0`, 0.5) OVER () AS `bfcol_1`, + CAST(FLOOR(PERCENTILE_CONT(`bfcol_0`, 0.5) OVER ()) AS INT64) AS `bfcol_2` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `quantile`, + `bfcol_2` AS `quantile_floor` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index bf2523930f..e23e55a1b6 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -141,6 +141,22 @@ def test_min(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_quantile(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_ops_map = { + "quantile": agg_ops.QuantileOp(q=0.5).as_expr(col_name), + "quantile_floor": agg_ops.QuantileOp(q=0.5, should_floor_result=True).as_expr( + col_name + ), + } + sql = _apply_unary_agg_ops( + bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys()) + ) + + snapshot.assert_match(sql, "out.sql") + + def test_rank(scalar_types_df: bpd.DataFrame, snapshot): col_name = "int64_col" bf_df = scalar_types_df[[col_name]] From 17f04401590f29a6e47b8fe7b1164d4d83b26313 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 23 Sep 2025 22:07:59 +0000 Subject: [PATCH 2/5] refactor: add agg_ops.ApproxQuartilesOp to sqlglot compiler --- .../sqlglot/aggregations/op_registration.py | 9 +++------ .../sqlglot/aggregations/unary_compiler.py | 20 +++++++++++++++++++ .../test_approx_quartiles/out.sql | 16 +++++++++++++++ .../aggregations/test_unary_compiler.py | 15 ++++++++++++++ 4 files changed, 54 insertions(+), 6 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_approx_quartiles/out.sql diff --git a/bigframes/core/compile/sqlglot/aggregations/op_registration.py b/bigframes/core/compile/sqlglot/aggregations/op_registration.py index 10d2242999..d2c5968e4a 100644 --- a/bigframes/core/compile/sqlglot/aggregations/op_registration.py +++ b/bigframes/core/compile/sqlglot/aggregations/op_registration.py @@ -41,12 +41,9 @@ def arg_checker(*args, **kwargs): ) return item(*args, **kwargs) - if hasattr(op, "name"): - key = typing.cast(str, op.name) - if key in self._registered_ops: - raise ValueError(f"{key} is already registered") - else: - raise ValueError(f"The operator must have a 'name' attribute. Got {op}") + key = op if isinstance(op, type) else type(op) + if key in self._registered_ops: + raise ValueError(f"{key} is already registered") self._registered_ops[key] = item return arg_checker diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 8e4bd0386f..7021e2833c 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -38,6 +38,26 @@ def compile( return UNARY_OP_REGISTRATION[op](op, column, window=window) +@UNARY_OP_REGISTRATION.register(agg_ops.ApproxQuartilesOp) +def _( + op: agg_ops.ApproxQuartilesOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + if window is not None: + raise NotImplementedError("Approx Quartiles with windowing is not supported.") + # APPROX_QUANTILES returns an array of the quartiles, so we need to index it. + # The op.quartile is 1-based for the quartile, but array is 0-indexed. + # The quartiles are Q0, Q1, Q2, Q3, Q4. op.quartile is 1, 2, or 3. + # The array has 5 elements (for N=4 intervals). + # So we want the element at index `op.quartile`. + approx_quantiles_expr = sge.func("APPROX_QUANTILES", column.expr, sge.convert(4)) + return sge.Bracket( + this=approx_quantiles_expr, + expressions=[sge.func("OFFSET", sge.convert(op.quartile))], + ) + + @UNARY_OP_REGISTRATION.register(agg_ops.CountOp) def _( op: agg_ops.CountOp, diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_approx_quartiles/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_approx_quartiles/out.sql new file mode 100644 index 0000000000..e7bb16e57c --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_approx_quartiles/out.sql @@ -0,0 +1,16 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + APPROX_QUANTILES(`bfcol_0`, 4)[OFFSET(1)] AS `bfcol_1`, + APPROX_QUANTILES(`bfcol_0`, 4)[OFFSET(2)] AS `bfcol_2`, + APPROX_QUANTILES(`bfcol_0`, 4)[OFFSET(3)] AS `bfcol_3` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `q1`, + `bfcol_2` AS `q2`, + `bfcol_3` AS `q3` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index e23e55a1b6..b83ad5173d 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -63,6 +63,21 @@ def _apply_unary_window_op( return sql +def test_approx_quartiles(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_ops_map = { + "q1": agg_ops.ApproxQuartilesOp(quartile=1).as_expr(col_name), + "q2": agg_ops.ApproxQuartilesOp(quartile=2).as_expr(col_name), + "q3": agg_ops.ApproxQuartilesOp(quartile=3).as_expr(col_name), + } + sql = _apply_unary_agg_ops( + bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys()) + ) + + snapshot.assert_match(sql, "out.sql") + + def test_count(scalar_types_df: bpd.DataFrame, snapshot): col_name = "int64_col" bf_df = scalar_types_df[[col_name]] From 1cf1b983a13650b93f16917713697a29c8cf46f2 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 23 Sep 2025 22:10:46 +0000 Subject: [PATCH 3/5] refactor: add agg_ops.ApproxTopCountOp to sqlglot compiler --- .../compile/sqlglot/aggregations/unary_compiler.py | 11 +++++++++++ .../test_approx_top_count/out.sql | 12 ++++++++++++ .../sqlglot/aggregations/test_unary_compiler.py | 9 +++++++++ 3 files changed, 32 insertions(+) create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_approx_top_count/out.sql diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 7021e2833c..f3edb674dc 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -58,6 +58,17 @@ def _( ) +@UNARY_OP_REGISTRATION.register(agg_ops.ApproxTopCountOp) +def _( + op: agg_ops.ApproxTopCountOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + if window is not None: + raise NotImplementedError("Approx top count with windowing is not supported.") + return sge.func("APPROX_TOP_COUNT", column.expr, sge.convert(op.number)) + + @UNARY_OP_REGISTRATION.register(agg_ops.CountOp) def _( op: agg_ops.CountOp, diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_approx_top_count/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_approx_top_count/out.sql new file mode 100644 index 0000000000..b61a72d1b2 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_approx_top_count/out.sql @@ -0,0 +1,12 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + APPROX_TOP_COUNT(`bfcol_0`, 10) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `int64_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index b83ad5173d..4abf80df19 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -78,6 +78,15 @@ def test_approx_quartiles(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_approx_top_count(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_expr = agg_ops.ApproxTopCountOp(number=10).as_expr(col_name) + sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name]) + + snapshot.assert_match(sql, "out.sql") + + def test_count(scalar_types_df: bpd.DataFrame, snapshot): col_name = "int64_col" bf_df = scalar_types_df[[col_name]] From fe072f33176d61767840badb364805a7ee35db9f Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 24 Sep 2025 20:15:08 +0000 Subject: [PATCH 4/5] fix mypy --- .../core/compile/sqlglot/aggregations/op_registration.py | 4 ++-- .../core/compile/sqlglot/aggregations/unary_compiler.py | 2 +- bigframes/operations/aggregations.py | 5 ++++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/bigframes/core/compile/sqlglot/aggregations/op_registration.py b/bigframes/core/compile/sqlglot/aggregations/op_registration.py index d2c5968e4a..b3f257c665 100644 --- a/bigframes/core/compile/sqlglot/aggregations/op_registration.py +++ b/bigframes/core/compile/sqlglot/aggregations/op_registration.py @@ -44,7 +44,7 @@ def arg_checker(*args, **kwargs): key = op if isinstance(op, type) else type(op) if key in self._registered_ops: raise ValueError(f"{key} is already registered") - self._registered_ops[key] = item + self._registered_ops[str(key)] = item return arg_checker return decorator @@ -57,5 +57,5 @@ def __getitem__(self, op: str | agg_ops.WindowOp) -> CompilationFunc: key = typing.cast(str, op.name) if key in self._registered_ops: return self._registered_ops[key] - return self._registered_ops[type(op)] + return self._registered_ops[str(type(op))] return self._registered_ops[op] diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index f3edb674dc..11d53cdd4c 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -148,7 +148,7 @@ def _( ) -> sge.Expression: # TODO: Support interpolation argument # TODO: Support percentile_disc - result = sge.func("PERCENTILE_CONT", column.expr, sge.convert(op.q)) + result: sge.Expression = sge.func("PERCENTILE_CONT", column.expr, sge.convert(op.q)) if window is None: # PERCENTILE_CONT is a navigation function, not an aggregate function, so it always needs an OVER clause. result = sge.Window(this=result) diff --git a/bigframes/operations/aggregations.py b/bigframes/operations/aggregations.py index 800d07324e..f6e8600d42 100644 --- a/bigframes/operations/aggregations.py +++ b/bigframes/operations/aggregations.py @@ -223,10 +223,13 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT @dataclasses.dataclass(frozen=True) class QuantileOp(UnaryAggregateOp): - name: typing.ClassVar[str] = "quantile" q: float should_floor_result: bool = False + @property + def name(self): + return f"{int(self.q * 100)}%" + @property def order_independent(self) -> bool: return True From 9863f29a89adfa45b803193c3ac150338374ddd3 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 24 Sep 2025 20:30:14 +0000 Subject: [PATCH 5/5] fix unit tests --- .../sqlglot/aggregations/op_registration.py | 17 ++++++----------- .../aggregations/test_op_registration.py | 1 - 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/bigframes/core/compile/sqlglot/aggregations/op_registration.py b/bigframes/core/compile/sqlglot/aggregations/op_registration.py index b3f257c665..eb02b8bd50 100644 --- a/bigframes/core/compile/sqlglot/aggregations/op_registration.py +++ b/bigframes/core/compile/sqlglot/aggregations/op_registration.py @@ -41,21 +41,16 @@ def arg_checker(*args, **kwargs): ) return item(*args, **kwargs) - key = op if isinstance(op, type) else type(op) + key = str(op) if key in self._registered_ops: raise ValueError(f"{key} is already registered") - self._registered_ops[str(key)] = item + self._registered_ops[key] = item return arg_checker return decorator def __getitem__(self, op: str | agg_ops.WindowOp) -> CompilationFunc: - if isinstance(op, agg_ops.WindowOp): - if not hasattr(op, "name"): - raise ValueError(f"The operator must have a 'name' attribute. Got {op}") - else: - key = typing.cast(str, op.name) - if key in self._registered_ops: - return self._registered_ops[key] - return self._registered_ops[str(type(op))] - return self._registered_ops[op] + key = op if isinstance(op, type) else type(op) + if str(key) not in self._registered_ops: + raise ValueError(f"{key} is already not registered") + return self._registered_ops[str(key)] diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_op_registration.py b/tests/unit/core/compile/sqlglot/aggregations/test_op_registration.py index e3688f19df..dbdeb2307e 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_op_registration.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_op_registration.py @@ -29,7 +29,6 @@ def test_func(op: agg_ops.SizeOp, input: sge.Expression) -> sge.Expression: return input assert reg[agg_ops.SizeOp()](op, input) == test_func(op, input) - assert reg[agg_ops.SizeOp.name](op, input) == test_func(op, input) def test_register_function_first_argument_is_not_agg_op_raise_error():