99from collections import namedtuple
1010from collections .abc import Sequence
1111from functools import partial , reduce
12+ from itertools import product
1213from numbers import Integral
1314from typing import (
1415 TYPE_CHECKING ,
2324import numpy_groupies as npg
2425import pandas as pd
2526import toolz as tlz
27+ from scipy .sparse import csc_array
2628
2729from . import xrdtypes
2830from .aggregate_flox import _prepare_for_flox
@@ -203,6 +205,16 @@ def _unique(a: np.ndarray) -> np.ndarray:
203205 return np .sort (pd .unique (a .reshape (- 1 )))
204206
205207
208+ def slices_from_chunks (chunks ):
209+ """slightly modified from dask.array.core.slices_from_chunks to be lazy"""
210+ cumdims = [tlz .accumulate (operator .add , bds , 0 ) for bds in chunks ]
211+ slices = (
212+ (slice (s , s + dim ) for s , dim in zip (starts , shapes ))
213+ for starts , shapes in zip (cumdims , chunks )
214+ )
215+ return product (* slices )
216+
217+
206218@memoize
207219def find_group_cohorts (labels , chunks , merge : bool = True ) -> dict :
208220 """
@@ -215,9 +227,10 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
215227 Parameters
216228 ----------
217229 labels : np.ndarray
218- mD Array of group labels
230+ mD Array of integer group codes, factorized so that -1
231+ represents NaNs.
219232 chunks : tuple
220- nD array that is being reduced
233+ chunks of the array being reduced
221234 merge : bool, optional
222235 Attempt to merge cohorts when one cohort's chunks are a subset
223236 of another cohort's chunks.
@@ -227,33 +240,59 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
227240 cohorts: dict_values
228241 Iterable of cohorts
229242 """
230- import dask
231-
232243 # To do this, we must have values in memory so casting to numpy should be safe
233244 labels = np .asarray (labels )
234245
235- # Build an array with the shape of labels, but where every element is the "chunk number"
236- # 1. First subset the array appropriately
237- axis = range (- labels .ndim , 0 )
238- # Easier to create a dask array and use the .blocks property
239- array = dask .array .empty (tuple (sum (c ) for c in chunks ), chunks = chunks )
240- labels = np .broadcast_to (labels , array .shape [- labels .ndim :])
241-
242- # Iterate over each block and create a new block of same shape with "chunk number"
243- shape = tuple (array .blocks .shape [ax ] for ax in axis )
244- # Use a numpy object array to enable assignment in the loop
245- # TODO: is it possible to just use a nested list?
246- # That is what we need for `np.block`
247- blocks = np .empty (shape , dtype = object )
248- array_chunks = tuple (np .array (c ) for c in array .chunks )
249- for idx , blockindex in enumerate (np .ndindex (array .numblocks )):
250- chunkshape = tuple (c [i ] for c , i in zip (array_chunks , blockindex ))
251- blocks [blockindex ] = np .full (chunkshape , idx )
252- which_chunk = np .block (blocks .tolist ()).reshape (- 1 )
253-
254- raveled = labels .reshape (- 1 )
255- # these are chunks where a label is present
256- label_chunks = pd .Series (which_chunk ).groupby (raveled ).unique ()
246+ shape = tuple (sum (c ) for c in chunks )
247+ nchunks = math .prod (len (c ) for c in chunks )
248+
249+ # assumes that `labels` are factorized
250+ nlabels = labels .max () + 1
251+
252+ labels = np .broadcast_to (labels , shape [- labels .ndim :])
253+
254+ rows = []
255+ cols = []
256+ # Add one to handle the -1 sentinel value
257+ label_is_present = np .zeros ((nlabels + 1 ,), dtype = bool )
258+ ilabels = np .arange (nlabels )
259+ for idx , region in enumerate (slices_from_chunks (chunks )):
260+ # This is a quite fast way to find unique integers, when we know how many there are
261+ # inspired by a similar idea in numpy_groupies for first, last
262+ # instead of explicitly finding uniques, repeatedly write True to the same location
263+ subset = labels [region ]
264+ # The reshape is not strictly necessary but is about 100ms faster on a test problem.
265+ label_is_present [subset .reshape (- 1 )] = True
266+ # skip the -1 sentinel by slicing
267+ uniques = ilabels [label_is_present [:- 1 ]]
268+ rows .append ([idx ] * len (uniques ))
269+ cols .append (uniques )
270+ label_is_present [:] = False
271+ rows_array = np .concatenate (rows )
272+ cols_array = np .concatenate (cols )
273+ data = np .broadcast_to (np .array (1 , dtype = np .uint8 ), rows_array .shape )
274+ bitmask = csc_array ((data , (rows_array , cols_array )), dtype = bool , shape = (nchunks , nlabels ))
275+ label_chunks = {
276+ lab : bitmask .indices [slice (bitmask .indptr [lab ], bitmask .indptr [lab + 1 ])]
277+ for lab in range (nlabels )
278+ }
279+
280+ ## numpy bitmask approach, faster than finding uniques, but lots of memory
281+ # bitmask = np.zeros((nchunks, nlabels), dtype=bool)
282+ # for idx, region in enumerate(slices_from_chunks(chunks)):
283+ # bitmask[idx, labels[region]] = True
284+ # bitmask = bitmask[:, :-1]
285+ # chunk = np.arange(nchunks) # [:, np.newaxis] * bitmask
286+ # label_chunks = {lab: chunk[bitmask[:, lab]] for lab in range(nlabels - 1)}
287+
288+ ## Pandas GroupBy approach, quite slow!
289+ # which_chunk = np.empty(shape, dtype=np.int64)
290+ # for idx, region in enumerate(slices_from_chunks(chunks)):
291+ # which_chunk[region] = idx
292+ # which_chunk = which_chunk.reshape(-1)
293+ # raveled = labels.reshape(-1)
294+ # # these are chunks where a label is present
295+ # label_chunks = pd.Series(which_chunk).groupby(raveled).unique()
257296
258297 # These invert the label_chunks mapping so we know which labels occur together.
259298 def invert (x ) -> tuple [np .ndarray , ...]:
@@ -264,33 +303,31 @@ def invert(x) -> tuple[np.ndarray, ...]:
264303
265304 # If our dataset has chunksize one along the axis,
266305 # then no merging is possible.
267- single_chunks = all (( ac == 1 ). all () for ac in array_chunks )
306+ single_chunks = all (all ( a == 1 for a in ac ) for ac in chunks )
268307
269- if merge and not single_chunks :
308+ if not single_chunks and merge :
270309 # First sort by number of chunks occupied by cohort
271310 sorted_chunks_cohorts = dict (
272311 sorted (chunks_cohorts .items (), key = lambda kv : len (kv [0 ]), reverse = True )
273312 )
274313
275- items = tuple (sorted_chunks_cohorts .items ())
314+ items = tuple (( k , set ( k ), v ) for k , v in sorted_chunks_cohorts .items () if k )
276315
277316 merged_cohorts = {}
278- merged_keys = []
317+ merged_keys = set ()
279318
280319 # Now we iterate starting with the longest number of chunks,
281320 # and then merge in cohorts that are present in a subset of those chunks
282321 # I think this is suboptimal and must fail at some point.
283322 # But it might work for most cases. There must be a better way...
284- for idx , (k1 , v1 ) in enumerate (items ):
323+ for idx , (k1 , set_k1 , v1 ) in enumerate (items ):
285324 if k1 in merged_keys :
286325 continue
287326 merged_cohorts [k1 ] = copy .deepcopy (v1 )
288- for k2 , v2 in items [idx + 1 :]:
289- if k2 in merged_keys :
290- continue
291- if set (k2 ).issubset (set (k1 )):
327+ for k2 , set_k2 , v2 in items [idx + 1 :]:
328+ if k2 not in merged_keys and set_k2 .issubset (set_k1 ):
292329 merged_cohorts [k1 ].extend (v2 )
293- merged_keys .append ( k2 )
330+ merged_keys .update (( k2 ,) )
294331
295332 # make sure each cohort is sorted after merging
296333 sorted_merged_cohorts = {k : sorted (v ) for k , v in merged_cohorts .items ()}
@@ -1373,7 +1410,6 @@ def dask_groupby_agg(
13731410
13741411 inds = tuple (range (array .ndim ))
13751412 name = f"groupby_{ agg .name } "
1376- token = dask .base .tokenize (array , by , agg , expected_groups , axis )
13771413
13781414 if expected_groups is None and reindex :
13791415 expected_groups = _get_expected_groups (by , sort = sort )
@@ -1394,6 +1430,9 @@ def dask_groupby_agg(
13941430 by = dask .array .from_array (by , chunks = chunks )
13951431 _ , (array , by ) = dask .array .unify_chunks (array , inds , by , inds [- by .ndim :])
13961432
1433+ # tokenize here since by has already been hashed if its numpy
1434+ token = dask .base .tokenize (array , by , agg , expected_groups , axis )
1435+
13971436 # preprocess the array:
13981437 # - for argreductions, this zips the index together with the array block
13991438 # - not necessary for blockwise with argreductions
@@ -1510,7 +1549,7 @@ def dask_groupby_agg(
15101549 index = pd .Index (cohort )
15111550 subset = subset_to_blocks (intermediate , blks , array .blocks .shape [- len (axis ) :])
15121551 reindexed = dask .array .map_blocks (
1513- reindex_intermediates , subset , agg = agg , unique_groups = index , meta = subset ._meta
1552+ reindex_intermediates , subset , agg , index , meta = subset ._meta
15141553 )
15151554 # now that we have reindexed, we can set reindex=True explicitlly
15161555 reduced_ .append (
0 commit comments