From 0043e11b80e8f6937d5ef9d0c8f34346c81e698f Mon Sep 17 00:00:00 2001 From: Jan Zavadil Date: Mon, 3 Feb 2025 17:59:05 +0100 Subject: [PATCH 1/9] Wrote the dataset, linked it to main, not tested --- environment.yml | 1 + main.py | 2 +- utils/dataloaders/__init__.py | 1 + utils/dataloaders/mnist_0_3.py | 130 +++++++++++++++++++++++++++++++++ utils/load_data.py | 4 +- 5 files changed, 136 insertions(+), 2 deletions(-) create mode 100644 utils/dataloaders/mnist_0_3.py diff --git a/environment.yml b/environment.yml index 214e5fc..65a0cec 100644 --- a/environment.yml +++ b/environment.yml @@ -18,6 +18,7 @@ dependencies: - pytest - ruff - scalene + - pickle - pip: - torch - torchvision diff --git a/main.py b/main.py index fe563f5..956a5cc 100644 --- a/main.py +++ b/main.py @@ -73,7 +73,7 @@ def main(): "--dataset", type=str, default="svhn", - choices=["svhn", "usps_0-6"], + choices=["svhn", "usps_0-6", "mnist_0-3"], help="Which dataset to train the model on.", ) diff --git a/utils/dataloaders/__init__.py b/utils/dataloaders/__init__.py index df404f7..b95c192 100644 --- a/utils/dataloaders/__init__.py +++ b/utils/dataloaders/__init__.py @@ -1,3 +1,4 @@ __all__ = ["USPSDataset0_6"] from .usps_0_6 import USPSDataset0_6 +from .mnist_0_3 import MNISTDataset0_3 \ No newline at end of file diff --git a/utils/dataloaders/mnist_0_3.py b/utils/dataloaders/mnist_0_3.py new file mode 100644 index 0000000..63da48d --- /dev/null +++ b/utils/dataloaders/mnist_0_3.py @@ -0,0 +1,130 @@ +from pathlib import Path + +from torch.utils.data import Dataset +import numpy as np +import urllib.request +import gzip +import os + + + +class MNISTDataset0_3(Dataset): + """ + A custom dataset class for loading MNIST data, specifically for digits 0 through 3. + Parameters + ---------- + data_path : Path + The root directory where the MNIST data is stored or will be downloaded. + train : bool, optional + If True, loads the training data, otherwise loads the test data. Default is False. + transform : callable, optional + A function/transform that takes in an image and returns a transformed version. Default is None. + download : bool, optional + If True, downloads the dataset if it is not already present in the specified data_path. Default is False. + Attributes + ---------- + data_path : Path + The root directory where the MNIST data is stored. + mnist_path : Path + The directory where the MNIST data files are stored. + train : bool + Indicates whether the training data or test data is being used. + transform : callable + A function/transform that takes in an image and returns a transformed version. + download : bool + Indicates whether the dataset should be downloaded if not present. + images_path : Path + The path to the image file (training or test) based on the `train` flag. + labels_path : Path + The path to the label file (training or test) based on the `train` flag. + idx : numpy.ndarray + Indices of the labels that are less than 4. + length : int + The number of samples in the dataset. + Methods + ------- + _parse_labels(train) + Parses the labels from the label file. + _chech_is_downloaded() + Checks if the dataset is already downloaded. + _download_data() + Downloads and extracts the MNIST dataset. + __len__() + Returns the number of samples in the dataset. + __getitem__(index) + Returns the image and label at the specified index. + """ + def __init__(self, data_path: Path, train: bool = False, transform=None, download: bool = False,): + super().__init__() + + self.data_path = data_path + self.mnist_path = self.data_path / "MNIST" + self.train = train + self.transform = transform + self.download = download + + if self.download and not self._chech_is_downloaded(): + self._download_data() + + self.images_path = self.mnist_path / ("train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte") + self.labels_path = self.mnist_path / ("train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte") + + labels = self._parse_labels(train=self.train) + + self.idx = np.where(labels < 4)[0] + + self.length = len(self.idx) + + + def _parse_labels(self, train): + with open(self.labels_path, "rb") as f: + data = np.frombuffer(f.read(), dtype=np.uint8, offset=8) + return data + + def _chech_is_downloaded(self): + if self.mnist_path.exists(): + required_files = ["train-images-idx3-ubyte", "train-labels-idx1-ubyte", "t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte"] + if all([(self.mnist_path / file).exists() for file in required_files]): + print("Data already downloaded.") + return True + else: + return False + else: + self.mnist_path.mkdir(parents=True, exist_ok=True) + return False + + + def _download_data(self): + urls = { + "train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", + "train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", + "test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", + "test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", + } + + for name, url in urls.items(): + file_path = os.path.join(self.mnist_path, url.split("/")[-1]) + if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading + urllib.request.urlretrieve(url, file_path) + with gzip.open(file_path, 'rb') as f_in: + with open(file_path.replace(".gz", ""), 'wb') as f_out: + f_out.write(f_in.read()) + os.remove(file_path) # Remove compressed file + + + def __len__(self): + return self.length + + def __getitem__(self, index): + with open(self.labels_path, "rb") as f: + f.seek(8 + index) # Jump to the label position + label = int.from_bytes(f.read(1), byteorder="big") # Read 1 byte for label + + with open(self.images_path, "rb") as f: + f.seek(16 + index * 28) # Jump to image position + image = np.frombuffer(f.read(28), dtype=np.uint8).reshape(28, 28) # Read image data + + if self.transform: + image = self.transform(image) + + return image, label \ No newline at end of file diff --git a/utils/load_data.py b/utils/load_data.py index ac1bcfd..e71f27e 100644 --- a/utils/load_data.py +++ b/utils/load_data.py @@ -1,11 +1,13 @@ from torch.utils.data import Dataset -from .dataloaders import USPSDataset0_6 +from .dataloaders import USPSDataset0_6, MNISTDataset0_3 def load_data(dataset: str, *args, **kwargs) -> Dataset: match dataset.lower(): case "usps_0-6": return USPSDataset0_6(*args, **kwargs) + case "mnist_0-3": + return MNISTDataset0_3(*args, **kwargs) case _: raise ValueError(f"Dataset: {dataset} not implemented.") From 5af2c6190aae718874046c29253b8d354702daea Mon Sep 17 00:00:00 2001 From: Jan Zavadil Date: Mon, 3 Feb 2025 18:11:33 +0100 Subject: [PATCH 2/9] fixed minor bugs discovered during testing the dataloader --- utils/dataloaders/mnist_0_3.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/utils/dataloaders/mnist_0_3.py b/utils/dataloaders/mnist_0_3.py index 63da48d..bd3c92f 100644 --- a/utils/dataloaders/mnist_0_3.py +++ b/utils/dataloaders/mnist_0_3.py @@ -62,7 +62,10 @@ def __init__(self, data_path: Path, train: bool = False, transform=None, downloa self.train = train self.transform = transform self.download = download + self.num_classes = 4 + if not self.download and not self._chech_is_downloaded(): + raise ValueError("Data not found. Set --download-data=True to download the data.") if self.download and not self._chech_is_downloaded(): self._download_data() @@ -121,8 +124,8 @@ def __getitem__(self, index): label = int.from_bytes(f.read(1), byteorder="big") # Read 1 byte for label with open(self.images_path, "rb") as f: - f.seek(16 + index * 28) # Jump to image position - image = np.frombuffer(f.read(28), dtype=np.uint8).reshape(28, 28) # Read image data + f.seek(16 + index * 28*28) # Jump to image position + image = np.frombuffer(f.read(28*28), dtype=np.uint8).reshape(28, 28) # Read image data if self.transform: image = self.transform(image) From ecb6db44660eac853f1376dd22e64b1f995bc0a9 Mon Sep 17 00:00:00 2001 From: Jan Zavadil Date: Tue, 4 Feb 2025 10:56:56 +0100 Subject: [PATCH 3/9] Changed the input of load_model to enable models to process all datasets --- main.py | 11 ++-- utils/dataloaders/__init__.py | 2 +- utils/dataloaders/mnist_0_3.py | 6 ++- utils/load_model.py | 4 +- utils/models/__init__.py | 3 +- utils/models/jan_model.py | 96 ++++++++++++++++++++++++++++++++++ 6 files changed, 111 insertions(+), 11 deletions(-) create mode 100644 utils/models/jan_model.py diff --git a/main.py b/main.py index 956a5cc..b9a340d 100644 --- a/main.py +++ b/main.py @@ -139,16 +139,15 @@ def main(): data_path=args.datafolder, ) - # Find number of channels in the dataset - if len(traindata[0][0].shape) == 2: - channels = 1 - else: - channels = traindata[0][0].shape[0] + # Find the shape of the data, if is 2D, add a channel dimension + data_shape = traindata[0][0].shape + if len(data_shape) == 2: + data_shape = (1, *data_shape) # load model model = load_model( args.modelname, - in_channels=channels, + image_shape=data_shape, num_classes=traindata.num_classes, ) model.to(device) diff --git a/utils/dataloaders/__init__.py b/utils/dataloaders/__init__.py index b95c192..842431a 100644 --- a/utils/dataloaders/__init__.py +++ b/utils/dataloaders/__init__.py @@ -1,4 +1,4 @@ -__all__ = ["USPSDataset0_6"] +__all__ = ["USPSDataset0_6","MNISTDataset0_3"] from .usps_0_6 import USPSDataset0_6 from .mnist_0_3 import MNISTDataset0_3 \ No newline at end of file diff --git a/utils/dataloaders/mnist_0_3.py b/utils/dataloaders/mnist_0_3.py index bd3c92f..df7214c 100644 --- a/utils/dataloaders/mnist_0_3.py +++ b/utils/dataloaders/mnist_0_3.py @@ -88,7 +88,7 @@ def _chech_is_downloaded(self): if self.mnist_path.exists(): required_files = ["train-images-idx3-ubyte", "train-labels-idx1-ubyte", "t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte"] if all([(self.mnist_path / file).exists() for file in required_files]): - print("Data already downloaded.") + print("MNIST Dataset already downloaded.") return True else: return False @@ -126,7 +126,9 @@ def __getitem__(self, index): with open(self.images_path, "rb") as f: f.seek(16 + index * 28*28) # Jump to image position image = np.frombuffer(f.read(28*28), dtype=np.uint8).reshape(28, 28) # Read image data - + + image = np.expand_dims(image, axis=0) # Add channel dimension + if self.transform: image = self.transform(image) diff --git a/utils/load_model.py b/utils/load_model.py index 7e55699..601e3c2 100644 --- a/utils/load_model.py +++ b/utils/load_model.py @@ -1,6 +1,6 @@ import torch.nn as nn -from .models import ChristianModel, MagnusModel +from .models import ChristianModel, MagnusModel, JanModel def load_model(modelname: str, *args, **kwargs) -> nn.Module: @@ -9,6 +9,8 @@ def load_model(modelname: str, *args, **kwargs) -> nn.Module: return MagnusModel(*args, **kwargs) case "christianmodel": return ChristianModel(*args, **kwargs) + case "janmodel": + return JanModel(*args, **kwargs) case _: raise ValueError( f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling" diff --git a/utils/models/__init__.py b/utils/models/__init__.py index 7cbae91..64706b0 100644 --- a/utils/models/__init__.py +++ b/utils/models/__init__.py @@ -1,4 +1,5 @@ -__all__ = ["MagnusModel", "ChristianModel"] +__all__ = ["MagnusModel", "ChristianModel", "JanModel"] from .christian_model import ChristianModel from .magnus_model import MagnusModel +from .jan_model import JanModel diff --git a/utils/models/jan_model.py b/utils/models/jan_model.py new file mode 100644 index 0000000..8f7ab4f --- /dev/null +++ b/utils/models/jan_model.py @@ -0,0 +1,96 @@ +import torch +""" +A simple neural network model for classification tasks. +Parameters +---------- +in_channels : int + Number of input channels. +num_classes : int + Number of output classes. +Attributes +---------- +in_channels : int + Number of input channels. +num_classes : int + Number of output classes. +fc1 : nn.Linear + First fully connected layer. +fc2 : nn.Linear + Second fully connected layer. +out : nn.Linear + Output fully connected layer. +leaky_relu : nn.LeakyReLU + Leaky ReLU activation function. +flatten : nn.Flatten + Flatten layer to reshape input tensor. +Methods +------- +forward(x) + Defines the forward pass of the model. +""" +import torch.nn as nn + + + +class JanModel(nn.Module): + """A simple MLP network model for image classification tasks. + + Args + ---- + in_channels : int + Number of input channels. + num_classes : int + Number of classes in the dataset. + + Processing Images + ----------------- + Input: (N, C, H, W) + N: Batch size + C: Number of input channels + H: Height of the input image + W: Width of the input image + + Example: + For grayscale images, C = 1. + + Input Image Shape: (5, 1, 28, 28) + flatten Output Shape: (5, 784) + fc1 Output Shape: (5, 100) + fc2 Output Shape: (5, 100) + out Output Shape: (5, num_classes) + """ + def __init__(self, image_shape, num_classes): + super().__init__() + + self.in_channels = image_shape[0] + self.height = image_shape[1] + self.width = image_shape[2] + self.num_classes = num_classes + + self.fc1 = nn.Linear(self.height * self.width * self.in_channels, 100) + + self.fc2 = nn.Linear(100, 100) + + self.out = nn.Linear(100, num_classes) + + self.leaky_relu = nn.LeakyReLU() + + self.flatten = nn.Flatten() + + def forward(self, x): + x = self.flatten(x) + x = self.fc1(x) + x = self.leaky_relu(x) + x = self.fc2(x) + x = self.leaky_relu(x) + x = self.out(x) + return x + + +if __name__ == "__main__": + model = JanModel(2, 4) + + x = torch.randn(3, 2, 28, 28) + y = model(x) + + print(y) From d1c3839faa2f9425a919296e05b0af8ccb390dd6 Mon Sep 17 00:00:00 2001 From: Jan Zavadil <79144013+hzavadil98@users.noreply.github.com> Date: Tue, 4 Feb 2025 10:58:53 +0100 Subject: [PATCH 4/9] -pickle Co-authored-by: Christian Salomonsen <55956280+salomaestro@users.noreply.github.com> --- environment.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/environment.yml b/environment.yml index 65a0cec..214e5fc 100644 --- a/environment.yml +++ b/environment.yml @@ -18,7 +18,6 @@ dependencies: - pytest - ruff - scalene - - pickle - pip: - torch - torchvision From 1285d36eeb92ca97db578e40e821af0b00d33acd Mon Sep 17 00:00:00 2001 From: Jan Zavadil Date: Tue, 4 Feb 2025 11:11:11 +0100 Subject: [PATCH 5/9] ran ruff and isort --- utils/dataloaders/__init__.py | 4 +- utils/dataloaders/mnist_0_3.py | 92 ++++++++++++++++++++-------------- utils/load_data.py | 2 +- utils/load_model.py | 2 +- utils/models/__init__.py | 2 +- utils/models/jan_model.py | 19 +++---- 6 files changed, 69 insertions(+), 52 deletions(-) diff --git a/utils/dataloaders/__init__.py b/utils/dataloaders/__init__.py index 842431a..1eca302 100644 --- a/utils/dataloaders/__init__.py +++ b/utils/dataloaders/__init__.py @@ -1,4 +1,4 @@ -__all__ = ["USPSDataset0_6","MNISTDataset0_3"] +__all__ = ["USPSDataset0_6", "MNISTDataset0_3"] +from .mnist_0_3 import MNISTDataset0_3 from .usps_0_6 import USPSDataset0_6 -from .mnist_0_3 import MNISTDataset0_3 \ No newline at end of file diff --git a/utils/dataloaders/mnist_0_3.py b/utils/dataloaders/mnist_0_3.py index df7214c..5e5a935 100644 --- a/utils/dataloaders/mnist_0_3.py +++ b/utils/dataloaders/mnist_0_3.py @@ -1,11 +1,10 @@ -from pathlib import Path - -from torch.utils.data import Dataset -import numpy as np -import urllib.request import gzip import os +import urllib.request +from pathlib import Path +import numpy as np +from torch.utils.data import Dataset class MNISTDataset0_3(Dataset): @@ -54,39 +53,56 @@ class MNISTDataset0_3(Dataset): __getitem__(index) Returns the image and label at the specified index. """ - def __init__(self, data_path: Path, train: bool = False, transform=None, download: bool = False,): + + def __init__( + self, + data_path: Path, + train: bool = False, + transform=None, + download: bool = False, + ): super().__init__() - + self.data_path = data_path self.mnist_path = self.data_path / "MNIST" self.train = train self.transform = transform self.download = download self.num_classes = 4 - + if not self.download and not self._chech_is_downloaded(): - raise ValueError("Data not found. Set --download-data=True to download the data.") + raise ValueError( + "Data not found. Set --download-data=True to download the data." + ) if self.download and not self._chech_is_downloaded(): self._download_data() - - self.images_path = self.mnist_path / ("train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte") - self.labels_path = self.mnist_path / ("train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte") - + + self.images_path = self.mnist_path / ( + "train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte" + ) + self.labels_path = self.mnist_path / ( + "train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte" + ) + labels = self._parse_labels(train=self.train) - - self.idx = np.where(labels < 4)[0] - + + self.idx = np.where(labels < 4)[0] + self.length = len(self.idx) - - + def _parse_labels(self, train): with open(self.labels_path, "rb") as f: data = np.frombuffer(f.read(), dtype=np.uint8, offset=8) return data - + def _chech_is_downloaded(self): if self.mnist_path.exists(): - required_files = ["train-images-idx3-ubyte", "train-labels-idx1-ubyte", "t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte"] + required_files = [ + "train-images-idx3-ubyte", + "train-labels-idx1-ubyte", + "t10k-images-idx3-ubyte", + "t10k-labels-idx1-ubyte", + ] if all([(self.mnist_path / file).exists() for file in required_files]): print("MNIST Dataset already downloaded.") return True @@ -95,26 +111,24 @@ def _chech_is_downloaded(self): else: self.mnist_path.mkdir(parents=True, exist_ok=True) return False - - + def _download_data(self): urls = { - "train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", - "train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", - "test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", - "test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", - } - + "train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", + "train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", + "test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", + "test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", + } + for name, url in urls.items(): file_path = os.path.join(self.mnist_path, url.split("/")[-1]) if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading urllib.request.urlretrieve(url, file_path) - with gzip.open(file_path, 'rb') as f_in: - with open(file_path.replace(".gz", ""), 'wb') as f_out: + with gzip.open(file_path, "rb") as f_in: + with open(file_path.replace(".gz", ""), "wb") as f_out: f_out.write(f_in.read()) os.remove(file_path) # Remove compressed file - def __len__(self): return self.length @@ -124,12 +138,14 @@ def __getitem__(self, index): label = int.from_bytes(f.read(1), byteorder="big") # Read 1 byte for label with open(self.images_path, "rb") as f: - f.seek(16 + index * 28*28) # Jump to image position - image = np.frombuffer(f.read(28*28), dtype=np.uint8).reshape(28, 28) # Read image data - - image = np.expand_dims(image, axis=0) # Add channel dimension - + f.seek(16 + index * 28 * 28) # Jump to image position + image = np.frombuffer(f.read(28 * 28), dtype=np.uint8).reshape( + 28, 28 + ) # Read image data + + image = np.expand_dims(image, axis=0) # Add channel dimension + if self.transform: image = self.transform(image) - - return image, label \ No newline at end of file + + return image, label diff --git a/utils/load_data.py b/utils/load_data.py index e71f27e..7252e4d 100644 --- a/utils/load_data.py +++ b/utils/load_data.py @@ -1,6 +1,6 @@ from torch.utils.data import Dataset -from .dataloaders import USPSDataset0_6, MNISTDataset0_3 +from .dataloaders import MNISTDataset0_3, USPSDataset0_6 def load_data(dataset: str, *args, **kwargs) -> Dataset: diff --git a/utils/load_model.py b/utils/load_model.py index 601e3c2..8c76959 100644 --- a/utils/load_model.py +++ b/utils/load_model.py @@ -1,6 +1,6 @@ import torch.nn as nn -from .models import ChristianModel, MagnusModel, JanModel +from .models import ChristianModel, JanModel, MagnusModel def load_model(modelname: str, *args, **kwargs) -> nn.Module: diff --git a/utils/models/__init__.py b/utils/models/__init__.py index 64706b0..eb09d1d 100644 --- a/utils/models/__init__.py +++ b/utils/models/__init__.py @@ -1,5 +1,5 @@ __all__ = ["MagnusModel", "ChristianModel", "JanModel"] from .christian_model import ChristianModel -from .magnus_model import MagnusModel from .jan_model import JanModel +from .magnus_model import MagnusModel diff --git a/utils/models/jan_model.py b/utils/models/jan_model.py index 8f7ab4f..4b4c3d1 100644 --- a/utils/models/jan_model.py +++ b/utils/models/jan_model.py @@ -1,4 +1,5 @@ import torch + """ A simple neural network model for classification tasks. Parameters @@ -31,7 +32,6 @@ import torch.nn as nn - class JanModel(nn.Module): """A simple MLP network model for image classification tasks. @@ -59,22 +59,23 @@ class JanModel(nn.Module): fc2 Output Shape: (5, 100) out Output Shape: (5, num_classes) """ + def __init__(self, image_shape, num_classes): super().__init__() - + self.in_channels = image_shape[0] self.height = image_shape[1] self.width = image_shape[2] self.num_classes = num_classes - + self.fc1 = nn.Linear(self.height * self.width * self.in_channels, 100) - + self.fc2 = nn.Linear(100, 100) - + self.out = nn.Linear(100, num_classes) - + self.leaky_relu = nn.LeakyReLU() - + self.flatten = nn.Flatten() def forward(self, x): @@ -85,8 +86,8 @@ def forward(self, x): x = self.leaky_relu(x) x = self.out(x) return x - - + + if __name__ == "__main__": model = JanModel(2, 4) From 9ad01e4939e53b0f7111e71d0c556428b7dabd3d Mon Sep 17 00:00:00 2001 From: Jan Zavadil Date: Tue, 4 Feb 2025 11:18:11 +0100 Subject: [PATCH 6/9] fixed jan_model.py to pass test --- utils/models/jan_model.py | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/utils/models/jan_model.py b/utils/models/jan_model.py index 4b4c3d1..ef4ac66 100644 --- a/utils/models/jan_model.py +++ b/utils/models/jan_model.py @@ -1,34 +1,4 @@ import torch - -""" -A simple neural network model for classification tasks. -Parameters ----------- -in_channels : int - Number of input channels. -num_classes : int - Number of output classes. -Attributes ----------- -in_channels : int - Number of input channels. -num_classes : int - Number of output classes. -fc1 : nn.Linear - First fully connected layer. -fc2 : nn.Linear - Second fully connected layer. -out : nn.Linear - Output fully connected layer. -leaky_relu : nn.LeakyReLU - Leaky ReLU activation function. -flatten : nn.Flatten - Flatten layer to reshape input tensor. -Methods -------- -forward(x) - Defines the forward pass of the model. -""" import torch.nn as nn From f4e5591447278f5fa2d454fe0ea14bdd2fdb3be7 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Tue, 4 Feb 2025 11:15:13 +0100 Subject: [PATCH 7/9] Update model to input `image_shape` rather than `in_channels` --- utils/models/christian_model.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/utils/models/christian_model.py b/utils/models/christian_model.py index 1adb76e..a277b33 100644 --- a/utils/models/christian_model.py +++ b/utils/models/christian_model.py @@ -27,8 +27,8 @@ class ChristianModel(nn.Module): Args ---- - in_channels : int - Number of input channels. + image_shape : tuple(int, int, int) + Shape of the input image (C, H, W). num_classes : int Number of classes in the dataset. @@ -49,10 +49,12 @@ class ChristianModel(nn.Module): FC Output Shape: (5, num_classes) """ - def __init__(self, in_channels, num_classes): + def __init__(self, image_shape, num_classes): super().__init__() - self.cnn1 = CNNBlock(in_channels, 50) + C, *_ = image_shape + + self.cnn1 = CNNBlock(C, 50) self.cnn2 = CNNBlock(50, 100) self.fc1 = nn.Linear(100 * 4 * 4, num_classes) From d911e4afaad787263b49b63225fccc21219c4d01 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Tue, 4 Feb 2025 11:15:58 +0100 Subject: [PATCH 8/9] Update tests to accept new model input --- tests/test_models.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 4747490..15a7504 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,11 +4,14 @@ from utils.models import ChristianModel -@pytest.mark.parametrize("in_channels, num_classes", [(1, 6), (3, 6)]) -def test_christian_model(in_channels, num_classes): - n, c, h, w = 5, in_channels, 16, 16 +@pytest.mark.parametrize( + "image_shape, num_classes", + [((1, 16, 16), 6), ((3, 16, 16), 6)], +) +def test_christian_model(image_shape, num_classes): + n, c, h, w = 5, *image_shape - model = ChristianModel(c, num_classes) + model = ChristianModel(image_shape, num_classes) x = torch.randn(n, c, h, w) y = model(x) From 5ae4d22988f1b16a7f218f411a4a235f63ef4b81 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Tue, 4 Feb 2025 13:34:01 +0100 Subject: [PATCH 9/9] Format after conflicts resolved --- main.py | 37 +++++++++++++------------- tests/test_metrics.py | 2 +- utils/dataloaders/__init__.py | 2 +- utils/load_data.py | 5 ++-- utils/metrics/F1.py | 1 - utils/models/solveig_model.py | 49 +++++++++++++++++------------------ 6 files changed, 48 insertions(+), 48 deletions(-) diff --git a/main.py b/main.py index badb174..b77ea43 100644 --- a/main.py +++ b/main.py @@ -152,15 +152,16 @@ def main(): ) model.to(device) - trainloader = DataLoader(traindata, - batch_size=args.batchsize, - shuffle=True, - pin_memory=True, - drop_last=True) - valiloader = DataLoader(validata, - batch_size=args.batchsize, - shuffle=False, - pin_memory=True) + trainloader = DataLoader( + traindata, + batch_size=args.batchsize, + shuffle=True, + pin_memory=True, + drop_last=True, + ) + valiloader = DataLoader( + validata, batch_size=args.batchsize, shuffle=False, pin_memory=True + ) criterion = nn.CrossEntropyLoss() optimizer = th.optim.Adam(model.parameters(), lr=args.learning_rate) @@ -170,12 +171,10 @@ def main(): print("Dry run completed") exit(0) - wandb.init(project='', - tags=[]) + wandb.init(project="", tags=[]) wandb.watch(model) for epoch in range(args.epoch): - # Training loop start trainingloss = [] model.train() @@ -200,12 +199,14 @@ def main(): loss = criterion(y, pred) evalloss.append(loss.item()) - wandb.log({ - 'Epoch': epoch, - 'Train loss': np.mean(trainingloss), - 'Evaluation Loss': np.mean(evalloss) - }) + wandb.log( + { + "Epoch": epoch, + "Train loss": np.mean(trainingloss), + "Evaluation Loss": np.mean(evalloss), + } + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 1650e01..b7e1baa 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,4 +1,4 @@ -from utils.metrics import Recall, F1Score +from utils.metrics import F1Score, Recall def test_recall(): diff --git a/utils/dataloaders/__init__.py b/utils/dataloaders/__init__.py index e986d44..1f506e6 100644 --- a/utils/dataloaders/__init__.py +++ b/utils/dataloaders/__init__.py @@ -2,4 +2,4 @@ from .mnist_0_3 import MNISTDataset0_3 from .usps_0_6 import USPSDataset0_6 -from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset \ No newline at end of file +from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset diff --git a/utils/load_data.py b/utils/load_data.py index cb7b3f5..4d27b65 100644 --- a/utils/load_data.py +++ b/utils/load_data.py @@ -1,6 +1,7 @@ from torch.utils.data import Dataset -from .dataloaders import MNISTDataset0_3, USPSDataset0_6, USPSH5_Digit_7_9_Dataset +from .dataloaders import (MNISTDataset0_3, USPSDataset0_6, + USPSH5_Digit_7_9_Dataset) def load_data(dataset: str, *args, **kwargs) -> Dataset: @@ -10,6 +11,6 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset: case "mnist_0-3": return MNISTDataset0_3(*args, **kwargs) case "usps_7-9": - return USPSH5_Digit_7_9_Dataset(*args, **kwargs) + return USPSH5_Digit_7_9_Dataset(*args, **kwargs) case _: raise ValueError(f"Dataset: {dataset} not implemented.") diff --git a/utils/metrics/F1.py b/utils/metrics/F1.py index 36e5e34..1e0e795 100644 --- a/utils/metrics/F1.py +++ b/utils/metrics/F1.py @@ -84,4 +84,3 @@ def compute(self): ) return f1_score - diff --git a/utils/models/solveig_model.py b/utils/models/solveig_model.py index c16dbaf..d04094b 100644 --- a/utils/models/solveig_model.py +++ b/utils/models/solveig_model.py @@ -4,26 +4,26 @@ class SolveigModel(nn.Module): """ - A Convolutional Neural Network model for classification. - - Args - ---- - image_shape : tuple(int, int, int) - Shape of the input image (C, H, W). - num_classes : int - Number of classes in the dataset. - - Attributes: - ----------- - conv_block1 : nn.Sequential - First convolutional block containing a convolutional layer, ReLU activation, and max-pooling. - conv_block2 : nn.Sequential - Second convolutional block containing a convolutional layer and ReLU activation. - conv_block3 : nn.Sequential - Third convolutional block containing a convolutional layer and ReLU activation. - fc1 : nn.Linear - Fully connected layer that outputs the final classification scores. - """ + A Convolutional Neural Network model for classification. + + Args + ---- + image_shape : tuple(int, int, int) + Shape of the input image (C, H, W). + num_classes : int + Number of classes in the dataset. + + Attributes: + ----------- + conv_block1 : nn.Sequential + First convolutional block containing a convolutional layer, ReLU activation, and max-pooling. + conv_block2 : nn.Sequential + Second convolutional block containing a convolutional layer and ReLU activation. + conv_block3 : nn.Sequential + Third convolutional block containing a convolutional layer and ReLU activation. + fc1 : nn.Linear + Fully connected layer that outputs the final classification scores. + """ def __init__(self, image_shape, num_classes): super().__init__() @@ -34,19 +34,19 @@ def __init__(self, image_shape, num_classes): self.conv_block1 = nn.Sequential( nn.Conv2d(in_channels=C, out_channels=25, kernel_size=3, padding=1), nn.ReLU(), - nn.MaxPool2d(kernel_size=2, stride=2) + nn.MaxPool2d(kernel_size=2, stride=2), ) # Define the second convolutional block (conv + relu) self.conv_block2 = nn.Sequential( nn.Conv2d(in_channels=25, out_channels=50, kernel_size=3, padding=1), - nn.ReLU() + nn.ReLU(), ) # Define the third convolutional block (conv + relu) self.conv_block3 = nn.Sequential( nn.Conv2d(in_channels=50, out_channels=100, kernel_size=3, padding=1), - nn.ReLU() + nn.ReLU(), ) self.fc1 = nn.Linear(100 * 8 * 8, num_classes) @@ -64,8 +64,7 @@ def forward(self, x): if __name__ == "__main__": - - x = torch.randn(1,3, 16, 16) + x = torch.randn(1, 3, 16, 16) model = SolveigModel(x.shape[1:], 3)