@@ -208,15 +208,20 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
208208 # 1. First subset the array appropriately
209209 axis = range (- labels .ndim , 0 )
210210 # Easier to create a dask array and use the .blocks property
211- array = dask .array .ones (tuple (sum (c ) for c in chunks ), chunks = chunks )
211+ array = dask .array .empty (tuple (sum (c ) for c in chunks ), chunks = chunks )
212212 labels = np .broadcast_to (labels , array .shape [- labels .ndim :])
213213
214214 # Iterate over each block and create a new block of same shape with "chunk number"
215215 shape = tuple (array .blocks .shape [ax ] for ax in axis )
216- blocks = np .empty (math .prod (shape ), dtype = object )
217- for idx , block in enumerate (array .blocks .ravel ()):
218- blocks [idx ] = np .full (tuple (block .shape [ax ] for ax in axis ), idx )
219- which_chunk = np .block (blocks .reshape (shape ).tolist ()).reshape (- 1 )
216+ # Use a numpy object array to enable assignment in the loop
217+ # TODO: is it possible to just use a nested list?
218+ # That is what we need for `np.block`
219+ blocks = np .empty (shape , dtype = object )
220+ array_chunks = tuple (np .array (c ) for c in array .chunks )
221+ for idx , blockindex in enumerate (np .ndindex (array .numblocks )):
222+ chunkshape = tuple (c [i ] for c , i in zip (array_chunks , blockindex ))
223+ blocks [blockindex ] = np .full (chunkshape , idx )
224+ which_chunk = np .block (blocks .tolist ()).reshape (- 1 )
220225
221226 raveled = labels .reshape (- 1 )
222227 # these are chunks where a label is present
@@ -229,7 +234,11 @@ def invert(x) -> tuple[np.ndarray, ...]:
229234
230235 chunks_cohorts = tlz .groupby (invert , label_chunks .keys ())
231236
232- if merge :
237+ # If our dataset has chunksize one along the axis,
238+ # then no merging is possible.
239+ single_chunks = all ((ac == 1 ).all () for ac in array_chunks )
240+
241+ if merge and not single_chunks :
233242 # First sort by number of chunks occupied by cohort
234243 sorted_chunks_cohorts = dict (
235244 sorted (chunks_cohorts .items (), key = lambda kv : len (kv [0 ]), reverse = True )
0 commit comments