Skip to content

Commit 094becc

Browse files
committed
resolve aggregation nodes for dtype and support dropna
1 parent 1fa0b77 commit 094becc

File tree

14 files changed

+154
-39
lines changed

14 files changed

+154
-39
lines changed

bigframes/core/compile/sqlglot/aggregate_compiler.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
import sqlglot.expressions as sge
2020

21-
from bigframes import dtypes
2221
from bigframes.core import expression, window_spec
22+
from bigframes.core.compile.sqlglot.expressions import typed_expr
2323
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2424
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
2525
import bigframes.operations as ops
@@ -34,14 +34,23 @@ def compile_aggregate(
3434
if isinstance(aggregate, expression.NullaryAggregation):
3535
return compile_nullary_agg(aggregate.op)
3636
if isinstance(aggregate, expression.UnaryAggregation):
37-
column = scalar_compiler.compile_scalar_expression(aggregate.arg)
37+
column = typed_expr.TypedExpr(
38+
scalar_compiler.compile_scalar_expression(aggregate.arg),
39+
aggregate.arg.output_type,
40+
)
3841
if not aggregate.op.order_independent:
3942
return compile_ordered_unary_agg(aggregate.op, column, order_by=order_by) # type: ignore
4043
else:
4144
return compile_unary_agg(aggregate.op, column) # type: ignore
4245
elif isinstance(aggregate, expression.BinaryAggregation):
43-
left = scalar_compiler.compile_scalar_expression(aggregate.left)
44-
right = scalar_compiler.compile_scalar_expression(aggregate.right)
46+
left = typed_expr.TypedExpr(
47+
scalar_compiler.compile_scalar_expression(aggregate.left),
48+
aggregate.left.output_type,
49+
)
50+
right = typed_expr.TypedExpr(
51+
scalar_compiler.compile_scalar_expression(aggregate.right),
52+
aggregate.right.output_type,
53+
)
4554
return compile_binary_agg(aggregate.op, left, right) # type: ignore
4655
else:
4756
raise ValueError(f"Unexpected aggregation: {aggregate}")
@@ -58,8 +67,8 @@ def compile_nullary_agg(
5867
@functools.singledispatch
5968
def compile_binary_agg(
6069
op: ops.aggregations.WindowOp,
61-
left: sge.Expression,
62-
right: sge.Expression,
70+
left: typed_expr.TypedExpr,
71+
right: typed_expr.TypedExpr,
6372
window: typing.Optional[window_spec.WindowSpec] = None,
6473
) -> sge.Expression:
6574
raise ValueError(f"Can't compile unrecognized operation: {op}")
@@ -68,7 +77,7 @@ def compile_binary_agg(
6877
@functools.singledispatch
6978
def compile_unary_agg(
7079
op: ops.aggregations.WindowOp,
71-
column: sge.Expression,
80+
column: typed_expr.TypedExpr,
7281
window: typing.Optional[window_spec.WindowSpec] = None,
7382
) -> sge.Expression:
7483
raise ValueError(f"Can't compile unrecognized operation: {op}")
@@ -77,7 +86,7 @@ def compile_unary_agg(
7786
@functools.singledispatch
7887
def compile_ordered_unary_agg(
7988
op: ops.aggregations.WindowOp,
80-
column: sge.Expression,
89+
column: typed_expr.TypedExpr,
8190
window: typing.Optional[window_spec.WindowSpec] = None,
8291
) -> sge.Expression:
8392
raise ValueError(f"Can't compile unrecognized operation: {op}")
@@ -87,12 +96,13 @@ def compile_ordered_unary_agg(
8796
@compile_unary_agg.register
8897
def _(
8998
op: ops.aggregations.SumOp,
90-
column: sge.Expression,
99+
column: typed_expr.TypedExpr,
91100
window: typing.Optional[window_spec.WindowSpec] = None,
92101
) -> sge.Expression:
93102
# Will be null if all inputs are null. Pandas defaults to zero sum though.
94-
expr = _apply_window_if_present(sge.func("SUM", column), window)
95-
return sge.func("IFNULL", expr, ir._literal(0, dtypes.INT_DTYPE))
103+
expr = _apply_window_if_present(sge.func("SUM", column.expr), window)
104+
# TODO (b/430350912): check `column.dtype` works for all?
105+
return sge.func("IFNULL", expr, ir._literal(0, column.dtype))
96106

97107

98108
def _apply_window_if_present(
@@ -101,4 +111,4 @@ def _apply_window_if_present(
101111
) -> sge.Expression:
102112
if window is not None:
103113
raise NotImplementedError("Can't apply window to the expression.")
104-
return window
114+
return value

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import typing
1919

2020
from google.cloud import bigquery
21+
import sqlglot as sg
2122
import sqlglot.expressions as sge
2223

2324
from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite
@@ -218,7 +219,7 @@ def compile_filter(
218219
self, node: nodes.FilterNode, child: ir.SQLGlotIR
219220
) -> ir.SQLGlotIR:
220221
condition = scalar_compiler.compile_scalar_expression(node.predicate)
221-
return child.filter(condition)
222+
return child.filter(tuple([condition]))
222223

223224
@_compile_node.register
224225
def compile_join(
@@ -293,12 +294,14 @@ def compile_aggregate(
293294
)
294295

295296
result = child.aggregate(aggregations, by_cols)
296-
# TODO(chelsealin): Support dropna
297-
# TODO: Remove dropna field and use filter node instead
298-
# if node.dropna:
299-
# for key in node.by_column_ids:
300-
# if node.child.field_by_id[key.id].nullable:
301-
# result = result.filter(operations.notnull_op.as_expr(key))
297+
if node.dropna:
298+
conditions = []
299+
for key, by_col in zip(node.by_column_ids, by_cols):
300+
if node.child.field_by_id[key.id].nullable:
301+
conditions.append(
302+
sg.not_(sge.Is(this=by_col, expression=sge.Null()))
303+
)
304+
result = result.filter(tuple(conditions))
302305
return result
303306

304307

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import dataclasses
18+
import functools
1819
import typing
1920

2021
from google.cloud import bigquery
@@ -278,9 +279,13 @@ def limit(
278279

279280
def filter(
280281
self,
281-
condition: sge.Expression,
282+
conditions: tuple[sge.Expression, ...],
282283
) -> SQLGlotIR:
283284
"""Filters the query by adding a WHERE clause."""
285+
condition = _and(conditions)
286+
if condition is None:
287+
return SQLGlotIR(expr=self.expr.copy(), uid_gen=self.uid_gen)
288+
284289
new_expr = _select_to_cte(
285290
self.expr,
286291
sge.to_identifier(
@@ -314,10 +319,11 @@ def join(
314319
right_ctes = right_select.args.pop("with", [])
315320
merged_ctes = [*left_ctes, *right_ctes]
316321

317-
join_conditions = [
318-
_join_condition(left, right, joins_nulls) for left, right in conditions
319-
]
320-
join_on = sge.And(expressions=join_conditions) if join_conditions else None
322+
join_on = _and(
323+
tuple(
324+
_join_condition(left, right, joins_nulls) for left, right in conditions
325+
)
326+
)
321327

322328
join_type_str = join_type if join_type != "outer" else "full outer"
323329
new_expr = (
@@ -582,6 +588,16 @@ def _table(table: bigquery.TableReference) -> sge.Table:
582588
)
583589

584590

591+
def _and(conditions: tuple[sge.Expression, ...]) -> typing.Optional[sge.Expression]:
592+
"""Chains multiple expressions together using a logical AND."""
593+
if not conditions:
594+
return None
595+
596+
return functools.reduce(
597+
lambda left, right: sge.And(this=left, expression=right), conditions
598+
)
599+
600+
585601
def _join_condition(
586602
left: typed_expr.TypedExpr,
587603
right: typed_expr.TypedExpr,

bigframes/core/rewrite/schema_binding.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import dataclasses
16+
import typing
1617

1718
from bigframes.core import bigframe_node
1819
from bigframes.core import expression as ex
@@ -65,4 +66,49 @@ def bind_schema_to_node(
6566
conditions=conditions,
6667
)
6768

69+
if isinstance(node, nodes.AggregateNode):
70+
aggregations = []
71+
for aggregation, id in node.aggregations:
72+
if isinstance(aggregation, ex.UnaryAggregation):
73+
replaced = typing.cast(
74+
ex.Aggregation,
75+
dataclasses.replace(
76+
aggregation,
77+
arg=typing.cast(
78+
ex.RefOrConstant,
79+
ex.bind_schema_fields(
80+
aggregation.arg, node.child.field_by_id
81+
),
82+
),
83+
),
84+
)
85+
aggregations.append((replaced, id))
86+
elif isinstance(aggregation, ex.BinaryAggregation):
87+
replaced = typing.cast(
88+
ex.Aggregation,
89+
dataclasses.replace(
90+
aggregation,
91+
left=typing.cast(
92+
ex.RefOrConstant,
93+
ex.bind_schema_fields(
94+
aggregation.left, node.child.field_by_id
95+
),
96+
),
97+
right=typing.cast(
98+
ex.RefOrConstant,
99+
ex.bind_schema_fields(
100+
aggregation.right, node.child.field_by_id
101+
),
102+
),
103+
),
104+
)
105+
aggregations.append((replaced, id))
106+
else:
107+
aggregations.append((aggregation, id))
108+
109+
return dataclasses.replace(
110+
node,
111+
aggregations=tuple(aggregations),
112+
)
113+
68114
return node

tests/unit/core/compile/sqlglot/snapshots/test_compile_aggregate/test_compile_aggregate/out.sql

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,16 @@ WITH `bfcte_0` AS (
1616
FROM `bfcte_1`
1717
GROUP BY
1818
`bfcol_3`
19+
), `bfcte_3` AS (
20+
SELECT
21+
*
22+
FROM `bfcte_2`
23+
WHERE
24+
NOT `bfcol_3` IS NULL
1925
)
2026
SELECT
2127
`bfcol_3` AS `bool_col`,
2228
`bfcol_6` AS `int64_too`
23-
FROM `bfcte_2`
29+
FROM `bfcte_3`
2430
ORDER BY
2531
`bfcol_3` ASC NULLS LAST
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_too` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
`bfcol_1` AS `bfcol_2`,
10+
`bfcol_0` AS `bfcol_3`
11+
FROM `bfcte_0`
12+
), `bfcte_2` AS (
13+
SELECT
14+
`bfcol_3`,
15+
COALESCE(SUM(`bfcol_2`), 0) AS `bfcol_6`
16+
FROM `bfcte_1`
17+
GROUP BY
18+
`bfcol_3`
19+
)
20+
SELECT
21+
`bfcol_3` AS `bool_col`,
22+
`bfcol_6` AS `int64_too`
23+
FROM `bfcte_2`
24+
ORDER BY
25+
`bfcol_3` ASC NULLS LAST

tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ WITH `bfcte_1` AS (
2323
*
2424
FROM `bfcte_2`
2525
LEFT JOIN `bfcte_3`
26-
ON COALESCE(`bfcol_2`, 0) = COALESCE(`bfcol_6`, 0)
27-
AND COALESCE(`bfcol_2`, 1) = COALESCE(`bfcol_6`, 1)
26+
ON COALESCE(`bfcol_2`, 0) = COALESCE(`bfcol_6`, 0)
27+
AND COALESCE(`bfcol_2`, 1) = COALESCE(`bfcol_6`, 1)
2828
)
2929
SELECT
3030
`bfcol_3` AS `int64_col`,

tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/bool_col/out.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ WITH `bfcte_1` AS (
2323
*
2424
FROM `bfcte_2`
2525
INNER JOIN `bfcte_3`
26-
ON COALESCE(CAST(`bfcol_3` AS STRING), '0') = COALESCE(CAST(`bfcol_7` AS STRING), '0')
27-
AND COALESCE(CAST(`bfcol_3` AS STRING), '1') = COALESCE(CAST(`bfcol_7` AS STRING), '1')
26+
ON COALESCE(CAST(`bfcol_3` AS STRING), '0') = COALESCE(CAST(`bfcol_7` AS STRING), '0')
27+
AND COALESCE(CAST(`bfcol_3` AS STRING), '1') = COALESCE(CAST(`bfcol_7` AS STRING), '1')
2828
)
2929
SELECT
3030
`bfcol_2` AS `rowindex_x`,

tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/float64_col/out.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ WITH `bfcte_1` AS (
2323
*
2424
FROM `bfcte_2`
2525
INNER JOIN `bfcte_3`
26-
ON IF(IS_NAN(`bfcol_3`), 2, COALESCE(`bfcol_3`, 0)) = IF(IS_NAN(`bfcol_7`), 2, COALESCE(`bfcol_7`, 0))
27-
AND IF(IS_NAN(`bfcol_3`), 3, COALESCE(`bfcol_3`, 1)) = IF(IS_NAN(`bfcol_7`), 3, COALESCE(`bfcol_7`, 1))
26+
ON IF(IS_NAN(`bfcol_3`), 2, COALESCE(`bfcol_3`, 0)) = IF(IS_NAN(`bfcol_7`), 2, COALESCE(`bfcol_7`, 0))
27+
AND IF(IS_NAN(`bfcol_3`), 3, COALESCE(`bfcol_3`, 1)) = IF(IS_NAN(`bfcol_7`), 3, COALESCE(`bfcol_7`, 1))
2828
)
2929
SELECT
3030
`bfcol_2` AS `rowindex_x`,

tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/int64_col/out.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ WITH `bfcte_1` AS (
2323
*
2424
FROM `bfcte_2`
2525
INNER JOIN `bfcte_3`
26-
ON COALESCE(`bfcol_3`, 0) = COALESCE(`bfcol_7`, 0)
27-
AND COALESCE(`bfcol_3`, 1) = COALESCE(`bfcol_7`, 1)
26+
ON COALESCE(`bfcol_3`, 0) = COALESCE(`bfcol_7`, 0)
27+
AND COALESCE(`bfcol_3`, 1) = COALESCE(`bfcol_7`, 1)
2828
)
2929
SELECT
3030
`bfcol_2` AS `rowindex_x`,

0 commit comments

Comments
 (0)