diff --git a/CollaborativeCoding/load_metric.py b/CollaborativeCoding/load_metric.py index 3420808..c47a02a 100644 --- a/CollaborativeCoding/load_metric.py +++ b/CollaborativeCoding/load_metric.py @@ -107,7 +107,7 @@ def resetmetric(self): y_pred = th.rand((5, class_size)) metricwrapper = MetricWrapper( - metric, + *metrics, num_classes=class_size, macro_averaging=True if class_size % 2 == 0 else False, ) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index bec965e..3a6ca75 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -1,16 +1,12 @@ from pathlib import Path -from tempfile import TemporaryDirectory import pytest -import torch -from torchvision import transforms +import torch as th from CollaborativeCoding import MetricWrapper, load_data, load_model def test_load_model(): - import torch as th - image_shape = (1, 16, 16) num_classes = 4 @@ -36,9 +32,6 @@ def test_load_model(): def test_load_data(): - from tempfile import TemporaryDirectory - - import torch as th from torchvision import transforms dataset_names = [ @@ -46,7 +39,7 @@ def test_load_data(): "mnist_0-3", "usps_7-9", "svhn", - "mnist_4-9", # Uncomment when implemented + "mnist_4-9", ] trans = transforms.Compose( @@ -56,21 +49,16 @@ def test_load_data(): ] ) - with TemporaryDirectory() as tmppath: - for name in dataset_names: - dataset, _, _ = load_data( - name, train=False, data_dir=Path(tmppath), transform=trans - ) + for name in dataset_names: + dataset = load_data(name, train=False, data_dir=Path.cwd() / "Data", transform=trans) - im, _ = dataset.__getitem__(0) + im, _ = dataset.__getitem__(0) - assert dataset.__len__() != 0 - assert type(im) == th.Tensor and len(im.size()) == 3 + assert dataset.__len__() != 0 + assert type(im) is th.Tensor and len(im.size()) == 3 def test_load_metric(): - import torch as th - metrics = ("entropy", "f1", "recall", "precision", "accuracy") class_sizes = [3, 6, 10]