diff --git a/CollaborativeCoding/metrics/F1.py b/CollaborativeCoding/metrics/F1.py index 0833389..33c0a4d 100644 --- a/CollaborativeCoding/metrics/F1.py +++ b/CollaborativeCoding/metrics/F1.py @@ -1,3 +1,4 @@ +import numpy as np import torch import torch.nn as nn @@ -52,13 +53,14 @@ def __init__(self, num_classes, macro_averaging=False): self.num_classes = num_classes self.macro_averaging = macro_averaging - + self.y_true = [] + self.y_pred = [] # Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN) self.tp = torch.zeros(num_classes) self.fp = torch.zeros(num_classes) self.fn = torch.zeros(num_classes) - def _micro_F1(self): + def _micro_F1(self, target, preds): """ Compute the Micro F1 score by aggregating TP, FP, and FN across all classes. @@ -69,6 +71,11 @@ def _micro_F1(self): torch.Tensor The micro-averaged F1 score. """ + for i in range(self.num_classes): + self.tp[i] += torch.sum((preds == i) & (target == i)).float() + self.fp[i] += torch.sum((preds == i) & (target != i)).float() + self.fn[i] += torch.sum((preds != i) & (target == i)).float() + tp = torch.sum(self.tp) fp = torch.sum(self.fp) fn = torch.sum(self.fn) @@ -81,7 +88,7 @@ def _micro_F1(self): ) # Avoid division by zero return f1 - def _macro_F1(self): + def _macro_F1(self, target, preds): """ Compute the Macro F1 score by calculating the F1 score per class and averaging. @@ -93,6 +100,12 @@ def _macro_F1(self): torch.Tensor The macro-averaged F1 score. """ + # Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class + for i in range(self.num_classes): + self.tp[i] += torch.sum((preds == i) & (target == i)).float() + self.fp[i] += torch.sum((preds == i) & (target != i)).float() + self.fn[i] += torch.sum((preds != i) & (target == i)).float() + precision_per_class = self.tp / ( self.tp + self.fp + 1e-8 ) # Avoid division by zero @@ -133,18 +146,24 @@ def forward(self, preds, target): The computed F1 score (either micro or macro, based on `macro_averaging`). """ preds = torch.argmax(preds, dim=-1) + self.y_true.append(target) + self.y_pred.append(preds) + + def __returnmetric__(self): + if self.y_true == [] or self.y_pred == []: + return np.nan + if isinstance(self.y_true, list): + if len(self.y_true) == 1: + self.y_true = self.y_true[0] + self.y_pred = self.y_pred[0] + else: + self.y_true = torch.cat(self.y_true) + self.y_pred = torch.cat(self.y_pred) + return self._micro_F1(self.y_true, self.y_pred) if not self.macro_averaging else self._macro_F1(self.y_true, self.y_pred) + + def __reset__(self): + self.y_true = [] + self.y_pred = [] + return None - # Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class - for i in range(self.num_classes): - self.tp[i] += torch.sum((preds == i) & (target == i)).float() - self.fp[i] += torch.sum((preds == i) & (target != i)).float() - self.fn[i] += torch.sum((preds != i) & (target == i)).float() - if self.macro_averaging: - # Calculate Macro F1 score - f1_score = self._macro_F1() - else: - # Calculate Micro F1 score - f1_score = self._micro_F1() - - return f1_score diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 67db356..25b9974 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -71,17 +71,32 @@ def test_recall(): def test_f1score(): import torch - f1_metric = F1Score(num_classes=3) - preds = torch.tensor( - [[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.3, 0.5], [0.1, 0.2, 0.7]] - ) + # Example case with known output + y_true = torch.tensor([0, 1, 2, 2, 1, 0]) # True labels + y_pred = torch.tensor([0, 1, 1, 2, 0, 0]) # Predicted labels + + # Create F1Score object for micro and macro averaging + f1_micro = F1Score(num_classes=3, macro_averaging=False) + f1_macro = F1Score(num_classes=3, macro_averaging=True) + + # Update F1 score with predictions + f1_micro(y_true, y_pred) + f1_macro(y_true, y_pred) + + # Get F1 scores + micro_f1_score = f1_micro.__returnmetric__() + macro_f1_score = f1_macro.__returnmetric__() + + # Check if outputs are tensors + assert isinstance(micro_f1_score, torch.Tensor), "Micro F1 score should be a tensor." + assert isinstance(macro_f1_score, torch.Tensor), "Macro F1 score should be a tensor." - target = torch.tensor([0, 1, 0, 2]) + # Check that F1 scores are between 0 and 1 + assert 0 <= micro_f1_score.item() <= 1, "Micro F1 score should be between 0 and 1." + assert 0 <= macro_f1_score.item() <= 1, "Macro F1 score should be between 0 and 1." - f1_metric(preds, target) - assert f1_metric.tp.sum().item() > 0, "Expected some true positives." - assert f1_metric.fp.sum().item() > 0, "Expected some false positives." - assert f1_metric.fn.sum().item() > 0, "Expected some false negatives." + print(f"Micro F1 Score: {micro_f1_score.item()}") + print(f"Macro F1 Score: {macro_f1_score.item()}") def test_precision():