@@ -1598,7 +1598,11 @@ def _lazy_factorize_wrapper(*by: T_By, **kwargs) -> np.ndarray:
15981598
15991599
16001600def _factorize_multiple (
1601- by : T_Bys , expected_groups : T_ExpectIndexTuple , any_by_dask : bool , reindex : bool
1601+ by : T_Bys ,
1602+ expected_groups : T_ExpectIndexTuple ,
1603+ any_by_dask : bool ,
1604+ reindex : bool ,
1605+ sort : bool = True ,
16021606) -> tuple [tuple [np .ndarray ], tuple [np .ndarray , ...], tuple [int , ...]]:
16031607 if any_by_dask :
16041608 import dask .array
@@ -1617,6 +1621,7 @@ def _factorize_multiple(
16171621 expected_groups = expected_groups ,
16181622 fastpath = True ,
16191623 reindex = reindex ,
1624+ sort = sort ,
16201625 )
16211626
16221627 fg , gs = [], []
@@ -1643,6 +1648,7 @@ def _factorize_multiple(
16431648 expected_groups = expected_groups ,
16441649 fastpath = True ,
16451650 reindex = reindex ,
1651+ sort = sort ,
16461652 )
16471653
16481654 return (group_idx ,), found_groups , grp_shape
@@ -1833,21 +1839,28 @@ def groupby_reduce(
18331839 # (pd.IntervalIndex or not)
18341840 expected_groups = _convert_expected_groups_to_index (expected_groups , isbins , sort )
18351841
1836- is_binning = any ([ isinstance ( e , pd . IntervalIndex ) for e in expected_groups ])
1837-
1838- # TODO: could restrict this to dask-only
1839- factorize_early = ( nby > 1 ) or (
1840- is_binning and method == "cohorts" and is_duck_dask_array ( array )
1842+ # Don't factorize "early only when
1843+ # grouping by dask arrays, and not having expected_groups
1844+ factorize_early = not (
1845+ # can't do it if we are grouping by dask array but don't have expected_groups
1846+ any ( is_dask and ex_ is None for is_dask , ex_ in zip ( by_is_dask , expected_groups ) )
18411847 )
18421848 if factorize_early :
18431849 bys , final_groups , grp_shape = _factorize_multiple (
1844- bys , expected_groups , any_by_dask = any_by_dask , reindex = reindex
1850+ bys ,
1851+ expected_groups ,
1852+ any_by_dask = any_by_dask ,
1853+ # This is the only way it makes sense I think.
1854+ # reindex controls what's actually allocated in chunk_reduce
1855+ # At this point, we care about an accurate conversion to codes.
1856+ reindex = True ,
1857+ sort = sort ,
18451858 )
18461859 expected_groups = (pd .RangeIndex (math .prod (grp_shape )),)
18471860
18481861 assert len (bys ) == 1
1849- by_ = bys [ 0 ]
1850- expected_groups = expected_groups [ 0 ]
1862+ ( by_ ,) = bys
1863+ ( expected_groups ,) = expected_groups
18511864
18521865 if axis is None :
18531866 axis_ = tuple (array .ndim + np .arange (- by_ .ndim , 0 ))
0 commit comments