Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ def main():

device = args.device


if "usps" in args.dataset.lower():

transform = transforms.Compose(
[
transforms.Resize((28, 28)),
Expand All @@ -47,7 +45,6 @@ def main():
data_dir=args.datafolder,
transform=transform,
val_size=args.val_size,

)

train_metrics = MetricWrapper(
Expand Down Expand Up @@ -129,7 +126,6 @@ def main():
project=args.run_name,
tags=[args.modelname, args.dataset],
config=args,

)
wandb.watch(model)

Expand Down
17 changes: 16 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch

from utils.models import ChristianModel, JanModel, MagnusModel
from utils.models import ChristianModel, JanModel, MagnusModel, SolveigModel


@pytest.mark.parametrize(
Expand Down Expand Up @@ -34,6 +34,21 @@ def test_jan_model(image_shape, num_classes):
assert y.shape == (n, num_classes), f"Shape: {y.shape}"


@pytest.mark.parametrize(
"image_shape, num_classes",
[((3, 16, 16), 3), ((3, 16, 16), 7)],
)
def test_solveig_model(image_shape, num_classes):
n, c, h, w = 5, *image_shape

model = SolveigModel(image_shape, num_classes)

x = torch.randn(n, c, h, w)
y = model(x)

assert y.shape == (n, num_classes), f"Shape: {y.shape}"


@pytest.mark.parametrize("image_shape", [(3, 28, 28)])
def test_magnus_model(image_shape):
import torch as th
Expand Down
2 changes: 0 additions & 2 deletions utils/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def get_args():
help="Whether model should be saved or not.",
)


# Data/Model specific values
parser.add_argument(
"--modelname",
Expand Down Expand Up @@ -83,7 +82,6 @@ def get_args():
"--macro_averaging",
action="store_true",
help="If the flag is included, the metrics will be calculated using macro averaging.",

)

# Training specific values
Expand Down
2 changes: 0 additions & 2 deletions utils/dataloaders/svhn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os


import h5py
import numpy as np
from PIL import Image
Expand Down Expand Up @@ -95,7 +94,6 @@ def __getitem__(self, index):
img = Image.fromarray(h5f["images"][index])

if self.nr_channels == 1:

img = img.convert("L")
if self.transforms is not None:
img = self.transforms(img)
Expand Down
17 changes: 10 additions & 7 deletions utils/dataloaders/uspsh5_7_9.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import h5py
import numpy as np
import torch
Expand Down Expand Up @@ -30,7 +32,7 @@ class USPSH5_Digit_7_9_Dataset(Dataset):
A transform function to apply to the images.
"""

def __init__(self, h5_path, mode, transform=None):
def __init__(self, data_path, train=False, transform=None):
super().__init__()
"""
Initializes the USPS dataset by loading images and labels from the given `.h5` file.
Expand All @@ -43,12 +45,13 @@ def __init__(self, h5_path, mode, transform=None):
transform : callable, optional, default=None
A transform function to apply on images.
"""

self.filename = "usps.h5"
path = data_path if isinstance(data_path, Path) else Path(data_path)
self.filepath = path / self.filename
self.transform = transform
self.mode = mode
self.h5_path = h5_path
self.mode = "train" if train else "test"
# Load the dataset from the HDF5 file
with h5py.File(self.h5_path, "r") as hf:
with h5py.File(self.filepath, "r") as hf:
images = hf[self.mode]["data"][:]
labels = hf[self.mode]["target"][:]

Expand Down Expand Up @@ -105,8 +108,8 @@ def main():

# Load the dataset
dataset = USPSH5_Digit_7_9_Dataset(
h5_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git/usps.h5",
mode="train",
data_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git",
train=False,
transform=transform,
)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)
Expand Down
1 change: 1 addition & 0 deletions utils/metrics/F1.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def _macro_F1(self):

def forward(self, preds, target):
"""

Update the True Positives, False Positives, and False Negatives, and compute the F1 score.

This method computes the F1 score based on the predictions and true labels. It can compute either the
Expand Down
Loading