From 5430fd9b9d7d6edb19108d0f03a6aa39c4227aa6 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Thu, 20 Feb 2025 12:40:07 +0100 Subject: [PATCH 1/5] add mnist 4-9 --- tests/test_wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index f551f52..875c0d1 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -46,7 +46,7 @@ def test_load_data(): "mnist_0-3", "usps_7-9", "svhn", - # 'mnist_4-9' #Uncomment when implemented + "mnist_4-9", ] trans = transforms.Compose( From 090a506e5f48228f79cef72dbcacd555e77eba2a Mon Sep 17 00:00:00 2001 From: salomaestro Date: Thu, 20 Feb 2025 12:40:58 +0100 Subject: [PATCH 2/5] Ruff --- tests/test_dataloaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index ad33498..16c15a6 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -36,4 +36,4 @@ def test_load_data(data_name, expected): assert isinstance(dataset[0][0], torch.Tensor) assert isinstance( dataset[0][1], (int, torch.Tensor, np.ndarray) - ) # Should probably restrict this to only int or one-hot encoded tensor or array for consistency. \ No newline at end of file + ) # Should probably restrict this to only int or one-hot encoded tensor or array for consistency. From b8c076f8eb5419fd90506d0ad727a7cce2cd4868 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Thu, 20 Feb 2025 13:07:50 +0100 Subject: [PATCH 3/5] Fix bug in metric wrapper --- CollaborativeCoding/load_metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CollaborativeCoding/load_metric.py b/CollaborativeCoding/load_metric.py index 3420808..c47a02a 100644 --- a/CollaborativeCoding/load_metric.py +++ b/CollaborativeCoding/load_metric.py @@ -107,7 +107,7 @@ def resetmetric(self): y_pred = th.rand((5, class_size)) metricwrapper = MetricWrapper( - metric, + *metrics, num_classes=class_size, macro_averaging=True if class_size % 2 == 0 else False, ) From 9f4c389b9bb2f0a3807abb6ccab93fdb316587f0 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Thu, 20 Feb 2025 13:20:20 +0100 Subject: [PATCH 4/5] Restructure tests --- tests/test_wrappers.py | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 875c0d1..f18f442 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -1,16 +1,12 @@ from pathlib import Path -from tempfile import TemporaryDirectory import pytest -import torch -from torchvision import transforms +import torch as th from CollaborativeCoding import MetricWrapper, load_data, load_model def test_load_model(): - import torch as th - image_shape = (1, 16, 16) num_classes = 4 @@ -36,9 +32,6 @@ def test_load_model(): def test_load_data(): - from tempfile import TemporaryDirectory - - import torch as th from torchvision import transforms dataset_names = [ @@ -56,21 +49,16 @@ def test_load_data(): ] ) - with TemporaryDirectory() as tmppath: - for name in dataset_names: - dataset = load_data( - name, train=False, data_dir=Path(tmppath), transform=trans - ) + for name in dataset_names: + dataset = load_data(name, train=False, data_dir=Path.cwd(), transform=trans) - im, _ = dataset.__getitem__(0) + im, _ = dataset.__getitem__(0) - assert dataset.__len__() != 0 - assert type(im) == th.Tensor and len(im.size()) == 3 + assert dataset.__len__() != 0 + assert type(im) is th.Tensor and len(im.size()) == 3 def test_load_metric(): - import torch as th - metrics = ("entropy", "f1", "recall", "precision", "accuracy") class_sizes = [3, 6, 10] From 9683fe864461296b444df1743ff78bc11183b69a Mon Sep 17 00:00:00 2001 From: salomaestro Date: Thu, 20 Feb 2025 13:23:01 +0100 Subject: [PATCH 5/5] Fix bug with pathing during dataloader tests --- tests/test_wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index f18f442..3a6ca75 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -50,7 +50,7 @@ def test_load_data(): ) for name in dataset_names: - dataset = load_data(name, train=False, data_dir=Path.cwd(), transform=trans) + dataset = load_data(name, train=False, data_dir=Path.cwd() / "Data", transform=trans) im, _ = dataset.__getitem__(0)