@@ -1066,7 +1066,8 @@ def test_cohorts_map_reduce_consistent_dtypes(method, dtype, labels_dtype):
10661066@pytest .mark .parametrize ("func" , ALL_FUNCS )
10671067@pytest .mark .parametrize ("axis" , (- 1 , None ))
10681068@pytest .mark .parametrize ("method" , ["blockwise" , "cohorts" , "map-reduce" ])
1069- def test_cohorts_nd_by (func , method , axis , engine ):
1069+ @pytest .mark .parametrize ("by_is_dask" , [True , False ])
1070+ def test_cohorts_nd_by (by_is_dask , func , method , axis , engine ):
10701071 if (
10711072 ("arg" in func and (axis is None or engine in ["flox" , "numbagg" ]))
10721073 or (method != "blockwise" and func in BLOCKWISE_FUNCS )
@@ -1080,10 +1081,12 @@ def test_cohorts_nd_by(func, method, axis, engine):
10801081 o2 = dask .array .ones ((2 , 3 ), chunks = - 1 )
10811082
10821083 array = dask .array .block ([[o , 2 * o ], [3 * o2 , 4 * o2 ]])
1083- by = array .compute (). astype (np .int64 )
1084+ by = array .astype (np .int64 )
10841085 by [0 , 1 ] = 30
10851086 by [2 , 1 ] = 40
10861087 by [0 , 4 ] = 31
1088+ if not by_is_dask :
1089+ by = by .compute ()
10871090 array = np .broadcast_to (array , (2 , 3 ) + array .shape )
10881091
10891092 if func in ["any" , "all" ]:
@@ -1099,10 +1102,19 @@ def test_cohorts_nd_by(func, method, axis, engine):
10991102 assert_equal (groups , sorted_groups )
11001103 assert_equal (actual , expected )
11011104
1102- actual , groups = groupby_reduce (array , by , sort = False , ** kwargs )
1103- assert_equal (groups , np .array ([1 , 30 , 2 , 31 , 3 , 4 , 40 ], dtype = np .int64 ))
1104- reindexed = reindex_ (actual , groups , pd .Index (sorted_groups ))
1105- assert_equal (reindexed , expected )
1105+ if isinstance (by , dask .array .Array ):
1106+ cache .clear ()
1107+ actual_cohorts = find_group_cohorts (by , array .chunks [- by .ndim :])
1108+ expected_cohorts = find_group_cohorts (by .compute (), array .chunks [- by .ndim :])
1109+ assert actual_cohorts == expected_cohorts
1110+ # assert cache.nbytes
1111+
1112+ if not isinstance (by , dask .array .Array ):
1113+ # Always sorting groups with cohorts and dask array
1114+ actual , groups = groupby_reduce (array , by , sort = False , ** kwargs )
1115+ assert_equal (groups , np .array ([1 , 30 , 2 , 31 , 3 , 4 , 40 ], dtype = np .int64 ))
1116+ reindexed = reindex_ (actual , groups , pd .Index (sorted_groups ))
1117+ assert_equal (reindexed , expected )
11061118
11071119
11081120@pytest .mark .parametrize ("func" , ["sum" , "count" ])
0 commit comments