Skip to content

Commit a3c2522

Browse files
authored
refactor: add agg_ops.QuantileOp, ApproxQuartilesOp and ApproxTopCountOp to sqlglot compiler (#2110)
1 parent 8fc098a commit a3c2522

File tree

7 files changed

+143
-18
lines changed

7 files changed

+143
-18
lines changed

bigframes/core/compile/sqlglot/aggregations/op_registration.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,22 +41,16 @@ def arg_checker(*args, **kwargs):
4141
)
4242
return item(*args, **kwargs)
4343

44-
if hasattr(op, "name"):
45-
key = typing.cast(str, op.name)
46-
if key in self._registered_ops:
47-
raise ValueError(f"{key} is already registered")
48-
else:
49-
raise ValueError(f"The operator must have a 'name' attribute. Got {op}")
44+
key = str(op)
45+
if key in self._registered_ops:
46+
raise ValueError(f"{key} is already registered")
5047
self._registered_ops[key] = item
5148
return arg_checker
5249

5350
return decorator
5451

5552
def __getitem__(self, op: str | agg_ops.WindowOp) -> CompilationFunc:
56-
if isinstance(op, agg_ops.WindowOp):
57-
if not hasattr(op, "name"):
58-
raise ValueError(f"The operator must have a 'name' attribute. Got {op}")
59-
else:
60-
key = typing.cast(str, op.name)
61-
return self._registered_ops[key]
62-
return self._registered_ops[op]
53+
key = op if isinstance(op, type) else type(op)
54+
if str(key) not in self._registered_ops:
55+
raise ValueError(f"{key} is already not registered")
56+
return self._registered_ops[str(key)]

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,37 @@ def compile(
3838
return UNARY_OP_REGISTRATION[op](op, column, window=window)
3939

4040

41+
@UNARY_OP_REGISTRATION.register(agg_ops.ApproxQuartilesOp)
42+
def _(
43+
op: agg_ops.ApproxQuartilesOp,
44+
column: typed_expr.TypedExpr,
45+
window: typing.Optional[window_spec.WindowSpec] = None,
46+
) -> sge.Expression:
47+
if window is not None:
48+
raise NotImplementedError("Approx Quartiles with windowing is not supported.")
49+
# APPROX_QUANTILES returns an array of the quartiles, so we need to index it.
50+
# The op.quartile is 1-based for the quartile, but array is 0-indexed.
51+
# The quartiles are Q0, Q1, Q2, Q3, Q4. op.quartile is 1, 2, or 3.
52+
# The array has 5 elements (for N=4 intervals).
53+
# So we want the element at index `op.quartile`.
54+
approx_quantiles_expr = sge.func("APPROX_QUANTILES", column.expr, sge.convert(4))
55+
return sge.Bracket(
56+
this=approx_quantiles_expr,
57+
expressions=[sge.func("OFFSET", sge.convert(op.quartile))],
58+
)
59+
60+
61+
@UNARY_OP_REGISTRATION.register(agg_ops.ApproxTopCountOp)
62+
def _(
63+
op: agg_ops.ApproxTopCountOp,
64+
column: typed_expr.TypedExpr,
65+
window: typing.Optional[window_spec.WindowSpec] = None,
66+
) -> sge.Expression:
67+
if window is not None:
68+
raise NotImplementedError("Approx top count with windowing is not supported.")
69+
return sge.func("APPROX_TOP_COUNT", column.expr, sge.convert(op.number))
70+
71+
4172
@UNARY_OP_REGISTRATION.register(agg_ops.CountOp)
4273
def _(
4374
op: agg_ops.CountOp,
@@ -109,13 +140,23 @@ def _(
109140
return apply_window_if_present(sge.func("MIN", column.expr), window)
110141

111142

112-
@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp)
143+
@UNARY_OP_REGISTRATION.register(agg_ops.QuantileOp)
113144
def _(
114-
op: agg_ops.SizeUnaryOp,
115-
_,
145+
op: agg_ops.QuantileOp,
146+
column: typed_expr.TypedExpr,
116147
window: typing.Optional[window_spec.WindowSpec] = None,
117148
) -> sge.Expression:
118-
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)
149+
# TODO: Support interpolation argument
150+
# TODO: Support percentile_disc
151+
result: sge.Expression = sge.func("PERCENTILE_CONT", column.expr, sge.convert(op.q))
152+
if window is None:
153+
# PERCENTILE_CONT is a navigation function, not an aggregate function, so it always needs an OVER clause.
154+
result = sge.Window(this=result)
155+
else:
156+
result = apply_window_if_present(result, window)
157+
if op.should_floor_result:
158+
result = sge.Cast(this=sge.func("FLOOR", result), to="INT64")
159+
return result
119160

120161

121162
@UNARY_OP_REGISTRATION.register(agg_ops.RankOp)
@@ -130,6 +171,15 @@ def _(
130171
)
131172

132173

174+
@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp)
175+
def _(
176+
op: agg_ops.SizeUnaryOp,
177+
_,
178+
window: typing.Optional[window_spec.WindowSpec] = None,
179+
) -> sge.Expression:
180+
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)
181+
182+
133183
@UNARY_OP_REGISTRATION.register(agg_ops.SumOp)
134184
def _(
135185
op: agg_ops.SumOp,
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
APPROX_QUANTILES(`bfcol_0`, 4)[OFFSET(1)] AS `bfcol_1`,
8+
APPROX_QUANTILES(`bfcol_0`, 4)[OFFSET(2)] AS `bfcol_2`,
9+
APPROX_QUANTILES(`bfcol_0`, 4)[OFFSET(3)] AS `bfcol_3`
10+
FROM `bfcte_0`
11+
)
12+
SELECT
13+
`bfcol_1` AS `q1`,
14+
`bfcol_2` AS `q2`,
15+
`bfcol_3` AS `q3`
16+
FROM `bfcte_1`
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
APPROX_TOP_COUNT(`bfcol_0`, 10) AS `bfcol_1`
8+
FROM `bfcte_0`
9+
)
10+
SELECT
11+
`bfcol_1` AS `int64_col`
12+
FROM `bfcte_1`
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
PERCENTILE_CONT(`bfcol_0`, 0.5) OVER () AS `bfcol_1`,
8+
CAST(FLOOR(PERCENTILE_CONT(`bfcol_0`, 0.5) OVER ()) AS INT64) AS `bfcol_2`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `quantile`,
13+
`bfcol_2` AS `quantile_floor`
14+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_op_registration.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def test_func(op: agg_ops.SizeOp, input: sge.Expression) -> sge.Expression:
2929
return input
3030

