Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,4 @@ def _(
if window is None:
# ROW_NUMBER always needs an OVER clause.
return sge.Window(this=result)
return apply_window_if_present(result, window)
return apply_window_if_present(result, window, include_framing_clauses=False)
50 changes: 48 additions & 2 deletions bigframes/core/compile/sqlglot/aggregations/unary_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,51 @@ def _(
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
return apply_window_if_present(sge.func("DENSE_RANK"), window)
return apply_window_if_present(
sge.func("DENSE_RANK"), window, include_framing_clauses=False
)


@UNARY_OP_REGISTRATION.register(agg_ops.FirstOp)
def _(
op: agg_ops.FirstOp,
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
# FIRST_VALUE in BQ respects nulls by default.
return apply_window_if_present(sge.FirstValue(this=column.expr), window)


@UNARY_OP_REGISTRATION.register(agg_ops.FirstNonNullOp)
def _(
op: agg_ops.FirstNonNullOp,
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
return apply_window_if_present(
sge.IgnoreNulls(this=sge.FirstValue(this=column.expr)), window
)


@UNARY_OP_REGISTRATION.register(agg_ops.LastOp)
def _(
op: agg_ops.LastOp,
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
# LAST_VALUE in BQ respects nulls by default.
return apply_window_if_present(sge.LastValue(this=column.expr), window)


@UNARY_OP_REGISTRATION.register(agg_ops.LastNonNullOp)
def _(
op: agg_ops.LastNonNullOp,
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
return apply_window_if_present(
sge.IgnoreNulls(this=sge.LastValue(this=column.expr)), window
)


@UNARY_OP_REGISTRATION.register(agg_ops.MaxOp)
Expand Down Expand Up @@ -182,7 +226,9 @@ def _(
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
return apply_window_if_present(sge.func("RANK"), window)
return apply_window_if_present(
sge.func("RANK"), window, include_framing_clauses=False
)


@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp)
Expand Down
7 changes: 4 additions & 3 deletions bigframes/core/compile/sqlglot/aggregations/windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
def apply_window_if_present(
value: sge.Expression,
window: typing.Optional[window_spec.WindowSpec] = None,
include_framing_clauses: bool = True,
) -> sge.Expression:
if window is None:
return value
Expand Down Expand Up @@ -64,11 +65,11 @@ def apply_window_if_present(
if not window.bounds and not order:
return sge.Window(this=value, partition_by=group_by)

if not window.bounds:
if not window.bounds and not include_framing_clauses:
return sge.Window(this=value, partition_by=group_by, order=order)

kind = (
"ROWS" if isinstance(window.bounds, window_spec.RowsWindowBounds) else "RANGE"
"RANGE" if isinstance(window.bounds, window_spec.RangeWindowBounds) else "ROWS"
)

start: typing.Union[int, float, None] = None
Expand Down Expand Up @@ -125,7 +126,7 @@ def get_window_order_by(
nulls_first=nulls_first,
)
)
elif not nulls_first and not desc:
elif (not nulls_first) and (not desc):
order_by.append(
sge.Ordered(
this=is_null_expr,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
WITH `bfcte_0` AS (
SELECT
`int64_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
CASE
WHEN `bfcol_0` IS NULL
THEN NULL
ELSE FIRST_VALUE(`bfcol_0`) OVER (
ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
)
END AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `agg_int64`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
WITH `bfcte_0` AS (
SELECT
`int64_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
FIRST_VALUE(`bfcol_0` IGNORE NULLS) OVER (
ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `agg_int64`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
WITH `bfcte_0` AS (
SELECT
`int64_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
CASE
WHEN `bfcol_0` IS NULL
THEN NULL
ELSE LAST_VALUE(`bfcol_0`) OVER (
ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
)
END AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `agg_int64`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
WITH `bfcte_0` AS (
SELECT
`int64_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
LAST_VALUE(`bfcol_0` IGNORE NULLS) OVER (
ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `agg_int64`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import typing

import pytest
Expand Down Expand Up @@ -126,6 +127,66 @@ def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot):
snapshot.assert_match(sql, "out.sql")


def test_first(scalar_types_df: bpd.DataFrame, snapshot):
if sys.version_info < (3, 12):
pytest.skip(
"Skipping test due to inconsistent SQL formatting on Python < 3.12.",
)
col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
agg_expr = agg_exprs.UnaryAggregation(agg_ops.FirstOp(), expression.deref(col_name))
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64")

snapshot.assert_match(sql, "out.sql")


def test_first_non_null(scalar_types_df: bpd.DataFrame, snapshot):
if sys.version_info < (3, 12):
pytest.skip(
"Skipping test due to inconsistent SQL formatting on Python < 3.12.",
)
col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
agg_expr = agg_exprs.UnaryAggregation(
agg_ops.FirstNonNullOp(), expression.deref(col_name)
)
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64")

snapshot.assert_match(sql, "out.sql")


def test_last(scalar_types_df: bpd.DataFrame, snapshot):
if sys.version_info < (3, 12):
pytest.skip(
"Skipping test due to inconsistent SQL formatting on Python < 3.12.",
)
col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
agg_expr = agg_exprs.UnaryAggregation(agg_ops.LastOp(), expression.deref(col_name))
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64")

snapshot.assert_match(sql, "out.sql")


def test_last_non_null(scalar_types_df: bpd.DataFrame, snapshot):
if sys.version_info < (3, 12):
pytest.skip(
"Skipping test due to inconsistent SQL formatting on Python < 3.12.",
)
col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
agg_expr = agg_exprs.UnaryAggregation(
agg_ops.LastNonNullOp(), expression.deref(col_name)
)
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64")

snapshot.assert_match(sql, "out.sql")


def test_max(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
Expand Down