@@ -129,12 +129,12 @@ def quantile(
129129 window_spec = window ,
130130 )
131131 quantile_cols .append (quantile_col )
132- block , _ = block .aggregate (
133- grouping_column_ids ,
132+ block = block .aggregate (
134133 tuple (
135134 agg_expressions .UnaryAggregation (agg_ops .AnyValueOp (), ex .deref (col ))
136135 for col in quantile_cols
137136 ),
137+ grouping_column_ids ,
138138 column_labels = pd .Index (labels ),
139139 dropna = dropna ,
140140 )
@@ -358,12 +358,12 @@ def value_counts(
358358 if grouping_keys and drop_na :
359359 # only need this if grouping_keys is involved, otherwise the drop_na in the aggregation will handle it for us
360360 block = dropna (block , columns , how = "any" )
361- block , agg_ids = block .aggregate (
362- by_column_ids = (* grouping_keys , * columns ),
361+ block = block .aggregate (
363362 aggregations = [agg_expressions .NullaryAggregation (agg_ops .size_op )],
363+ by_column_ids = (* grouping_keys , * columns ),
364364 dropna = drop_na and not grouping_keys ,
365365 )
366- count_id = agg_ids [0 ]
366+ count_id = block . value_columns [0 ]
367367 if normalize :
368368 unbound_window = windows .unbound (grouping_keys = tuple (grouping_keys ))
369369 block , total_count_id = block .apply_window_op (
@@ -621,40 +621,28 @@ def skew(
621621 original_columns = skew_column_ids
622622 column_labels = block .select_columns (original_columns ).column_labels
623623
624- block , delta3_ids = _mean_delta_to_power (
625- block , 3 , original_columns , grouping_column_ids
626- )
627624 # counts, moment3 for each column
628625 aggregations = []
629- for i , col in enumerate (original_columns ):
626+ for col in original_columns :
627+ delta3_expr = _mean_delta_to_power (3 , col )
630628 count_agg = agg_expressions .UnaryAggregation (
631629 agg_ops .count_op ,
632630 ex .deref (col ),
633631 )
634632 moment3_agg = agg_expressions .UnaryAggregation (
635633 agg_ops .mean_op ,
636- ex . deref ( delta3_ids [ i ]) ,
634+ delta3_expr ,
637635 )
638636 variance_agg = agg_expressions .UnaryAggregation (
639637 agg_ops .PopVarOp (),
640638 ex .deref (col ),
641639 )
642- aggregations .extend ([count_agg , moment3_agg , variance_agg ])
640+ skew_expr = _skew_from_moments_and_count (count_agg , moment3_agg , variance_agg )
641+ aggregations .append (skew_expr )
643642
644- block , agg_ids = block .aggregate (
645- by_column_ids = grouping_column_ids , aggregations = aggregations
643+ block = block .aggregate (
644+ aggregations , grouping_column_ids , column_labels = column_labels
646645 )
647-
648- skew_ids = []
649- for i , col in enumerate (original_columns ):
650- # Corresponds to order of aggregations in preceding loop
651- count_id , moment3_id , var_id = agg_ids [i * 3 : (i * 3 ) + 3 ]
652- block , skew_id = _skew_from_moments_and_count (
653- block , count_id , moment3_id , var_id
654- )
655- skew_ids .append (skew_id )
656-
657- block = block .select_columns (skew_ids ).with_column_labels (column_labels )
658646 if not grouping_column_ids :
659647 # When ungrouped, transpose result row into a series
660648 # perform transpose last, so as to not invalidate cache
@@ -671,36 +659,23 @@ def kurt(
671659) -> blocks .Block :
672660 original_columns = skew_column_ids
673661 column_labels = block .select_columns (original_columns ).column_labels
674-
675- block , delta4_ids = _mean_delta_to_power (
676- block , 4 , original_columns , grouping_column_ids
677- )
678662 # counts, moment4 for each column
679- aggregations = []
680- for i , col in enumerate (original_columns ):
663+ kurt_exprs = []
664+ for col in original_columns :
665+ delta_4_expr = _mean_delta_to_power (4 , col )
681666 count_agg = agg_expressions .UnaryAggregation (agg_ops .count_op , ex .deref (col ))
682- moment4_agg = agg_expressions .UnaryAggregation (
683- agg_ops .mean_op , ex .deref (delta4_ids [i ])
684- )
667+ moment4_agg = agg_expressions .UnaryAggregation (agg_ops .mean_op , delta_4_expr )
685668 variance_agg = agg_expressions .UnaryAggregation (
686669 agg_ops .PopVarOp (), ex .deref (col )
687670 )
688- aggregations .extend ([count_agg , moment4_agg , variance_agg ])
689-
690- block , agg_ids = block .aggregate (
691- by_column_ids = grouping_column_ids , aggregations = aggregations
692- )
693671
694- kurt_ids = []
695- for i , col in enumerate (original_columns ):
696672 # Corresponds to order of aggregations in preceding loop
697- count_id , moment4_id , var_id = agg_ids [i * 3 : (i * 3 ) + 3 ]
698- block , kurt_id = _kurt_from_moments_and_count (
699- block , count_id , moment4_id , var_id
700- )
701- kurt_ids .append (kurt_id )
673+ kurt_expr = _kurt_from_moments_and_count (count_agg , moment4_agg , variance_agg )
674+ kurt_exprs .append (kurt_expr )
702675
703- block = block .select_columns (kurt_ids ).with_column_labels (column_labels )
676+ block = block .aggregate (
677+ kurt_exprs , grouping_column_ids , column_labels = column_labels
678+ )
704679 if not grouping_column_ids :
705680 # When ungrouped, transpose result row into a series
706681 # perform transpose last, so as to not invalidate cache
@@ -711,38 +686,30 @@ def kurt(
711686
712687
713688def _mean_delta_to_power (
714- block : blocks .Block ,
715689 n_power : int ,
716- column_ids : typing .Sequence [str ],
717- grouping_column_ids : typing .Sequence [str ],
718- ) -> typing .Tuple [blocks .Block , typing .Sequence [str ]]:
690+ val_id : str ,
691+ ) -> ex .Expression :
719692 """Calculate (x-mean(x))^n. Useful for calculating moment statistics such as skew and kurtosis."""
720- window = windows .unbound (grouping_keys = tuple (grouping_column_ids ))
721- block , mean_ids = block .multi_apply_window_op (column_ids , agg_ops .mean_op , window )
722- delta_ids = []
723- for val_id , mean_val_id in zip (column_ids , mean_ids ):
724- delta = ops .sub_op .as_expr (val_id , mean_val_id )
725- delta_power = ops .pow_op .as_expr (delta , ex .const (n_power ))
726- block , delta_power_id = block .project_expr (delta_power )
727- delta_ids .append (delta_power_id )
728- return block , delta_ids
693+ mean_expr = agg_expressions .UnaryAggregation (agg_ops .mean_op , ex .deref (val_id ))
694+ delta = ops .sub_op .as_expr (val_id , mean_expr )
695+ return ops .pow_op .as_expr (delta , ex .const (n_power ))
729696
730697
731698def _skew_from_moments_and_count (
732- block : blocks . Block , count_id : str , moment3_id : str , moment2_id : str
733- ) -> typing . Tuple [ blocks . Block , str ] :
699+ count : ex . Expression , moment3 : ex . Expression , moment2 : ex . Expression
700+ ) -> ex . Expression :
734701 # Calculate skew using count, third moment and population variance
735702 # See G1 estimator:
736703 # https://en.wikipedia.org/wiki/Skewness#Sample_skewness
737704 moments_estimator = ops .div_op .as_expr (
738- moment3_id , ops .pow_op .as_expr (moment2_id , ex .const (3 / 2 ))
705+ moment3 , ops .pow_op .as_expr (moment2 , ex .const (3 / 2 ))
739706 )
740707
741- countminus1 = ops .sub_op .as_expr (count_id , ex .const (1 ))
742- countminus2 = ops .sub_op .as_expr (count_id , ex .const (2 ))
708+ countminus1 = ops .sub_op .as_expr (count , ex .const (1 ))
709+ countminus2 = ops .sub_op .as_expr (count , ex .const (2 ))
743710 adjustment = ops .div_op .as_expr (
744711 ops .unsafe_pow_op .as_expr (
745- ops .mul_op .as_expr (count_id , countminus1 ), ex .const (1 / 2 )
712+ ops .mul_op .as_expr (count , countminus1 ), ex .const (1 / 2 )
746713 ),
747714 countminus2 ,
748715 )
@@ -751,14 +718,14 @@ def _skew_from_moments_and_count(
751718
752719 # Need to produce NA if have less than 3 data points
753720 cleaned_skew = ops .where_op .as_expr (
754- skew , ops .ge_op .as_expr (count_id , ex .const (3 )), ex .const (None )
721+ skew , ops .ge_op .as_expr (count , ex .const (3 )), ex .const (None )
755722 )
756- return block . project_expr ( cleaned_skew )
723+ return cleaned_skew
757724
758725
759726def _kurt_from_moments_and_count (
760- block : blocks . Block , count_id : str , moment4_id : str , moment2_id : str
761- ) -> typing . Tuple [ blocks . Block , str ] :
727+ count : ex . Expression , moment4 : ex . Expression , moment2 : ex . Expression
728+ ) -> ex . Expression :
762729 # Kurtosis is often defined as the second standardize moment: moment(4)/moment(2)**2
763730 # Pandas however uses Fisher’s estimator, implemented below
764731 # numerator = (count + 1) * (count - 1) * moment4
@@ -767,28 +734,26 @@ def _kurt_from_moments_and_count(
767734 # kurtosis = (numerator / denominator) - adjustment
768735
769736 numerator = ops .mul_op .as_expr (
770- moment4_id ,
737+ moment4 ,
771738 ops .mul_op .as_expr (
772- ops .sub_op .as_expr (count_id , ex .const (1 )),
773- ops .add_op .as_expr (count_id , ex .const (1 )),
739+ ops .sub_op .as_expr (count , ex .const (1 )),
740+ ops .add_op .as_expr (count , ex .const (1 )),
774741 ),
775742 )
776743
777744 # Denominator
778- countminus2 = ops .sub_op .as_expr (count_id , ex .const (2 ))
779- countminus3 = ops .sub_op .as_expr (count_id , ex .const (3 ))
745+ countminus2 = ops .sub_op .as_expr (count , ex .const (2 ))
746+ countminus3 = ops .sub_op .as_expr (count , ex .const (3 ))
780747
781748 # Denominator
782749 denominator = ops .mul_op .as_expr (
783- ops .unsafe_pow_op .as_expr (moment2_id , ex .const (2 )),
750+ ops .unsafe_pow_op .as_expr (moment2 , ex .const (2 )),
784751 ops .mul_op .as_expr (countminus2 , countminus3 ),
785752 )
786753
787754 # Adjustment
788755 adj_num = ops .mul_op .as_expr (
789- ops .unsafe_pow_op .as_expr (
790- ops .sub_op .as_expr (count_id , ex .const (1 )), ex .const (2 )
791- ),
756+ ops .unsafe_pow_op .as_expr (ops .sub_op .as_expr (count , ex .const (1 )), ex .const (2 )),
792757 ex .const (3 ),
793758 )
794759 adj_denom = ops .mul_op .as_expr (countminus2 , countminus3 )
@@ -799,9 +764,9 @@ def _kurt_from_moments_and_count(
799764
800765 # Need to produce NA if have less than 4 data points
801766 cleaned_kurt = ops .where_op .as_expr (
802- kurt , ops .ge_op .as_expr (count_id , ex .const (4 )), ex .const (None )
767+ kurt , ops .ge_op .as_expr (count , ex .const (4 )), ex .const (None )
803768 )
804- return block . project_expr ( cleaned_kurt )
769+ return cleaned_kurt
805770
806771
807772def align (
0 commit comments