Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ doc/autoapi

*.DS_Store

# Jan
job.yaml
sync.sh

#Magnus specific
job*
Expand Down
8 changes: 6 additions & 2 deletions CollaborativeCoding/dataloaders/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

class Downloader:
"""
Class to download and load the USPS dataset.
Class used to verify availability and potentially download implemented datasets.

Methods
-------
mnist(data_dir: Path) -> tuple[np.ndarray, np.ndarray]
Download the MNIST dataset and save it as an HDF5 file to `data_dir`.
Checks the availability of mnist dataset. If not present downloads it into MNIST folder in `data_dir`.
svhn(data_dir: Path) -> tuple[np.ndarray, np.ndarray]
Download the SVHN dataset and save it as an HDF5 file to `data_dir`.
usps(data_dir: Path) -> tuple[np.ndarray, np.ndarray]
Expand All @@ -42,6 +42,10 @@ class Downloader:
"""

def mnist(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
"""
Check the availability of mnist dataset. If not present downloads it into MNIST folder in `data_dir`.
"""

def _chech_is_downloaded(path: Path) -> bool:
path = path / "MNIST"
if path.exists():
Expand Down
12 changes: 6 additions & 6 deletions CollaborativeCoding/dataloaders/mnist_0_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@
class MNISTDataset0_3(Dataset):
"""
A custom Dataset class for loading a subset of the MNIST dataset containing digits 0 to 3.
Parameters

Args
----------
data_path : Path
The root directory where the MNIST data is stored.
The root directory where the MNIST folder with data is stored.
sample_ids : list
A list of indices specifying which samples to load.
train : bool, optional
If True, load training data, otherwise load test data. Default is False.
transform : callable, optional
A function/transform to apply to the images. Default is None.

Attributes
----------
data_path : Path
The root directory where the MNIST data is stored.
mnist_path : Path
The directory where the MNIST dataset is located within the root directory.
idx : list
Expand All @@ -40,6 +40,7 @@ class MNISTDataset0_3(Dataset):
The path to the label file (train or test) based on the `train` flag.
length : int
The number of samples in the dataset.

Methods
-------
__len__()
Expand All @@ -58,8 +59,7 @@ def __init__(
):
super().__init__()

self.data_path = data_path
self.mnist_path = self.data_path / "MNIST"
self.mnist_path = data_path / "MNIST"
self.idx = sample_ids
self.train = train
self.transform = transform
Expand Down
2 changes: 1 addition & 1 deletion CollaborativeCoding/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def filter_labels(samples: list, wanted_labels: list) -> list:

def load_data(dataset: str, *args, **kwargs) -> tuple:
"""
load the dataset based on the dataset name.
Load the dataset based on the dataset name.

Args
----
Expand Down
10 changes: 8 additions & 2 deletions CollaborativeCoding/load_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,32 @@ class MetricWrapper(nn.Module):
This class allows you to compute several metrics simultaneously on given
true and predicted labels. It supports a variety of common metrics and
provides methods to accumulate results and reset the state.

Args
----
num_classes : int
The number of classes in the classification task.
metrics : list[str]
A list of metric names to be evaluated.
macro_averaging : bool
Whether to compute macro-averaged metrics for multi-class classification.

Attributes
----------
metrics : dict
A dictionary mapping metric names to their corresponding functions.
num_classes : int
The number of classes for the classification task.

Methods
-------
__call__(y_true, y_pred)
Computes the specified metrics on the provided true and predicted labels.
Passes the true and predicted labels to the metric functions.
getmetrics(str_prefix: str = None)
Retrieves the computed metrics, optionally prefixed with a string.
Retrieves the dictionary of computed metrics, optionally all keys can be prefixed with a string.
resetmetric()
Resets the state of all metric computations.

Examples
--------
>>> from CollaborativeCoding import MetricWrapperProposed
Expand Down
77 changes: 54 additions & 23 deletions CollaborativeCoding/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,45 @@


class Accuracy(nn.Module):
"""
Computes the accuracy of a model's predictions.

Args
----------
num_classes : int
The number of classes in the classification task.
macro_averaging : bool, optional
If True, computes macro-average accuracy. Otherwise, computes micro-average accuracy. Default is False.


Methods
-------
forward(y_true, y_pred)
Stores the true and predicted labels. Typically called for each batch during the forward pass of a model.
_macro_acc()
Computes the macro-average accuracy.
_micro_acc()
Computes the micro-average accuracy.
__returnmetric__()
Returns the computed accuracy based on the averaging method for all stored predictions.
__reset__()
Resets the stored true and predicted labels.

Examples
--------
>>> y_true = torch.tensor([0, 1, 2, 3, 3])
>>> y_pred = torch.tensor([0, 1, 2, 3, 0])
>>> accuracy = Accuracy(num_classes=4)
>>> accuracy(y_true, y_pred)
>>> accuracy.__returnmetric__()
0.8
>>> accuracy.__reset__()
>>> accuracy.macro_averaging = True
>>> accuracy(y_true, y_pred)
>>> accuracy.__returnmetric__()
0.875
"""

def __init__(self, num_classes, macro_averaging=False):
super().__init__()
self.num_classes = num_classes
Expand All @@ -13,19 +52,14 @@ def __init__(self, num_classes, macro_averaging=False):

def forward(self, y_true, y_pred):
"""
Compute the accuracy of the model.
Store the true and predicted labels.

Parameters
----------
y_true : torch.Tensor
True labels.
y_pred : torch.Tensor
Predicted labels.

Returns
-------
float
Accuracy score.
Predicted labels. Either a 1D tensor of shape (batch_size,) or a 2D tensor of shape (batch_size, num_classes).
"""
if y_pred.dim() > 1:
y_pred = y_pred.argmax(dim=1)
Expand All @@ -34,14 +68,7 @@ def forward(self, y_true, y_pred):

def _macro_acc(self):
"""
Compute the macro-average accuracy.

Parameters
----------
y_true : torch.Tensor
True labels.
y_pred : torch.Tensor
Predicted labels.
Compute the macro-average accuracy on the stored predictions.

Returns
-------
Expand All @@ -63,14 +90,7 @@ def _macro_acc(self):

def _micro_acc(self):
"""
Compute the micro-average accuracy.

Parameters
----------
y_true : torch.Tensor
True labels.
y_pred : torch.Tensor
Predicted labels.
Compute the micro-average accuracy on the stored predictions.

Returns
-------
Expand All @@ -80,6 +100,14 @@ def _micro_acc(self):
return (self.y_true == self.y_pred).float().mean().item()

def __returnmetric__(self):
"""
Return the computed accuracy based on the averaging method for all stored predictions.

Returns
-------
float
Computed accuracy score.
"""
if self.y_true == [] or self.y_pred == []:
return np.nan
if isinstance(self.y_true, list):
Expand All @@ -92,6 +120,9 @@ def __returnmetric__(self):
return self._micro_acc() if not self.macro_averaging else self._macro_acc()

def __reset__(self):
"""
Reset the stored true and predicted labels.
"""
self.y_true = []
self.y_pred = []
return None
6 changes: 3 additions & 3 deletions CollaborativeCoding/models/jan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@


class JanModel(nn.Module):
"""A simple MLP network model for image classification tasks.
"""A simple MLP network model for image classification tasks. Two hidden layers with 100 neurons.

Args
----
in_channels : int
Number of input channels.
image_shape : tuple(int, int, int)
Shape of the input image (C, H, W).
num_classes : int
Number of classes in the dataset.

Expand Down
6 changes: 5 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import numpy as np
import torch as th
import torch.nn as nn
import wandb
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

import wandb
from CollaborativeCoding import (
MetricWrapper,
createfolders,
Expand Down Expand Up @@ -132,6 +132,7 @@ def main():
wandb.init(
entity="ColabCode",
project=args.run_name,
dir=args.resultfolder,
tags=[args.modelname, args.dataset],
config=args,
)
Expand Down Expand Up @@ -178,6 +179,9 @@ def main():
train_metrics.resetmetric()
val_metrics.resetmetric()

if args.savemodel:
th.save(model, args.modelfolder / f"{args.modelname}_run:{args.run_name}.pth")

testloss = []
model.eval()
with th.no_grad():
Expand Down
Loading