diff --git a/main.py b/main.py index f9a7c98..c6230eb 100644 --- a/main.py +++ b/main.py @@ -30,9 +30,7 @@ def main(): device = args.device - if "usps" in args.dataset.lower(): - transform = transforms.Compose( [ transforms.Resize((28, 28)), @@ -47,7 +45,6 @@ def main(): data_dir=args.datafolder, transform=transform, val_size=args.val_size, - ) train_metrics = MetricWrapper( @@ -129,7 +126,6 @@ def main(): project=args.run_name, tags=[args.modelname, args.dataset], config=args, - ) wandb.watch(model) diff --git a/tests/test_models.py b/tests/test_models.py index 5266f0e..e94c805 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,7 +1,7 @@ import pytest import torch -from utils.models import ChristianModel, JanModel, MagnusModel +from utils.models import ChristianModel, JanModel, MagnusModel, SolveigModel @pytest.mark.parametrize( @@ -34,6 +34,21 @@ def test_jan_model(image_shape, num_classes): assert y.shape == (n, num_classes), f"Shape: {y.shape}" +@pytest.mark.parametrize( + "image_shape, num_classes", + [((3, 16, 16), 3), ((3, 16, 16), 7)], +) +def test_solveig_model(image_shape, num_classes): + n, c, h, w = 5, *image_shape + + model = SolveigModel(image_shape, num_classes) + + x = torch.randn(n, c, h, w) + y = model(x) + + assert y.shape == (n, num_classes), f"Shape: {y.shape}" + + @pytest.mark.parametrize("image_shape", [(3, 28, 28)]) def test_magnus_model(image_shape): import torch as th diff --git a/utils/arg_parser.py b/utils/arg_parser.py index 31cfcba..240226c 100644 --- a/utils/arg_parser.py +++ b/utils/arg_parser.py @@ -33,7 +33,6 @@ def get_args(): help="Whether model should be saved or not.", ) - # Data/Model specific values parser.add_argument( "--modelname", @@ -83,7 +82,6 @@ def get_args(): "--macro_averaging", action="store_true", help="If the flag is included, the metrics will be calculated using macro averaging.", - ) # Training specific values diff --git a/utils/dataloaders/svhn.py b/utils/dataloaders/svhn.py index 14e2edb..e48b517 100644 --- a/utils/dataloaders/svhn.py +++ b/utils/dataloaders/svhn.py @@ -1,6 +1,5 @@ import os - import h5py import numpy as np from PIL import Image @@ -95,7 +94,6 @@ def __getitem__(self, index): img = Image.fromarray(h5f["images"][index]) if self.nr_channels == 1: - img = img.convert("L") if self.transforms is not None: img = self.transforms(img) diff --git a/utils/dataloaders/uspsh5_7_9.py b/utils/dataloaders/uspsh5_7_9.py index 98cbd03..4d63255 100644 --- a/utils/dataloaders/uspsh5_7_9.py +++ b/utils/dataloaders/uspsh5_7_9.py @@ -1,3 +1,5 @@ +from pathlib import Path + import h5py import numpy as np import torch @@ -30,7 +32,7 @@ class USPSH5_Digit_7_9_Dataset(Dataset): A transform function to apply to the images. """ - def __init__(self, h5_path, mode, transform=None): + def __init__(self, data_path, train=False, transform=None): super().__init__() """ Initializes the USPS dataset by loading images and labels from the given `.h5` file. @@ -43,12 +45,13 @@ def __init__(self, h5_path, mode, transform=None): transform : callable, optional, default=None A transform function to apply on images. """ - + self.filename = "usps.h5" + path = data_path if isinstance(data_path, Path) else Path(data_path) + self.filepath = path / self.filename self.transform = transform - self.mode = mode - self.h5_path = h5_path + self.mode = "train" if train else "test" # Load the dataset from the HDF5 file - with h5py.File(self.h5_path, "r") as hf: + with h5py.File(self.filepath, "r") as hf: images = hf[self.mode]["data"][:] labels = hf[self.mode]["target"][:] @@ -105,8 +108,8 @@ def main(): # Load the dataset dataset = USPSH5_Digit_7_9_Dataset( - h5_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git/usps.h5", - mode="train", + data_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git", + train=False, transform=transform, ) data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True) diff --git a/utils/metrics/F1.py b/utils/metrics/F1.py index 70791c5..0833389 100644 --- a/utils/metrics/F1.py +++ b/utils/metrics/F1.py @@ -112,6 +112,7 @@ def _macro_F1(self): def forward(self, preds, target): """ + Update the True Positives, False Positives, and False Negatives, and compute the F1 score. This method computes the F1 score based on the predictions and true labels. It can compute either the