diff --git a/bigframes/core/compile/ibis_compiler/aggregate_compiler.py b/bigframes/core/compile/ibis_compiler/aggregate_compiler.py index b101f4e09f..0106b150e2 100644 --- a/bigframes/core/compile/ibis_compiler/aggregate_compiler.py +++ b/bigframes/core/compile/ibis_compiler/aggregate_compiler.py @@ -175,15 +175,11 @@ def _( @compile_unary_agg.register -@numeric_op def _( op: agg_ops.MedianOp, column: ibis_types.NumericColumn, window=None, ) -> ibis_types.NumericValue: - # TODO(swast): Allow switching between exact and approximate median. - # For now, the best we can do is an approximate median when we're doing - # an aggregation, as PERCENTILE_CONT is only an analytic function. return cast(ibis_types.NumericValue, column.approx_median()) diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 542bb10670..4cb0000894 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -56,6 +56,18 @@ def _( return apply_window_if_present(sge.func("MAX", column.expr), window) +@UNARY_OP_REGISTRATION.register(agg_ops.MedianOp) +def _( + op: agg_ops.MedianOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + approx_quantiles = sge.func("APPROX_QUANTILES", column.expr, sge.convert(2)) + return sge.Bracket( + this=approx_quantiles, expressions=[sge.func("OFFSET", sge.convert(1))] + ) + + @UNARY_OP_REGISTRATION.register(agg_ops.MinOp) def _( op: agg_ops.MinOp, diff --git a/tests/system/small/engines/test_aggregation.py b/tests/system/small/engines/test_aggregation.py index a4a49c622a..98d5cd4ac8 100644 --- a/tests/system/small/engines/test_aggregation.py +++ b/tests/system/small/engines/test_aggregation.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from google.cloud import bigquery import pytest from bigframes.core import agg_expressions, array_value, expression, identifiers, nodes import bigframes.operations.aggregations as agg_ops -from bigframes.session import polars_executor +from bigframes.session import direct_gbq_execution, polars_executor from bigframes.testing.engine_utils import assert_equivalence_execution pytest.importorskip("polars") @@ -84,6 +85,21 @@ def test_engines_unary_aggregates( assert_equivalence_execution(node, REFERENCE_ENGINE, engine) +def test_sql_engines_median_op_aggregates( + scalars_array_value: array_value.ArrayValue, + bigquery_client: bigquery.Client, +): + node = apply_agg_to_all_valid( + scalars_array_value, + agg_ops.MedianOp(), + ).node + left_engine = direct_gbq_execution.DirectGbqExecutor(bigquery_client) + right_engine = direct_gbq_execution.DirectGbqExecutor( + bigquery_client, compiler="sqlglot" + ) + assert_equivalence_execution(node, left_engine, right_engine) + + @pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) @pytest.mark.parametrize( "grouping_cols", diff --git a/tests/system/small/test_series.py b/tests/system/small/test_series.py index 0a761a3a3a..d1a252f8dc 100644 --- a/tests/system/small/test_series.py +++ b/tests/system/small/test_series.py @@ -1919,10 +1919,22 @@ def test_mean(scalars_dfs): assert math.isclose(pd_result, bf_result) -def test_median(scalars_dfs): +@pytest.mark.parametrize( + ("col_name"), + [ + "int64_col", + # Non-numeric column + "bytes_col", + "date_col", + "datetime_col", + "time_col", + "timestamp_col", + "string_col", + ], +) +def test_median(scalars_dfs, col_name): scalars_df, scalars_pandas_df = scalars_dfs - col_name = "int64_col" - bf_result = scalars_df[col_name].median() + bf_result = scalars_df[col_name].median(exact=False) pd_max = scalars_pandas_df[col_name].max() pd_min = scalars_pandas_df[col_name].min() # Median is approximate, so just check for plausibility. @@ -1932,7 +1944,7 @@ def test_median(scalars_dfs): def test_median_exact(scalars_dfs): scalars_df, scalars_pandas_df = scalars_dfs col_name = "int64_col" - bf_result = scalars_df[col_name].median(exact=True) + bf_result = scalars_df[col_name].median() pd_result = scalars_pandas_df[col_name].median() assert math.isclose(pd_result, bf_result) diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_median/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_median/out.sql new file mode 100644 index 0000000000..bf7006ef87 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_median/out.sql @@ -0,0 +1,18 @@ +WITH `bfcte_0` AS ( + SELECT + `date_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1`, + `string_col` AS `bfcol_2` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + APPROX_QUANTILES(`bfcol_1`, 2)[OFFSET(1)] AS `bfcol_3`, + APPROX_QUANTILES(`bfcol_0`, 2)[OFFSET(1)] AS `bfcol_4`, + APPROX_QUANTILES(`bfcol_2`, 2)[OFFSET(1)] AS `bfcol_5` + FROM `bfcte_0` +) +SELECT + `bfcol_3` AS `int64_col`, + `bfcol_4` AS `date_col`, + `bfcol_5` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index 311c039e11..4f0016a6e7 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -56,6 +56,18 @@ def test_max(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_median(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df + ops_map = { + "int64_col": agg_ops.MedianOp().as_expr("int64_col"), + "date_col": agg_ops.MedianOp().as_expr("date_col"), + "string_col": agg_ops.MedianOp().as_expr("string_col"), + } + sql = _apply_unary_agg_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) + + snapshot.assert_match(sql, "out.sql") + + def test_min(scalar_types_df: bpd.DataFrame, snapshot): col_name = "int64_col" bf_df = scalar_types_df[[col_name]]