From bda9024196fefc07be91dcb48c6515161c77ce89 Mon Sep 17 00:00:00 2001 From: Jan Zavadil Date: Mon, 3 Mar 2025 10:17:57 +0100 Subject: [PATCH] Fixed macro averaging in recall --- CollaborativeCoding/metrics/recall.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CollaborativeCoding/metrics/recall.py b/CollaborativeCoding/metrics/recall.py index 385f974..0443f8e 100644 --- a/CollaborativeCoding/metrics/recall.py +++ b/CollaborativeCoding/metrics/recall.py @@ -79,7 +79,7 @@ def __compute_macro_averaging(self, y_true, y_pred): recall = 0 for i in range(self.num_classes): tp = (y_true[:, i] * y_pred[:, i]).sum() - fn = torch.sum(~y_pred[y_true[:, i].bool()].bool()) + fn = (y_true[:, i] * (1 - y_pred[:, i])).sum() recall += tp / (tp + fn) recall /= self.num_classes @@ -87,7 +87,7 @@ def __compute_macro_averaging(self, y_true, y_pred): def __compute_micro_averaging(self, y_true, y_pred): true_positives = (y_true * y_pred).sum() - false_negatives = torch.sum(~y_pred[y_true.bool()].bool()) + false_negatives = (y_true * (1 - y_pred)).sum() recall = true_positives / (true_positives + false_negatives) return recall