Skip to content

Commit 616eccf

Browse files
fix various problems, migrate rank to new api
1 parent 05497c6 commit 616eccf

File tree

5 files changed

+67
-92
lines changed

5 files changed

+67
-92
lines changed

bigframes/core/agg_expressions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def transform_children(
210210
t: Callable[[expression.Expression], expression.Expression],
211211
) -> WindowExpression:
212212
return WindowExpression(
213-
self.analytic_expr.transform_children(t),
213+
t(self.analytic_expr), # type: ignore
214214
self.window.transform_exprs(t),
215215
)
216216

bigframes/core/block_transforms.py

Lines changed: 55 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -431,16 +431,11 @@ def rank(
431431

432432
columns = columns or tuple(col for col in block.value_columns)
433433
labels = [block.col_id_to_label[id] for id in columns]
434-
# Step 1: Calculate row numbers for each row
435-
# Identify null values to be treated according to na_option param
436-
rownum_col_ids = []
437-
nullity_col_ids = []
434+
435+
result_exprs = []
438436
for col in columns:
439-
block, nullity_col_id = block.apply_unary_op(
440-
col,
441-
ops.isnull_op,
442-
)
443-
nullity_col_ids.append(nullity_col_id)
437+
# Step 1: Calculate row numbers for each row
438+
# Identify null values to be treated according to na_option param
444439
window_ordering = (
445440
ordering.OrderingExpression(
446441
ex.deref(col),
@@ -451,87 +446,66 @@ def rank(
451446
),
452447
)
453448
# Count_op ignores nulls, so if na_option is "top" or "bottom", we instead count the nullity columns, where nulls have been mapped to bools
454-
block, rownum_id = block.apply_window_op(
455-
col if na_option == "keep" else nullity_col_id,
456-
agg_ops.dense_rank_op if method == "dense" else agg_ops.count_op,
457-
window_spec=windows.unbound(
458-
grouping_keys=grouping_cols, ordering=window_ordering
459-
)
449+
target_expr = (
450+
ex.deref(col) if na_option == "keep" else ops.isnull_op.as_expr(col)
451+
)
452+
window_op = agg_ops.dense_rank_op if method == "dense" else agg_ops.count_op
453+
window_spec = (
454+
windows.unbound(grouping_keys=grouping_cols, ordering=window_ordering)
460455
if method == "dense"
461456
else windows.rows(
462457
end=0, ordering=window_ordering, grouping_keys=grouping_cols
463-
),
464-
skip_reproject_unsafe=(col != columns[-1]),
458+
)
459+
)
460+
result_expr: ex.Expression = agg_expressions.WindowExpression(
461+
agg_expressions.UnaryAggregation(window_op, target_expr), window_spec
465462
)
466463
if pct:
467-
block, max_id = block.apply_window_op(
468-
rownum_id, agg_ops.max_op, windows.unbound(grouping_keys=grouping_cols)
464+
result_expr = ops.div_op.as_expr(
465+
result_expr,
466+
agg_expressions.WindowExpression(
467+
agg_expressions.UnaryAggregation(agg_ops.max_op, result_expr),
468+
windows.unbound(grouping_keys=grouping_cols),
469+
),
469470
)
470-
block, rownum_id = block.project_expr(ops.div_op.as_expr(rownum_id, max_id))
471-
472-
rownum_col_ids.append(rownum_id)
473-
474-
# Step 2: Apply aggregate to groups of like input values.
475-
# This step is skipped for method=='first' or 'dense'
476-
if method in ["average", "min", "max"]:
477-
agg_op = {
478-
"average": agg_ops.mean_op,
479-
"min": agg_ops.min_op,
480-
"max": agg_ops.max_op,
481-
}[method]
482-
post_agg_rownum_col_ids = []
483-
for i in range(len(columns)):
484-
block, result_id = block.apply_window_op(
485-
rownum_col_ids[i],
486-
agg_op,
487-
window_spec=windows.unbound(grouping_keys=(columns[i], *grouping_cols)),
488-
skip_reproject_unsafe=(i < (len(columns) - 1)),
471+
# Step 2: Apply aggregate to groups of like input values.
472+
# This step is skipped for method=='first' or 'dense'
473+
if method in ["average", "min", "max"]:
474+
agg_op = {
475+
"average": agg_ops.mean_op,
476+
"min": agg_ops.min_op,
477+
"max": agg_ops.max_op,
478+
}[method]
479+
result_expr = agg_expressions.WindowExpression(
480+
agg_expressions.UnaryAggregation(agg_op, result_expr),
481+
windows.unbound(grouping_keys=(col, *grouping_cols)),
489482
)
490-
post_agg_rownum_col_ids.append(result_id)
491-
rownum_col_ids = post_agg_rownum_col_ids
492-
493-
# Pandas masks all values where any grouping column is null
494-
# Note: we use pd.NA instead of float('nan')
495-
if grouping_cols:
496-
predicate = functools.reduce(
497-
ops.and_op.as_expr,
498-
[ops.notnull_op.as_expr(column_id) for column_id in grouping_cols],
499-
)
500-
block = block.project_exprs(
501-
[
502-
ops.where_op.as_expr(
503-
ex.deref(col),
504-
predicate,
505-
ex.const(None),
506-
)
507-
for col in rownum_col_ids
508-
],
509-
labels=labels,
510-
)
511-
rownum_col_ids = list(block.value_columns[-len(rownum_col_ids) :])
512-
513-
# Step 3: post processing: mask null values and cast to float
514-
if method in ["min", "max", "first", "dense"]:
515-
# Pandas rank always produces Float64, so must cast for aggregation types that produce ints
516-
return (
517-
block.select_columns(rownum_col_ids)
518-
.multi_apply_unary_op(ops.AsTypeOp(pd.Float64Dtype()))
519-
.with_column_labels(labels)
520-
)
521-
if na_option == "keep":
522-
# For na_option "keep", null inputs must produce null outputs
523-
exprs = []
524-
for i in range(len(columns)):
525-
exprs.append(
526-
ops.where_op.as_expr(
527-
ex.const(pd.NA, dtype=pd.Float64Dtype()),
528-
nullity_col_ids[i],
529-
rownum_col_ids[i],
530-
)
483+
# Pandas masks all values where any grouping column is null
484+
# Note: we use pd.NA instead of float('nan')
485+
if grouping_cols:
486+
predicate = functools.reduce(
487+
ops.and_op.as_expr,
488+
[ops.notnull_op.as_expr(column_id) for column_id in grouping_cols],
489+
)
490+
result_expr = ops.where_op.as_expr(
491+
result_expr,
492+
predicate,
493+
ex.const(None),
531494
)
532-
return block.project_exprs(exprs, labels=labels, drop=True)
533495

534-
return block.select_columns(rownum_col_ids).with_column_labels(labels)
496+
# Step 3: post processing: mask null values and cast to float
497+
if method in ["min", "max", "first", "dense"]:
498+
# Pandas rank always produces Float64, so must cast for aggregation types that produce ints
499+
result_expr = ops.AsTypeOp(pd.Float64Dtype()).as_expr(result_expr)
500+
elif na_option == "keep":
501+
# For na_option "keep", null inputs must produce null outputs
502+
result_expr = ops.where_op.as_expr(
503+
ex.const(pd.NA, dtype=pd.Float64Dtype()),
504+
ops.isnull_op.as_expr(col),
505+
result_expr,
506+
)
507+
result_exprs.append(result_expr)
508+
return block.project_block_exprs(result_exprs, labels=labels, drop=True)
535509

536510

537511
def dropna(

bigframes/core/blocks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,6 +1165,7 @@ def project_block_exprs(
11651165
if drop:
11661166
new_array = new_array.drop_columns(self.value_columns)
11671167

1168+
new_array.node.validate_tree()
11681169
return Block(
11691170
new_array,
11701171
index_columns=self.index_columns,

bigframes/core/expression_factoring.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,10 @@ def push_into_tree(
135135
for child_id in expr.expr.column_references
136136
if child_id in by_id.keys()
137137
)
138-
# be careful about merging multi-parent ids
139138
# TODO: Also prevent inlining expensive or non-deterministic
139+
# We avoid inlining multi-parent ids, as they would be inlined multiple places, potentially increasing work and/or compiled text size
140140
multi_parent_ids = set(id for id in graph.nodes if len(graph.parents(id)) > 2)
141141
scalar_ids = set(expr.name for expr in exprs if expr.expr.is_scalar_expr)
142-
post_ids = (*root.ids, *target_ids)
143142

144143
def graph_extract_scalar_exprs() -> Sequence[NamedExpression]:
145144
results: dict[identifiers.ColumnId, expression.Expression] = dict()
@@ -168,11 +167,8 @@ def graph_extract_scalar_exprs() -> Sequence[NamedExpression]:
168167
id: by_id[id].expr.bind_refs(results, allow_partial_bindings=True)
169168
}
170169
results.update(new_exprs)
171-
return tuple(
172-
NamedExpression(expr, id)
173-
for id, expr in results.items()
174-
if id in set([*graph.sinks, *target_ids])
175-
)
170+
# TODO: We can prune expressions that won't be reused here,
171+
return tuple(NamedExpression(expr, id) for id, expr in results.items())
176172

177173
def graph_extract_window_expr() -> Optional[
178174
Tuple[identifiers.ColumnId, agg_expressions.WindowExpression]
@@ -193,19 +189,17 @@ def graph_extract_window_expr() -> Optional[
193189
curr_root = nodes.ProjectionNode(
194190
curr_root, tuple((x.expr, x.name) for x in scalar_exprs)
195191
)
196-
curr_root._validate()
197192
while result := graph_extract_window_expr():
198193
id, window_expr = result
199194
curr_root = nodes.WindowOpNode(
200195
curr_root, window_expr.analytic_expr, window_expr.window, output_name=id
201196
)
202-
curr_root._validate()
203197
# TODO: Try to get the ordering right earlier, so can avoid this extra node.
198+
post_ids = (*root.ids, *target_ids)
204199
if tuple(curr_root.ids) != post_ids:
205200
curr_root = nodes.SelectionNode(
206201
curr_root, tuple(nodes.AliasedRef.identity(id) for id in post_ids)
207202
)
208-
curr_root._validate()
209203
return curr_root
210204

211205

bigframes/core/nodes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,6 +1199,7 @@ def _validate(self):
11991199
for expression, _ in self.assignments:
12001200
# throws TypeError if invalid
12011201
_ = ex.bind_schema_fields(expression, self.child.field_by_id).output_type
1202+
assert expression.is_scalar_expr
12021203
# Cannot assign to existing variables - append only!
12031204
assert all(name not in self.child.schema.names for _, name in self.assignments)
12041205

@@ -1404,6 +1405,11 @@ def _validate(self):
14041405
not self.window_spec.is_row_bounded
14051406
) or self.expression.op.implicitly_inherits_order
14061407
assert all(ref in self.child.ids for ref in self.expression.column_references)
1408+
assert self.added_field.dtype is not None
1409+
for agg_child in self.expression.children:
1410+
assert agg_child.is_scalar_expr
1411+
for window_expr in self.window_spec.expressions:
1412+
assert window_expr.is_scalar_expr
14071413

14081414
@property
14091415
def non_local(self) -> bool:

0 commit comments

Comments
 (0)