diff --git a/.gitignore b/.gitignore index 3afaf91..376d6ea 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,9 @@ env2/* ruffian.sh localtest.sh +# Johanthings +formatting.x + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/CollaborativeCoding/dataloaders/mnist_4_9.py b/CollaborativeCoding/dataloaders/mnist_4_9.py index 5de281b..9ec8f50 100644 --- a/CollaborativeCoding/dataloaders/mnist_4_9.py +++ b/CollaborativeCoding/dataloaders/mnist_4_9.py @@ -20,7 +20,14 @@ class MNISTDataset4_9(Dataset): Whether to train the model or not, by default False """ - def __init__(self, data_path: Path, sample_ids: np.ndarray, train: bool = False): + def __init__( + self, + data_path: Path, + sample_ids: np.ndarray, + train: bool = False, + transform=None, + nr_channels: int = 1, + ): super.__init__() self.data_path = data_path self.mnist_path = self.data_path / "MNIST" @@ -52,4 +59,7 @@ def __getitem__(self, idx): image = np.expand_dims(image, axis=0) # Channel + if self.transform: + image = self.transform(image) + return image, label diff --git a/CollaborativeCoding/load_data.py b/CollaborativeCoding/load_data.py index 200368f..b4a247b 100644 --- a/CollaborativeCoding/load_data.py +++ b/CollaborativeCoding/load_data.py @@ -4,6 +4,7 @@ from .dataloaders import ( Downloader, MNISTDataset0_3, + MNISTDataset4_9, SVHNDataset, USPSDataset0_6, USPSH5_Digit_7_9_Dataset, @@ -65,7 +66,9 @@ def load_data(dataset: str, *args, **kwargs) -> tuple: train_labels, test_labels = downloader.svhn(data_dir=data_dir) labels = np.arange(10) case "mnist_4-9": - raise NotImplementedError("MNIST 4-9 dataset not yet implemented.") + dataset = MNISTDataset4_9 + train_labels, test_labels = downloader.mnist(data_dir=data_dir) + labels = np.arange(4, 10) case _: raise NotImplementedError(f"Dataset: {dataset} not implemented.") diff --git a/CollaborativeCoding/load_metric.py b/CollaborativeCoding/load_metric.py index 839d9c6..49a60f6 100644 --- a/CollaborativeCoding/load_metric.py +++ b/CollaborativeCoding/load_metric.py @@ -25,9 +25,9 @@ class MetricWrapper(nn.Module): ------- __call__(y_true, y_pred) Computes the specified metrics on the provided true and predicted labels. - __getmetrics__(str_prefix: str = None) + getmetrics(str_prefix: str = None) Retrieves the computed metrics, optionally prefixed with a string. - reset() + resetmetric() Resets the state of all metric computations. Examples -------- @@ -36,10 +36,10 @@ class MetricWrapper(nn.Module): >>> y_true = [0, 1, 0, 1] >>> y_pred = [0, 1, 1, 0] >>> metrics(y_true, y_pred) - >>> metrics.__getmetrics__() + >>> metrics.getmetrics() {'entropy': 0.6931471805599453, 'f1': 0.5, 'precision': 0.5} - >>> metrics.reset() - >>> metrics.__getmetrics__() + >>> metrics.resetmetric() + >>> metrics.getmetrics() {'entropy': [], 'f1': [], 'precision': []} """ diff --git a/CollaborativeCoding/metrics/precision.py b/CollaborativeCoding/metrics/precision.py index a596df7..2b70d7d 100644 --- a/CollaborativeCoding/metrics/precision.py +++ b/CollaborativeCoding/metrics/precision.py @@ -1,3 +1,4 @@ +import numpy as np import torch import torch.nn as nn @@ -18,6 +19,8 @@ def __init__(self, num_classes: int, macro_averaging: bool = False): self.num_classes = num_classes self.macro_averaging = macro_averaging + self.y_true = [] + self.y_pred = [] def forward(self, y_true: torch.tensor, logits: torch.tensor) -> torch.tensor: """Compute precision of model @@ -35,11 +38,10 @@ def forward(self, y_true: torch.tensor, logits: torch.tensor) -> torch.tensor: Precision score """ y_pred = logits.argmax(dim=-1) - return ( - self._macro_avg_precision(y_true, y_pred) - if self.macro_averaging - else self._micro_avg_precision(y_true, y_pred) - ) + + # Append to the class-global values + self.y_true.append(y_true) + self.y_pred.append(y_pred) def _micro_avg_precision( self, y_true: torch.tensor, y_pred: torch.tensor @@ -58,7 +60,6 @@ def _micro_avg_precision( torch.tensor Micro-averaged precision """ - print(y_true.shape) true_oh = torch.zeros(y_true.size(0), self.num_classes).scatter_( 1, y_true.unsqueeze(1), 1 ) @@ -98,6 +99,31 @@ def _macro_avg_precision( return torch.nanmean(tp / (tp + fp)) + def __returnmetric__(self): + if self.y_true == [] and self.y_pred == []: + return np.nan + elif self.y_true == [] or self.y_pred == []: + raise ValueError("y_true or y_pred is empty.") + self.y_true = torch.cat(self.y_true) + self.y_pred = torch.cat(self.y_pred) + + return ( + self._macro_avg_precision(self.y_true, self.y_pred) + if self.macro_averaging + else self._micro_avg_precision(self.y_true, self.y_pred) + ) + + def __reset__(self): + """Resets the class-global lists of true and predicted values to empty lists. + + Returns + ------- + None + Returns None + """ + self.y_true = self.y_pred = [] + return None + if __name__ == "__main__": print( diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 3107d73..8f7b4e5 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -85,8 +85,6 @@ def test_f1score(): def test_precision(): - from random import randint - import numpy as np import torch from sklearn.metrics import precision_score @@ -100,9 +98,13 @@ def test_precision(): precision_micro = Precision(num_classes=C) precision_macro = Precision(num_classes=C, macro_averaging=True) - # find scores - micro_precision_score = precision_micro(y_true, logits) - macro_precision_score = precision_macro(y_true, logits) + # run metric object + precision_micro(y_true, logits) + precision_macro(y_true, logits) + + # get metric scores + micro_precision_score = precision_micro.__returnmetric__() + macro_precision_score = precision_macro.__returnmetric__() # check output to be tensor assert isinstance(micro_precision_score, torch.Tensor), "Tensor output is expected." @@ -113,12 +115,12 @@ def test_precision(): assert macro_precision_score.item() >= 0, "Expected non-negative value" # find predictions - y_pred = logits.argmax(dim=-1, keepdims=True) + y_pred = logits.argmax(dim=-1) # check dimension - assert y_true.shape == torch.Size([N, 1]) or torch.Size([N]) + assert y_true.shape == torch.Size([N]) assert logits.shape == torch.Size([N, C]) - assert y_pred.shape == torch.Size([N, 1]) or torch.Size([N]) + assert y_pred.shape == torch.Size([N]) # find true values with scikit learn scikit_macro_precision = precision_score(y_true, y_pred, average="macro")