Skip to content

Commit bc0edac

Browse files
committed
refactor: add agg_ops.MedianOp compiler to sqlglot
1 parent 60056ca commit bc0edac

File tree

4 files changed

+56
-1
lines changed

4 files changed

+56
-1
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,21 @@ def _(
5656
return apply_window_if_present(sge.func("MAX", column.expr), window)
5757

5858

59+
@UNARY_OP_REGISTRATION.register(agg_ops.MedianOp)
60+
def _(
61+
op: agg_ops.MedianOp,
62+
column: typed_expr.TypedExpr,
63+
window: typing.Optional[window_spec.WindowSpec] = None,
64+
) -> sge.Expression:
65+
# TODO(swast): Allow switching between exact and approximate median.
66+
# For now, the best we can do is an approximate median when we're doing
67+
# an aggregation, as PERCENTILE_CONT is only an analytic function.
68+
approx_quantiles = sge.func("APPROX_QUANTILES", column.expr, sge.convert(2))
69+
return sge.Bracket(
70+
this=approx_quantiles, expressions=[sge.func("OFFSET", sge.convert(1))]
71+
)
72+
73+
5974
@UNARY_OP_REGISTRATION.register(agg_ops.MinOp)
6075
def _(
6176
op: agg_ops.MinOp,

tests/system/small/engines/test_aggregation.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from google.cloud import bigquery
1516
import pytest
1617

1718
from bigframes.core import agg_expressions, array_value, expression, identifiers, nodes
1819
import bigframes.operations.aggregations as agg_ops
19-
from bigframes.session import polars_executor
20+
from bigframes.session import direct_gbq_execution, polars_executor
2021
from bigframes.testing.engine_utils import assert_equivalence_execution
2122

2223
pytest.importorskip("polars")
@@ -84,6 +85,24 @@ def test_engines_unary_aggregates(
8485
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)
8586

8687

88+
@pytest.mark.parametrize(
89+
"op",
90+
[agg_ops.MedianOp],
91+
)
92+
def test_sql_engines_unary_aggregates(
93+
scalars_array_value: array_value.ArrayValue,
94+
bigquery_client: bigquery.Client,
95+
op,
96+
):
97+
# TODO: this is not working??
98+
node = apply_agg_to_all_valid(scalars_array_value, op).node
99+
left_engine = direct_gbq_execution.DirectGbqExecutor(bigquery_client)
100+
right_engine = direct_gbq_execution.DirectGbqExecutor(
101+
bigquery_client, compiler="sqlglot"
102+
)
103+
assert_equivalence_execution(node, left_engine, right_engine)
104+
105+
87106
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
88107
@pytest.mark.parametrize(
89108
"grouping_cols",
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
APPROX_QUANTILES(`bfcol_0`, 2)[OFFSET(1)] AS `bfcol_1`
8+
FROM `bfcte_0`
9+
)
10+
SELECT
11+
`bfcol_1` AS `int64_col`
12+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@ def test_max(scalar_types_df: bpd.DataFrame, snapshot):
5656
snapshot.assert_match(sql, "out.sql")
5757

5858

59+
def test_median(scalar_types_df: bpd.DataFrame, snapshot):
60+
col_name = "int64_col"
61+
bf_df = scalar_types_df[[col_name]]
62+
agg_expr = agg_ops.MedianOp().as_expr(col_name)
63+
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])
64+
65+
snapshot.assert_match(sql, "out.sql")
66+
67+
5968
def test_min(scalar_types_df: bpd.DataFrame, snapshot):
6069
col_name = "int64_col"
6170
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)