3131
assert reg[agg_ops.SizeOp()](op, input) == test_func(op, input)
32-
assert reg[agg_ops.SizeOp.name](op, input) == test_func(op, input)
3332

3433

3534
def test_register_function_first_argument_is_not_agg_op_raise_error():

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,30 @@ def _apply_unary_window_op(
6363
return sql
6464

6565

66+
def test_approx_quartiles(scalar_types_df: bpd.DataFrame, snapshot):
67+
col_name = "int64_col"
68+
bf_df = scalar_types_df[[col_name]]
69+
agg_ops_map = {
70+
"q1": agg_ops.ApproxQuartilesOp(quartile=1).as_expr(col_name),
71+
"q2": agg_ops.ApproxQuartilesOp(quartile=2).as_expr(col_name),
72+
"q3": agg_ops.ApproxQuartilesOp(quartile=3).as_expr(col_name),
73+
}
74+
sql = _apply_unary_agg_ops(
75+
bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys())
76+
)
77+
78+
snapshot.assert_match(sql, "out.sql")
79+
80+
81+
def test_approx_top_count(scalar_types_df: bpd.DataFrame, snapshot):
82+
col_name = "int64_col"
83+
bf_df = scalar_types_df[[col_name]]
84+
agg_expr = agg_ops.ApproxTopCountOp(number=10).as_expr(col_name)
85+
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])
86+
87+
snapshot.assert_match(sql, "out.sql")
88+
89+
6690
def test_count(scalar_types_df: bpd.DataFrame, snapshot):
6791
col_name = "int64_col"
6892
bf_df = scalar_types_df[[col_name]]
@@ -141,6 +165,22 @@ def test_min(scalar_types_df: bpd.DataFrame, snapshot):
141165
snapshot.assert_match(sql, "out.sql")
142166

143167

168+
def test_quantile(scalar_types_df: bpd.DataFrame, snapshot):
169+
col_name = "int64_col"
170+
bf_df = scalar_types_df[[col_name]]
171+
agg_ops_map = {
172+
"quantile": agg_ops.QuantileOp(q=0.5).as_expr(col_name),
173+
"quantile_floor": agg_ops.QuantileOp(q=0.5, should_floor_result=True).as_expr(
174+
col_name
175+
),
176+
}
177+
sql = _apply_unary_agg_ops(
178+
bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys())
179+
)
180+
181+
snapshot.assert_match(sql, "out.sql")
182+
183+
144184
def test_rank(scalar_types_df: bpd.DataFrame, snapshot):
145185
col_name = "int64_col"
146186
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)