Skip to content

Commit a553b47

Browse files
committed
refactor: fix test_dataframe_aggregates_median on agg_ops.QuantileOp
1 parent c6c3330 commit a553b47

File tree

3 files changed

+20
-13
lines changed

3 files changed

+20
-13
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -515,14 +515,18 @@ def _(
515515
column: typed_expr.TypedExpr,
516516
window: typing.Optional[window_spec.WindowSpec] = None,
517517
) -> sge.Expression:
518-
# TODO: Support interpolation argument
519-
# TODO: Support percentile_disc
520-
result: sge.Expression = sge.func("PERCENTILE_CONT", column.expr, sge.convert(op.q))
518+
expr = column.expr
519+
if column.dtype == dtypes.BOOL_DTYPE:
520+
expr = sge.Cast(this=expr, to="INT64")
521+
522+
result: sge.Expression = sge.func("PERCENTILE_CONT", expr, sge.convert(op.q))
521523
if window is None:
522-
# PERCENTILE_CONT is a navigation function, not an aggregate function, so it always needs an OVER clause.
524+
# PERCENTILE_CONT is a navigation function, not an aggregate function,
525+
# so it always needs an OVER clause.
523526
result = sge.Window(this=result)
524527
else:
525528
result = apply_window_if_present(result, window)
529+
526530
if op.should_floor_result:
527531
result = sge.Cast(this=sge.func("FLOOR", result), to="INT64")
528532
return result
Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
WITH `bfcte_0` AS (
22
SELECT
3+
`bool_col`,
34
`int64_col`
45
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
56
), `bfcte_1` AS (
67
SELECT
7-
PERCENTILE_CONT(`int64_col`, 0.5) OVER () AS `bfcol_1`,
8-
CAST(FLOOR(PERCENTILE_CONT(`int64_col`, 0.5) OVER ()) AS INT64) AS `bfcol_2`
8+
PERCENTILE_CONT(`int64_col`, 0.5) OVER () AS `bfcol_4`,
9+
PERCENTILE_CONT(CAST(`bool_col` AS INT64), 0.5) OVER () AS `bfcol_5`,
10+
CAST(FLOOR(PERCENTILE_CONT(`int64_col`, 0.5) OVER ()) AS INT64) AS `bfcol_6`
911
FROM `bfcte_0`
1012
)
1113
SELECT
12-
`bfcol_1` AS `quantile`,
13-
`bfcol_2` AS `quantile_floor`
14+
`bfcol_4` AS `int64`,
15+
`bfcol_5` AS `bool`,
16+
`bfcol_6` AS `int64_w_floor`
1417
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -491,12 +491,12 @@ def test_qcut(scalar_types_df: bpd.DataFrame, snapshot):
491491

492492

493493
def test_quantile(scalar_types_df: bpd.DataFrame, snapshot):
494-
col_name = "int64_col"
495-
bf_df = scalar_types_df[[col_name]]
494+
bf_df = scalar_types_df[["int64_col", "bool_col"]]
496495
agg_ops_map = {
497-
"quantile": agg_ops.QuantileOp(q=0.5).as_expr(col_name),
498-
"quantile_floor": agg_ops.QuantileOp(q=0.5, should_floor_result=True).as_expr(
499-
col_name
496+
"int64": agg_ops.QuantileOp(q=0.5).as_expr("int64_col"),
497+
"bool": agg_ops.QuantileOp(q=0.5).as_expr("bool_col"),
498+
"int64_w_floor": agg_ops.QuantileOp(q=0.5, should_floor_result=True).as_expr(
499+
"int64_col"
500500
),
501501
}
502502
sql = _apply_unary_agg_ops(

0 commit comments

Comments
 (0)