diff --git a/CollaborativeCoding/load_metric.py b/CollaborativeCoding/load_metric.py index c47a02a..002260c 100644 --- a/CollaborativeCoding/load_metric.py +++ b/CollaborativeCoding/load_metric.py @@ -79,6 +79,7 @@ def _get_metric(self, key): raise ValueError(f"Metric {key} not supported") def __call__(self, y_true, y_pred): + y_true, y_pred = y_true.detach().cpu(), y_pred.detach().cpu() for key in self.metrics: self.metrics[key](y_true, y_pred) diff --git a/CollaborativeCoding/metrics/EntropyPred.py b/CollaborativeCoding/metrics/EntropyPred.py index b77e8d7..9338f44 100644 --- a/CollaborativeCoding/metrics/EntropyPred.py +++ b/CollaborativeCoding/metrics/EntropyPred.py @@ -36,16 +36,13 @@ def __call__(self, y_true: th.Tensor, y_logits: th.Tensor): assert y_logits.size(-1) == self.num_classes, ( f"y_logit class length: {y_logits.size(-1)}, expected: {self.num_classes}" ) - y_pred = nn.Softmax(dim=1)(y_logits) - print(f"y_pred: {y_pred}") entropy_values = entropy(y_pred, axis=1) entropy_values = th.from_numpy(entropy_values) # Fix numerical errors for perfect guesses entropy_values[entropy_values == th.inf] = 0 entropy_values = th.nan_to_num(entropy_values) - print(f"Entropy Values: {entropy_values}") for sample in entropy_values: self.stored_entropy_values.append(sample.item())