diff --git a/main.py b/main.py index ff7664f..b77ea43 100644 --- a/main.py +++ b/main.py @@ -73,7 +73,7 @@ def main(): "--dataset", type=str, default="svhn", - choices=["svhn", "usps_0-6", "uspsh5_7_9"], + choices=["svhn", "usps_0-6", "uspsh5_7_9", "mnist_0-3"], help="Which dataset to train the model on.", ) @@ -139,29 +139,29 @@ def main(): data_path=args.datafolder, ) - # Find number of channels in the dataset - if len(traindata[0][0].shape) == 2: - channels = 1 - else: - channels = traindata[0][0].shape[0] + # Find the shape of the data, if is 2D, add a channel dimension + data_shape = traindata[0][0].shape + if len(data_shape) == 2: + data_shape = (1, *data_shape) # load model model = load_model( args.modelname, - in_channels=channels, + image_shape=data_shape, num_classes=traindata.num_classes, ) model.to(device) - trainloader = DataLoader(traindata, - batch_size=args.batchsize, - shuffle=True, - pin_memory=True, - drop_last=True) - valiloader = DataLoader(validata, - batch_size=args.batchsize, - shuffle=False, - pin_memory=True) + trainloader = DataLoader( + traindata, + batch_size=args.batchsize, + shuffle=True, + pin_memory=True, + drop_last=True, + ) + valiloader = DataLoader( + validata, batch_size=args.batchsize, shuffle=False, pin_memory=True + ) criterion = nn.CrossEntropyLoss() optimizer = th.optim.Adam(model.parameters(), lr=args.learning_rate) @@ -171,12 +171,10 @@ def main(): print("Dry run completed") exit(0) - wandb.init(project='', - tags=[]) + wandb.init(project="", tags=[]) wandb.watch(model) for epoch in range(args.epoch): - # Training loop start trainingloss = [] model.train() @@ -201,12 +199,14 @@ def main(): loss = criterion(y, pred) evalloss.append(loss.item()) - wandb.log({ - 'Epoch': epoch, - 'Train loss': np.mean(trainingloss), - 'Evaluation Loss': np.mean(evalloss) - }) + wandb.log( + { + "Epoch": epoch, + "Train loss": np.mean(trainingloss), + "Evaluation Loss": np.mean(evalloss), + } + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 1650e01..b7e1baa 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,4 +1,4 @@ -from utils.metrics import Recall, F1Score +from utils.metrics import F1Score, Recall def test_recall(): diff --git a/tests/test_models.py b/tests/test_models.py index 4747490..15a7504 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,11 +4,14 @@ from utils.models import ChristianModel -@pytest.mark.parametrize("in_channels, num_classes", [(1, 6), (3, 6)]) -def test_christian_model(in_channels, num_classes): - n, c, h, w = 5, in_channels, 16, 16 +@pytest.mark.parametrize( + "image_shape, num_classes", + [((1, 16, 16), 6), ((3, 16, 16), 6)], +) +def test_christian_model(image_shape, num_classes): + n, c, h, w = 5, *image_shape - model = ChristianModel(c, num_classes) + model = ChristianModel(image_shape, num_classes) x = torch.randn(n, c, h, w) y = model(x) diff --git a/utils/dataloaders/__init__.py b/utils/dataloaders/__init__.py index bb97adc..1f506e6 100644 --- a/utils/dataloaders/__init__.py +++ b/utils/dataloaders/__init__.py @@ -1,4 +1,5 @@ -__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset"] +__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset", "MNISTDataset0_3"] +from .mnist_0_3 import MNISTDataset0_3 from .usps_0_6 import USPSDataset0_6 -from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset \ No newline at end of file +from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset diff --git a/utils/dataloaders/mnist_0_3.py b/utils/dataloaders/mnist_0_3.py new file mode 100644 index 0000000..5e5a935 --- /dev/null +++ b/utils/dataloaders/mnist_0_3.py @@ -0,0 +1,151 @@ +import gzip +import os +import urllib.request +from pathlib import Path + +import numpy as np +from torch.utils.data import Dataset + + +class MNISTDataset0_3(Dataset): + """ + A custom dataset class for loading MNIST data, specifically for digits 0 through 3. + Parameters + ---------- + data_path : Path + The root directory where the MNIST data is stored or will be downloaded. + train : bool, optional + If True, loads the training data, otherwise loads the test data. Default is False. + transform : callable, optional + A function/transform that takes in an image and returns a transformed version. Default is None. + download : bool, optional + If True, downloads the dataset if it is not already present in the specified data_path. Default is False. + Attributes + ---------- + data_path : Path + The root directory where the MNIST data is stored. + mnist_path : Path + The directory where the MNIST data files are stored. + train : bool + Indicates whether the training data or test data is being used. + transform : callable + A function/transform that takes in an image and returns a transformed version. + download : bool + Indicates whether the dataset should be downloaded if not present. + images_path : Path + The path to the image file (training or test) based on the `train` flag. + labels_path : Path + The path to the label file (training or test) based on the `train` flag. + idx : numpy.ndarray + Indices of the labels that are less than 4. + length : int + The number of samples in the dataset. + Methods + ------- + _parse_labels(train) + Parses the labels from the label file. + _chech_is_downloaded() + Checks if the dataset is already downloaded. + _download_data() + Downloads and extracts the MNIST dataset. + __len__() + Returns the number of samples in the dataset. + __getitem__(index) + Returns the image and label at the specified index. + """ + + def __init__( + self, + data_path: Path, + train: bool = False, + transform=None, + download: bool = False, + ): + super().__init__() + + self.data_path = data_path + self.mnist_path = self.data_path / "MNIST" + self.train = train + self.transform = transform + self.download = download + self.num_classes = 4 + + if not self.download and not self._chech_is_downloaded(): + raise ValueError( + "Data not found. Set --download-data=True to download the data." + ) + if self.download and not self._chech_is_downloaded(): + self._download_data() + + self.images_path = self.mnist_path / ( + "train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte" + ) + self.labels_path = self.mnist_path / ( + "train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte" + ) + + labels = self._parse_labels(train=self.train) + + self.idx = np.where(labels < 4)[0] + + self.length = len(self.idx) + + def _parse_labels(self, train): + with open(self.labels_path, "rb") as f: + data = np.frombuffer(f.read(), dtype=np.uint8, offset=8) + return data + + def _chech_is_downloaded(self): + if self.mnist_path.exists(): + required_files = [ + "train-images-idx3-ubyte", + "train-labels-idx1-ubyte", + "t10k-images-idx3-ubyte", + "t10k-labels-idx1-ubyte", + ] + if all([(self.mnist_path / file).exists() for file in required_files]): + print("MNIST Dataset already downloaded.") + return True + else: + return False + else: + self.mnist_path.mkdir(parents=True, exist_ok=True) + return False + + def _download_data(self): + urls = { + "train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", + "train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", + "test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", + "test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", + } + + for name, url in urls.items(): + file_path = os.path.join(self.mnist_path, url.split("/")[-1]) + if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading + urllib.request.urlretrieve(url, file_path) + with gzip.open(file_path, "rb") as f_in: + with open(file_path.replace(".gz", ""), "wb") as f_out: + f_out.write(f_in.read()) + os.remove(file_path) # Remove compressed file + + def __len__(self): + return self.length + + def __getitem__(self, index): + with open(self.labels_path, "rb") as f: + f.seek(8 + index) # Jump to the label position + label = int.from_bytes(f.read(1), byteorder="big") # Read 1 byte for label + + with open(self.images_path, "rb") as f: + f.seek(16 + index * 28 * 28) # Jump to image position + image = np.frombuffer(f.read(28 * 28), dtype=np.uint8).reshape( + 28, 28 + ) # Read image data + + image = np.expand_dims(image, axis=0) # Add channel dimension + + if self.transform: + image = self.transform(image) + + return image, label diff --git a/utils/load_data.py b/utils/load_data.py index f54e94a..4d27b65 100644 --- a/utils/load_data.py +++ b/utils/load_data.py @@ -1,13 +1,16 @@ from torch.utils.data import Dataset -from .dataloaders import USPSDataset0_6, USPSH5_Digit_7_9_Dataset +from .dataloaders import (MNISTDataset0_3, USPSDataset0_6, + USPSH5_Digit_7_9_Dataset) def load_data(dataset: str, *args, **kwargs) -> Dataset: match dataset.lower(): case "usps_0-6": return USPSDataset0_6(*args, **kwargs) + case "mnist_0-3": + return MNISTDataset0_3(*args, **kwargs) case "usps_7-9": - return USPSH5_Digit_7_9_Dataset(*args, **kwargs) + return USPSH5_Digit_7_9_Dataset(*args, **kwargs) case _: raise ValueError(f"Dataset: {dataset} not implemented.") diff --git a/utils/load_model.py b/utils/load_model.py index 21d9d03..b8f96e2 100644 --- a/utils/load_model.py +++ b/utils/load_model.py @@ -1,6 +1,6 @@ import torch.nn as nn -from .models import ChristianModel, MagnusModel, SolveigModel +from .models import ChristianModel, JanModel, MagnusModel, SolveigModel def load_model(modelname: str, *args, **kwargs) -> nn.Module: @@ -9,6 +9,8 @@ def load_model(modelname: str, *args, **kwargs) -> nn.Module: return MagnusModel(*args, **kwargs) case "christianmodel": return ChristianModel(*args, **kwargs) + case "janmodel": + return JanModel(*args, **kwargs) case "solveigmodel": return SolveigModel(*args, **kwargs) case _: diff --git a/utils/metrics/F1.py b/utils/metrics/F1.py index 36e5e34..1e0e795 100644 --- a/utils/metrics/F1.py +++ b/utils/metrics/F1.py @@ -84,4 +84,3 @@ def compute(self): ) return f1_score - diff --git a/utils/models/__init__.py b/utils/models/__init__.py index 8573991..9ff7ced 100644 --- a/utils/models/__init__.py +++ b/utils/models/__init__.py @@ -1,5 +1,6 @@ -__all__ = ["MagnusModel", "ChristianModel", "SolveigModel"] +__all__ = ["MagnusModel", "ChristianModel", "JanModel", "SolveigModel"] from .christian_model import ChristianModel +from .jan_model import JanModel from .magnus_model import MagnusModel from .solveig_model import SolveigModel diff --git a/utils/models/christian_model.py b/utils/models/christian_model.py index 1adb76e..a277b33 100644 --- a/utils/models/christian_model.py +++ b/utils/models/christian_model.py @@ -27,8 +27,8 @@ class ChristianModel(nn.Module): Args ---- - in_channels : int - Number of input channels. + image_shape : tuple(int, int, int) + Shape of the input image (C, H, W). num_classes : int Number of classes in the dataset. @@ -49,10 +49,12 @@ class ChristianModel(nn.Module): FC Output Shape: (5, num_classes) """ - def __init__(self, in_channels, num_classes): + def __init__(self, image_shape, num_classes): super().__init__() - self.cnn1 = CNNBlock(in_channels, 50) + C, *_ = image_shape + + self.cnn1 = CNNBlock(C, 50) self.cnn2 = CNNBlock(50, 100) self.fc1 = nn.Linear(100 * 4 * 4, num_classes) diff --git a/utils/models/jan_model.py b/utils/models/jan_model.py new file mode 100644 index 0000000..ef4ac66 --- /dev/null +++ b/utils/models/jan_model.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn + + +class JanModel(nn.Module): + """A simple MLP network model for image classification tasks. + + Args + ---- + in_channels : int + Number of input channels. + num_classes : int + Number of classes in the dataset. + + Processing Images + ----------------- + Input: (N, C, H, W) + N: Batch size + C: Number of input channels + H: Height of the input image + W: Width of the input image + + Example: + For grayscale images, C = 1. + + Input Image Shape: (5, 1, 28, 28) + flatten Output Shape: (5, 784) + fc1 Output Shape: (5, 100) + fc2 Output Shape: (5, 100) + out Output Shape: (5, num_classes) + """ + + def __init__(self, image_shape, num_classes): + super().__init__() + + self.in_channels = image_shape[0] + self.height = image_shape[1] + self.width = image_shape[2] + self.num_classes = num_classes + + self.fc1 = nn.Linear(self.height * self.width * self.in_channels, 100) + + self.fc2 = nn.Linear(100, 100) + + self.out = nn.Linear(100, num_classes) + + self.leaky_relu = nn.LeakyReLU() + + self.flatten = nn.Flatten() + + def forward(self, x): + x = self.flatten(x) + x = self.fc1(x) + x = self.leaky_relu(x) + x = self.fc2(x) + x = self.leaky_relu(x) + x = self.out(x) + return x + + +if __name__ == "__main__": + model = JanModel(2, 4) + + x = torch.randn(3, 2, 28, 28) + y = model(x) + + print(y) diff --git a/utils/models/solveig_model.py b/utils/models/solveig_model.py index c16dbaf..d04094b 100644 --- a/utils/models/solveig_model.py +++ b/utils/models/solveig_model.py @@ -4,26 +4,26 @@ class SolveigModel(nn.Module): """ - A Convolutional Neural Network model for classification. - - Args - ---- - image_shape : tuple(int, int, int) - Shape of the input image (C, H, W). - num_classes : int - Number of classes in the dataset. - - Attributes: - ----------- - conv_block1 : nn.Sequential - First convolutional block containing a convolutional layer, ReLU activation, and max-pooling. - conv_block2 : nn.Sequential - Second convolutional block containing a convolutional layer and ReLU activation. - conv_block3 : nn.Sequential - Third convolutional block containing a convolutional layer and ReLU activation. - fc1 : nn.Linear - Fully connected layer that outputs the final classification scores. - """ + A Convolutional Neural Network model for classification. + + Args + ---- + image_shape : tuple(int, int, int) + Shape of the input image (C, H, W). + num_classes : int + Number of classes in the dataset. + + Attributes: + ----------- + conv_block1 : nn.Sequential + First convolutional block containing a convolutional layer, ReLU activation, and max-pooling. + conv_block2 : nn.Sequential + Second convolutional block containing a convolutional layer and ReLU activation. + conv_block3 : nn.Sequential + Third convolutional block containing a convolutional layer and ReLU activation. + fc1 : nn.Linear + Fully connected layer that outputs the final classification scores. + """ def __init__(self, image_shape, num_classes): super().__init__() @@ -34,19 +34,19 @@ def __init__(self, image_shape, num_classes): self.conv_block1 = nn.Sequential( nn.Conv2d(in_channels=C, out_channels=25, kernel_size=3, padding=1), nn.ReLU(), - nn.MaxPool2d(kernel_size=2, stride=2) + nn.MaxPool2d(kernel_size=2, stride=2), ) # Define the second convolutional block (conv + relu) self.conv_block2 = nn.Sequential( nn.Conv2d(in_channels=25, out_channels=50, kernel_size=3, padding=1), - nn.ReLU() + nn.ReLU(), ) # Define the third convolutional block (conv + relu) self.conv_block3 = nn.Sequential( nn.Conv2d(in_channels=50, out_channels=100, kernel_size=3, padding=1), - nn.ReLU() + nn.ReLU(), ) self.fc1 = nn.Linear(100 * 8 * 8, num_classes) @@ -64,8 +64,7 @@ def forward(self, x): if __name__ == "__main__": - - x = torch.randn(1,3, 16, 16) + x = torch.randn(1, 3, 16, 16) model = SolveigModel(x.shape[1:], 3)