Skip to content

Commit 0eed1c5

Browse files
authored
Merge branch 'main' into add-py-314
2 parents 1e95ef5 + 4d5de14 commit 0eed1c5

File tree

14 files changed

+664
-272
lines changed

14 files changed

+664
-272
lines changed

bigframes/core/array_value.py

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from dataclasses import dataclass
1717
import datetime
1818
import functools
19-
import itertools
2019
import typing
2120
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple
2221

@@ -267,21 +266,96 @@ def compute_values(self, assignments: Sequence[ex.Expression]):
267266
)
268267

269268
def compute_general_expression(self, assignments: Sequence[ex.Expression]):
269+
"""
270+
Applies arbitrary column expressions to the current execution block.
271+
272+
This method transforms the logical plan by applying a sequence of expressions that
273+
preserve the length of the input columns. It supports both scalar operations
274+
and window functions. Each expression is assigned a unique internal column identifier.
275+
276+
Args:
277+
assignments (Sequence[ex.Expression]): A sequence of expression objects
278+
representing the transformations to apply to the columns.
279+
280+
Returns:
281+
Tuple[ArrayValue, Tuple[str, ...]]: A tuple containing:
282+
- An `ArrayValue` wrapping the new root node of the updated logical plan.
283+
- A tuple of strings representing the unique column IDs generated for
284+
each expression in the assignments.
285+
"""
270286
named_exprs = [
271287
nodes.ColumnDef(expr, ids.ColumnId.unique()) for expr in assignments
272288
]
273289
# TODO: Push this to rewrite later to go from block expression to planning form
274-
# TODO: Jointly fragmentize expressions to more efficiently reuse common sub-expressions
275-
fragments = tuple(
276-
itertools.chain.from_iterable(
277-
expression_factoring.fragmentize_expression(expr)
278-
for expr in named_exprs
279-
)
280-
)
290+
new_root = expression_factoring.apply_col_exprs_to_plan(self.node, named_exprs)
291+
281292
target_ids = tuple(named_expr.id for named_expr in named_exprs)
282-
new_root = expression_factoring.push_into_tree(self.node, fragments, target_ids)
283293
return (ArrayValue(new_root), target_ids)
284294

