Skip to content
5 changes: 2 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__pycache__/
.ipynb_checkpoints/
Data/*
data/*
Results/*
Experiments/*
_build/
Expand All @@ -14,9 +15,7 @@ doc/autoapi

#Magnus specific
job*
env2/*
ruffian.sh
localtest.sh
local*

# Johanthings
formatting.x
Expand Down
2 changes: 1 addition & 1 deletion CollaborativeCoding/dataloaders/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _get_labels(path: Path) -> np.ndarray:

def svhn(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
def download_svhn(path, train: bool = True):
SVHN()
SVHN(path, split="train" if train else "test", download=True)

parent_path = data_dir / "SVHN"

Expand Down
4 changes: 3 additions & 1 deletion CollaborativeCoding/dataloaders/mnist_0_3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

import numpy as np
from PIL import Image
from torch.utils.data import Dataset

from .datasources import MNIST_SOURCE
Expand Down Expand Up @@ -87,7 +88,8 @@ def __getitem__(self, index):
28, 28
) # Read image data

image = np.expand_dims(image, axis=0) # Add channel dimension
# image = np.expand_dims(image, axis=0) # Add channel dimension
image = Image.fromarray(image.astype(np.uint8))

if self.transform:
image = self.transform(image)
Expand Down
10 changes: 7 additions & 3 deletions CollaborativeCoding/dataloaders/mnist_4_9.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

import numpy as np
from PIL import Image
from torch.utils.data import Dataset

from .datasources import MNIST_SOURCE
Expand Down Expand Up @@ -28,11 +29,13 @@ def __init__(
transform=None,
nr_channels: int = 1,
):
super.__init__()
super().__init__()
self.data_path = data_path
self.mnist_path = self.data_path / "MNIST"
self.samples = sample_ids
self.train = train
self.transform = transform
self.num_classes = 6

self.images_path = self.mnist_path / (
MNIST_SOURCE["train_images"][1] if train else MNIST_SOURCE["test_images"][1]
Expand All @@ -46,7 +49,7 @@ def __len__(self):

def __getitem__(self, idx):
with open(self.labels_path, "rb") as labelfile:
label_pos = 8 + self.sample[idx]
label_pos = 8 + self.samples[idx]
labelfile.seek(label_pos)
label = int.from_bytes(labelfile.read(1), byteorder="big")

Expand All @@ -57,7 +60,8 @@ def __getitem__(self, idx):
28, 28
)

image = np.expand_dims(image, axis=0) # Channel
# image = np.expand_dims(image, axis=0) # Channel
image = Image.fromarray(image.astype(np.uint8))

if self.transform:
image = self.transform(image)
Expand Down
13 changes: 11 additions & 2 deletions CollaborativeCoding/dataloaders/usps_0_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
sample_ids: list,
train: bool = False,
transform=None,
nr_channels=1,
):
super().__init__()

Expand All @@ -91,6 +92,7 @@ def __init__(
self.transform = transform
self.mode = "train" if train else "test"
self.sample_ids = sample_ids
self.nr_channels = nr_channels

def __len__(self):
return len(self.sample_ids)
Expand All @@ -100,11 +102,18 @@ def __getitem__(self, id):

with h5.File(self.filepath, "r") as f:
data = f[self.mode]["data"][index].astype(np.uint8)
label = f[self.mode]["target"][index]
label = int(f[self.mode]["target"][index])

data = Image.fromarray(data, mode="L")
if self.nr_channels == 1:
data = Image.fromarray(data, mode="L")
elif self.nr_channels == 3:
data = Image.fromarray(data, mode="RGB")
else:
raise ValueError("Invalid number of channels")

if self.transform:
data = self.transform(data)

# label = torch.tensor(label).long()

return data, label
7 changes: 5 additions & 2 deletions CollaborativeCoding/dataloaders/uspsh5_7_9.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ class USPSH5_Digit_7_9_Dataset(Dataset):
A transform function to apply to the images.
"""

def __init__(self, data_path, sample_ids, train=False, transform=None, nr_channels=1):
def __init__(
self, data_path, sample_ids, train=False, transform=None, nr_channels=1
):
super().__init__()
"""
Initializes the USPS dataset by loading images and labels from the given `.h5` file.
Expand Down Expand Up @@ -112,7 +114,8 @@ def main():
indices = np.array([7, 8, 9])
# Load the dataset
dataset = USPSH5_Digit_7_9_Dataset(
data_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git", sample_ids=indices,
data_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git",
sample_ids=indices,
train=False,
transform=transform,
)
Expand Down
8 changes: 4 additions & 4 deletions CollaborativeCoding/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
case "svhn":
dataset = SVHNDataset
train_labels, test_labels = downloader.svhn(data_dir=data_dir)
labels = np.arange(10)
labels = np.unique(train_labels)
case "mnist_4-9":
dataset = MNISTDataset4_9
train_labels, test_labels = downloader.mnist(data_dir=data_dir)
Expand All @@ -89,23 +89,23 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
sample_ids=train_samples,
train=True,
transform=transform,
nr_channels=kwargs.get("nr_channels"),
nr_channels=kwargs.get("nr_channels", 1),
)

val = dataset(
data_path=data_dir,
sample_ids=val_samples,
train=True,
transform=transform,
nr_channels=kwargs.get("nr_channels"),
nr_channels=kwargs.get("nr_channels", 1),
)

test = dataset(
data_path=data_dir,
sample_ids=test_samples,
train=False,
transform=transform,
nr_channels=kwargs.get("nr_channels"),
nr_channels=kwargs.get("nr_channels", 1),
)

return train, val, test
25 changes: 25 additions & 0 deletions CollaborativeCoding/load_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,28 @@ def getmetrics(self, str_prefix: str = None):
def resetmetric(self):
for key in self.metrics:
self.metrics[key].__reset__()


if __name__ == "__main__":
import torch as th

metrics = ["entropy", "f1", "recall", "precision", "accuracy"]

class_sizes = [3, 6, 10]
for class_size in class_sizes:
y_true = th.rand((5, class_size)).argmax(dim=1)
y_pred = th.rand((5, class_size))

metricwrapper = MetricWrapper(
metric,
num_classes=class_size,
macro_averaging=True if class_size % 2 == 0 else False,
)

metricwrapper(y_true, y_pred)
metric = metricwrapper.getmetrics()
assert metric is not None

metricwrapper.resetmetric()
metric2 = metricwrapper.getmetrics()
assert metric != metric2
8 changes: 5 additions & 3 deletions CollaborativeCoding/metrics/F1.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,13 @@ def __returnmetric__(self):
else:
self.y_true = torch.cat(self.y_true)
self.y_pred = torch.cat(self.y_pred)
return self._micro_F1(self.y_true, self.y_pred) if not self.macro_averaging else self._macro_F1(self.y_true, self.y_pred)
return (
self._micro_F1(self.y_true, self.y_pred)
if not self.macro_averaging
else self._macro_F1(self.y_true, self.y_pred)
)

def __reset__(self):
self.y_true = []
self.y_pred = []
return None


42 changes: 33 additions & 9 deletions CollaborativeCoding/metrics/recall.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import torch
import torch.nn as nn

Expand Down Expand Up @@ -57,26 +58,49 @@ def __init__(self, num_classes, macro_averaging=False):
self.num_classes = num_classes
self.macro_averaging = macro_averaging

self.__y_true = []
self.__y_pred = []

def forward(self, true, logits):
pred = logits.argmax(dim=-1)
y_true = one_hot_encode(true, self.num_classes)
y_pred = one_hot_encode(pred, self.num_classes)

self.__y_true.append(y_true)
self.__y_pred.append(y_pred)

def compute(self, y_true, y_pred):
if self.macro_averaging:
recall = 0
for i in range(self.num_classes):
tp = (y_true[:, i] * y_pred[:, i]).sum()
fn = torch.sum(~y_pred[y_true[:, i].bool()].bool())
recall += tp / (tp + fn)
recall /= self.num_classes
else:
recall = self.__compute(y_true, y_pred)
return self.__compute_macro_averaging(y_true, y_pred)

return self.__compute_micro_averaging(y_true, y_pred)

def __compute_macro_averaging(self, y_true, y_pred):
recall = 0
for i in range(self.num_classes):
tp = (y_true[:, i] * y_pred[:, i]).sum()
fn = torch.sum(~y_pred[y_true[:, i].bool()].bool())
recall += tp / (tp + fn)
recall /= self.num_classes

return recall

def __compute(self, y_true, y_pred):
def __compute_micro_averaging(self, y_true, y_pred):
true_positives = (y_true * y_pred).sum()
false_negatives = torch.sum(~y_pred[y_true.bool()].bool())

recall = true_positives / (true_positives + false_negatives)
return recall

def __returnmetric__(self):
if len(self.__y_true) == 0 and len(self.__y_pred) == 0:
return np.nan

y_true = torch.cat(self.__y_true, dim=0)
y_pred = torch.cat(self.__y_pred, dim=0)

return self.compute(y_true, y_pred)

def __reset__(self):
self.__y_true = []
self.__y_pred = []
36 changes: 18 additions & 18 deletions CollaborativeCoding/models/solveig_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,24 @@

def find_fc_input_shape(image_shape, model):
"""
Find the shape of the input to the fully connected layer after passing through the convolutional layers.

Code inspired by @Seilmast (https://github.com/SFI-Visual-Intelligence/Collaborative-Coding-Exam/issues/67#issuecomment-2651212254)

Args
----
image_shape : tuple(int, int, int)
Shape of the input image (C, H, W), where C is the number of channels,
H is the height, and W is the width of the image.
model : nn.Module
The CNN model containing the convolutional layers, whose output size is used to
determine the number of input features for the fully connected layer.

Returns
-------
int
The number of elements in the input to the fully connected layer.
"""
Find the shape of the input to the fully connected layer after passing through the convolutional layers.

Code inspired by @Seilmast (https://github.com/SFI-Visual-Intelligence/Collaborative-Coding-Exam/issues/67#issuecomment-2651212254)

Args
----
image_shape : tuple(int, int, int)
Shape of the input image (C, H, W), where C is the number of channels,
H is the height, and W is the width of the image.
model : nn.Module
The CNN model containing the convolutional layers, whose output size is used to
determine the number of input features for the fully connected layer.

Returns
-------
int
The number of elements in the input to the fully connected layer.
"""

dummy_img = torch.randn(1, *image_shape)
with torch.no_grad():
Expand Down
3 changes: 1 addition & 2 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import numpy as np
import torch as th
import torch.nn as nn
import wandb
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

import wandb
from CollaborativeCoding import (
MetricWrapper,
createfolders,
Expand All @@ -17,7 +17,6 @@
# from wandb_api import WANDB_API



def main():
"""

Expand Down
15 changes: 7 additions & 8 deletions tests/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import numpy as np
import pytest
import torch
from PIL import Image
from torchvision import transforms

from CollaborativeCoding.dataloaders import (
MNISTDataset0_3,
MNISTDataset4_9,
SVHNDataset,
USPSDataset0_6,
USPSH5_Digit_7_9_Dataset,
)
Expand All @@ -20,11 +21,13 @@
("usps_0-6", USPSDataset0_6),
("usps_7-9", USPSH5_Digit_7_9_Dataset),
("mnist_0-3", MNISTDataset0_3),
# TODO: Add more datasets here
("svhn", SVHNDataset),
("mnist_4-9", MNISTDataset4_9),
],
)
def test_load_data(data_name, expected):
dataset = load_data(
print(data_name)
dataset, _, _ = load_data(
data_name,
data_dir=Path("data"),
transform=transforms.ToTensor(),
Expand All @@ -33,8 +36,4 @@ def test_load_data(data_name, expected):
assert len(dataset) > 0
assert isinstance(dataset[0], tuple)
assert isinstance(dataset[0][0], torch.Tensor)
assert isinstance(
dataset[0][1], (int, torch.Tensor, np.ndarray)
) # Should probably restrict this to only int or one-hot encoded tensor or array for consistency.


assert isinstance(dataset[0][1], int)
Loading
Loading