From 40bb5c03612dbe60cebf0f00d9ebb8f51ef33146 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 13:36:08 +0100 Subject: [PATCH 1/5] Add ChristianModel: 2 layer CNN w/maxpooling --- main.py | 2 +- utils/load_model.py | 19 ++++--- utils/models/__init__.py | 3 +- utils/models/christian_model.py | 92 +++++++++++++++++++++++++++++++++ 4 files changed, 106 insertions(+), 10 deletions(-) create mode 100644 utils/models/christian_model.py diff --git a/main.py b/main.py index 6fab214..e3c7c45 100644 --- a/main.py +++ b/main.py @@ -66,7 +66,7 @@ def main(): "--modelname", type=str, default="MagnusModel", - choices=["MagnusModel"], + choices=["MagnusModel", "ChristianModel"], help="Model which to be trained on", ) parser.add_argument( diff --git a/utils/load_model.py b/utils/load_model.py index db242a0..7e55699 100644 --- a/utils/load_model.py +++ b/utils/load_model.py @@ -1,12 +1,15 @@ import torch.nn as nn -from .models import MagnusModel +from .models import ChristianModel, MagnusModel -def load_model(modelname: str) -> nn.Module: - if modelname == "MagnusModel": - return MagnusModel() - else: - raise ValueError( - f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling" - ) +def load_model(modelname: str, *args, **kwargs) -> nn.Module: + match modelname.lower(): + case "magnusmodel": + return MagnusModel(*args, **kwargs) + case "christianmodel": + return ChristianModel(*args, **kwargs) + case _: + raise ValueError( + f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling" + ) diff --git a/utils/models/__init__.py b/utils/models/__init__.py index ca05548..7cbae91 100644 --- a/utils/models/__init__.py +++ b/utils/models/__init__.py @@ -1,3 +1,4 @@ -__all__ = ["MagnusModel"] +__all__ = ["MagnusModel", "ChristianModel"] +from .christian_model import ChristianModel from .magnus_model import MagnusModel diff --git a/utils/models/christian_model.py b/utils/models/christian_model.py new file mode 100644 index 0000000..9bdd2da --- /dev/null +++ b/utils/models/christian_model.py @@ -0,0 +1,92 @@ +import pytest +import torch +import torch.nn as nn + + +class CNNBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + padding=1, + ) + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.maxpool(x) + x = self.relu(x) + return x + + +class ChristianModel(nn.Module): + """Simple CNN model for image classification. + + Args + ---- + in_channels : int + Number of input channels. + num_classes : int + Number of classes in the dataset. + + Processing Images + ----------------- + Input: (N, C, H, W) + N: Batch size + C: Number of input channels + H: Height of the input image + W: Width of the input image + + Example: + For grayscale images, C = 1. + + Input Image Shape: (5, 1, 16, 16) + CNN1 Output Shape: (5, 50, 8, 8) + CNN2 Output Shape: (5, 100, 4, 4) + FC Output Shape: (5, num_classes) + """ + def __init__(self, in_channels, num_classes): + super().__init__() + + self.cnn1 = CNNBlock(in_channels, 50) + self.cnn2 = CNNBlock(50, 100) + + self.fc1 = nn.Linear(100 * 4 * 4, num_classes) + self.softmax = nn.Softmax(dim=1) + + def forward(self, x): + x = self.cnn1(x) + x = self.cnn2(x) + + x = x.view(x.size(0), -1) + x = self.fc1(x) + x = self.softmax(x) + + return x + + +@pytest.mark.parametrize("in_channels, num_classes", [(1, 6), (3, 6)]) +def test_christian_model(in_channels, num_classes): + n, c, h, w = 5, in_channels, 16, 16 + + model = ChristianModel(c, num_classes) + + x = torch.randn(n, c, h, w) + y = model(x) + + assert y.shape == (n, num_classes), f"Shape: {y.shape}" + assert y.sum(dim=1).allclose(torch.ones(n), atol=1e-5), f"Softmax output should sum to 1, but got: {y.sum()}" + + +if __name__ == "__main__": + + model = ChristianModel(3, 7) + + x = torch.randn(3, 3, 16, 16) + y = model(x) + + print(y) From fc787c2678bd01b64703d1d92f510f400109024b Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 14:08:29 +0100 Subject: [PATCH 2/5] finds number of channels based on dataset. Adds num_classes to dataset --- main.py | 20 +++++++++++++++----- utils/dataloaders/usps_0_6.py | 3 +++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 6fab214..e6330ec 100644 --- a/main.py +++ b/main.py @@ -108,7 +108,7 @@ def main(): parser.add_argument( "--device", type=str, - default="cuda", + default="cpu", choices=["cuda", "cpu", "mps"], help="Which device to run the training on.", ) @@ -124,10 +124,6 @@ def main(): device = args.device - # load model - model = load_model(args.modelname) - model.to(device) - metrics = MetricWrapper(*args.metric) # Dataset @@ -143,6 +139,20 @@ def main(): data_path=args.datafolder, ) + # Find number of channels in the dataset + if len(traindata[0][0].shape) == 2: + channels = 1 + else: + channels = traindata[0][0].shape[0] + + # load model + model = load_model( + args.modelname, + in_channels=channels, + num_classes=traindata.num_classes, + ) + model.to(device) + trainloader = DataLoader(traindata, batch_size=args.batchsize, shuffle=True, diff --git a/utils/dataloaders/usps_0_6.py b/utils/dataloaders/usps_0_6.py index 8d5f496..7a2608f 100644 --- a/utils/dataloaders/usps_0_6.py +++ b/utils/dataloaders/usps_0_6.py @@ -36,6 +36,8 @@ class USPSDataset0_6(Dataset): A function/transform that takes in a sample and returns a transformed version. idx : numpy.ndarray Indices of samples with labels 0-6. + num_classes : int + Number of classes in the dataset Methods ------- @@ -71,6 +73,7 @@ def __init__( super().__init__() self.path = list(data_path.glob("*.h5"))[0] self.transform = transform + self.num_classes = 7 if download: raise NotImplementedError("Download functionality not implemented.") From cd1e0866228930deffd3293f7c68eb7132df0ee9 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 14:12:26 +0100 Subject: [PATCH 3/5] Add Recall metric --- utils/metrics/__init__.py | 3 +- utils/metrics/recall.py | 62 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 utils/metrics/recall.py diff --git a/utils/metrics/__init__.py b/utils/metrics/__init__.py index 6e79fdc..3afeee5 100644 --- a/utils/metrics/__init__.py +++ b/utils/metrics/__init__.py @@ -1,3 +1,4 @@ -__all__ = ["EntropyPrediction"] +__all__ = ["EntropyPrediction", "Recall"] from .EntropyPred import EntropyPrediction +from .recall import Recall diff --git a/utils/metrics/recall.py b/utils/metrics/recall.py new file mode 100644 index 0000000..4aaae43 --- /dev/null +++ b/utils/metrics/recall.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn + + +def one_hot_encode(y_true, num_classes): + """One-hot encode the target tensor. + + Args + ---- + y_true : torch.Tensor + Target tensor. + num_classes : int + Number of classes in the dataset. + + Returns + ------- + torch.Tensor + One-hot encoded tensor. + """ + + y_onehot = torch.zeros(y_true.size(0), num_classes) + y_onehot.scatter_(1, y_true.unsqueeze(1), 1) + return y_onehot + + +class Recall(nn.Module): + def __init__(self, num_classes): + super().__init__() + + self.num_classes = num_classes + + def forward(self, y_true, y_pred): + true_onehot = one_hot_encode(y_true, self.num_classes) + pred_onehot = one_hot_encode(y_pred, self.num_classes) + + true_positives = (true_onehot * pred_onehot).sum() + + false_negatives = torch.sum(~pred_onehot[true_onehot.bool()].bool()) + + recall = true_positives / (true_positives + false_negatives) + + return recall + + +def test_recall(): + recall = Recall(7) + + y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6]) + y_pred = torch.tensor([2, 1, 2, 1, 4, 5, 6]) + + recall_score = recall(y_true, y_pred) + + assert recall_score.allclose(torch.tensor(0.7143), atol=1e-5), f"Recall Score: {recall_score.item()}" + + +def test_one_hot_encode(): + num_classes = 7 + + y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6]) + y_onehot = one_hot_encode(y_true, num_classes) + + assert y_onehot.shape == (7, 7), f"Shape: {y_onehot.shape}" From efe689465b203c21738e50d318929f9e35e8aec3 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 14:38:01 +0100 Subject: [PATCH 4/5] Fix bug where labels werent put on the device --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index 1b6bb96..fe563f5 100644 --- a/main.py +++ b/main.py @@ -196,7 +196,7 @@ def main(): model.eval() with th.no_grad(): for x, y in valiloader: - x = x.to(device) + x, y = x.to(device), y.to(device) pred = model.forward(x) loss = criterion(y, pred) evalloss.append(loss.item()) From 68b5616279997b4ed33f16a8a092e2f22a275e8a Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 14:38:32 +0100 Subject: [PATCH 5/5] Onehot encode labels in dataset --- utils/dataloaders/usps_0_6.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/utils/dataloaders/usps_0_6.py b/utils/dataloaders/usps_0_6.py index 7a2608f..4e68191 100644 --- a/utils/dataloaders/usps_0_6.py +++ b/utils/dataloaders/usps_0_6.py @@ -106,6 +106,12 @@ def __getitem__(self, idx): data = data.reshape(16, 16) + # one hot encode the target + target = np.eye(self.num_classes, dtype=np.float32)[target] + + # Add channel dimension + data = np.expand_dims(data, axis=0) + if self.transform: data = self.transform(data)