From bc0edacfb1196aaa88e2bc6ec16166fe298e48c7 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Mon, 22 Sep 2025 22:54:54 +0000 Subject: [PATCH 1/3] refactor: add agg_ops.MedianOp compiler to sqlglot --- .../sqlglot/aggregations/unary_compiler.py | 15 +++++++++++++ .../system/small/engines/test_aggregation.py | 21 ++++++++++++++++++- .../test_unary_compiler/test_median/out.sql | 12 +++++++++++ .../aggregations/test_unary_compiler.py | 9 ++++++++ 4 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_median/out.sql diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 542bb10670..85230d58d9 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -56,6 +56,21 @@ def _( return apply_window_if_present(sge.func("MAX", column.expr), window) +@UNARY_OP_REGISTRATION.register(agg_ops.MedianOp) +def _( + op: agg_ops.MedianOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + # TODO(swast): Allow switching between exact and approximate median. + # For now, the best we can do is an approximate median when we're doing + # an aggregation, as PERCENTILE_CONT is only an analytic function. + approx_quantiles = sge.func("APPROX_QUANTILES", column.expr, sge.convert(2)) + return sge.Bracket( + this=approx_quantiles, expressions=[sge.func("OFFSET", sge.convert(1))] + ) + + @UNARY_OP_REGISTRATION.register(agg_ops.MinOp) def _( op: agg_ops.MinOp, diff --git a/tests/system/small/engines/test_aggregation.py b/tests/system/small/engines/test_aggregation.py index a4a49c622a..6666a5cb19 100644 --- a/tests/system/small/engines/test_aggregation.py +++ b/tests/system/small/engines/test_aggregation.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from google.cloud import bigquery import pytest from bigframes.core import agg_expressions, array_value, expression, identifiers, nodes import bigframes.operations.aggregations as agg_ops -from bigframes.session import polars_executor +from bigframes.session import direct_gbq_execution, polars_executor from bigframes.testing.engine_utils import assert_equivalence_execution pytest.importorskip("polars") @@ -84,6 +85,24 @@ def test_engines_unary_aggregates( assert_equivalence_execution(node, REFERENCE_ENGINE, engine) +@pytest.mark.parametrize( + "op", + [agg_ops.MedianOp], +) +def test_sql_engines_unary_aggregates( + scalars_array_value: array_value.ArrayValue, + bigquery_client: bigquery.Client, + op, +): + # TODO: this is not working?? + node = apply_agg_to_all_valid(scalars_array_value, op).node + left_engine = direct_gbq_execution.DirectGbqExecutor(bigquery_client) + right_engine = direct_gbq_execution.DirectGbqExecutor( + bigquery_client, compiler="sqlglot" + ) + assert_equivalence_execution(node, left_engine, right_engine) + + @pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) @pytest.mark.parametrize( "grouping_cols", diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_median/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_median/out.sql new file mode 100644 index 0000000000..264d68bed2 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_median/out.sql @@ -0,0 +1,12 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + APPROX_QUANTILES(`bfcol_0`, 2)[OFFSET(1)] AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `int64_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index 311c039e11..e7a2a9c604 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -56,6 +56,15 @@ def test_max(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_median(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_expr = agg_ops.MedianOp().as_expr(col_name) + sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name]) + + snapshot.assert_match(sql, "out.sql") + + def test_min(scalar_types_df: bpd.DataFrame, snapshot): col_name = "int64_col" bf_df = scalar_types_df[[col_name]] From f3f733f8ea82174b430fbad4542643a476ab340b Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 23 Sep 2025 19:02:19 +0000 Subject: [PATCH 2/3] enable engine tests --- .../system/small/engines/test_aggregation.py | 22 ++++++++++------- .../test_unary_compiler/test_median/out.sql | 24 ++++++++++++++++--- .../aggregations/test_unary_compiler.py | 16 +++++++++---- 3 files changed, 47 insertions(+), 15 deletions(-) diff --git a/tests/system/small/engines/test_aggregation.py b/tests/system/small/engines/test_aggregation.py index 6666a5cb19..452cc0c71a 100644 --- a/tests/system/small/engines/test_aggregation.py +++ b/tests/system/small/engines/test_aggregation.py @@ -85,17 +85,23 @@ def test_engines_unary_aggregates( assert_equivalence_execution(node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize( - "op", - [agg_ops.MedianOp], -) -def test_sql_engines_unary_aggregates( +def test_sql_engines_median_op_aggregates( scalars_array_value: array_value.ArrayValue, bigquery_client: bigquery.Client, - op, ): - # TODO: this is not working?? - node = apply_agg_to_all_valid(scalars_array_value, op).node + node = apply_agg_to_all_valid( + scalars_array_value, + agg_ops.MedianOp(), + # Exclude columns are not supported by Ibis. + excluded_cols=[ + "bytes_col", + "date_col", + "datetime_col", + "time_col", + "timestamp_col", + "string_col", + ], + ).node left_engine = direct_gbq_execution.DirectGbqExecutor(bigquery_client) right_engine = direct_gbq_execution.DirectGbqExecutor( bigquery_client, compiler="sqlglot" diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_median/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_median/out.sql index 264d68bed2..7061cc65c3 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_median/out.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_median/out.sql @@ -1,12 +1,30 @@ WITH `bfcte_0` AS ( SELECT - `int64_col` AS `bfcol_0` + `bytes_col` AS `bfcol_0`, + `date_col` AS `bfcol_1`, + `datetime_col` AS `bfcol_2`, + `int64_col` AS `bfcol_3`, + `string_col` AS `bfcol_4`, + `time_col` AS `bfcol_5`, + `timestamp_col` AS `bfcol_6` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` ), `bfcte_1` AS ( SELECT - APPROX_QUANTILES(`bfcol_0`, 2)[OFFSET(1)] AS `bfcol_1` + APPROX_QUANTILES(`bfcol_3`, 2)[OFFSET(1)] AS `bfcol_7`, + APPROX_QUANTILES(`bfcol_0`, 2)[OFFSET(1)] AS `bfcol_8`, + APPROX_QUANTILES(`bfcol_1`, 2)[OFFSET(1)] AS `bfcol_9`, + APPROX_QUANTILES(`bfcol_2`, 2)[OFFSET(1)] AS `bfcol_10`, + APPROX_QUANTILES(`bfcol_5`, 2)[OFFSET(1)] AS `bfcol_11`, + APPROX_QUANTILES(`bfcol_6`, 2)[OFFSET(1)] AS `bfcol_12`, + APPROX_QUANTILES(`bfcol_4`, 2)[OFFSET(1)] AS `bfcol_13` FROM `bfcte_0` ) SELECT - `bfcol_1` AS `int64_col` + `bfcol_7` AS `int64_col`, + `bfcol_8` AS `bytes_col`, + `bfcol_9` AS `date_col`, + `bfcol_10` AS `datetime_col`, + `bfcol_11` AS `time_col`, + `bfcol_12` AS `timestamp_col`, + `bfcol_13` AS `string_col` FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index e7a2a9c604..1ac295a1da 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -57,10 +57,18 @@ def test_max(scalar_types_df: bpd.DataFrame, snapshot): def test_median(scalar_types_df: bpd.DataFrame, snapshot): - col_name = "int64_col" - bf_df = scalar_types_df[[col_name]] - agg_expr = agg_ops.MedianOp().as_expr(col_name) - sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name]) + bf_df = scalar_types_df + ops_map = { + "int64_col": agg_ops.MedianOp().as_expr("int64_col"), + # Includes columns are not supported by Ibis but supported by BigQuery. + "bytes_col": agg_ops.MedianOp().as_expr("bytes_col"), + "date_col": agg_ops.MedianOp().as_expr("date_col"), + "datetime_col": agg_ops.MedianOp().as_expr("datetime_col"), + "time_col": agg_ops.MedianOp().as_expr("time_col"), + "timestamp_col": agg_ops.MedianOp().as_expr("timestamp_col"), + "string_col": agg_ops.MedianOp().as_expr("string_col"), + } + sql = _apply_unary_agg_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) snapshot.assert_match(sql, "out.sql") From 1edee547964b9790c9f9dd48cdfee1effdeddcfb Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 23 Sep 2025 20:34:13 +0000 Subject: [PATCH 3/3] enable non-numeric for ibis compiler too --- .../ibis_compiler/aggregate_compiler.py | 4 --- .../sqlglot/aggregations/unary_compiler.py | 3 -- .../system/small/engines/test_aggregation.py | 9 ------ tests/system/small/test_series.py | 20 ++++++++++--- .../test_unary_compiler/test_median/out.sql | 30 ++++++------------- .../aggregations/test_unary_compiler.py | 5 ---- 6 files changed, 25 insertions(+), 46 deletions(-) diff --git a/bigframes/core/compile/ibis_compiler/aggregate_compiler.py b/bigframes/core/compile/ibis_compiler/aggregate_compiler.py index b101f4e09f..0106b150e2 100644 --- a/bigframes/core/compile/ibis_compiler/aggregate_compiler.py +++ b/bigframes/core/compile/ibis_compiler/aggregate_compiler.py @@ -175,15 +175,11 @@ def _( @compile_unary_agg.register -@numeric_op def _( op: agg_ops.MedianOp, column: ibis_types.NumericColumn, window=None, ) -> ibis_types.NumericValue: - # TODO(swast): Allow switching between exact and approximate median. - # For now, the best we can do is an approximate median when we're doing - # an aggregation, as PERCENTILE_CONT is only an analytic function. return cast(ibis_types.NumericValue, column.approx_median()) diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 85230d58d9..4cb0000894 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -62,9 +62,6 @@ def _( column: typed_expr.TypedExpr, window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: - # TODO(swast): Allow switching between exact and approximate median. - # For now, the best we can do is an approximate median when we're doing - # an aggregation, as PERCENTILE_CONT is only an analytic function. approx_quantiles = sge.func("APPROX_QUANTILES", column.expr, sge.convert(2)) return sge.Bracket( this=approx_quantiles, expressions=[sge.func("OFFSET", sge.convert(1))] diff --git a/tests/system/small/engines/test_aggregation.py b/tests/system/small/engines/test_aggregation.py index 452cc0c71a..98d5cd4ac8 100644 --- a/tests/system/small/engines/test_aggregation.py +++ b/tests/system/small/engines/test_aggregation.py @@ -92,15 +92,6 @@ def test_sql_engines_median_op_aggregates( node = apply_agg_to_all_valid( scalars_array_value, agg_ops.MedianOp(), - # Exclude columns are not supported by Ibis. - excluded_cols=[ - "bytes_col", - "date_col", - "datetime_col", - "time_col", - "timestamp_col", - "string_col", - ], ).node left_engine = direct_gbq_execution.DirectGbqExecutor(bigquery_client) right_engine = direct_gbq_execution.DirectGbqExecutor( diff --git a/tests/system/small/test_series.py b/tests/system/small/test_series.py index 0a761a3a3a..d1a252f8dc 100644 --- a/tests/system/small/test_series.py +++ b/tests/system/small/test_series.py @@ -1919,10 +1919,22 @@ def test_mean(scalars_dfs): assert math.isclose(pd_result, bf_result) -def test_median(scalars_dfs): +@pytest.mark.parametrize( + ("col_name"), + [ + "int64_col", + # Non-numeric column + "bytes_col", + "date_col", + "datetime_col", + "time_col", + "timestamp_col", + "string_col", + ], +) +def test_median(scalars_dfs, col_name): scalars_df, scalars_pandas_df = scalars_dfs - col_name = "int64_col" - bf_result = scalars_df[col_name].median() + bf_result = scalars_df[col_name].median(exact=False) pd_max = scalars_pandas_df[col_name].max() pd_min = scalars_pandas_df[col_name].min() # Median is approximate, so just check for plausibility. @@ -1932,7 +1944,7 @@ def test_median(scalars_dfs): def test_median_exact(scalars_dfs): scalars_df, scalars_pandas_df = scalars_dfs col_name = "int64_col" - bf_result = scalars_df[col_name].median(exact=True) + bf_result = scalars_df[col_name].median() pd_result = scalars_pandas_df[col_name].median() assert math.isclose(pd_result, bf_result) diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_median/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_median/out.sql index 7061cc65c3..bf7006ef87 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_median/out.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_median/out.sql @@ -1,30 +1,18 @@ WITH `bfcte_0` AS ( SELECT - `bytes_col` AS `bfcol_0`, - `date_col` AS `bfcol_1`, - `datetime_col` AS `bfcol_2`, - `int64_col` AS `bfcol_3`, - `string_col` AS `bfcol_4`, - `time_col` AS `bfcol_5`, - `timestamp_col` AS `bfcol_6` + `date_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1`, + `string_col` AS `bfcol_2` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` ), `bfcte_1` AS ( SELECT - APPROX_QUANTILES(`bfcol_3`, 2)[OFFSET(1)] AS `bfcol_7`, - APPROX_QUANTILES(`bfcol_0`, 2)[OFFSET(1)] AS `bfcol_8`, - APPROX_QUANTILES(`bfcol_1`, 2)[OFFSET(1)] AS `bfcol_9`, - APPROX_QUANTILES(`bfcol_2`, 2)[OFFSET(1)] AS `bfcol_10`, - APPROX_QUANTILES(`bfcol_5`, 2)[OFFSET(1)] AS `bfcol_11`, - APPROX_QUANTILES(`bfcol_6`, 2)[OFFSET(1)] AS `bfcol_12`, - APPROX_QUANTILES(`bfcol_4`, 2)[OFFSET(1)] AS `bfcol_13` + APPROX_QUANTILES(`bfcol_1`, 2)[OFFSET(1)] AS `bfcol_3`, + APPROX_QUANTILES(`bfcol_0`, 2)[OFFSET(1)] AS `bfcol_4`, + APPROX_QUANTILES(`bfcol_2`, 2)[OFFSET(1)] AS `bfcol_5` FROM `bfcte_0` ) SELECT - `bfcol_7` AS `int64_col`, - `bfcol_8` AS `bytes_col`, - `bfcol_9` AS `date_col`, - `bfcol_10` AS `datetime_col`, - `bfcol_11` AS `time_col`, - `bfcol_12` AS `timestamp_col`, - `bfcol_13` AS `string_col` + `bfcol_3` AS `int64_col`, + `bfcol_4` AS `date_col`, + `bfcol_5` AS `string_col` FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index 1ac295a1da..4f0016a6e7 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -60,12 +60,7 @@ def test_median(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df ops_map = { "int64_col": agg_ops.MedianOp().as_expr("int64_col"), - # Includes columns are not supported by Ibis but supported by BigQuery. - "bytes_col": agg_ops.MedianOp().as_expr("bytes_col"), "date_col": agg_ops.MedianOp().as_expr("date_col"), - "datetime_col": agg_ops.MedianOp().as_expr("datetime_col"), - "time_col": agg_ops.MedianOp().as_expr("time_col"), - "timestamp_col": agg_ops.MedianOp().as_expr("timestamp_col"), "string_col": agg_ops.MedianOp().as_expr("string_col"), } sql = _apply_unary_agg_ops(bf_df, list(ops_map.values()), list(ops_map.keys()))