Skip to content

Commit 7990eb3

Browse files
committed
refactor: add agg_ops.ApproxQuartilesOp to sqlglot compiler
1 parent bf2e4f8 commit 7990eb3

File tree

4 files changed

+54
-6
lines changed

4 files changed

+54
-6
lines changed

bigframes/core/compile/sqlglot/aggregations/op_registration.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,9 @@ def arg_checker(*args, **kwargs):
4141
)
4242
return item(*args, **kwargs)
4343

44-
if hasattr(op, "name"):
45-
key = typing.cast(str, op.name)
46-
if key in self._registered_ops:
47-
raise ValueError(f"{key} is already registered")
48-
else:
49-
raise ValueError(f"The operator must have a 'name' attribute. Got {op}")
44+
key = op if isinstance(op, type) else type(op)
45+
if key in self._registered_ops:
46+
raise ValueError(f"{key} is already registered")
5047
self._registered_ops[key] = item
5148
return arg_checker
5249

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,26 @@ def compile(
3838
return UNARY_OP_REGISTRATION[op](op, column, window=window)
3939

4040

41+
@UNARY_OP_REGISTRATION.register(agg_ops.ApproxQuartilesOp)
42+
def _(
43+
op: agg_ops.ApproxQuartilesOp,
44+
column: typed_expr.TypedExpr,
45+
window: typing.Optional[window_spec.WindowSpec] = None,
46+
) -> sge.Expression:
47+
if window is not None:
48+
raise NotImplementedError("Approx Quartiles with windowing is not supported.")
49+
# APPROX_QUANTILES returns an array of the quartiles, so we need to index it.
50+
# The op.quartile is 1-based for the quartile, but array is 0-indexed.
51+
# The quartiles are Q0, Q1, Q2, Q3, Q4. op.quartile is 1, 2, or 3.
52+
# The array has 5 elements (for N=4 intervals).
53+
# So we want the element at index `op.quartile`.
54+
approx_quantiles_expr = sge.func("APPROX_QUANTILES", column.expr, sge.convert(4))
55+
return sge.Bracket(
56+
this=approx_quantiles_expr,
57+
expressions=[sge.func("OFFSET", sge.convert(op.quartile))],
58+
)
59+
60+
4161
@UNARY_OP_REGISTRATION.register(agg_ops.CountOp)
4262
def _(
4363
op: agg_ops.CountOp,
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
APPROX_QUANTILES(`bfcol_0`, 4)[OFFSET(1)] AS `bfcol_1`,
8+
APPROX_QUANTILES(`bfcol_0`, 4)[OFFSET(2)] AS `bfcol_2`,
9+
APPROX_QUANTILES(`bfcol_0`, 4)[OFFSET(3)] AS `bfcol_3`
10+
FROM `bfcte_0`
11+
)
12+
SELECT
13+
`bfcol_1` AS `q1`,
14+
`bfcol_2` AS `q2`,
15+
`bfcol_3` AS `q3`
16+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,21 @@ def _apply_unary_agg_ops(
3838
return sql
3939

4040

41+
def test_approx_quartiles(scalar_types_df: bpd.DataFrame, snapshot):
42+
col_name = "int64_col"
43+
bf_df = scalar_types_df[[col_name]]
44+
agg_ops_map = {
45+
"q1": agg_ops.ApproxQuartilesOp(quartile=1).as_expr(col_name),
46+
"q2": agg_ops.ApproxQuartilesOp(quartile=2).as_expr(col_name),
47+
"q3": agg_ops.ApproxQuartilesOp(quartile=3).as_expr(col_name),
48+
}
49+
sql = _apply_unary_agg_ops(
50+
bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys())
51+
)
52+
53+
snapshot.assert_match(sql, "out.sql")
54+
55+
4156
def test_count(scalar_types_df: bpd.DataFrame, snapshot):
4257
col_name = "int64_col"
4358
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)