Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def main():
"--modelname",
type=str,
default="MagnusModel",
choices=["MagnusModel"],
choices=["MagnusModel", "ChristianModel"],
help="Model which to be trained on",
)
parser.add_argument(
Expand Down Expand Up @@ -108,7 +108,7 @@ def main():
parser.add_argument(
"--device",
type=str,
default="cuda",
default="cpu",
choices=["cuda", "cpu", "mps"],
help="Which device to run the training on.",
)
Expand All @@ -124,10 +124,6 @@ def main():

device = args.device

# load model
model = load_model(args.modelname)
model.to(device)

metrics = MetricWrapper(*args.metric)

# Dataset
Expand All @@ -143,6 +139,20 @@ def main():
data_path=args.datafolder,
)

# Find number of channels in the dataset
if len(traindata[0][0].shape) == 2:
channels = 1
else:
channels = traindata[0][0].shape[0]

# load model
model = load_model(
args.modelname,
in_channels=channels,
num_classes=traindata.num_classes,
)
model.to(device)

trainloader = DataLoader(traindata,
batch_size=args.batchsize,
shuffle=True,
Expand Down Expand Up @@ -186,7 +196,7 @@ def main():
model.eval()
with th.no_grad():
for x, y in valiloader:
x = x.to(device)
x, y = x.to(device), y.to(device)
pred = model.forward(x)
loss = criterion(y, pred)
evalloss.append(loss.item())
Expand Down
9 changes: 9 additions & 0 deletions utils/dataloaders/usps_0_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class USPSDataset0_6(Dataset):
A function/transform that takes in a sample and returns a transformed version.
idx : numpy.ndarray
Indices of samples with labels 0-6.
num_classes : int
Number of classes in the dataset

Methods
-------
Expand Down Expand Up @@ -71,6 +73,7 @@ def __init__(
super().__init__()
self.path = list(data_path.glob("*.h5"))[0]
self.transform = transform
self.num_classes = 7

if download:
raise NotImplementedError("Download functionality not implemented.")
Expand Down Expand Up @@ -103,6 +106,12 @@ def __getitem__(self, idx):

data = data.reshape(16, 16)

# one hot encode the target
target = np.eye(self.num_classes, dtype=np.float32)[target]

# Add channel dimension
data = np.expand_dims(data, axis=0)

if self.transform:
data = self.transform(data)

Expand Down
19 changes: 11 additions & 8 deletions utils/load_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import torch.nn as nn

from .models import MagnusModel
from .models import ChristianModel, MagnusModel


def load_model(modelname: str) -> nn.Module:
if modelname == "MagnusModel":
return MagnusModel()
else:
raise ValueError(
f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling"
)
def load_model(modelname: str, *args, **kwargs) -> nn.Module:
match modelname.lower():
case "magnusmodel":
return MagnusModel(*args, **kwargs)
case "christianmodel":
return ChristianModel(*args, **kwargs)
case _:
raise ValueError(
f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling"
)
3 changes: 2 additions & 1 deletion utils/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__all__ = ["EntropyPrediction"]
__all__ = ["EntropyPrediction", "Recall"]

from .EntropyPred import EntropyPrediction
from .recall import Recall
62 changes: 62 additions & 0 deletions utils/metrics/recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
import torch.nn as nn


def one_hot_encode(y_true, num_classes):
"""One-hot encode the target tensor.

Args
----
y_true : torch.Tensor
Target tensor.
num_classes : int
Number of classes in the dataset.

Returns
-------
torch.Tensor
One-hot encoded tensor.
"""

y_onehot = torch.zeros(y_true.size(0), num_classes)
y_onehot.scatter_(1, y_true.unsqueeze(1), 1)
return y_onehot


class Recall(nn.Module):
def __init__(self, num_classes):
super().__init__()

self.num_classes = num_classes

def forward(self, y_true, y_pred):
true_onehot = one_hot_encode(y_true, self.num_classes)
pred_onehot = one_hot_encode(y_pred, self.num_classes)

true_positives = (true_onehot * pred_onehot).sum()

false_negatives = torch.sum(~pred_onehot[true_onehot.bool()].bool())

recall = true_positives / (true_positives + false_negatives)

return recall


def test_recall():
recall = Recall(7)

y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6])
y_pred = torch.tensor([2, 1, 2, 1, 4, 5, 6])

recall_score = recall(y_true, y_pred)

assert recall_score.allclose(torch.tensor(0.7143), atol=1e-5), f"Recall Score: {recall_score.item()}"


def test_one_hot_encode():
num_classes = 7

y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6])
y_onehot = one_hot_encode(y_true, num_classes)

assert y_onehot.shape == (7, 7), f"Shape: {y_onehot.shape}"
3 changes: 2 additions & 1 deletion utils/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__all__ = ["MagnusModel"]
__all__ = ["MagnusModel", "ChristianModel"]

from .christian_model import ChristianModel
from .magnus_model import MagnusModel
92 changes: 92 additions & 0 deletions utils/models/christian_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pytest
import torch
import torch.nn as nn


class CNNBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()

self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
padding=1,
)
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu = nn.ReLU()

def forward(self, x):
x = self.conv(x)
x = self.maxpool(x)
x = self.relu(x)
return x


class ChristianModel(nn.Module):
"""Simple CNN model for image classification.

Args
----
in_channels : int
Number of input channels.
num_classes : int
Number of classes in the dataset.

Processing Images
-----------------
Input: (N, C, H, W)
N: Batch size
C: Number of input channels
H: Height of the input image
W: Width of the input image

Example:
For grayscale images, C = 1.

Input Image Shape: (5, 1, 16, 16)
CNN1 Output Shape: (5, 50, 8, 8)
CNN2 Output Shape: (5, 100, 4, 4)
FC Output Shape: (5, num_classes)
"""
def __init__(self, in_channels, num_classes):
super().__init__()

self.cnn1 = CNNBlock(in_channels, 50)
self.cnn2 = CNNBlock(50, 100)

self.fc1 = nn.Linear(100 * 4 * 4, num_classes)
self.softmax = nn.Softmax(dim=1)

def forward(self, x):
x = self.cnn1(x)
x = self.cnn2(x)

x = x.view(x.size(0), -1)
x = self.fc1(x)
x = self.softmax(x)

return x


@pytest.mark.parametrize("in_channels, num_classes", [(1, 6), (3, 6)])
def test_christian_model(in_channels, num_classes):
n, c, h, w = 5, in_channels, 16, 16

model = ChristianModel(c, num_classes)

x = torch.randn(n, c, h, w)
y = model(x)

assert y.shape == (n, num_classes), f"Shape: {y.shape}"
assert y.sum(dim=1).allclose(torch.ones(n), atol=1e-5), f"Softmax output should sum to 1, but got: {y.sum()}"


if __name__ == "__main__":

model = ChristianModel(3, 7)

x = torch.randn(3, 3, 16, 16)
y = model(x)

print(y)