Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion bigframes/core/compile/sqlglot/aggregate_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import sqlglot.expressions as sge

from bigframes.core import expression
from bigframes.core import expression, window_spec
from bigframes.core.compile.sqlglot.aggregations import (
binary_compiler,
nullary_compiler,
Expand Down Expand Up @@ -56,3 +56,21 @@ def compile_aggregate(
return binary_compiler.compile(aggregate.op, left, right)
else:
raise ValueError(f"Unexpected aggregation: {aggregate}")


def compile_analytic(
aggregate: expression.Aggregation,
window: window_spec.WindowSpec,
) -> sge.Expression:
if isinstance(aggregate, expression.NullaryAggregation):
return nullary_compiler.compile(aggregate.op)
if isinstance(aggregate, expression.UnaryAggregation):
column = typed_expr.TypedExpr(
scalar_compiler.compile_scalar_expression(aggregate.arg),
aggregate.arg.output_type,
)
return unary_compiler.compile(aggregate.op, column, window)
elif isinstance(aggregate, expression.BinaryAggregation):
raise NotImplementedError("binary analytic operations not yet supported")
else:
raise ValueError(f"Unexpected analytic operation: {aggregate}")
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import sqlglot.expressions as sge

from bigframes import dtypes
from bigframes.core import window_spec
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
from bigframes.core.compile.sqlglot.aggregations.windows import apply_window_if_present
Expand All @@ -42,8 +43,11 @@ def _(
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
expr = column.expr
if column.dtype == dtypes.BOOL_DTYPE:
expr = sge.Cast(this=column.expr, to="INT64")
# Will be null if all inputs are null. Pandas defaults to zero sum though.
expr = apply_window_if_present(sge.func("SUM", column.expr), window)
expr = apply_window_if_present(sge.func("SUM", expr), window)
return sge.func("IFNULL", expr, ir._literal(0, column.dtype))


Expand Down
66 changes: 66 additions & 0 deletions bigframes/core/compile/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,72 @@ def compile_aggregate(

return child.aggregate(aggregations, by_cols, tuple(dropna_cols))

@_compile_node.register
def compile_window(
self, node: nodes.WindowOpNode, child: ir.SQLGlotIR
) -> ir.SQLGlotIR:
window_spec = node.window_spec
if node.expression.op.order_independent and window_spec.is_unbounded:
# notably percentile_cont does not support ordering clause
window_spec = window_spec.without_order()

window_op = aggregate_compiler.compile_analytic(node.expression, window_spec)

inputs: tuple[sge.Expression, ...] = tuple(
scalar_compiler.compile_scalar_expression(expression.DerefOp(column))
for column in node.expression.column_references
)

clauses: list[tuple[sge.Expression, sge.Expression]] = []
if node.expression.op.skips_nulls and not node.never_skip_nulls:
for column in inputs:
clauses.append((sge.Is(this=column, expression=sge.Null()), sge.Null()))

if window_spec.min_periods and len(inputs) > 0:
if node.expression.op.skips_nulls:
# Most operations do not count NULL values towards min_periods
not_null_columns = [
sge.Not(this=sge.Is(this=column, expression=sge.Null()))
for column in inputs
]
# All inputs must be non-null for observation to count
if not not_null_columns:
is_observation_expr: sge.Expression = sge.convert(True)
else:
is_observation_expr = not_null_columns[0]
for expr in not_null_columns[1:]:
is_observation_expr = sge.And(
this=is_observation_expr, expression=expr
)
is_observation = ir._cast(is_observation_expr, "INT64")
else:
# Operations like count treat even NULLs as valid observations
# for the sake of min_periods notnull is just used to convert
# null values to non-null (FALSE) values to be counted.
is_observation = ir._cast(
sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())),
"INT64",
)

observation_count = windows.apply_window_if_present(
sge.func("SUM", is_observation), window_spec
)
clauses.append(
(
observation_count < sge.convert(window_spec.min_periods),
sge.Null(),
)
)
if clauses:
when_expressions = [sge.When(this=cond, true=res) for cond, res in clauses]
window_op = sge.Case(ifs=when_expressions, default=window_op)

# TODO: check if we can directly window the expression.
return child.window(
window_op=window_op,
output_column_id=node.output_name.sql,
)


def _replace_unsupported_ops(node: nodes.BigFrameNode):
node = nodes.bottom_up(node, rewrite.rewrite_slice)
Expand Down
2 changes: 1 addition & 1 deletion bigframes/core/compile/sqlglot/scalar_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

@functools.singledispatch
def compile_scalar_expression(
expression: expression.Expression,
expr: expression.Expression,
) -> sge.Expression:
"""Compiles BigFrames scalar expression into SQLGlot expression."""
raise ValueError(f"Can't compile unrecognized node: {expression}")
Expand Down
7 changes: 7 additions & 0 deletions bigframes/core/compile/sqlglot/sqlglot_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,13 @@ def aggregate(
new_expr = new_expr.where(condition, append=False)
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)

def window(
self,
window_op: sge.Expression,
output_column_id: str,
) -> SQLGlotIR:
return self.project(((output_column_id, window_op),))

def insert(
self,
destination: bigquery.TableReference,
Expand Down
103 changes: 67 additions & 36 deletions bigframes/core/rewrite/schema_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from bigframes.core import bigframe_node
from bigframes.core import expression as ex
from bigframes.core import nodes
from bigframes.core import nodes, ordering


def bind_schema_to_tree(
Expand Down Expand Up @@ -79,46 +79,77 @@ def bind_schema_to_node(
if isinstance(node, nodes.AggregateNode):
aggregations = []
for aggregation, id in node.aggregations:
if isinstance(aggregation, ex.UnaryAggregation):
replaced = typing.cast(
ex.Aggregation,
dataclasses.replace(
aggregation,
arg=typing.cast(
ex.RefOrConstant,
ex.bind_schema_fields(
aggregation.arg, node.child.field_by_id
),
),
),
aggregations.append(
(_bind_schema_to_aggregation_expr(aggregation, node.child), id)
)

return dataclasses.replace(
node,
aggregations=tuple(aggregations),
)

if isinstance(node, nodes.WindowOpNode):
window_spec = dataclasses.replace(
node.window_spec,
grouping_keys=tuple(
typing.cast(
ex.DerefOp, ex.bind_schema_fields(expr, node.child.field_by_id)
)
aggregations.append((replaced, id))
elif isinstance(aggregation, ex.BinaryAggregation):
replaced = typing.cast(
ex.Aggregation,
dataclasses.replace(
aggregation,
left=typing.cast(
ex.RefOrConstant,
ex.bind_schema_fields(
aggregation.left, node.child.field_by_id
),
),
right=typing.cast(
ex.RefOrConstant,
ex.bind_schema_fields(
aggregation.right, node.child.field_by_id
),
),
for expr in node.window_spec.grouping_keys
),
ordering=tuple(
ordering.OrderingExpression(
scalar_expression=ex.bind_schema_fields(
expr.scalar_expression, node.child.field_by_id
),
direction=expr.direction,
na_last=expr.na_last,
)
aggregations.append((replaced, id))
else:
aggregations.append((aggregation, id))

for expr in node.window_spec.ordering
),
)
return dataclasses.replace(
node,
aggregations=tuple(aggregations),
expression=_bind_schema_to_aggregation_expr(node.expression, node.child),
window_spec=window_spec,
)

return node


def _bind_schema_to_aggregation_expr(
aggregation: ex.Aggregation,
child: bigframe_node.BigFrameNode,
) -> ex.Aggregation:
assert isinstance(
aggregation, ex.Aggregation
), f"Expected Aggregation, got {type(aggregation)}"

if isinstance(aggregation, ex.UnaryAggregation):
return typing.cast(
ex.Aggregation,
dataclasses.replace(
aggregation,
arg=typing.cast(
ex.RefOrConstant,
ex.bind_schema_fields(aggregation.arg, child.field_by_id),
),
),
)
elif isinstance(aggregation, ex.BinaryAggregation):
return typing.cast(
ex.Aggregation,
dataclasses.replace(
aggregation,
left=typing.cast(
ex.RefOrConstant,
ex.bind_schema_fields(aggregation.left, child.field_by_id),
),
right=typing.cast(
ex.RefOrConstant,
ex.bind_schema_fields(aggregation.right, child.field_by_id),
),
),
)
else:
return aggregation
1 change: 1 addition & 0 deletions bigframes/operations/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ def skips_nulls(self):

@dataclasses.dataclass(frozen=True)
class DiffOp(UnaryWindowOp):
name: ClassVar[str] = "diff"
periods: int

@property
Expand Down
31 changes: 29 additions & 2 deletions tests/system/small/engines/test_windowing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud import bigquery
import pytest

from bigframes.core import array_value
from bigframes.session import polars_executor
from bigframes.core import array_value, expression, identifiers, nodes, window_spec
import bigframes.operations.aggregations as agg_ops
from bigframes.session import direct_gbq_execution, polars_executor
from bigframes.testing.engine_utils import assert_equivalence_execution

pytest.importorskip("polars")
Expand All @@ -31,3 +33,28 @@ def test_engines_with_offsets(
):
result, _ = scalars_array_value.promote_offsets()
assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)


def test_engines_with_rows_window(
scalars_array_value: array_value.ArrayValue,
bigquery_client: bigquery.Client,
):
window = window_spec.WindowSpec(
bounds=window_spec.RowsWindowBounds.from_window_size(3, "left"),
)
window_node = nodes.WindowOpNode(
child=scalars_array_value.node,
expression=expression.UnaryAggregation(
agg_ops.sum_op, expression.deref("int64_too")
),
window_spec=window,
output_name=identifiers.ColumnId("sum_int64"),
never_skip_nulls=False,
skip_reproject_unsafe=False,
)

bq_executor = direct_gbq_execution.DirectGbqExecutor(bigquery_client)
bq_sqlgot_executor = direct_gbq_execution.DirectGbqExecutor(
bigquery_client, compiler="sqlglot"
)
assert_equivalence_execution(window_node, bq_executor, bq_sqlgot_executor)
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
WITH `bfcte_0` AS (
SELECT
`bool_col` AS `bfcol_0`,
`int64_col` AS `bfcol_1`,
`rowindex` AS `bfcol_2`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
`bfcol_2` AS `bfcol_6`,
`bfcol_0` AS `bfcol_7`,
`bfcol_1` AS `bfcol_8`,
`bfcol_0` AS `bfcol_9`
FROM `bfcte_0`
), `bfcte_2` AS (
SELECT
*
FROM `bfcte_1`
WHERE
NOT `bfcol_9` IS NULL
), `bfcte_3` AS (
SELECT
*,
CASE
WHEN SUM(CAST(NOT `bfcol_7` IS NULL AS INT64)) OVER (
PARTITION BY `bfcol_9`
ORDER BY `bfcol_9` IS NULL ASC NULLS LAST, `bfcol_9` ASC NULLS LAST, `bfcol_2` IS NULL ASC NULLS LAST, `bfcol_2` ASC NULLS LAST
ROWS BETWEEN 3 PRECEDING AND CURRENT ROW
) < 3
THEN NULL
ELSE COALESCE(
SUM(CAST(`bfcol_7` AS INT64)) OVER (
PARTITION BY `bfcol_9`
ORDER BY `bfcol_9` IS NULL ASC NULLS LAST, `bfcol_9` ASC NULLS LAST, `bfcol_2` IS NULL ASC NULLS LAST, `bfcol_2` ASC NULLS LAST
ROWS BETWEEN 3 PRECEDING AND CURRENT ROW
),
0
)
END AS `bfcol_15`
FROM `bfcte_2`
), `bfcte_4` AS (
SELECT
*
FROM `bfcte_3`
WHERE
NOT `bfcol_9` IS NULL
), `bfcte_5` AS (
SELECT
*,
CASE
WHEN SUM(CAST(NOT `bfcol_8` IS NULL AS INT64)) OVER (
PARTITION BY `bfcol_9`
ORDER BY `bfcol_9` IS NULL ASC NULLS LAST, `bfcol_9` ASC NULLS LAST, `bfcol_2` IS NULL ASC NULLS LAST, `bfcol_2` ASC NULLS LAST
ROWS BETWEEN 3 PRECEDING AND CURRENT ROW
) < 3
THEN NULL
ELSE COALESCE(
SUM(`bfcol_8`) OVER (
PARTITION BY `bfcol_9`
ORDER BY `bfcol_9` IS NULL ASC NULLS LAST, `bfcol_9` ASC NULLS LAST, `bfcol_2` IS NULL ASC NULLS LAST, `bfcol_2` ASC NULLS LAST
ROWS BETWEEN 3 PRECEDING AND CURRENT ROW
),
0
)
END AS `bfcol_21`
FROM `bfcte_4`
)
SELECT
`bfcol_9` AS `bool_col`,
`bfcol_6` AS `rowindex`,
`bfcol_15` AS `bool_col_1`,
`bfcol_21` AS `int64_col`
FROM `bfcte_5`
ORDER BY
`bfcol_9` ASC NULLS LAST,
`bfcol_2` ASC NULLS LAST
Loading