Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions CollaborativeCoding/metrics/F1.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import torch
import torch.nn as nn

Expand Down Expand Up @@ -52,13 +53,14 @@ def __init__(self, num_classes, macro_averaging=False):

self.num_classes = num_classes
self.macro_averaging = macro_averaging

self.y_true = []
self.y_pred = []
# Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN)
self.tp = torch.zeros(num_classes)
self.fp = torch.zeros(num_classes)
self.fn = torch.zeros(num_classes)

def _micro_F1(self):
def _micro_F1(self, target, preds):
"""
Compute the Micro F1 score by aggregating TP, FP, and FN across all classes.

Expand All @@ -69,6 +71,11 @@ def _micro_F1(self):
torch.Tensor
The micro-averaged F1 score.
"""
for i in range(self.num_classes):
self.tp[i] += torch.sum((preds == i) & (target == i)).float()
self.fp[i] += torch.sum((preds == i) & (target != i)).float()
self.fn[i] += torch.sum((preds != i) & (target == i)).float()

tp = torch.sum(self.tp)
fp = torch.sum(self.fp)
fn = torch.sum(self.fn)
Expand All @@ -81,7 +88,7 @@ def _micro_F1(self):
) # Avoid division by zero
return f1

def _macro_F1(self):
def _macro_F1(self, target, preds):
"""
Compute the Macro F1 score by calculating the F1 score per class and averaging.

Expand All @@ -93,6 +100,12 @@ def _macro_F1(self):
torch.Tensor
The macro-averaged F1 score.
"""
# Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class
for i in range(self.num_classes):
self.tp[i] += torch.sum((preds == i) & (target == i)).float()
self.fp[i] += torch.sum((preds == i) & (target != i)).float()
self.fn[i] += torch.sum((preds != i) & (target == i)).float()

precision_per_class = self.tp / (
self.tp + self.fp + 1e-8
) # Avoid division by zero
Expand Down Expand Up @@ -133,18 +146,24 @@ def forward(self, preds, target):
The computed F1 score (either micro or macro, based on `macro_averaging`).
"""
preds = torch.argmax(preds, dim=-1)
self.y_true.append(target)
self.y_pred.append(preds)

def __returnmetric__(self):
if self.y_true == [] or self.y_pred == []:
return np.nan
if isinstance(self.y_true, list):
if len(self.y_true) == 1:
self.y_true = self.y_true[0]
self.y_pred = self.y_pred[0]
else:
self.y_true = torch.cat(self.y_true)
self.y_pred = torch.cat(self.y_pred)
return self._micro_F1(self.y_true, self.y_pred) if not self.macro_averaging else self._macro_F1(self.y_true, self.y_pred)

def __reset__(self):
self.y_true = []
self.y_pred = []
return None

# Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class
for i in range(self.num_classes):
self.tp[i] += torch.sum((preds == i) & (target == i)).float()
self.fp[i] += torch.sum((preds == i) & (target != i)).float()
self.fn[i] += torch.sum((preds != i) & (target == i)).float()

if self.macro_averaging:
# Calculate Macro F1 score
f1_score = self._macro_F1()
else:
# Calculate Micro F1 score
f1_score = self._micro_F1()

return f1_score
33 changes: 24 additions & 9 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,32 @@ def test_recall():
def test_f1score():
import torch

f1_metric = F1Score(num_classes=3)
preds = torch.tensor(
[[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.3, 0.5], [0.1, 0.2, 0.7]]
)
# Example case with known output
y_true = torch.tensor([0, 1, 2, 2, 1, 0]) # True labels
y_pred = torch.tensor([0, 1, 1, 2, 0, 0]) # Predicted labels

# Create F1Score object for micro and macro averaging
f1_micro = F1Score(num_classes=3, macro_averaging=False)
f1_macro = F1Score(num_classes=3, macro_averaging=True)

# Update F1 score with predictions
f1_micro(y_true, y_pred)
f1_macro(y_true, y_pred)

# Get F1 scores
micro_f1_score = f1_micro.__returnmetric__()
macro_f1_score = f1_macro.__returnmetric__()

# Check if outputs are tensors
assert isinstance(micro_f1_score, torch.Tensor), "Micro F1 score should be a tensor."
assert isinstance(macro_f1_score, torch.Tensor), "Macro F1 score should be a tensor."

target = torch.tensor([0, 1, 0, 2])
# Check that F1 scores are between 0 and 1
assert 0 <= micro_f1_score.item() <= 1, "Micro F1 score should be between 0 and 1."
assert 0 <= macro_f1_score.item() <= 1, "Macro F1 score should be between 0 and 1."

f1_metric(preds, target)
assert f1_metric.tp.sum().item() > 0, "Expected some true positives."
assert f1_metric.fp.sum().item() > 0, "Expected some false positives."
assert f1_metric.fn.sum().item() > 0, "Expected some false negatives."
print(f"Micro F1 Score: {micro_f1_score.item()}")
print(f"Macro F1 Score: {macro_f1_score.item()}")


def test_precision():
Expand Down
Loading