@@ -176,6 +176,7 @@ def find_group_cohorts(labels, chunks, merge: bool = True):
176176 axis = range (- labels .ndim , 0 )
177177 # Easier to create a dask array and use the .blocks property
178178 array = dask .array .ones (tuple (sum (c ) for c in chunks ), chunks = chunks )
179+ labels = np .broadcast_to (labels , array .shape [- labels .ndim :])
179180
180181 # Iterate over each block and create a new block of same shape with "chunk number"
181182 shape = tuple (array .blocks .shape [ax ] for ax in axis )
@@ -479,7 +480,7 @@ def factorize_(
479480 idx , groups = pd .factorize (flat , sort = sort )
480481
481482 found_groups .append (np .array (groups ))
482- factorized .append (idx )
483+ factorized .append (idx . reshape ( groupvar . shape ) )
483484
484485 grp_shape = tuple (len (grp ) for grp in found_groups )
485486 ngroups = math .prod (grp_shape )
@@ -489,20 +490,18 @@ def factorize_(
489490 # Restore these after the raveling
490491 nan_by_mask = reduce (np .logical_or , [(f == - 1 ) for f in factorized ])
491492 group_idx [nan_by_mask ] = - 1
492- group_idx = group_idx .reshape (by [0 ].shape )
493493 else :
494494 group_idx = factorized [0 ]
495495
496496 if fastpath :
497- return group_idx . reshape ( by [ 0 ]. shape ) , found_groups , grp_shape
497+ return group_idx , found_groups , grp_shape
498498
499499 if np .isscalar (axis ) and groupvar .ndim > 1 :
500500 # Not reducing along all dimensions of by
501501 # this is OK because for 3D by and axis=(1,2),
502502 # we collapse to a 2D by and axis=-1
503503 offset_group = True
504504 group_idx , size = offset_labels (group_idx .reshape (by [0 ].shape ), ngroups )
505- group_idx = group_idx .reshape (- 1 )
506505 else :
507506 size = ngroups
508507 offset_group = False
@@ -647,6 +646,8 @@ def chunk_reduce(
647646 else :
648647 nax = by .ndim
649648
649+ assert by .ndim <= array .ndim
650+
650651 final_array_shape = array .shape [:- nax ] + (1 ,) * (nax - 1 )
651652 final_groups_shape = (1 ,) * (nax - 1 )
652653
@@ -667,9 +668,17 @@ def chunk_reduce(
667668 )
668669 groups = groups [0 ]
669670
671+ if isinstance (axis , Sequence ):
672+ needs_broadcast = any (
673+ group_idx .shape [ax ] != array .shape [ax ] and group_idx .shape [ax ] == 1
674+ for ax in range (- len (axis ), 0 )
675+ )
676+ if needs_broadcast :
677+ group_idx = np .broadcast_to (group_idx , array .shape [- by .ndim :])
670678 # always reshape to 1D along group dimensions
671679 newshape = array .shape [: array .ndim - by .ndim ] + (math .prod (array .shape [- by .ndim :]),)
672680 array = array .reshape (newshape )
681+ group_idx = group_idx .reshape (- 1 )
673682
674683 assert group_idx .ndim == 1
675684 empty = np .all (props .nanmask )
@@ -1219,7 +1228,9 @@ def dask_groupby_agg(
12191228 # chunk numpy arrays like the input array
12201229 # This removes an extra rechunk-merge layer that would be
12211230 # added otherwise
1222- by = dask .array .from_array (by , chunks = tuple (array .chunks [ax ] for ax in range (- by .ndim , 0 )))
1231+ chunks = tuple (array .chunks [ax ] if by .shape [ax ] != 1 else (1 ,) for ax in range (- by .ndim , 0 ))
1232+
1233+ by = dask .array .from_array (by , chunks = chunks )
12231234 _ , (array , by ) = dask .array .unify_chunks (array , inds , by , inds [- by .ndim :])
12241235
12251236 # preprocess the array: for argreductions, this zips the index together with the array block
@@ -1429,7 +1440,7 @@ def _extract_result(result_dict: FinalResultsDict, key) -> np.ndarray:
14291440
14301441
14311442def _validate_reindex (
1432- reindex : bool | None , func , method : T_Method , expected_groups , by_is_dask : bool
1443+ reindex : bool | None , func , method : T_Method , expected_groups , any_by_dask : bool
14331444) -> bool :
14341445 if reindex is True :
14351446 if _is_arg_reduction (func ):
@@ -1447,7 +1458,7 @@ def _validate_reindex(
14471458 reindex = False
14481459
14491460 elif method == "map-reduce" :
1450- if expected_groups is None and by_is_dask :
1461+ if expected_groups is None and any_by_dask :
14511462 reindex = False
14521463 else :
14531464 reindex = True
@@ -1457,8 +1468,9 @@ def _validate_reindex(
14571468
14581469
14591470def _assert_by_is_aligned (shape , by ):
1471+ assert all (b .ndim == by [0 ].ndim for b in by [1 :])
14601472 for idx , b in enumerate (by ):
1461- if shape [- b .ndim :] != b .shape :
1473+ if not all ( j in [ i , 1 ] for i , j in zip ( shape [- b .ndim :], b .shape )) :
14621474 raise ValueError (
14631475 "`array` and `by` arrays must be aligned "
14641476 "i.e. array.shape[-by.ndim :] == by.shape. "
@@ -1495,26 +1507,34 @@ def _lazy_factorize_wrapper(*by, **kwargs):
14951507 return group_idx
14961508
14971509
1498- def _factorize_multiple (by , expected_groups , by_is_dask , reindex ):
1510+ def _factorize_multiple (by , expected_groups , any_by_dask , reindex ):
14991511 kwargs = dict (
15001512 expected_groups = expected_groups ,
15011513 axis = None , # always None, we offset later if necessary.
15021514 fastpath = True ,
15031515 reindex = reindex ,
15041516 )
1505- if by_is_dask :
1517+ if any_by_dask :
15061518 import dask .array
15071519
1520+ # unifying chunks will make sure all arrays in `by` are dask arrays
1521+ # with compatible chunks, even if there was originally a numpy array
1522+ inds = tuple (range (by [0 ].ndim ))
1523+ chunks , by_ = dask .array .unify_chunks (* itertools .chain (* zip (by , (inds ,) * len (by ))))
1524+
15081525 group_idx = dask .array .map_blocks (
15091526 _lazy_factorize_wrapper ,
1510- * np .broadcast_arrays (* by ),
1527+ * by_ ,
1528+ chunks = tuple (chunks .values ()),
15111529 meta = np .array ((), dtype = np .int64 ),
15121530 ** kwargs ,
15131531 )
15141532 found_groups = tuple (
15151533 None if is_duck_dask_array (b ) else pd .unique (b .reshape (- 1 )) for b in by
15161534 )
1517- grp_shape = tuple (len (e ) for e in expected_groups )
1535+ grp_shape = tuple (
1536+ len (e ) if e is not None else len (f ) for e , f in zip (expected_groups , found_groups )
1537+ )
15181538 else :
15191539 group_idx , found_groups , grp_shape = factorize_ (by , ** kwargs )
15201540
@@ -1644,15 +1664,16 @@ def groupby_reduce(
16441664
16451665 bys = tuple (np .asarray (b ) if not is_duck_array (b ) else b for b in by )
16461666 nby = len (bys )
1647- by_is_dask = any (is_duck_dask_array (b ) for b in bys )
1667+ by_is_dask = tuple (is_duck_dask_array (b ) for b in bys )
1668+ any_by_dask = any (by_is_dask )
16481669
1649- if method in ["split-reduce" , "cohorts" ] and by_is_dask :
1670+ if method in ["split-reduce" , "cohorts" ] and any_by_dask :
16501671 raise ValueError (f"method={ method !r} can only be used when grouping by numpy arrays." )
16511672
16521673 if method == "split-reduce" :
16531674 method = "cohorts"
16541675
1655- reindex = _validate_reindex (reindex , func , method , expected_groups , by_is_dask )
1676+ reindex = _validate_reindex (reindex , func , method , expected_groups , any_by_dask )
16561677
16571678 if not is_duck_array (array ):
16581679 array = np .asarray (array )
@@ -1667,6 +1688,11 @@ def groupby_reduce(
16671688 expected_groups = (None ,) * nby
16681689
16691690 _assert_by_is_aligned (array .shape , bys )
1691+ for idx , (expect , is_dask ) in enumerate (zip (expected_groups , by_is_dask )):
1692+ if is_dask and (reindex or nby > 1 ) and expect is None :
1693+ raise ValueError (
1694+ f"`expected_groups` for array { idx } in `by` cannot be None since it is a dask.array."
1695+ )
16701696
16711697 if nby == 1 and not isinstance (expected_groups , tuple ):
16721698 expected_groups = (np .asarray (expected_groups ),)
@@ -1686,7 +1712,7 @@ def groupby_reduce(
16861712 )
16871713 if factorize_early :
16881714 bys , final_groups , grp_shape = _factorize_multiple (
1689- bys , expected_groups , by_is_dask = by_is_dask , reindex = reindex
1715+ bys , expected_groups , any_by_dask = any_by_dask , reindex = reindex
16901716 )
16911717 expected_groups = (pd .RangeIndex (math .prod (grp_shape )),)
16921718
@@ -1709,7 +1735,7 @@ def groupby_reduce(
17091735
17101736 # TODO: make sure expected_groups is unique
17111737 if nax == 1 and by_ .ndim > 1 and expected_groups is None :
1712- if not by_is_dask :
1738+ if not any_by_dask :
17131739 expected_groups = _get_expected_groups (by_ , sort )
17141740 else :
17151741 # When we reduce along all axes, we are guaranteed to see all
0 commit comments