|
16 | 16 | from dataclasses import dataclass |
17 | 17 | import datetime |
18 | 18 | import functools |
19 | | -import itertools |
20 | 19 | import typing |
21 | 20 | from typing import Iterable, List, Mapping, Optional, Sequence, Tuple |
22 | 21 |
|
@@ -267,21 +266,96 @@ def compute_values(self, assignments: Sequence[ex.Expression]): |
267 | 266 | ) |
268 | 267 |
|
269 | 268 | def compute_general_expression(self, assignments: Sequence[ex.Expression]): |
| 269 | + """ |
| 270 | + Applies arbitrary column expressions to the current execution block. |
| 271 | +
|
| 272 | + This method transforms the logical plan by applying a sequence of expressions that |
| 273 | + preserve the length of the input columns. It supports both scalar operations |
| 274 | + and window functions. Each expression is assigned a unique internal column identifier. |
| 275 | +
|
| 276 | + Args: |
| 277 | + assignments (Sequence[ex.Expression]): A sequence of expression objects |
| 278 | + representing the transformations to apply to the columns. |
| 279 | +
|
| 280 | + Returns: |
| 281 | + Tuple[ArrayValue, Tuple[str, ...]]: A tuple containing: |
| 282 | + - An `ArrayValue` wrapping the new root node of the updated logical plan. |
| 283 | + - A tuple of strings representing the unique column IDs generated for |
| 284 | + each expression in the assignments. |
| 285 | + """ |
270 | 286 | named_exprs = [ |
271 | 287 | nodes.ColumnDef(expr, ids.ColumnId.unique()) for expr in assignments |
272 | 288 | ] |
273 | 289 | # TODO: Push this to rewrite later to go from block expression to planning form |
274 | | - # TODO: Jointly fragmentize expressions to more efficiently reuse common sub-expressions |
275 | | - fragments = tuple( |
276 | | - itertools.chain.from_iterable( |
277 | | - expression_factoring.fragmentize_expression(expr) |
278 | | - for expr in named_exprs |
279 | | - ) |
280 | | - ) |
| 290 | + new_root = expression_factoring.apply_col_exprs_to_plan(self.node, named_exprs) |
| 291 | + |
281 | 292 | target_ids = tuple(named_expr.id for named_expr in named_exprs) |
282 | | - new_root = expression_factoring.push_into_tree(self.node, fragments, target_ids) |
283 | 293 | return (ArrayValue(new_root), target_ids) |
284 | 294 |
|
| 295 | + def compute_general_reduction( |
| 296 | + self, |
| 297 | + assignments: Sequence[ex.Expression], |
| 298 | + by_column_ids: typing.Sequence[str] = (), |
| 299 | + *, |
| 300 | + dropna: bool = False, |
| 301 | + ): |
| 302 | + """ |
| 303 | + Applies arbitrary aggregation expressions to the block, optionally grouped by keys. |
| 304 | +
|
| 305 | + This method handles reduction operations (e.g., sum, mean, count) that collapse |
| 306 | + multiple input rows into a single scalar value per group. If grouping keys are |
| 307 | + provided, the operation is performed per group; otherwise, it is a global reduction. |
| 308 | +
|
| 309 | + Note: Intermediate aggregations (those that are inputs to further aggregations) |
| 310 | + must be windowizable. Notably excluded are approx quantile, top count ops. |
| 311 | +
|
| 312 | + Args: |
| 313 | + assignments (Sequence[ex.Expression]): A sequence of aggregation expressions |
| 314 | + to be calculated. |
| 315 | + by_column_ids (typing.Sequence[str], optional): A sequence of column IDs |
| 316 | + to use as grouping keys. Defaults to an empty tuple (global reduction). |
| 317 | + dropna (bool, optional): If True, rows containing null values in the |
| 318 | + `by_column_ids` columns will be filtered out before the reduction |
| 319 | + is applied. Defaults to False. |
| 320 | +
|
| 321 | + Returns: |
| 322 | + ArrayValue: |
| 323 | + The new root node representing the aggregation/group-by result. |
| 324 | + """ |
| 325 | + plan = self.node |
| 326 | + |
| 327 | + # shortcircuit to keep things simple if all aggs are simple |
| 328 | + # TODO: Fully unify paths once rewriters are strong enough to simplify complexity from full path |
| 329 | + def _is_direct_agg(agg_expr): |
| 330 | + return isinstance(agg_expr, agg_expressions.Aggregation) and all( |
| 331 | + isinstance(child, (ex.DerefOp, ex.ScalarConstantExpression)) |
| 332 | + for child in agg_expr.children |
| 333 | + ) |
| 334 | + |
| 335 | + if all(_is_direct_agg(agg) for agg in assignments): |
| 336 | + agg_defs = tuple((agg, ids.ColumnId.unique()) for agg in assignments) |
| 337 | + return ArrayValue( |
| 338 | + nodes.AggregateNode( |
| 339 | + child=self.node, |
| 340 | + aggregations=agg_defs, # type: ignore |
| 341 | + by_column_ids=tuple(map(ex.deref, by_column_ids)), |
| 342 | + dropna=dropna, |
| 343 | + ) |
| 344 | + ) |
| 345 | + |
| 346 | + if dropna: |
| 347 | + for col_id in by_column_ids: |
| 348 | + plan = nodes.FilterNode(plan, ops.notnull_op.as_expr(col_id)) |
| 349 | + |
| 350 | + named_exprs = [ |
| 351 | + nodes.ColumnDef(expr, ids.ColumnId.unique()) for expr in assignments |
| 352 | + ] |
| 353 | + # TODO: Push this to rewrite later to go from block expression to planning form |
| 354 | + new_root = expression_factoring.apply_agg_exprs_to_plan( |
| 355 | + plan, named_exprs, grouping_keys=[ex.deref(by) for by in by_column_ids] |
| 356 | + ) |
| 357 | + return ArrayValue(new_root) |
| 358 | + |
285 | 359 | def project_to_id(self, expression: ex.Expression): |
286 | 360 | array_val, ids = self.compute_values( |
287 | 361 | [expression], |
|
0 commit comments