2525REFERENCE_ENGINE = polars_executor .PolarsExecutor ()
2626
2727
28+ def apply_agg_to_all_valid (
29+ array : array_value .ArrayValue , op : agg_ops .UnaryAggregateOp , excluded_cols = []
30+ ) -> array_value .ArrayValue :
31+ """
32+ Apply the aggregation to every column in the array that has a compatible datatype.
33+ """
34+ exprs_by_name = []
35+ for arg in array .column_ids :
36+ if arg in excluded_cols :
37+ continue
38+ try :
39+ _ = op .output_type (array .get_column_type (arg ))
40+ expr = expression .UnaryAggregation (op , expression .deref (arg ))
41+ name = f"{ arg } -{ op .name } "
42+ exprs_by_name .append ((expr , name ))
43+ except TypeError :
44+ continue
45+ assert len (exprs_by_name ) > 0
46+ new_arr = array .aggregate (exprs_by_name )
47+ return new_arr
48+
49+
2850@pytest .mark .parametrize ("engine" , ["polars" , "bq" ], indirect = True )
2951def test_engines_aggregate_size (
3052 scalars_array_value : array_value .ArrayValue ,
@@ -48,6 +70,20 @@ def test_engines_aggregate_size(
4870 assert_equivalence_execution (node , REFERENCE_ENGINE , engine )
4971
5072
73+ @pytest .mark .parametrize ("engine" , ["polars" , "bq" ], indirect = True )
74+ @pytest .mark .parametrize (
75+ "op" ,
76+ [agg_ops .min_op , agg_ops .max_op , agg_ops .mean_op , agg_ops .sum_op , agg_ops .count_op ],
77+ )
78+ def test_engines_unary_aggregates (
79+ scalars_array_value : array_value .ArrayValue ,
80+ engine ,
81+ op ,
82+ ):
83+ node = apply_agg_to_all_valid (scalars_array_value , op ).node
84+ assert_equivalence_execution (node , REFERENCE_ENGINE , engine )
85+
86+
5187@pytest .mark .parametrize ("engine" , ["polars" , "bq" ], indirect = True )
5288@pytest .mark .parametrize (
5389 "grouping_cols" ,
0 commit comments