Skip to content

Commit ca1e44c

Browse files
authored
refactor: add agg_ops.MedianOp compiler to sqlglot (#2108)
* refactor: add agg_ops.MedianOp compiler to sqlglot * enable engine tests * enable non-numeric for ibis compiler too
1 parent af6b862 commit ca1e44c

File tree

6 files changed

+75
-9
lines changed

6 files changed

+75
-9
lines changed

bigframes/core/compile/ibis_compiler/aggregate_compiler.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,15 +175,11 @@ def _(
175175

176176

177177
@compile_unary_agg.register
178-
@numeric_op
179178
def _(
180179
op: agg_ops.MedianOp,
181180
column: ibis_types.NumericColumn,
182181
window=None,
183182
) -> ibis_types.NumericValue:
184-
# TODO(swast): Allow switching between exact and approximate median.
185-
# For now, the best we can do is an approximate median when we're doing
186-
# an aggregation, as PERCENTILE_CONT is only an analytic function.
187183
return cast(ibis_types.NumericValue, column.approx_median())
188184

189185

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,18 @@ def _(
5656
return apply_window_if_present(sge.func("MAX", column.expr), window)
5757

5858

59+
@UNARY_OP_REGISTRATION.register(agg_ops.MedianOp)
60+
def _(
61+
op: agg_ops.MedianOp,
62+
column: typed_expr.TypedExpr,
63+
window: typing.Optional[window_spec.WindowSpec] = None,
64+
) -> sge.Expression:
65+
approx_quantiles = sge.func("APPROX_QUANTILES", column.expr, sge.convert(2))
66+
return sge.Bracket(
67+
this=approx_quantiles, expressions=[sge.func("OFFSET", sge.convert(1))]
68+
)
69+
70+
5971
@UNARY_OP_REGISTRATION.register(agg_ops.MinOp)
6072
def _(
6173
op: agg_ops.MinOp,

tests/system/small/engines/test_aggregation.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from google.cloud import bigquery
1516
import pytest
1617

1718
from bigframes.core import agg_expressions, array_value, expression, identifiers, nodes
1819
import bigframes.operations.aggregations as agg_ops
19-
from bigframes.session import polars_executor
20+
from bigframes.session import direct_gbq_execution, polars_executor
2021
from bigframes.testing.engine_utils import assert_equivalence_execution
2122

2223
pytest.importorskip("polars")
@@ -84,6 +85,21 @@ def test_engines_unary_aggregates(
8485
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)
8586

8687

88+
def test_sql_engines_median_op_aggregates(
89+
scalars_array_value: array_value.ArrayValue,
90+
bigquery_client: bigquery.Client,
91+
):
92+
node = apply_agg_to_all_valid(
93+
scalars_array_value,
94+
agg_ops.MedianOp(),
95+
).node
96+
left_engine = direct_gbq_execution.DirectGbqExecutor(bigquery_client)
97+
right_engine = direct_gbq_execution.DirectGbqExecutor(
98+
bigquery_client, compiler="sqlglot"
99+
)
100+
assert_equivalence_execution(node, left_engine, right_engine)
101+
102+
87103
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
88104
@pytest.mark.parametrize(
89105
"grouping_cols",

tests/system/small/test_series.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1919,10 +1919,22 @@ def test_mean(scalars_dfs):
19191919
assert math.isclose(pd_result, bf_result)
19201920

19211921

1922-
def test_median(scalars_dfs):
1922+
@pytest.mark.parametrize(
1923+
("col_name"),
1924+
[
1925+
"int64_col",
1926+
# Non-numeric column
1927+
"bytes_col",
1928+
"date_col",
1929+
"datetime_col",
1930+
"time_col",
1931+
"timestamp_col",
1932+
"string_col",
1933+
],
1934+
)
1935+
def test_median(scalars_dfs, col_name):
19231936
scalars_df, scalars_pandas_df = scalars_dfs
1924-
col_name = "int64_col"
1925-
bf_result = scalars_df[col_name].median()
1937+
bf_result = scalars_df[col_name].median(exact=False)
19261938
pd_max = scalars_pandas_df[col_name].max()
19271939
pd_min = scalars_pandas_df[col_name].min()
19281940
# Median is approximate, so just check for plausibility.
@@ -1932,7 +1944,7 @@ def test_median(scalars_dfs):
19321944
def test_median_exact(scalars_dfs):
19331945
scalars_df, scalars_pandas_df = scalars_dfs
19341946
col_name = "int64_col"
1935-
bf_result = scalars_df[col_name].median(exact=True)
1947+
bf_result = scalars_df[col_name].median()
19361948
pd_result = scalars_pandas_df[col_name].median()
19371949
assert math.isclose(pd_result, bf_result)
19381950

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`date_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`string_col` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
APPROX_QUANTILES(`bfcol_1`, 2)[OFFSET(1)] AS `bfcol_3`,
10+
APPROX_QUANTILES(`bfcol_0`, 2)[OFFSET(1)] AS `bfcol_4`,
11+
APPROX_QUANTILES(`bfcol_2`, 2)[OFFSET(1)] AS `bfcol_5`
12+
FROM `bfcte_0`
13+
)
14+
SELECT
15+
`bfcol_3` AS `int64_col`,
16+
`bfcol_4` AS `date_col`,
17+
`bfcol_5` AS `string_col`
18+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,18 @@ def test_max(scalar_types_df: bpd.DataFrame, snapshot):
5656
snapshot.assert_match(sql, "out.sql")
5757

5858

59+
def test_median(scalar_types_df: bpd.DataFrame, snapshot):
60+
bf_df = scalar_types_df
61+
ops_map = {
62+
"int64_col": agg_ops.MedianOp().as_expr("int64_col"),
63+
"date_col": agg_ops.MedianOp().as_expr("date_col"),
64+
"string_col": agg_ops.MedianOp().as_expr("string_col"),
65+
}
66+
sql = _apply_unary_agg_ops(bf_df, list(ops_map.values()), list(ops_map.keys()))
67+
68+
snapshot.assert_match(sql, "out.sql")
69+
70+
5971
def test_min(scalar_types_df: bpd.DataFrame, snapshot):
6072
col_name = "int64_col"
6173
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)