Skip to content

Commit e6c3ba9

Browse files
committed
Merge remote-tracking branch 'origin/main' into b409390651-progress-bar
2 parents d03e5d1 + 1fc563c commit e6c3ba9

File tree

18 files changed

+385
-29
lines changed

18 files changed

+385
-29
lines changed

bigframes/core/compile/sqlglot/aggregate_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def compile_analytic(
6363
window: window_spec.WindowSpec,
6464
) -> sge.Expression:
6565
if isinstance(aggregate, agg_expressions.NullaryAggregation):
66-
return nullary_compiler.compile(aggregate.op)
66+
return nullary_compiler.compile(aggregate.op, window)
6767
if isinstance(aggregate, agg_expressions.UnaryAggregation):
6868
column = typed_expr.TypedExpr(
6969
scalar_compiler.scalar_op_compiler.compile_expression(aggregate.arg),

bigframes/core/compile/sqlglot/aggregations/binary_compiler.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from bigframes.core import window_spec
2222
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
23+
from bigframes.core.compile.sqlglot.aggregations.windows import apply_window_if_present
2324
import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr
2425
from bigframes.operations import aggregations as agg_ops
2526

@@ -33,3 +34,25 @@ def compile(
3334
window: typing.Optional[window_spec.WindowSpec] = None,
3435
) -> sge.Expression:
3536
return BINARY_OP_REGISTRATION[op](op, left, right, window=window)
37+
38+
39+
@BINARY_OP_REGISTRATION.register(agg_ops.CorrOp)
40+
def _(
41+
op: agg_ops.CorrOp,
42+
left: typed_expr.TypedExpr,
43+
right: typed_expr.TypedExpr,
44+
window: typing.Optional[window_spec.WindowSpec] = None,
45+
) -> sge.Expression:
46+
result = sge.func("CORR", left.expr, right.expr)
47+
return apply_window_if_present(result, window)
48+
49+
50+
@BINARY_OP_REGISTRATION.register(agg_ops.CovOp)
51+
def _(
52+
op: agg_ops.CovOp,
53+
left: typed_expr.TypedExpr,
54+
right: typed_expr.TypedExpr,
55+
window: typing.Optional[window_spec.WindowSpec] = None,
56+
) -> sge.Expression:
57+
result = sge.func("COVAR_SAMP", left.expr, right.expr)
58+
return apply_window_if_present(result, window)

bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,15 @@ def _(
3939
window: typing.Optional[window_spec.WindowSpec] = None,
4040
) -> sge.Expression:
4141
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)
42+
43+
44+
@NULLARY_OP_REGISTRATION.register(agg_ops.RowNumberOp)
45+
def _(
46+
op: agg_ops.RowNumberOp,
47+
window: typing.Optional[window_spec.WindowSpec] = None,
48+
) -> sge.Expression:
49+
result: sge.Expression = sge.func("ROW_NUMBER")
50+
if window is None:
51+
# ROW_NUMBER always needs an OVER clause.
52+
return sge.Window(this=result)
53+
return apply_window_if_present(result, window)

bigframes/core/compile/sqlglot/aggregations/op_registration.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,22 +41,16 @@ def arg_checker(*args, **kwargs):
4141
)
4242
return item(*args, **kwargs)
4343

44-
if hasattr(op, "name"):
45-
key = typing.cast(str, op.name)
46-
if key in self._registered_ops:
47-
raise ValueError(f"{key} is already registered")
48-
else:
49-
raise ValueError(f"The operator must have a 'name' attribute. Got {op}")
44+
key = str(op)
45+
if key in self._registered_ops:
46+
raise ValueError(f"{key} is already registered")
5047
self._registered_ops[key] = item
5148
return arg_checker
5249

5350
return decorator
5451