295+
def compute_general_reduction(
296+
self,
297+
assignments: Sequence[ex.Expression],
298+
by_column_ids: typing.Sequence[str] = (),
299+
*,
300+
dropna: bool = False,
301+
):
302+
"""
303+
Applies arbitrary aggregation expressions to the block, optionally grouped by keys.
304+
305+
This method handles reduction operations (e.g., sum, mean, count) that collapse
306+
multiple input rows into a single scalar value per group. If grouping keys are
307+
provided, the operation is performed per group; otherwise, it is a global reduction.
308+
309+
Note: Intermediate aggregations (those that are inputs to further aggregations)
310+
must be windowizable. Notably excluded are approx quantile, top count ops.
311+
312+
Args:
313+
assignments (Sequence[ex.Expression]): A sequence of aggregation expressions
314+
to be calculated.
315+
by_column_ids (typing.Sequence[str], optional): A sequence of column IDs
316+
to use as grouping keys. Defaults to an empty tuple (global reduction).
317+
dropna (bool, optional): If True, rows containing null values in the
318+
`by_column_ids` columns will be filtered out before the reduction
319+
is applied. Defaults to False.
320+
321+
Returns:
322+
ArrayValue:
323+
The new root node representing the aggregation/group-by result.
324+
"""
325+
plan = self.node
326+
327+
# shortcircuit to keep things simple if all aggs are simple
328+
# TODO: Fully unify paths once rewriters are strong enough to simplify complexity from full path
329+
def _is_direct_agg(agg_expr):
330+
return isinstance(agg_expr, agg_expressions.Aggregation) and all(
331+
isinstance(child, (ex.DerefOp, ex.ScalarConstantExpression))
332+
for child in agg_expr.children
333+
)
334+
335+
if all(_is_direct_agg(agg) for agg in assignments):
336+
agg_defs = tuple((agg, ids.ColumnId.unique()) for agg in assignments)
337+
return ArrayValue(
338+
nodes.AggregateNode(
339+
child=self.node,
340+
aggregations=agg_defs, # type: ignore
341+
by_column_ids=tuple(map(ex.deref, by_column_ids)),
342+
dropna=dropna,
343+
)
344+
)
345+
346+
if dropna:
347+
for col_id in by_column_ids:
348+
plan = nodes.FilterNode(plan, ops.notnull_op.as_expr(col_id))
349+
350+
named_exprs = [
351+
nodes.ColumnDef(expr, ids.ColumnId.unique()) for expr in assignments
352+
]
353+
# TODO: Push this to rewrite later to go from block expression to planning form
354+
new_root = expression_factoring.apply_agg_exprs_to_plan(
355+
plan, named_exprs, grouping_keys=[ex.deref(by) for by in by_column_ids]
356+
)
357+
return ArrayValue(new_root)
358+
285359
def project_to_id(self, expression: ex.Expression):
286360
array_val, ids = self.compute_values(
287361
[expression],

bigframes/core/block_transforms.py

Lines changed: 45 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,12 @@ def quantile(
129129
window_spec=window,
130130
)
131131
quantile_cols.append(quantile_col)
132-
block, _ = block.aggregate(
133-
grouping_column_ids,
132+
block = block.aggregate(
134133
tuple(
135134
agg_expressions.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col))
136135
for col in quantile_cols
137136
),
137+
grouping_column_ids,
138138
column_labels=pd.Index(labels),
139139
dropna=dropna,
140140
)
@@ -358,12 +358,12 @@ def value_counts(
358358
if grouping_keys and drop_na:
359359
# only need this if grouping_keys is involved, otherwise the drop_na in the aggregation will handle it for us
360360
block = dropna(block, columns, how="any")
361-
block, agg_ids = block.aggregate(
362-
by_column_ids=(*grouping_keys, *columns),
361+
block = block.aggregate(
363362
aggregations=[agg_expressions.NullaryAggregation(agg_ops.size_op)],
363+
by_column_ids=(*grouping_keys, *columns),
364364
dropna=drop_na and not grouping_keys,
365365
)
366-
count_id = agg_ids[0]
366+
count_id = block.value_columns[0]
367367
if normalize:
368368
unbound_window = windows.unbound(grouping_keys=tuple(grouping_keys))
369369
block, total_count_id = block.apply_window_op(
@@ -621,40 +621,28 @@ def skew(
621621
original_columns = skew_column_ids
622622
column_labels = block.select_columns(original_columns).column_labels
623623

624-
block, delta3_ids = _mean_delta_to_power(
625-
block, 3, original_columns, grouping_column_ids
626-
)
627624
# counts, moment3 for each column
628625
aggregations = []
629-
for i, col in enumerate(original_columns):
626+
for col in original_columns:
627+
delta3_expr = _mean_delta_to_power(3, col)
630628
count_agg = agg_expressions.UnaryAggregation(
631629
agg_ops.count_op,
632630
ex.deref(col),
633631
)
634632
moment3_agg = agg_expressions.UnaryAggregation(
635633
agg_ops.mean_op,
636-
ex.deref(delta3_ids[i]),
634+
delta3_expr,
637635
)
638636
variance_agg = agg_expressions.UnaryAggregation(
639637
agg_ops.PopVarOp(),
640638
ex.deref(col),
641639
)
642-
aggregations.extend([count_agg, moment3_agg, variance_agg])
640+
skew_expr = _skew_from_moments_and_count(count_agg, moment3_agg, variance_agg)
641+
aggregations.append(skew_expr)
643642

644-
block, agg_ids = block.aggregate(
645-
by_column_ids=grouping_column_ids, aggregations=aggregations
643+
block = block.aggregate(
644+
aggregations, grouping_column_ids, column_labels=column_labels
646645
)
647-
648-
skew_ids = []
649-
for i, col in enumerate(original_columns):
650-
# Corresponds to order of aggregations in preceding loop
651-
count_id, moment3_id, var_id = agg_ids[i * 3 : (i * 3) + 3]
652-
block, skew_id = _skew_from_moments_and_count(
653-
block, count_id, moment3_id, var_id
654-
)
655-
skew_ids.append(skew_id)
656-
657-
block = block.select_columns(skew_ids).with_column_labels(column_labels)
658646
if not grouping_column_ids:
659647
# When ungrouped, transpose result row into a series
660648
# perform transpose last, so as to not invalidate cache
@@ -671,36 +659,23 @@ def kurt(
671659
) -> blocks.Block:
672660
original_columns = skew_column_ids
673661
column_labels = block.select_columns(original_columns).column_labels
674-
675-
block, delta4_ids = _mean_delta_to_power(
676-
block, 4, original_columns, grouping_column_ids
677-
)
678662
# counts, moment4 for each column
679-
aggregations = []
680-
for i, col in enumerate(original_columns):
663+
kurt_exprs = []
664+
for col in original_columns:
665+
delta_4_expr = _mean_delta_to_power(4, col)
681666
count_agg = agg_expressions.UnaryAggregation(agg_ops.count_op, ex.deref(col))
682-
moment4_agg = agg_expressions.UnaryAggregation(
683-
agg_ops.mean_op, ex.deref(delta4_ids[i])
684-
)
667+
moment4_agg = agg_expressions.UnaryAggregation(agg_ops.mean_op, delta_4_expr)
685668
variance_agg = agg_expressions.UnaryAggregation(
686669
agg_ops.PopVarOp(), ex.deref(col)
687670
)
688-
aggregations.extend([count_agg, moment4_agg, variance_agg])
689-
690-
block, agg_ids = block.aggregate(
691-
by_column_ids=grouping_column_ids, aggregations=aggregations
692-
)
693671

694-
kurt_ids = []
695-
for i, col in enumerate(original_columns):
696672
# Corresponds to order of aggregations in preceding loop
697-
count_id, moment4_id, var_id = agg_ids[i * 3 : (i * 3) + 3]
698-
block, kurt_id = _kurt_from_moments_and_count(
699-
block, count_id, moment4_id, var_id
700-
)
701-
kurt_ids.append(kurt_id)
673+
kurt_expr = _kurt_from_moments_and_count(count_agg, moment4_agg, variance_agg)
674+
kurt_exprs.append(kurt_expr)
702675

703-
block = block.select_columns(kurt_ids).with_column_labels(column_labels)
676+
block = block.aggregate(
677+
kurt_exprs, grouping_column_ids, column_labels=column_labels
678+
)
704679
if not grouping_column_ids:
705680
# When ungrouped, transpose result row into a series
706681
# perform transpose last, so as to not invalidate cache
@@ -711,38 +686,30 @@ def kurt(
711686

712687

713688
def _mean_delta_to_power(
714-
block: blocks.Block,
715689
n_power: int,
716-
column_ids: typing.Sequence[str],
717-
grouping_column_ids: typing.Sequence[str],
718-
) -> typing.Tuple[blocks.Block, typing.Sequence[str]]:
690+
val_id: str,
691+
) -> ex.Expression:
719692
"""Calculate (x-mean(x))^n. Useful for calculating moment statistics such as skew and kurtosis."""
720-
window = windows.unbound(grouping_keys=tuple(grouping_column_ids))
721-
block, mean_ids = block.multi_apply_window_op(column_ids, agg_ops.mean_op, window)
722-
delta_ids = []
723-
for val_id, mean_val_id in zip(column_ids, mean_ids):
724-
delta = ops.sub_op.as_expr(val_id, mean_val_id)
725-
delta_power = ops.pow_op.as_expr(delta, ex.const(n_power))
726-
block, delta_power_id = block.project_expr(delta_power)
727-
delta_ids.append(delta_power_id)
728-
return block, delta_ids
693+
mean_expr = agg_expressions.UnaryAggregation(agg_ops.mean_op, ex.deref(val_id))
694+
delta = ops.sub_op.as_expr(val_id, mean_expr)
695+
return ops.pow_op.as_expr(delta, ex.const(n_power))
729696

730697

731698
def _skew_from_moments_and_count(
732-
block: blocks.Block, count_id: str, moment3_id: str, moment2_id: str
733-
) -> typing.Tuple[blocks.Block, str]:
699+
count: ex.Expression, moment3: ex.Expression, moment2: ex.Expression
700+
) -> ex.Expression:
734701
# Calculate skew using count, third moment and population variance
735702
# See G1 estimator:
736703
# https://en.wikipedia.org/wiki/Skewness#Sample_skewness
737704
moments_estimator = ops.div_op.as_expr(
738-
moment3_id, ops.pow_op.as_expr(moment2_id, ex.const(3 / 2))
705+
moment3, ops.pow_op.as_expr(moment2, ex.const(3 / 2))
739706
)
740707

741-
countminus1 = ops.sub_op.as_expr(count_id, ex.const(1))
742-
countminus2 = ops.sub_op.as_expr(count_id, ex.const(2))
708+
countminus1 = ops.sub_op.as_expr(count, ex.const(1))
709+
countminus2 = ops.sub_op.as_expr(count, ex.const(2))
743710
adjustment = ops.div_op.as_expr(
744711
ops.unsafe_pow_op.as_expr(
745-
ops.mul_op.as_expr(count_id, countminus1), ex.const(1 / 2)
712+
ops.mul_op.as_expr(count, countminus1), ex.const(1 / 2)
746713
),
747714
countminus2,
748715
)
@@ -751,14 +718,14 @@ def _skew_from_moments_and_count(
751718

752719
# Need to produce NA if have less than 3 data points
753720
cleaned_skew = ops.where_op.as_expr(
754-
skew, ops.ge_op.as_expr(count_id, ex.const(3)), ex.const(None)
721+
skew, ops.ge_op.as_expr(count, ex.const(3)), ex.const(None)
755722
)
756-
return block.project_expr(cleaned_skew)
723+
return cleaned_skew
757724

758725

759726
def _kurt_from_moments_and_count(
760-
block: blocks.Block, count_id: str, moment4_id: str, moment2_id: str
761-
) -> typing.Tuple[blocks.Block, str]:
727+
count: ex.Expression, moment4: ex.Expression, moment2: ex.Expression
728+
) -> ex.Expression:
762729
# Kurtosis is often defined as the second standardize moment: moment(4)/moment(2)**2
763730
# Pandas however uses Fisher’s estimator, implemented below
764731
# numerator = (count + 1) * (count - 1) * moment4
@@ -767,28 +734,26 @@ def _kurt_from_moments_and_count(
767734
# kurtosis = (numerator / denominator) - adjustment
768735

769736
numerator = ops.mul_op.as_expr(
770-
moment4_id,
737+
moment4,
771738
ops.mul_op.as_expr(
772-
ops.sub_op.as_expr(count_id, ex.const(1)),
773-
ops.add_op.as_expr(count_id, ex.const(1)),
739+
ops.sub_op.as_expr(count, ex.const(1)),
740+
ops.add_op.as_expr(count, ex.const(1)),
774741
),
775742
)
776743

777744
# Denominator
778-
countminus2 = ops.sub_op.as_expr(count_id, ex.const(2))
779-
countminus3 = ops.sub_op.as_expr(count_id, ex.const(3))
745+
countminus2 = ops.sub_op.as_expr(count, ex.const(2))
746+
countminus3 = ops.sub_op.as_expr(count, ex.const(3))
780747

781748
# Denominator
782749
denominator = ops.mul_op.as_expr(
783-
ops.unsafe_pow_op.as_expr(moment2_id, ex.const(2)),
750+
ops.unsafe_pow_op.as_expr(moment2, ex.const(2)),
784751
ops.mul_op.as_expr(countminus2, countminus3),
785752
)
786753

787754
# Adjustment
788755
adj_num = ops.mul_op.as_expr(
789-
ops.unsafe_pow_op.as_expr(
790-
ops.sub_op.as_expr(count_id, ex.const(1)), ex.const(2)
791-
),
756+
ops.unsafe_pow_op.as_expr(ops.sub_op.as_expr(count, ex.const(1)), ex.const(2)),
792757
ex.const(3),
793758
)
794759
adj_denom = ops.mul_op.as_expr(countminus2, countminus3)
@@ -799,9 +764,9 @@ def _kurt_from_moments_and_count(
799764

800765
# Need to produce NA if have less than 4 data points
801766
cleaned_kurt = ops.where_op.as_expr(
802-
kurt, ops.ge_op.as_expr(count_id, ex.const(4)), ex.const(None)
767+
kurt, ops.ge_op.as_expr(count, ex.const(4)), ex.const(None)
803768
)
804-
return block.project_expr(cleaned_kurt)
769+
return cleaned_kurt
805770

806771

807772
def align(

0 commit comments

Comments
 (0)