Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions bindsnet/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,37 +27,47 @@ def assign_labels(
:return: Tuple of class assignments, per-class spike proportions, and per-class
firing rates.
"""

n_neurons = spikes.size(2)

if rates is None:
rates = torch.zeros((n_neurons, n_labels), device=spikes.device)

# Sum over time dimension (spike ordering doesn't matter).
spikes = spikes.sum(1)

for i in range(n_labels):
# Create mask.
mask = (labels == i)
# Count the number of samples with this label.
n_labeled = torch.sum(labels == i).float()
n_labeled = mask.sum().float()

if n_labeled > 0:
# Get indices of samples with this label.
indices = torch.nonzero(labels == i).view(-1)

# Compute average firing rates for this label.
selected_spikes = torch.index_select(
spikes, dim=0, index=torch.tensor(indices)
)
rates[:, i] = alpha * rates[:, i] + (
torch.sum(selected_spikes, 0) / n_labeled
)
label_sum = spikes[mask].sum(0)
# Update rates.
rates[:, i] = alpha * rates[:, i] + (label_sum / n_labeled)

# Compute proportions of spike activity per class.
proportions = rates / rates.sum(1, keepdim=True)
proportions[proportions != proportions] = 0 # Set NaNs to 0
total_activity = rates.sum(1, keepdim=True)
proportions = torch.where(total_activity > 0, rates / total_activity, torch.zeros_like(rates))

# Noise for random tie breaking.
eps = 1e-6 # Small enough not to distort real decisions
noise = eps * torch.randn_like(proportions)

# Neuron assignments are the labels they fire most for.
assignments = torch.max(proportions, 1)[1]
assignments = torch.argmax(proportions + noise, dim=1)

# Uniform assignment for silent neurons
silent_mask = total_activity.squeeze() == 0
n_silent = silent_mask.sum()

if n_silent > 0:
assignments[silent_mask] = torch.randint(
0, n_labels, (n_silent,), device=spikes.device
)

return assignments, proportions, rates


Expand Down