Skip to content

Commit bf2e4f8

Browse files
committed
refactor: add agg_ops.QuantileOp to sqlglot compiler
1 parent caa824a commit bf2e4f8

File tree

5 files changed

+53
-5
lines changed

5 files changed

+53
-5
lines changed

bigframes/core/compile/sqlglot/aggregations/op_registration.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,7 @@ def __getitem__(self, op: str | agg_ops.WindowOp) -> CompilationFunc:
5858
raise ValueError(f"The operator must have a 'name' attribute. Got {op}")
5959
else:
6060
key = typing.cast(str, op.name)
61-
return self._registered_ops[key]
61+
if key in self._registered_ops:
62+
return self._registered_ops[key]
63+
return self._registered_ops[type(op)]
6264
return self._registered_ops[op]

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,25 @@ def _(
9797
return apply_window_if_present(sge.func("MIN", column.expr), window)
9898

9999

100+
@UNARY_OP_REGISTRATION.register(agg_ops.QuantileOp)
101+
def _(
102+
op: agg_ops.QuantileOp,
103+
column: typed_expr.TypedExpr,
104+
window: typing.Optional[window_spec.WindowSpec] = None,
105+
) -> sge.Expression:
106+
# TODO: Support interpolation argument
107+
# TODO: Support percentile_disc
108+
result = sge.func("PERCENTILE_CONT", column.expr, sge.convert(op.q))
109+
if window is None:
110+
# PERCENTILE_CONT is a navigation function, not an aggregate function, so it always needs an OVER clause.
111+
result = sge.Window(this=result)
112+
else:
113+
result = apply_window_if_present(result, window)
114+
if op.should_floor_result:
115+
result = sge.Cast(this=sge.func("FLOOR", result), to="INT64")
116+
return result
117+
118+
100119
@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp)
101120
def _(
102121
op: agg_ops.SizeUnaryOp,

bigframes/operations/aggregations.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,10 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
223223

224224
@dataclasses.dataclass(frozen=True)
225225
class QuantileOp(UnaryAggregateOp):
226+
name: typing.ClassVar[str] = "quantile"
226227
q: float
227228
should_floor_result: bool = False
228229

229-
@property
230-
def name(self):
231-
return f"{int(self.q * 100)}%"
232-
233230
@property
234231
def order_independent(self) -> bool:
235232
return True
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
PERCENTILE_CONT(`bfcol_0`, 0.5) OVER () AS `bfcol_1`,
8+
CAST(FLOOR(PERCENTILE_CONT(`bfcol_0`, 0.5) OVER ()) AS INT64) AS `bfcol_2`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `quantile`,
13+
`bfcol_2` AS `quantile_floor`
14+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,22 @@ def test_min(scalar_types_df: bpd.DataFrame, snapshot):
104104
snapshot.assert_match(sql, "out.sql")
105105

106106

107+
def test_quantile(scalar_types_df: bpd.DataFrame, snapshot):
108+
col_name = "int64_col"
109+
bf_df = scalar_types_df[[col_name]]
110+
agg_ops_map = {
111+
"quantile": agg_ops.QuantileOp(q=0.5).as_expr(col_name),
112+
"quantile_floor": agg_ops.QuantileOp(q=0.5, should_floor_result=True).as_expr(
113+
col_name
114+
),
115+
}
116+
sql = _apply_unary_agg_ops(
117+
bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys())
118+
)
119+
120+
snapshot.assert_match(sql, "out.sql")
121+
122+
107123
def test_sum(scalar_types_df: bpd.DataFrame, snapshot):
108124
bf_df = scalar_types_df[["int64_col", "bool_col"]]
109125
agg_ops_map = {

0 commit comments

Comments
 (0)