diff --git a/bindsnet/evaluation/evaluation.py b/bindsnet/evaluation/evaluation.py index cd4fe2a6..39f6ae72 100644 --- a/bindsnet/evaluation/evaluation.py +++ b/bindsnet/evaluation/evaluation.py @@ -27,6 +27,7 @@ def assign_labels( :return: Tuple of class assignments, per-class spike proportions, and per-class firing rates. """ + n_neurons = spikes.size(2) if rates is None: @@ -34,30 +35,39 @@ def assign_labels( # Sum over time dimension (spike ordering doesn't matter). spikes = spikes.sum(1) - + for i in range(n_labels): + # Create mask. + mask = (labels == i) # Count the number of samples with this label. - n_labeled = torch.sum(labels == i).float() + n_labeled = mask.sum().float() if n_labeled > 0: # Get indices of samples with this label. - indices = torch.nonzero(labels == i).view(-1) - - # Compute average firing rates for this label. - selected_spikes = torch.index_select( - spikes, dim=0, index=torch.tensor(indices) - ) - rates[:, i] = alpha * rates[:, i] + ( - torch.sum(selected_spikes, 0) / n_labeled - ) + label_sum = spikes[mask].sum(0) + # Update rates. + rates[:, i] = alpha * rates[:, i] + (label_sum / n_labeled) # Compute proportions of spike activity per class. - proportions = rates / rates.sum(1, keepdim=True) - proportions[proportions != proportions] = 0 # Set NaNs to 0 + total_activity = rates.sum(1, keepdim=True) + proportions = torch.where(total_activity > 0, rates / total_activity, torch.zeros_like(rates)) + # Noise for random tie breaking. + eps = 1e-6 # Small enough not to distort real decisions + noise = eps * torch.randn_like(proportions) + # Neuron assignments are the labels they fire most for. - assignments = torch.max(proportions, 1)[1] + assignments = torch.argmax(proportions + noise, dim=1) + + # Uniform assignment for silent neurons + silent_mask = total_activity.squeeze() == 0 + n_silent = silent_mask.sum() + if n_silent > 0: + assignments[silent_mask] = torch.randint( + 0, n_labels, (n_silent,), device=spikes.device + ) + return assignments, proportions, rates