Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CollaborativeCoding/dataloaders/mnist_4_9.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class MNISTDataset4_9(Dataset):
Array of indices spcifying which samples to load. This determines the samples used by the dataloader.
train : bool, optional
Whether to train the model or not, by default False
transorm : callable, optional
Transform to apply to the images, by default None
nr_channels : int, optional
Number of channels in the images, by default 1
"""

def __init__(
Expand Down
21 changes: 12 additions & 9 deletions CollaborativeCoding/metrics/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@


class Precision(nn.Module):
"""Metric module for precision. Can calculate precision both as a mean of precisions or as brute function of true positives and false positives.
"""Metric module for precision. Can calculate both the micro- and macro-averaged precision.

Parameters
----------
num_classes : int
Number of classes in the dataset.
micro_averaging : bool
Wheter to compute the micro or macro precision (default False)
macro_averaging : bool
Performs macro-averaging if True, otherwise micro-averaging.
"""

def __init__(self, num_classes: int, macro_averaging: bool = False):
Expand All @@ -23,19 +23,15 @@ def __init__(self, num_classes: int, macro_averaging: bool = False):
self.y_pred = []

def forward(self, y_true: torch.tensor, logits: torch.tensor) -> torch.tensor:
"""Compute precision of model
"""Add true and predicted values to the class-global lists.

Parameters
----------
y_true : torch.tensor
True labels
y_pred : torch.tensor
logits : torch.tensor
Predicted labels

Returns
-------
torch.tensor
Precision score
"""
y_pred = logits.argmax(dim=-1)

Expand Down Expand Up @@ -100,6 +96,13 @@ def _macro_avg_precision(
return torch.nanmean(tp / (tp + fp))

def __returnmetric__(self):
"""Return the micro- or macro-averaged precision.

Returns
-------
torch.tensor
Micro- or macro-averaged precision
"""
if self.y_true == [] and self.y_pred == []:
return np.nan
elif self.y_true == [] or self.y_pred == []:
Expand Down
32 changes: 18 additions & 14 deletions CollaborativeCoding/models/johan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,32 @@
Multi-layer perceptron model for image classification.
"""

# class NeuronLayer(nn.Module):
# def __init__(self, in_features, out_features):
# super().__init__()

# self.fc = nn.Linear(in_features, out_features)
# self.relu = nn.ReLU()

# def forward(self, x):
# x = self.fc(x)
# x = self.relu(x)
# return x


class JohanModel(nn.Module):
"""Small MLP model for image classification.

Parameters
----------
in_features : int
Numer of input features.
image_shape : tuple(int, int, int)
Shape of the input image (C, H, W).
num_classes : int
Number of classes in the dataset.

Processing Images
-----------------
Input: (N, C, H, W)
N: Batch size
C: Number of input channels
H: Height of the input image
W: Width of the input image

Example:
Grayscale images (like MNIST) have C = 1.
Input shape: (N, 1, 28, 28)
fc1 Output shape: (N, 77)
fc2 Output shape: (N, 77)
fc3 Output shape: (N, 77)
fc4 Output shape: (N, num_classes)
"""

def __init__(self, image_shape, num_classes):
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import numpy as np
import torch as th
import torch.nn as nn
import wandb
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

import wandb
from CollaborativeCoding import (
MetricWrapper,
createfolders,
Expand Down