Skip to content

Commit b8c9713

Browse files
committed
refactor: add agg_ops.QuantileOp to sqlglot compiler
1 parent afe4331 commit b8c9713

File tree

5 files changed

+57
-9
lines changed

5 files changed

+57
-9
lines changed

bigframes/core/compile/sqlglot/aggregations/op_registration.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,7 @@ def __getitem__(self, op: str | agg_ops.WindowOp) -> CompilationFunc:
5858
raise ValueError(f"The operator must have a 'name' attribute. Got {op}")
5959
else:
6060
key = typing.cast(str, op.name)
61-
return self._registered_ops[key]
61+
if key in self._registered_ops:
62+
return self._registered_ops[key]
63+
return self._registered_ops[type(op)]
6264
return self._registered_ops[op]

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,23 @@ def _(
109109
return apply_window_if_present(sge.func("MIN", column.expr), window)
110110

111111

112-
@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp)
112+
@UNARY_OP_REGISTRATION.register(agg_ops.QuantileOp)
113113
def _(
114-
op: agg_ops.SizeUnaryOp,
115-
_,
114+
op: agg_ops.QuantileOp,
115+
column: typed_expr.TypedExpr,
116116
window: typing.Optional[window_spec.WindowSpec] = None,
117117
) -> sge.Expression:
118-
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)
118+
# TODO: Support interpolation argument
119+
# TODO: Support percentile_disc
120+
result = sge.func("PERCENTILE_CONT", column.expr, sge.convert(op.q))
121+
if window is None:
122+
# PERCENTILE_CONT is a navigation function, not an aggregate function, so it always needs an OVER clause.
123+
result = sge.Window(this=result)
124+
else:
125+
result = apply_window_if_present(result, window)
126+
if op.should_floor_result:
127+
result = sge.Cast(this=sge.func("FLOOR", result), to="INT64")
128+
return result
119129

120130

121131
@UNARY_OP_REGISTRATION.register(agg_ops.RankOp)
@@ -130,6 +140,15 @@ def _(
130140
)
131141

132142

143+
@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp)
144+
def _(
145+
op: agg_ops.SizeUnaryOp,
146+
_,
147+
window: typing.Optional[window_spec.WindowSpec] = None,
148+
) -> sge.Expression:
149+
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)
150+
151+
133152
@UNARY_OP_REGISTRATION.register(agg_ops.SumOp)
134153
def _(
135154
op: agg_ops.SumOp,

bigframes/operations/aggregations.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,10 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
223223

224224
@dataclasses.dataclass(frozen=True)
225225
class QuantileOp(UnaryAggregateOp):
226+
name: typing.ClassVar[str] = "quantile"
226227
q: float
227228
should_floor_result: bool = False
228229

229-
@property
230-
def name(self):
231-
return f"{int(self.q * 100)}%"
232-
233230
@property
234231
def order_independent(self) -> bool:
235232
return True
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
PERCENTILE_CONT(`bfcol_0`, 0.5) OVER () AS `bfcol_1`,
8+
CAST(FLOOR(PERCENTILE_CONT(`bfcol_0`, 0.5) OVER ()) AS INT64) AS `bfcol_2`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `quantile`,
13+
`bfcol_2` AS `quantile_floor`
14+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,22 @@ def test_min(scalar_types_df: bpd.DataFrame, snapshot):
141141
snapshot.assert_match(sql, "out.sql")
142142

143143

144+
def test_quantile(scalar_types_df: bpd.DataFrame, snapshot):
145+
col_name = "int64_col"
146+
bf_df = scalar_types_df[[col_name]]
147+
agg_ops_map = {
148+
"quantile": agg_ops.QuantileOp(q=0.5).as_expr(col_name),
149+
"quantile_floor": agg_ops.QuantileOp(q=0.5, should_floor_result=True).as_expr(
150+
col_name
151+
),
152+
}
153+
sql = _apply_unary_agg_ops(
154+
bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys())
155+
)
156+
157+
snapshot.assert_match(sql, "out.sql")
158+
159+
144160
def test_rank(scalar_types_df: bpd.DataFrame, snapshot):
145161
col_name = "int64_col"
146162
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)