Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions bigframes/core/compile/ibis_compiler/aggregate_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,11 @@ def _(


@compile_unary_agg.register
@numeric_op
def _(
op: agg_ops.MedianOp,
column: ibis_types.NumericColumn,
window=None,
) -> ibis_types.NumericValue:
# TODO(swast): Allow switching between exact and approximate median.
# For now, the best we can do is an approximate median when we're doing
# an aggregation, as PERCENTILE_CONT is only an analytic function.
return cast(ibis_types.NumericValue, column.approx_median())


Expand Down
12 changes: 12 additions & 0 deletions bigframes/core/compile/sqlglot/aggregations/unary_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ def _(
return apply_window_if_present(sge.func("MAX", column.expr), window)


@UNARY_OP_REGISTRATION.register(agg_ops.MedianOp)
def _(
op: agg_ops.MedianOp,
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
approx_quantiles = sge.func("APPROX_QUANTILES", column.expr, sge.convert(2))
return sge.Bracket(
this=approx_quantiles, expressions=[sge.func("OFFSET", sge.convert(1))]
)


@UNARY_OP_REGISTRATION.register(agg_ops.MinOp)
def _(
op: agg_ops.MinOp,
Expand Down
18 changes: 17 additions & 1 deletion tests/system/small/engines/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud import bigquery
import pytest

from bigframes.core import agg_expressions, array_value, expression, identifiers, nodes
import bigframes.operations.aggregations as agg_ops
from bigframes.session import polars_executor
from bigframes.session import direct_gbq_execution, polars_executor
from bigframes.testing.engine_utils import assert_equivalence_execution

pytest.importorskip("polars")
Expand Down Expand Up @@ -84,6 +85,21 @@ def test_engines_unary_aggregates(
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)


def test_sql_engines_median_op_aggregates(
scalars_array_value: array_value.ArrayValue,
bigquery_client: bigquery.Client,
):
node = apply_agg_to_all_valid(
scalars_array_value,
agg_ops.MedianOp(),
).node
left_engine = direct_gbq_execution.DirectGbqExecutor(bigquery_client)
right_engine = direct_gbq_execution.DirectGbqExecutor(
bigquery_client, compiler="sqlglot"
)
assert_equivalence_execution(node, left_engine, right_engine)


@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
@pytest.mark.parametrize(
"grouping_cols",
Expand Down
20 changes: 16 additions & 4 deletions tests/system/small/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1919,10 +1919,22 @@ def test_mean(scalars_dfs):
assert math.isclose(pd_result, bf_result)


def test_median(scalars_dfs):
@pytest.mark.parametrize(
("col_name"),
[
"int64_col",
# Non-numeric column
"bytes_col",
"date_col",
"datetime_col",
"time_col",
"timestamp_col",
"string_col",
],
)
def test_median(scalars_dfs, col_name):
scalars_df, scalars_pandas_df = scalars_dfs
col_name = "int64_col"
bf_result = scalars_df[col_name].median()
bf_result = scalars_df[col_name].median(exact=False)
pd_max = scalars_pandas_df[col_name].max()
pd_min = scalars_pandas_df[col_name].min()
# Median is approximate, so just check for plausibility.
Expand All @@ -1932,7 +1944,7 @@ def test_median(scalars_dfs):
def test_median_exact(scalars_dfs):
scalars_df, scalars_pandas_df = scalars_dfs
col_name = "int64_col"
bf_result = scalars_df[col_name].median(exact=True)
bf_result = scalars_df[col_name].median()
pd_result = scalars_pandas_df[col_name].median()
assert math.isclose(pd_result, bf_result)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
WITH `bfcte_0` AS (
SELECT
`date_col` AS `bfcol_0`,
`int64_col` AS `bfcol_1`,
`string_col` AS `bfcol_2`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
APPROX_QUANTILES(`bfcol_1`, 2)[OFFSET(1)] AS `bfcol_3`,
APPROX_QUANTILES(`bfcol_0`, 2)[OFFSET(1)] AS `bfcol_4`,
APPROX_QUANTILES(`bfcol_2`, 2)[OFFSET(1)] AS `bfcol_5`
FROM `bfcte_0`
)
SELECT
`bfcol_3` AS `int64_col`,
`bfcol_4` AS `date_col`,
`bfcol_5` AS `string_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ def test_max(scalar_types_df: bpd.DataFrame, snapshot):
snapshot.assert_match(sql, "out.sql")


def test_median(scalar_types_df: bpd.DataFrame, snapshot):
bf_df = scalar_types_df
ops_map = {
"int64_col": agg_ops.MedianOp().as_expr("int64_col"),
"date_col": agg_ops.MedianOp().as_expr("date_col"),
"string_col": agg_ops.MedianOp().as_expr("string_col"),
}
sql = _apply_unary_agg_ops(bf_df, list(ops_map.values()), list(ops_map.keys()))

snapshot.assert_match(sql, "out.sql")


def test_min(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
Expand Down