diff --git a/CollaborativeCoding/dataloaders/mnist_4_9.py b/CollaborativeCoding/dataloaders/mnist_4_9.py index 55d2453..bccdd70 100644 --- a/CollaborativeCoding/dataloaders/mnist_4_9.py +++ b/CollaborativeCoding/dataloaders/mnist_4_9.py @@ -43,12 +43,11 @@ def __init__( self.labels_path = self.mnist_path / ( MNIST_SOURCE["train_labels"][1] if train else MNIST_SOURCE["test_labels"][1] ) - - # Functions to map the labels from (4,9) -> (0,5) for CrossEntropyLoss to work properly. - self.label_shift = lambda x: x-4 - self.label_restore = lambda x: x+4 - - + + # Functions to map the labels from (4,9) -> (0,5) for CrossEntropyLoss to work properly. + self.label_shift = lambda x: x - 4 + self.label_restore = lambda x: x + 4 + def __len__(self): return len(self.samples) diff --git a/main.py b/main.py index 7225388..c34d317 100644 --- a/main.py +++ b/main.py @@ -139,7 +139,7 @@ def main(): for epoch in range(args.epoch): # Training loop start - print(f"Epoch: {epoch+1}/{args.epoch}") + print(f"Epoch: {epoch + 1}/{args.epoch}") trainingloss = [] model.train() for x, y in tqdm(trainloader, desc="Training"): diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index 4cd192c..b911524 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -1,6 +1,5 @@ from pathlib import Path -import numpy as np import pytest import torch from torchvision import transforms @@ -26,14 +25,19 @@ ], ) def test_load_data(data_name, expected): - print(data_name) dataset, _, _ = load_data( data_name, - data_dir=Path("data"), + train=False, + data_dir=Path("Data"), transform=transforms.ToTensor(), ) - assert isinstance(dataset, expected) - assert len(dataset) > 0 - assert isinstance(dataset[0], tuple) - assert isinstance(dataset[0][0], torch.Tensor) - assert isinstance(dataset[0][1], int) + + sample = dataset[0] + img, label = sample + + assert isinstance(dataset, expected), f"{type(dataset)} != {expected}" + assert len(dataset) > 0, "Dataset is empty" + assert isinstance(sample, tuple), f"{type(sample)} != tuple" + assert isinstance(img, torch.Tensor), f"{type(img)} != torch.Tensor" + assert isinstance(label, int), f"{type(label)} != int" + assert len(img.size()) == 3, f"{len(img.size())} != 3" diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 6b225c3..2390fe5 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,214 +1,57 @@ -from random import randint +import itertools import pytest from CollaborativeCoding.load_metric import MetricWrapper -from CollaborativeCoding.metrics import ( - Accuracy, - EntropyPrediction, - F1Score, - Precision, - Recall, -) - -@pytest.mark.parametrize( - "metric, num_classes, macro_averaging", - [ - ("f1", randint(2, 10), False), - ("f1", randint(2, 10), True), - ("recall", randint(2, 10), False), - ("recall", randint(2, 10), True), - ("accuracy", randint(2, 10), False), - ("accuracy", randint(2, 10), True), - ("precision", randint(2, 10), False), - ("precision", randint(2, 10), True), - ("entropy", randint(2, 10), False), - ], -) -def test_metric_wrapper(metric, num_classes, macro_averaging): - import numpy as np - import torch - - y_true = torch.arange(num_classes, dtype=torch.int64) - logits = torch.rand(num_classes, num_classes) - - metrics = MetricWrapper( - metric, - num_classes=num_classes, - macro_averaging=macro_averaging, - ) - - metrics(y_true, logits) - score = metrics.getmetrics() - metrics.resetmetric() - empty_score = metrics.getmetrics() - - assert isinstance(score, dict), "Expected a dictionary output." - assert metric in score, f"Expected {metric} metric in the output." - assert score[metric] >= 0, "Expected a non-negative value." - assert np.isnan(empty_score[metric]), "Expected an empty list." - - -def test_recall(): - import torch - - y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6]) - logits = torch.randn(7, 7) - - recall_micro = Recall(7) - recall_macro = Recall(7, macro_averaging=True) - - recall_micro(y_true, logits) - recall_macro(y_true, logits) - - recall_micro_score = recall_micro.__returnmetric__() - recall_macro_score = recall_macro.__returnmetric__() - - assert isinstance(recall_micro_score, torch.Tensor), "Expected a tensor output." - assert isinstance(recall_macro_score, torch.Tensor), "Expected a tensor output." - assert recall_micro_score.item() >= 0, "Expected a non-negative value." - assert recall_macro_score.item() >= 0, "Expected a non-negative value." +METRICS = ["f1", "recall", "accuracy", "precision", "entropy"] -def test_f1score(): - import torch - - # Example case with known output - y_true = torch.tensor([0, 1, 2, 2, 1, 0]) # True labels - y_pred = torch.tensor([0, 1, 1, 2, 0, 0]) # Predicted labels - - # Create F1Score object for micro and macro averaging - f1_micro = F1Score(num_classes=3, macro_averaging=False) - f1_macro = F1Score(num_classes=3, macro_averaging=True) - - # Update F1 score with predictions - f1_micro(y_true, y_pred) - f1_macro(y_true, y_pred) +def _metric_combinations(): + """ + Yield various combinations of metrics: + 1. Single metric as a list + 2. Pairs of metrics + 3. All metrics + """ - # Get F1 scores - micro_f1_score = f1_micro.__returnmetric__() - macro_f1_score = f1_macro.__returnmetric__() - - # Check if outputs are tensors - assert isinstance(micro_f1_score, torch.Tensor), ( - "Micro F1 score should be a tensor." - ) - assert isinstance(macro_f1_score, torch.Tensor), ( - "Macro F1 score should be a tensor." - ) + # Single metrics as lists + for m in METRICS: + yield [m] - # Check that F1 scores are between 0 and 1 - assert 0 <= micro_f1_score.item() <= 1, "Micro F1 score should be between 0 and 1." - assert 0 <= macro_f1_score.item() <= 1, "Macro F1 score should be between 0 and 1." + # Pairs of metrics (2-combinations) + for combo in itertools.combinations(METRICS, 2): + yield list(combo) - print(f"Micro F1 Score: {micro_f1_score.item()}") - print(f"Macro F1 Score: {macro_f1_score.item()}") + # Also test all metrics at once + yield METRICS -def test_precision(): +@pytest.mark.parametrize("metrics", _metric_combinations()) +@pytest.mark.parametrize("num_classes", [2, 3, 5, 10]) +@pytest.mark.parametrize("macro_averaging", [True, False]) +def test_metric_wrapper(metrics, num_classes, macro_averaging): import numpy as np import torch - from sklearn.metrics import precision_score - C = randint(2, 10) # number of classes - N = randint(2, 10 * C) # batchsize - y_true = torch.randint(0, C, (N,)) - logits = torch.randn(N, C) - - # create metric objects - precision_micro = Precision(num_classes=C) - precision_macro = Precision(num_classes=C, macro_averaging=True) - - # run metric object - precision_micro(y_true, logits) - precision_macro(y_true, logits) - - # get metric scores - micro_precision_score = precision_micro.__returnmetric__() - macro_precision_score = precision_macro.__returnmetric__() - - # check output to be tensor - assert isinstance(micro_precision_score, torch.Tensor), "Tensor output is expected." - assert isinstance(macro_precision_score, torch.Tensor), "Tensor output is expected." - - # check for non-negativity - assert micro_precision_score.item() >= 0, "Expected non-negative value" - assert macro_precision_score.item() >= 0, "Expected non-negative value" - - # find predictions - y_pred = logits.argmax(dim=-1) - - # check dimension - assert y_true.shape == torch.Size([N]) - assert logits.shape == torch.Size([N, C]) - assert y_pred.shape == torch.Size([N]) - - # find true values with scikit learn - scikit_macro_precision = precision_score(y_true, y_pred, average="macro") - scikit_micro_precision = precision_score(y_true, y_pred, average="micro") - - # check for similarity - assert np.isclose(scikit_micro_precision, micro_precision_score, atol=1e-5), ( - "Score does not match scikit's score" - ) - assert np.isclose(scikit_macro_precision, macro_precision_score, atol=1e-5), ( - "Score does not match scikit's score" - ) - - -def test_accuracy(): - import numpy as np - import torch + y_true = torch.arange(num_classes, dtype=torch.int64) + logits = torch.rand(num_classes, num_classes) - # Test the accuracy metric - y_true = torch.tensor([0, 1, 2, 3, 4, 5]) - y_pred = torch.tensor([0, 1, 2, 3, 4, 5]) - accuracy = Accuracy(num_classes=6, macro_averaging=False) - accuracy(y_true, y_pred) - assert accuracy.__returnmetric__() == 1.0, "Expected accuracy to be 1.0" - accuracy.__reset__() - assert accuracy.__returnmetric__() is np.nan, "Expected accuracy to be 0.0" - y_pred = torch.tensor([0, 1, 2, 3, 4, 4]) - accuracy(y_true, y_pred) - assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, ( - "Expected accuracy to be 0.8333333134651184" - ) - accuracy.__reset__() - accuracy.macro_averaging = True - accuracy(y_true, y_pred) - y_true_1 = torch.tensor([0, 1, 2, 3, 4, 5]) - y_pred_1 = torch.tensor([0, 1, 2, 3, 4, 4]) - accuracy(y_true_1, y_pred_1) - assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, ( - "Expected accuracy to be 0.8333333134651186" - ) - accuracy.macro_averaging = False - assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, ( - "Expected accuracy to be 0.8333333134651184" + mw = MetricWrapper( + *metrics, + num_classes=num_classes, + macro_averaging=macro_averaging, ) - accuracy.__reset__() + mw(y_true, logits) + score = mw.getmetrics() + mw.resetmetric() + empty_score = mw.getmetrics() -def test_entropypred(): - import torch as th - - true_lab = th.rand(6, 5) - - metric = EntropyPrediction(num_classes=5) - - # Test if the metric stores multiple values - pred_logits = th.rand(6, 5) - metric(true_lab, pred_logits) - - pred_logits = th.rand(6, 5) - metric(true_lab, pred_logits) - - pred_logits = th.rand(6, 5) - metric(true_lab, pred_logits) - - assert type(metric.__returnmetric__()) == th.Tensor + assert isinstance(score, dict), "Expected a dictionary output." + for m in metrics: + assert m in score, f"Expected metric '{m}' in the output." + assert score[m] >= 0, "Expected a non-negative value." - # Test than an error is raised with num_class != class dimension length - with pytest.raises(AssertionError): - metric(true_lab, th.rand(6, 6)) + assert m in empty_score, f"Expected metric '{m}' in the output." + assert np.isnan(empty_score[m]), "Expected an empty list." diff --git a/tests/test_models.py b/tests/test_models.py index 1b70987..94a0744 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,71 +1,28 @@ import pytest import torch -from CollaborativeCoding.models import ( - ChristianModel, - JanModel, - JohanModel, - MagnusModel, - SolveigModel, -) - - -@pytest.mark.parametrize( - "image_shape, num_classes", - [((1, 16, 16), 6), ((3, 16, 16), 6)], -) -def test_christian_model(image_shape, num_classes): - n, c, h, w = 5, *image_shape - - model = ChristianModel(image_shape, num_classes) - - x = torch.randn(n, c, h, w) - y = model(x) - - assert y.shape == (n, num_classes), f"Shape: {y.shape}" - - -@pytest.mark.parametrize( - "image_shape, num_classes", - [((1, 28, 28), 4), ((3, 16, 16), 10)], -) -def test_jan_model(image_shape, num_classes): - n, c, h, w = 5, *image_shape - - model = JanModel(image_shape, num_classes) - - x = torch.randn(n, c, h, w) - y = model(x) - - assert y.shape == (n, num_classes), f"Shape: {y.shape}" - - -@pytest.mark.parametrize( - "image_shape, num_classes", - [((3, 16, 16), 3), ((3, 16, 16), 7)], -) -def test_solveig_model(image_shape, num_classes): - n, c, h, w = 5, *image_shape - - model = SolveigModel(image_shape, num_classes) - - x = torch.randn(n, c, h, w) - y = model(x) - - assert y.shape == (n, num_classes), f"Shape: {y.shape}" +from CollaborativeCoding import load_model @pytest.mark.parametrize( - "image_shape, num_classes", [((3, 28, 28), 10), ((1, 16, 16), 10)] + "model_name", + [ + "magnusmodel", + "christianmodel", + "janmodel", + "johanmodel", + "solveigmodel", + ], ) -def test_magnus_model(image_shape, num_classes): - import torch as th +@pytest.mark.parametrize("image_shape", [(i, 28, 28) for i in [1, 3]]) +@pytest.mark.parametrize("num_classes", [3, 6, 10]) +def test_load_model(model_name, image_shape, num_classes): + model = load_model(model_name, image_shape, num_classes) n, c, h, w = 5, *image_shape - model = MagnusModel([h, w], num_classes, c) - x = th.rand((n, c, h, w)) - with th.no_grad(): - y = model(x) + dummy_img = torch.randn(n, c, h, w) + with torch.no_grad(): + y = model(dummy_img) - assert y.shape == (n, num_classes), f"Shape: {y.shape}" + assert y.shape == (n, num_classes), f"Shape: {y.shape} != {(n, num_classes)}" diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py deleted file mode 100644 index 3a6ca75..0000000 --- a/tests/test_wrappers.py +++ /dev/null @@ -1,81 +0,0 @@ -from pathlib import Path - -import pytest -import torch as th - -from CollaborativeCoding import MetricWrapper, load_data, load_model - - -def test_load_model(): - image_shape = (1, 16, 16) - num_classes = 4 - - dummy_img = th.rand((1, *image_shape)) - - modelnames = [ - "magnusmodel", - "christianmodel", - "janmodel", - "solveigmodel", - "johanmodel", - ] - - for name in modelnames: - print(name) - model = load_model(name, image_shape=image_shape, num_classes=num_classes) - - with th.no_grad(): - output = model(dummy_img) - assert output.size() == (1, 4), ( - f"Model {name} returned image of size {output}. Expected (1,4)" - ) - - -def test_load_data(): - from torchvision import transforms - - dataset_names = [ - "usps_0-6", - "mnist_0-3", - "usps_7-9", - "svhn", - "mnist_4-9", - ] - - trans = transforms.Compose( - [ - transforms.Resize((16, 16)), - transforms.ToTensor(), - ] - ) - - for name in dataset_names: - dataset = load_data(name, train=False, data_dir=Path.cwd() / "Data", transform=trans) - - im, _ = dataset.__getitem__(0) - - assert dataset.__len__() != 0 - assert type(im) is th.Tensor and len(im.size()) == 3 - - -def test_load_metric(): - metrics = ("entropy", "f1", "recall", "precision", "accuracy") - - class_sizes = [3, 6, 10] - for class_size in class_sizes: - y_true = th.rand((5, class_size)).argmax(dim=1) - y_pred = th.rand((5, class_size)) - - metricwrapper = MetricWrapper( - *metrics, - num_classes=class_size, - macro_averaging=True if class_size % 2 == 0 else False, - ) - - metricwrapper(y_true, y_pred) - metric = metricwrapper.getmetrics() - assert metric is not None - - metricwrapper.resetmetric() - metric2 = metricwrapper.getmetrics() - assert metric != metric2