@@ -329,7 +329,7 @@ def chunk_unique(labels, slicer, nlabels, label_is_present=None):
329329 rows_array = np .repeat (np .arange (nchunks ), tuple (len (col ) for col in cols ))
330330 cols_array = np .concatenate (cols )
331331
332- return make_bitmask (rows_array , cols_array )
332+ return make_bitmask (rows_array , cols_array ), nlabels , ilabels
333333
334334
335335# @memoize
@@ -362,8 +362,17 @@ def find_group_cohorts(
362362 cohorts: dict_values
363363 Iterable of cohorts
364364 """
365- # To do this, we must have values in memory so casting to numpy should be safe
366- labels = np .asarray (labels )
365+ if not is_duck_array (labels ):
366+ labels = np .asarray (labels )
367+
368+ if is_duck_dask_array (labels ):
369+ import dask
370+
371+ ((bitmask , nlabels , ilabels ),) = dask .compute (
372+ dask .delayed (_compute_label_chunk_bitmask )(labels , chunks )
373+ )
374+ else :
375+ bitmask , nlabels , ilabels = _compute_label_chunk_bitmask (labels , chunks )
367376
368377 shape = tuple (sum (c ) for c in chunks )
369378 nchunks = math .prod (len (c ) for c in chunks )
@@ -2409,9 +2418,6 @@ def groupby_reduce(
24092418 "Try engine='numpy' or engine='numba' instead."
24102419 )
24112420
2412- if method == "cohorts" and any_by_dask :
2413- raise ValueError (f"method={ method !r} can only be used when grouping by numpy arrays." )
2414-
24152421 reindex = _validate_reindex (
24162422 reindex , func , method , expected_groups , any_by_dask , is_duck_dask_array (array )
24172423 )
@@ -2443,6 +2449,12 @@ def groupby_reduce(
24432449 # can't do it if we are grouping by dask array but don't have expected_groups
24442450 any (is_dask and ex_ is None for is_dask , ex_ in zip (by_is_dask , expected_groups ))
24452451 )
2452+
2453+ if method == "cohorts" and not factorize_early :
2454+ raise ValueError (
2455+ "method='cohorts' can only be used when grouping by dask arrays if `expected_groups` is provided."
2456+ )
2457+
24462458 expected_ : pd .RangeIndex | None
24472459 if factorize_early :
24482460 bys , final_groups , grp_shape = _factorize_multiple (
0 commit comments