diff --git a/CollaborativeCoding/metrics/recall.py b/CollaborativeCoding/metrics/recall.py index 385f974..0443f8e 100644 --- a/CollaborativeCoding/metrics/recall.py +++ b/CollaborativeCoding/metrics/recall.py @@ -79,7 +79,7 @@ def __compute_macro_averaging(self, y_true, y_pred): recall = 0 for i in range(self.num_classes): tp = (y_true[:, i] * y_pred[:, i]).sum() - fn = torch.sum(~y_pred[y_true[:, i].bool()].bool()) + fn = (y_true[:, i] * (1 - y_pred[:, i])).sum() recall += tp / (tp + fn) recall /= self.num_classes @@ -87,7 +87,7 @@ def __compute_macro_averaging(self, y_true, y_pred): def __compute_micro_averaging(self, y_true, y_pred): true_positives = (y_true * y_pred).sum() - false_negatives = torch.sum(~y_pred[y_true.bool()].bool()) + false_negatives = (y_true * (1 - y_pred)).sum() recall = true_positives / (true_positives + false_negatives) return recall