@@ -403,7 +403,7 @@ def find_group_cohorts(
403403 # Invert the label_chunks mapping so we know which labels occur together.
404404 def invert (x ) -> tuple [np .ndarray , ...]:
405405 arr = label_chunks [x ]
406- return tuple (arr )
406+ return tuple (arr . tolist () )
407407
408408 chunks_cohorts = tlz .groupby (invert , label_chunks .keys ())
409409
@@ -477,22 +477,37 @@ def invert(x) -> tuple[np.ndarray, ...]:
477477 containment .nnz / math .prod (containment .shape )
478478 )
479479 )
480- # Use a threshold to force some merging. We do not use the filtered
481- # containment matrix for estimating "sparsity" because it is a bit
482- # hard to reason about.
480+
481+ # Next we for-loop over groups and merge those that are quite similar.
482+ # Use a threshold on containment to always force some merging.
483+ # Note that we do not use the filtered containment matrix for estimating "sparsity"
484+ # because it is a bit hard to reason about.
483485 MIN_CONTAINMENT = 0.75 # arbitrary
484486 mask = containment .data < MIN_CONTAINMENT
487+
488+ # Now we also know "exact cohorts" -- cohorts whose constituent groups
489+ # occur in exactly the same chunks. We only need examine one member of each group.
490+ # Skip the others by first looping over the exact cohorts, and zero out those rows.
491+ repeated = np .concatenate ([v [1 :] for v in chunks_cohorts .values ()]).astype (int )
492+ repeated_idx = np .searchsorted (present_labels , repeated )
493+ for i in repeated_idx :
494+ mask [containment .indptr [i ] : containment .indptr [i + 1 ]] = True
485495 containment .data [mask ] = 0
486496 containment .eliminate_zeros ()
487497
488- # Iterate over labels, beginning with those with most chunks
498+ # Figure out all the labels we need to loop over later
499+ n_overlapping_labels = containment .astype (bool ).sum (axis = 1 )
500+ order = np .argsort (n_overlapping_labels , kind = "stable" )[::- 1 ]
501+ # Order is such that we iterate over labels, beginning with those with most overlaps
502+ # Also filter out any "exact" cohorts
503+ order = order [n_overlapping_labels [order ] > 0 ]
504+
489505 logger .debug ("find_group_cohorts: merging cohorts" )
490- order = np .argsort (containment .sum (axis = LABEL_AXIS ), kind = "stable" )[::- 1 ]
491506 merged_cohorts = {}
492507 merged_keys = set ()
493- # TODO: we can optimize this to loop over chunk_cohorts instead
494- # by zeroing out rows that are already in a cohort
495508 for rowidx in order :
509+ if present_labels [rowidx ] in merged_keys :
510+ continue
496511 cohidx = containment .indices [
497512 slice (containment .indptr [rowidx ], containment .indptr [rowidx + 1 ])
498513 ]
@@ -507,6 +522,7 @@ def invert(x) -> tuple[np.ndarray, ...]:
507522
508523 actual_ngroups = np .concatenate (tuple (merged_cohorts .values ())).size
509524 expected_ngroups = present_labels .size
525+ assert len (merged_keys ) == actual_ngroups
510526 assert expected_ngroups == actual_ngroups , (expected_ngroups , actual_ngroups )
511527
512528 # sort by first label in cohort
0 commit comments