5552
def __getitem__(self, op: str | agg_ops.WindowOp) -> CompilationFunc:
56-
if isinstance(op, agg_ops.WindowOp):
57-
if not hasattr(op, "name"):
58-
raise ValueError(f"The operator must have a 'name' attribute. Got {op}")
59-
else:
60-
key = typing.cast(str, op.name)
61-
return self._registered_ops[key]
62-
return self._registered_ops[op]
53+
key = op if isinstance(op, type) else type(op)
54+
if str(key) not in self._registered_ops:
55+
raise ValueError(f"{key} is already not registered")
56+
return self._registered_ops[str(key)]

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,37 @@ def compile(
3838
return UNARY_OP_REGISTRATION[op](op, column, window=window)
3939

4040

41+
@UNARY_OP_REGISTRATION.register(agg_ops.ApproxQuartilesOp)
42+
def _(
43+
op: agg_ops.ApproxQuartilesOp,
44+
column: typed_expr.TypedExpr,
45+
window: typing.Optional[window_spec.WindowSpec] = None,
46+
) -> sge.Expression:
47+
if window is not None:
48+
raise NotImplementedError("Approx Quartiles with windowing is not supported.")
49+
# APPROX_QUANTILES returns an array of the quartiles, so we need to index it.
50+
# The op.quartile is 1-based for the quartile, but array is 0-indexed.
51+
# The quartiles are Q0, Q1, Q2, Q3, Q4. op.quartile is 1, 2, or 3.
52+
# The array has 5 elements (for N=4 intervals).
53+
# So we want the element at index `op.quartile`.
54+
approx_quantiles_expr = sge.func("APPROX_QUANTILES", column.expr, sge.convert(4))
55+
return sge.Bracket(
56+
this=approx_quantiles_expr,
57+
expressions=[sge.func("OFFSET", sge.convert(op.quartile))],
58+
)
59+
60+
61+
@UNARY_OP_REGISTRATION.register(agg_ops.ApproxTopCountOp)
62+
def _(
63+
op: agg_ops.ApproxTopCountOp,
64+
column: typed_expr.TypedExpr,
65+
window: typing.Optional[window_spec.WindowSpec] = None,
66+
) -> sge.Expression:
67+
if window is not None:
68+
raise NotImplementedError("Approx top count with windowing is not supported.")
69+
return sge.func("APPROX_TOP_COUNT", column.expr, sge.convert(op.number))
70+
71+
4172
@UNARY_OP_REGISTRATION.register(agg_ops.CountOp)
4273
def _(
4374
op: agg_ops.CountOp,
@@ -53,10 +84,7 @@ def _(
5384
column: typed_expr.TypedExpr,
5485
window: typing.Optional[window_spec.WindowSpec] = None,
5586
) -> sge.Expression:
56-
# Ranking functions do not support window framing clauses.
57-
return apply_window_if_present(
58-
sge.func("DENSE_RANK"), window, include_framing_clauses=False
59-
)
87+
return apply_window_if_present(sge.func("DENSE_RANK"), window)
6088

6189

6290
@UNARY_OP_REGISTRATION.register(agg_ops.MaxOp)
@@ -109,13 +137,23 @@ def _(
109137
return apply_window_if_present(sge.func("MIN", column.expr), window)
110138

111139

112-
@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp)
140+
@UNARY_OP_REGISTRATION.register(agg_ops.QuantileOp)
113141
def _(
114-
op: agg_ops.SizeUnaryOp,
115-
_,
142+
op: agg_ops.QuantileOp,
143+
column: typed_expr.TypedExpr,
116144
window: typing.Optional[window_spec.WindowSpec] = None,
117145
) -> sge.Expression:
118-
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)
146+
# TODO: Support interpolation argument
147+
# TODO: Support percentile_disc
148+
result: sge.Expression = sge.func("PERCENTILE_CONT", column.expr, sge.convert(op.q))
149+
if window is None:
150+
# PERCENTILE_CONT is a navigation function, not an aggregate function, so it always needs an OVER clause.
151+
result = sge.Window(this=result)
152+
else:
153+
result = apply_window_if_present(result, window)
154+
if op.should_floor_result:
155+
result = sge.Cast(this=sge.func("FLOOR", result), to="INT64")
156+
return result
119157

120158

121159
@UNARY_OP_REGISTRATION.register(agg_ops.RankOp)
@@ -124,10 +162,16 @@ def _(
124162
column: typed_expr.TypedExpr,
125163
window: typing.Optional[window_spec.WindowSpec] = None,
126164
) -> sge.Expression:
127-
# Ranking functions do not support window framing clauses.
128-
return apply_window_if_present(
129-
sge.func("RANK"), window, include_framing_clauses=False
130-
)
165+
return apply_window_if_present(sge.func("RANK"), window)
166+
167+
168+
@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp)
169+
def _(
170+
op: agg_ops.SizeUnaryOp,
171+
_,
172+
window: typing.Optional[window_spec.WindowSpec] = None,
173+
) -> sge.Expression:
174+
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)
131175

132176

133177
@UNARY_OP_REGISTRATION.register(agg_ops.SumOp)

bigframes/core/compile/sqlglot/aggregations/windows.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
def apply_window_if_present(
2626
value: sge.Expression,
2727
window: typing.Optional[window_spec.WindowSpec] = None,
28-
include_framing_clauses: bool = True,
2928
) -> sge.Expression:
3029
if window is None:
3130
return value
@@ -65,7 +64,7 @@ def apply_window_if_present(
6564
if not window.bounds and not order:
6665
return sge.Window(this=value, partition_by=group_by)
6766

68-
if not window.bounds and not include_framing_clauses:
67+
if not window.bounds:
6968
return sge.Window(this=value, partition_by=group_by, order=order)
7069

7170
kind = (
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`,
4+
`float64_col` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
CORR(`bfcol_0`, `bfcol_1`) AS `bfcol_2`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_2` AS `corr_col`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`,
4+
`float64_col` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
COVAR_SAMP(`bfcol_0`, `bfcol_1`) AS `bfcol_2`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_2` AS `cov_col`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
ROW_NUMBER() OVER () AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `row_number`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
ROW_NUMBER() OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `row_number`
13+
FROM `bfcte_1`

0 commit comments

Comments
 (0)