Skip to content

Commit bab9576

Browse files
committed
refactor: support agg_ops.FirstOp, FirstNonNullOp in the sqlglot compiler
1 parent 7600001 commit bab9576

File tree

6 files changed

+86
-6
lines changed

6 files changed

+86
-6
lines changed

bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,4 @@ def _(
5050
if window is None:
5151
# ROW_NUMBER always needs an OVER clause.
5252
return sge.Window(this=result)
53-
return apply_window_if_present(result, window)
53+
return apply_window_if_present(result, window, include_framing_clauses=False)

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,28 @@ def _(
104104
column: typed_expr.TypedExpr,
105105
window: typing.Optional[window_spec.WindowSpec] = None,
106106
) -> sge.Expression:
107-
return apply_window_if_present(sge.func("DENSE_RANK"), window)
107+
return apply_window_if_present(sge.func("DENSE_RANK"), window, include_framing_clauses=False)
108+
109+
110+
@UNARY_OP_REGISTRATION.register(agg_ops.FirstOp)
111+
def _(
112+
op: agg_ops.FirstOp,
113+
column: typed_expr.TypedExpr,
114+
window: typing.Optional[window_spec.WindowSpec] = None,
115+
) -> sge.Expression:
116+
# FIRST_VALUE in BQ respects nulls by default.
117+
return apply_window_if_present(sge.FirstValue(this=column.expr), window)
118+
119+
120+
@UNARY_OP_REGISTRATION.register(agg_ops.FirstNonNullOp)
121+
def _(
122+
op: agg_ops.FirstNonNullOp,
123+
column: typed_expr.TypedExpr,
124+
window: typing.Optional[window_spec.WindowSpec] = None,
125+
) -> sge.Expression:
126+
return apply_window_if_present(
127+
sge.IgnoreNulls(this=sge.FirstValue(this=column.expr)), window
128+
)
108129

109130

110131
@UNARY_OP_REGISTRATION.register(agg_ops.MaxOp)
@@ -182,7 +203,7 @@ def _(
182203
column: typed_expr.TypedExpr,
183204
window: typing.Optional[window_spec.WindowSpec] = None,
184205
) -> sge.Expression:
185-
return apply_window_if_present(sge.func("RANK"), window)
206+
return apply_window_if_present(sge.func("RANK"), window, include_framing_clauses=False)
186207

187208

188209
@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp)

bigframes/core/compile/sqlglot/aggregations/windows.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
def apply_window_if_present(
2626
value: sge.Expression,
2727
window: typing.Optional[window_spec.WindowSpec] = None,
28+
include_framing_clauses: bool = True,
2829
) -> sge.Expression:
2930
if window is None:
3031
return value
@@ -64,11 +65,11 @@ def apply_window_if_present(
6465
if not window.bounds and not order:
6566
return sge.Window(this=value, partition_by=group_by)
6667

67-
if not window.bounds:
68+
if not window.bounds and not include_framing_clauses:
6869
return sge.Window(this=value, partition_by=group_by, order=order)
6970

7071
kind = (
71-
"ROWS" if isinstance(window.bounds, window_spec.RowsWindowBounds) else "RANGE"
72+
"RANGE" if isinstance(window.bounds, window_spec.RangeWindowBounds) else "ROWS"
7273
)
7374

7475
start: typing.Union[int, float, None] = None
@@ -125,7 +126,7 @@ def get_window_order_by(
125126
nulls_first=nulls_first,
126127
)
127128
)
128-
elif not nulls_first and not desc:
129+
elif (not nulls_first) and (not desc):
129130
order_by.append(
130131
sge.Ordered(
131132
this=is_null_expr,
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE
9+
WHEN `bfcol_0` IS NULL
10+
THEN NULL
11+
ELSE FIRST_VALUE(`bfcol_0`) OVER (
12+
ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST
13+
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
14+
)
15+
END AS `bfcol_1`
16+
FROM `bfcte_0`
17+
)
18+
SELECT
19+
`bfcol_1` AS `agg_int64`
20+
FROM `bfcte_1`
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
FIRST_VALUE(`bfcol_0` IGNORE NULLS) OVER (
9+
ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST
10+
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
11+
) AS `bfcol_1`
12+
FROM `bfcte_0`
13+
)
14+
SELECT
15+
`bfcol_1` AS `agg_int64`
16+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,28 @@ def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot):
126126
snapshot.assert_match(sql, "out.sql")
127127

128128

129+
def test_first(scalar_types_df: bpd.DataFrame, snapshot):
130+
col_name = "int64_col"
131+
bf_df = scalar_types_df[[col_name]]
132+
agg_expr = agg_exprs.UnaryAggregation(agg_ops.FirstOp(), expression.deref(col_name))
133+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
134+
sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64")
135+
136+
snapshot.assert_match(sql, "out.sql")
137+
138+
139+
def test_first_non_null(scalar_types_df: bpd.DataFrame, snapshot):
140+
col_name = "int64_col"
141+
bf_df = scalar_types_df[[col_name]]
142+
agg_expr = agg_exprs.UnaryAggregation(
143+
agg_ops.FirstNonNullOp(), expression.deref(col_name)
144+
)
145+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
146+
sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64")
147+
148+
snapshot.assert_match(sql, "out.sql")
149+
150+
129151
def test_max(scalar_types_df: bpd.DataFrame, snapshot):
130152
col_name = "int64_col"
131153
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)