@@ -431,16 +431,11 @@ def rank(
431431
432432 columns = columns or tuple (col for col in block .value_columns )
433433 labels = [block .col_id_to_label [id ] for id in columns ]
434- # Step 1: Calculate row numbers for each row
435- # Identify null values to be treated according to na_option param
436- rownum_col_ids = []
437- nullity_col_ids = []
434+
435+ result_exprs = []
438436 for col in columns :
439- block , nullity_col_id = block .apply_unary_op (
440- col ,
441- ops .isnull_op ,
442- )
443- nullity_col_ids .append (nullity_col_id )
437+ # Step 1: Calculate row numbers for each row
438+ # Identify null values to be treated according to na_option param
444439 window_ordering = (
445440 ordering .OrderingExpression (
446441 ex .deref (col ),
@@ -451,87 +446,66 @@ def rank(
451446 ),
452447 )
453448 # Count_op ignores nulls, so if na_option is "top" or "bottom", we instead count the nullity columns, where nulls have been mapped to bools
454- block , rownum_id = block . apply_window_op (
455- col if na_option == "keep" else nullity_col_id ,
456- agg_ops . dense_rank_op if method == "dense" else agg_ops . count_op ,
457- window_spec = windows . unbound (
458- grouping_keys = grouping_cols , ordering = window_ordering
459- )
449+ target_expr = (
450+ ex . deref ( col ) if na_option == "keep" else ops . isnull_op . as_expr ( col )
451+ )
452+ window_op = agg_ops . dense_rank_op if method == "dense" else agg_ops . count_op
453+ window_spec = (
454+ windows . unbound ( grouping_keys = grouping_cols , ordering = window_ordering )
460455 if method == "dense"
461456 else windows .rows (
462457 end = 0 , ordering = window_ordering , grouping_keys = grouping_cols
463- ),
464- skip_reproject_unsafe = (col != columns [- 1 ]),
458+ )
459+ )
460+ result_expr : ex .Expression = agg_expressions .WindowExpression (
461+ agg_expressions .UnaryAggregation (window_op , target_expr ), window_spec
465462 )
466463 if pct :
467- block , max_id = block .apply_window_op (
468- rownum_id , agg_ops .max_op , windows .unbound (grouping_keys = grouping_cols )
464+ result_expr = ops .div_op .as_expr (
465+ result_expr ,
466+ agg_expressions .WindowExpression (
467+ agg_expressions .UnaryAggregation (agg_ops .max_op , result_expr ),
468+ windows .unbound (grouping_keys = grouping_cols ),
469+ ),
469470 )
470- block , rownum_id = block .project_expr (ops .div_op .as_expr (rownum_id , max_id ))
471-
472- rownum_col_ids .append (rownum_id )
473-
474- # Step 2: Apply aggregate to groups of like input values.
475- # This step is skipped for method=='first' or 'dense'
476- if method in ["average" , "min" , "max" ]:
477- agg_op = {
478- "average" : agg_ops .mean_op ,
479- "min" : agg_ops .min_op ,
480- "max" : agg_ops .max_op ,
481- }[method ]
482- post_agg_rownum_col_ids = []
483- for i in range (len (columns )):
484- block , result_id = block .apply_window_op (
485- rownum_col_ids [i ],
486- agg_op ,
487- window_spec = windows .unbound (grouping_keys = (columns [i ], * grouping_cols )),
488- skip_reproject_unsafe = (i < (len (columns ) - 1 )),
471+ # Step 2: Apply aggregate to groups of like input values.
472+ # This step is skipped for method=='first' or 'dense'
473+ if method in ["average" , "min" , "max" ]:
474+ agg_op = {
475+ "average" : agg_ops .mean_op ,
476+ "min" : agg_ops .min_op ,
477+ "max" : agg_ops .max_op ,
478+ }[method ]
479+ result_expr = agg_expressions .WindowExpression (
480+ agg_expressions .UnaryAggregation (agg_op , result_expr ),
481+ windows .unbound (grouping_keys = (col , * grouping_cols )),
489482 )
490- post_agg_rownum_col_ids .append (result_id )
491- rownum_col_ids = post_agg_rownum_col_ids
492-
493- # Pandas masks all values where any grouping column is null
494- # Note: we use pd.NA instead of float('nan')
495- if grouping_cols :
496- predicate = functools .reduce (
497- ops .and_op .as_expr ,
498- [ops .notnull_op .as_expr (column_id ) for column_id in grouping_cols ],
499- )
500- block = block .project_exprs (
501- [
502- ops .where_op .as_expr (
503- ex .deref (col ),
504- predicate ,
505- ex .const (None ),
506- )
507- for col in rownum_col_ids
508- ],
509- labels = labels ,
510- )
511- rownum_col_ids = list (block .value_columns [- len (rownum_col_ids ) :])
512-
513- # Step 3: post processing: mask null values and cast to float
514- if method in ["min" , "max" , "first" , "dense" ]:
515- # Pandas rank always produces Float64, so must cast for aggregation types that produce ints
516- return (
517- block .select_columns (rownum_col_ids )
518- .multi_apply_unary_op (ops .AsTypeOp (pd .Float64Dtype ()))
519- .with_column_labels (labels )
520- )
521- if na_option == "keep" :
522- # For na_option "keep", null inputs must produce null outputs
523- exprs = []
524- for i in range (len (columns )):
525- exprs .append (
526- ops .where_op .as_expr (
527- ex .const (pd .NA , dtype = pd .Float64Dtype ()),
528- nullity_col_ids [i ],
529- rownum_col_ids [i ],
530- )
483+ # Pandas masks all values where any grouping column is null
484+ # Note: we use pd.NA instead of float('nan')
485+ if grouping_cols :
486+ predicate = functools .reduce (
487+ ops .and_op .as_expr ,
488+ [ops .notnull_op .as_expr (column_id ) for column_id in grouping_cols ],
489+ )
490+ result_expr = ops .where_op .as_expr (
491+ result_expr ,
492+ predicate ,
493+ ex .const (None ),
531494 )
532- return block .project_exprs (exprs , labels = labels , drop = True )
533495
534- return block .select_columns (rownum_col_ids ).with_column_labels (labels )
496+ # Step 3: post processing: mask null values and cast to float
497+ if method in ["min" , "max" , "first" , "dense" ]:
498+ # Pandas rank always produces Float64, so must cast for aggregation types that produce ints
499+ result_expr = ops .AsTypeOp (pd .Float64Dtype ()).as_expr (result_expr )
500+ elif na_option == "keep" :
501+ # For na_option "keep", null inputs must produce null outputs
502+ result_expr = ops .where_op .as_expr (
503+ ex .const (pd .NA , dtype = pd .Float64Dtype ()),
504+ ops .isnull_op .as_expr (col ),
505+ result_expr ,
506+ )
507+ result_exprs .append (result_expr )
508+ return block .project_block_exprs (result_exprs , labels = labels , drop = True )
535509
536510
537511def dropna (
0 commit comments