From d88ef49125008d3831198b5b50f142b2f6510b44 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Thu, 30 Jan 2025 14:55:51 +0100 Subject: [PATCH 01/30] Dont track (personal) binary directory --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 8593055..281bb0a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ Data/ Results/ Experiments/ _build/ +bin/ From 1b20dd4accfbfff7c11818bcc9ba702fd92b0be0 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Thu, 30 Jan 2025 16:06:25 +0100 Subject: [PATCH 02/30] Add a dataset class loading USPS images with labels 0 to 6 This commit comes bundled with a class for loading USPS images and labels between 0 to 6. It is complete with a test, checking that it works as intended --- src/datahandlers/__init__.py | 3 + src/datahandlers/usps_0_6.py | 117 +++++++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 src/datahandlers/__init__.py create mode 100644 src/datahandlers/usps_0_6.py diff --git a/src/datahandlers/__init__.py b/src/datahandlers/__init__.py new file mode 100644 index 0000000..df404f7 --- /dev/null +++ b/src/datahandlers/__init__.py @@ -0,0 +1,3 @@ +__all__ = ["USPSDataset0_6"] + +from .usps_0_6 import USPSDataset0_6 diff --git a/src/datahandlers/usps_0_6.py b/src/datahandlers/usps_0_6.py new file mode 100644 index 0000000..29fe0da --- /dev/null +++ b/src/datahandlers/usps_0_6.py @@ -0,0 +1,117 @@ +""" +Dataset class for USPS dataset with labels 0-6. + +This module contains the Dataset class for the USPS dataset with labels 0-6. +""" + +from pathlib import Path + +import h5py as h5 +import numpy as np +from torch.utils.data import Dataset + + +class USPSDataset0_6(Dataset): + """ + Dataset class for USPS dataset with labels 0-6. + + Args + ---- + path : pathlib.Path + Path to the USPS dataset file. + mode : str + Mode of the dataset. Must be either 'train' or 'test'. + transform : callable, optional + A function/transform that takes in a sample and returns a transformed version. + + Attributes + ---------- + path : pathlib.Path + Path to the USPS dataset file. + mode : str + Mode of the dataset. + transform : callable + A function/transform that takes in a sample and returns a transformed version. + idx : numpy.ndarray + Indices of samples with labels 0-6. + + Methods + ------- + _index() + Get indices of samples with labels 0-6. + _load_data(idx) + Load data and target label from the dataset. + __len__() + Get the number of samples in the dataset. + __getitem__(idx) + Get a sample from the dataset. + + Examples + -------- + >>> from src.datahandlers import USPSDataset0_6 + >>> dataset = USPSDataset0_6(path="data/usps.h5", mode="train") + >>> len(dataset) + 5460 + >>> data, target = dataset[0] + >>> data.shape + (16, 16) + >>> target + 6 + """ + + def __init__(self, path: Path, mode: str = "train", transform=None): + super().__init__() + self.path = path + self.mode = mode + self.transform = transform + + if self.mode not in ["train", "test"]: + raise ValueError("Invalid mode. Must be either 'train' or 'test'") + + self.idx = self._index() + + def _index(self): + with h5.File(self.path, "r") as f: + labels = f[self.mode]["target"][:] + + # Get indices of samples with labels 0-6 + mask = labels <= 6 + idx = np.where(mask)[0] + + return idx + + def _load_data(self, idx): + with h5.File(self.path, "r") as f: + data = f[self.mode]["data"][idx] + label = f[self.mode]["target"][idx] + + return data, label + + def __len__(self): + return len(self.idx) + + def __getitem__(self, idx): + data, target = self._load_data(self.idx[idx]) + + data = data.reshape(16, 16) + + if self.transform: + data = self.transform(data) + + return data, target + + +def test_uspsdataset0_6(): + import pytest + + datapath = Path("data/USPS/usps.h5") + + dataset = USPSDataset0_6(path=datapath, mode="train") + assert len(dataset) == 5460 + data, target = dataset[0] + assert data.shape == (16, 16) + assert target == 6 + + # Test for an invalid mode + with pytest.raises(ValueError): + USPSDataset0_6(path=datapath, mode="inference") From 1e69248531d1a0a9f6229b91a3cae5dca469ed69 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Thu, 30 Jan 2025 16:06:38 +0100 Subject: [PATCH 03/30] dont track pycache and notebook checkpoints --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 281bb0a..df0fcc4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +__pycache__/ +.ipynb_checkpoints/ Data/ Results/ Experiments/ From 726c9e03ef698b5f523ea01789e6de6ae685aab9 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Thu, 30 Jan 2025 16:07:11 +0100 Subject: [PATCH 04/30] add jupyter, formatters/lsp/lint pytest and h5py --- environment.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/environment.yml b/environment.yml index 9e607e2..88d3cd6 100644 --- a/environment.yml +++ b/environment.yml @@ -9,4 +9,13 @@ dependencies: - sphinx-autobuild - sphinx-rtd-theme - pip + - h5py + - black + - isort + - jupyterlab + - numpy + - pandas + - pytest + - ruff + - scalene prefix: /opt/miniconda3/envs/cc-exam From d13126684735a922fce53c14b27871abba968b45 Mon Sep 17 00:00:00 2001 From: Solveig Date: Thu, 30 Jan 2025 18:21:50 +0100 Subject: [PATCH 05/30] Add USPS HDF5 dataloader and F1 metric implementation --- utils/dataloaders/__init__.py | 0 utils/dataloaders/uspsh5_7_9.py | 116 ++++++++++++++++++++++++++++++++ utils/metrics/F1.py | 96 ++++++++++++++++++++++++++ 3 files changed, 212 insertions(+) create mode 100644 utils/dataloaders/__init__.py create mode 100644 utils/dataloaders/uspsh5_7_9.py create mode 100644 utils/metrics/F1.py diff --git a/utils/dataloaders/__init__.py b/utils/dataloaders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/dataloaders/uspsh5_7_9.py b/utils/dataloaders/uspsh5_7_9.py new file mode 100644 index 0000000..a343554 --- /dev/null +++ b/utils/dataloaders/uspsh5_7_9.py @@ -0,0 +1,116 @@ +from torch.utils.data import Dataset +import numpy as np +import h5py +from torchvision import transforms +from PIL import Image +import torch + + +class USPSH5_Digit_7_9_Dataset(Dataset): + """ + Custom USPS dataset class that loads images with digits 7-9 from an .h5 file. + + Parameters + ---------- + h5_path : str + Path to the USPS `.h5` file. + + transform : callable, optional, default=None + A transform function to apply on images. If None, no transformation is applied. + + Attributes + ---------- + images : numpy.ndarray + The filtered images corresponding to digits 7-9. + + labels : numpy.ndarray + The filtered labels corresponding to digits 7-9. + + transform : callable, optional + A transform function to apply to the images. + """ + + def __init__(self, h5_path, mode, transform=None): + super().__init__() + """ + Initializes the USPS dataset by loading images and labels from the given `.h5` file. + + Parameters + ---------- + h5_path : str + Path to the USPS `.h5` file. + + transform : callable, optional, default=None + A transform function to apply on images. + """ + + self.transform = transform + self.mode = mode + self.h5_path = h5_path + # Load the dataset from the HDF5 file + with h5py.File(self.h5_path, "r") as hf: + images = hf[self.mode]["data"][:] + labels = hf[self.mode]["target"][:] + + # Filter only digits 7, 8, and 9 + mask = np.isin(labels, [7, 8, 9]) + self.images = images[mask] + self.labels = labels[mask] + + def __len__(self): + """ + Returns the total number of samples in the dataset. + + Returns + ------- + int + The number of images in the dataset. + """ + return len(self.images) + + def __getitem__(self, id): + """ + Returns a sample from the dataset given an index. + + Parameters + ---------- + idx : int + The index of the sample to retrieve. + + Returns + ------- + tuple + - image (PIL Image): The image at the specified index. + - label (int): The label corresponding to the image. + """ + # Convert to PIL Image (USPS images are typically grayscale 16x16) + image = Image.fromarray(self.images[id].astype(np.uint8), mode="L") + label = int(self.labels[id]) # Convert label to integer + + if self.transform: + image = self.transform(image) + + return image, label + + +def main(): + # Example Usage: + transform = transforms.Compose([ + transforms.Resize((16, 16)), # Ensure images are 16x16 + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1] + ]) + + # Load the dataset + dataset = USPSH5_Digit_7_9_Dataset(h5_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git/usps.h5", mode="train", transform=transform) + data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True) + batch = next(iter(data_loader)) # grab a batch from the dataloader + img, label = batch + print(img.shape) + print(label.shape) + + # Check dataset size + print(f"Dataset size: {len(dataset)}") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/utils/metrics/F1.py b/utils/metrics/F1.py new file mode 100644 index 0000000..16c87f8 --- /dev/null +++ b/utils/metrics/F1.py @@ -0,0 +1,96 @@ +import torch.nn as nn +import torch + + +class F1Score(nn.Module): + """ + F1 Score implementation with direct averaging inside the compute method. + + Parameters + ---------- + num_classes : int + Number of classes. + + Attributes + ---------- + num_classes : int + The number of classes. + + tp : torch.Tensor + Tensor for True Positives (TP) for each class. + + fp : torch.Tensor + Tensor for False Positives (FP) for each class. + + fn : torch.Tensor + Tensor for False Negatives (FN) for each class. + """ + def __init__(self, num_classes): + """ + Initializes the F1Score object, setting up the necessary state variables. + + Parameters + ---------- + num_classes : int + The number of classes in the classification task. + + """ + + super().__init__() + + self.num_classes = num_classes + + # Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN) + self.tp = torch.zeros(num_classes) + self.fp = torch.zeros(num_classes) + self.fn = torch.zeros(num_classes) + + def update(self, preds, target): + """ + Update the variables with predictions and true labels. + + Parameters + ---------- + preds : torch.Tensor + Predicted logits (shape: [batch_size, num_classes]). + + target : torch.Tensor + True labels (shape: [batch_size]). + """ + preds = torch.argmax(preds, dim=1) + + # Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class + for i in range(self.num_classes): + self.tp[i] += torch.sum((preds == i) & (target == i)).float() + self.fp[i] += torch.sum((preds == i) & (target != i)).float() + self.fn[i] += torch.sum((preds != i) & (target == i)).float() + + def compute(self): + """ + Compute the F1 score. + + Returns + ------- + torch.Tensor + The computed F1 score. + """ + + # Compute F1 score based on the specified averaging method + f1_score = 2 * torch.sum(self.tp) / (2 * torch.sum(self.tp) + torch.sum(self.fp) + torch.sum(self.fn)) + + return f1_score + + +def test_f1score(): + f1_metric = F1Score(num_classes=3) + preds = torch.tensor([[0.8, 0.1, 0.1], + [0.2, 0.7, 0.1], + [0.2, 0.3, 0.5], + [0.1, 0.2, 0.7]]) + + target = torch.tensor([0, 1, 0, 2]) + + f1_metric.update(preds, target) + assert f1_metric.tp.sum().item() > 0, "Expected some true positives." + assert f1_metric.fp.sum().item() > 0, "Expected some false positives." + assert f1_metric.fn.sum().item() > 0, "Expected some false negatives." From 61b0e53adcec6dd74c05c19edc849c157e588f15 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Thu, 30 Jan 2025 19:25:40 +0100 Subject: [PATCH 06/30] Remove notebook --- test.ipynb | 76 ------------------------------------------------------ 1 file changed, 76 deletions(-) delete mode 100644 test.ipynb diff --git a/test.ipynb b/test.ipynb deleted file mode 100644 index 0b50544..0000000 --- a/test.ipynb +++ /dev/null @@ -1,76 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "import argparse" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "usage: [-h] [--datafolder DATAFOLDER]\n", - ": error: unrecognized arguments: --f=/home/magnus/.local/share/jupyter/runtime/kernel-v3fc3d3b04bd8d83becf1be5eacf19e7bf46887012.json\n" - ] - }, - { - "ename": "SystemExit", - "evalue": "2", - "output_type": "error", - "traceback": [ - "An exception has occurred, use %tb to see the full traceback.\n", - "\u001b[0;31mSystemExit\u001b[0m\u001b[0;31m:\u001b[0m 2\n" - ] - } - ], - "source": [ - "parser = argparse.ArgumentParser(\n", - " prog='',\n", - " description='',\n", - " epilog='',\n", - " )\n", - "parser.add_argument('--datafolder', type=str, default='Data/', help='Path to where data will be saved during training.')\n", - "args = parser.parse_args()\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(args)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "env", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From cba9b80d25ebf8204f38ee6914f4c196d6ff2d41 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Thu, 30 Jan 2025 19:41:32 +0100 Subject: [PATCH 07/30] Simplify how metrics are parsed Use a switch statement to handle the different cases. --- main.py | 18 ++++------------ utils/load_metric.py | 51 +++++++++++++++++++++++++++----------------- 2 files changed, 35 insertions(+), 34 deletions(-) diff --git a/main.py b/main.py index 68fbb93..74e6bce 100644 --- a/main.py +++ b/main.py @@ -39,12 +39,8 @@ def main(): parser.add_argument('--dataset', type=str, default='svhn', choices=['svhn'], help='Which dataset to train the model on.') - parser.add_argument('--EntropyPrediction', type=bool, default=True, help='Include the Entropy Prediction metric in evaluation') - parser.add_argument('--F1Score', type=bool, default=True, help='Include the F1Score metric in evaluation') - parser.add_argument('--Recall', type=bool, default=True, help='Include the Recall metric in evaluation') - parser.add_argument('--Precision', type=bool, default=True, help='Include the Precision metric in evaluation') - parser.add_argument('--Accuracy', type=bool, default=True, help='Include the Accuracy metric in evaluation') - + parser.add_argument("--metric", type=str, default="entropy", choices=['entropy', 'f1', 'recall', 'precision', 'accuracy'], nargs="+", help='Which metric to use for evaluation') + #Training specific values parser.add_argument('--epoch', type=int, default=20, help='Amount of training epochs the model will do.') parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate parameter for model training.') @@ -61,13 +57,7 @@ def main(): model = load_model() model.to(device) - metrics = MetricWrapper( - EntropyPred = args.EntropyPrediction, - F1Score = args.F1Score, - Recall = args.Recall, - Precision = args.Precision, - Accuracy = args.Accuracy - ) + metrics = MetricWrapper(*args.metric) #Dataset traindata = load_data(args.dataset) @@ -126,4 +116,4 @@ def main(): if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/utils/load_metric.py b/utils/load_metric.py index cdce0b7..e142fc8 100644 --- a/utils/load_metric.py +++ b/utils/load_metric.py @@ -5,34 +5,45 @@ class MetricWrapper(nn.Module): - def __init__(self, - EntropyPred:bool = True, - F1Score:bool = True, - Recall:bool = True, - Precision:bool = True, - Accuracy:bool = True): + def __init__(self, *metrics): super().__init__() self.metrics = {} - - if EntropyPred: - self.metrics['Entropy of Predictions'] = EntropyPrediction() - if F1Score: - self.metrics['F1 Score'] = None - - if Recall: - self.metrics['Recall'] = None + for metric in metrics: + self.metrics[metric] = self._get_metric(metric) - if Precision: - self.metrics['Precision'] = None - - if Accuracy: - self.metrics['Accuracy'] = None - self.tmp_scores = copy.deepcopy(self.metrics) for key in self.tmp_scores: self.tmp_scores[key] = [] + + def _get_metric(self, key): + """ + Get the metric function based on the key + + Args + ---- + key (str): metric name + + Returns + ------- + metric (callable): metric function + """ + + match key.lower(): + case 'entropy': + return EntropyPrediction() + case 'f1': + raise NotImplementedError("F1 score not implemented yet") + case 'recall': + raise NotImplementedError("Recall score not implemented yet") + case 'precision': + raise NotImplementedError("Precision score not implemented yet") + case 'accuracy': + raise NotImplementedError("Accuracy score not implemented yet") + case _: + raise ValueError(f"Metric {key} not supported") + def __call__(self, y_true, y_pred): for key in self.metrics: self.tmp_scores[key].append(self.metrics[key](y_true, y_pred)) From 380116b66f0a48be09e5aba77591bc4fca343a11 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Thu, 30 Jan 2025 21:31:07 +0100 Subject: [PATCH 08/30] Add automatic formatting with ruff and import sorting (isort) --- .github/workflows/format.yml | 47 ++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 .github/workflows/format.yml diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 0000000..748ebe1 --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,47 @@ +name: Format + +on: + push: + branches: + - main + paths: + - 'utils/**' + pull_request: + branches: + - main + paths: + - 'utils/**' + +jobs: + format: + name: Run Ruff and isort + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.x' + + - name: Install dependencies + run: | + pip install ruff isort + + - name: Run Ruff formatter + run: | + ruff format utils/ + + - name: Run isort + run: | + isort utils/ + + - name: Commit and push changes + run: | + git config --global user.name "github-actions[bot]" + git config --global user.email "github-actions[bot]@users.noreply.github.com" + git add utils/ + git commit -m "Auto-format: Applied ruff format and isort" || exit 0 + git push From 53d23e31500b1eb27985daa02f5f15c5105ad409 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 30 Jan 2025 20:34:19 +0000 Subject: [PATCH 09/30] Auto-format: Applied ruff format and isort --- utils/__init__.py | 4 +-- utils/createfolders.py | 61 ++++++++++++++++++++++-------------- utils/dataloaders/svhn.py | 9 +++--- utils/load_data.py | 8 +++-- utils/load_metric.py | 24 +++++++------- utils/load_model.py | 12 ++++--- utils/metrics/EntropyPred.py | 7 ++--- utils/metrics/__init__.py | 2 +- utils/models/__init__.py | 2 +- utils/models/magnus_model.py | 7 +++-- 10 files changed, 78 insertions(+), 58 deletions(-) diff --git a/utils/__init__.py b/utils/__init__.py index f1c5c8b..ee1a0ca 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,4 +1,4 @@ +from .createfolders import createfolders +from .load_data import load_data from .load_metric import MetricWrapper from .load_model import load_model -from .load_data import load_data -from .createfolders import createfolders \ No newline at end of file diff --git a/utils/createfolders.py b/utils/createfolders.py index f8de995..fbdaabb 100644 --- a/utils/createfolders.py +++ b/utils/createfolders.py @@ -1,41 +1,56 @@ -import os -from tempfile import TemporaryDirectory import argparse +import os +from tempfile import TemporaryDirectory + + +def createfolders(args) -> None: + """ + Creates folders for storing data, results, model weights. -def createfolders(args) -> None: - ''' - Creates folders for storing data, results, model weights. - Parameters ---------- args ArgParse object containing string paths to be created - - ''' - + + """ + if not os.path.exists(args.datafolder): os.makedirs(args.datafolder) - print(f'Created a folder at {args.datafolder}') - + print(f"Created a folder at {args.datafolder}") + if not os.path.exists(args.resultfolder): os.makedirs(args.resultfolder) - print(f'Created a folder at {args.resultfolder}') - + print(f"Created a folder at {args.resultfolder}") + if not os.path.exists(args.modelfolder): os.makedirs(args.modelfolder) - print(f'Created a folder at {args.modelfolder}') - + print(f"Created a folder at {args.modelfolder}") def test_createfolders(): - with TemporaryDirectory(dir = 'tmp/') as temp_dir: + with TemporaryDirectory(dir="tmp/") as temp_dir: parser = argparse.ArgumentParser() - #Structuture related values - parser.add_argument('--datafolder', type=str, default=os.path.join(temp_dir, 'Data/'), help='Path to where data will be saved during training.') - parser.add_argument('--resultfolder', type=str, default=os.path.join(temp_dir, 'Results/'), help='Path to where results will be saved during evaluation.') - parser.add_argument('--modelfolder', type=str, default=os.path.join(temp_dir, 'Experiments/'), help='Path to where model weights will be saved at the end of training.') - + # Structuture related values + parser.add_argument( + "--datafolder", + type=str, + default=os.path.join(temp_dir, "Data/"), + help="Path to where data will be saved during training.", + ) + parser.add_argument( + "--resultfolder", + type=str, + default=os.path.join(temp_dir, "Results/"), + help="Path to where results will be saved during evaluation.", + ) + parser.add_argument( + "--modelfolder", + type=str, + default=os.path.join(temp_dir, "Experiments/"), + help="Path to where model weights will be saved at the end of training.", + ) + args = parser.parse_args() createfolders(args) - - return \ No newline at end of file + + return diff --git a/utils/dataloaders/svhn.py b/utils/dataloaders/svhn.py index 3a38f28..be9d09d 100644 --- a/utils/dataloaders/svhn.py +++ b/utils/dataloaders/svhn.py @@ -1,11 +1,12 @@ from torch.utils.data import Dataset + class SVHN(Dataset): def __init__(self): super().__init__() - + def __len__(self): - return - + return + def __getitem__(self, index): - return \ No newline at end of file + return diff --git a/utils/load_data.py b/utils/load_data.py index b664764..272383c 100644 --- a/utils/load_data.py +++ b/utils/load_data.py @@ -1,5 +1,7 @@ from torch.utils.data import Dataset -def load_data(dataset:str) -> Dataset: - - raise ValueError(f'Dataset: {dataset} not implemented. \nCheck the documentation for implemented metrics, or check your spelling') \ No newline at end of file + +def load_data(dataset: str) -> Dataset: + raise ValueError( + f"Dataset: {dataset} not implemented. \nCheck the documentation for implemented metrics, or check your spelling" + ) diff --git a/utils/load_metric.py b/utils/load_metric.py index e142fc8..489131f 100644 --- a/utils/load_metric.py +++ b/utils/load_metric.py @@ -1,6 +1,7 @@ -import copy -import numpy as np -import torch.nn as nn +import copy + +import numpy as np +import torch.nn as nn from metrics import EntropyPrediction @@ -11,12 +12,11 @@ def __init__(self, *metrics): for metric in metrics: self.metrics[metric] = self._get_metric(metric) - + self.tmp_scores = copy.deepcopy(self.metrics) for key in self.tmp_scores: self.tmp_scores[key] = [] - def _get_metric(self, key): """ Get the metric function based on the key @@ -31,15 +31,15 @@ def _get_metric(self, key): """ match key.lower(): - case 'entropy': + case "entropy": return EntropyPrediction() - case 'f1': + case "f1": raise NotImplementedError("F1 score not implemented yet") - case 'recall': + case "recall": raise NotImplementedError("Recall score not implemented yet") - case 'precision': + case "precision": raise NotImplementedError("Precision score not implemented yet") - case 'accuracy': + case "accuracy": raise NotImplementedError("Accuracy score not implemented yet") case _: raise ValueError(f"Metric {key} not supported") @@ -47,12 +47,12 @@ def _get_metric(self, key): def __call__(self, y_true, y_pred): for key in self.metrics: self.tmp_scores[key].append(self.metrics[key](y_true, y_pred)) - + def __getmetrics__(self): return_metrics = {} for key in self.metrics: return_metrics[key] = np.mean(self.tmp_scores[key]) - + return return_metrics def __resetvalues__(self): diff --git a/utils/load_model.py b/utils/load_model.py index 93d0491..1f04fb4 100644 --- a/utils/load_model.py +++ b/utils/load_model.py @@ -1,9 +1,11 @@ -import torch.nn as nn +import torch.nn as nn from models import MagnusModel -def load_model(modelname:str) -> nn.Module: - - if modelname == 'MagnusModel': + +def load_model(modelname: str) -> nn.Module: + if modelname == "MagnusModel": return MagnusModel() else: - raise ValueError(f'Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling') \ No newline at end of file + raise ValueError( + f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling" + ) diff --git a/utils/metrics/EntropyPred.py b/utils/metrics/EntropyPred.py index 08e4406..97058e7 100644 --- a/utils/metrics/EntropyPred.py +++ b/utils/metrics/EntropyPred.py @@ -1,10 +1,9 @@ -import torch.nn as nn +import torch.nn as nn class EntropyPrediction(nn.Module): def __init__(self): super().__init__() - + def __call__(self, y_true, y_false): - - return \ No newline at end of file + return diff --git a/utils/metrics/__init__.py b/utils/metrics/__init__.py index dd0baac..f010ece 100644 --- a/utils/metrics/__init__.py +++ b/utils/metrics/__init__.py @@ -1 +1 @@ -from .EntropyPred import EntropyPrediction \ No newline at end of file +from .EntropyPred import EntropyPrediction diff --git a/utils/models/__init__.py b/utils/models/__init__.py index d58bcdb..649519c 100644 --- a/utils/models/__init__.py +++ b/utils/models/__init__.py @@ -1 +1 @@ -from .magnus_model import MagnusModel \ No newline at end of file +from .magnus_model import MagnusModel diff --git a/utils/models/magnus_model.py b/utils/models/magnus_model.py index cbc1d77..6117a94 100644 --- a/utils/models/magnus_model.py +++ b/utils/models/magnus_model.py @@ -1,8 +1,9 @@ -import torch.nn as nn +import torch.nn as nn + class MagnusModel(nn.Module): def __init__(self): super().__init__() - + def forward(self, x): - return \ No newline at end of file + return From 8a50e7cbf02a8b690eeabad8c2f233c1154aaedc Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 08:42:12 +0100 Subject: [PATCH 10/30] Move code to utils folder --- src/datahandlers/__init__.py | 3 - src/datahandlers/usps_0_6.py | 117 ----------------------------------- 2 files changed, 120 deletions(-) delete mode 100644 src/datahandlers/__init__.py delete mode 100644 src/datahandlers/usps_0_6.py diff --git a/src/datahandlers/__init__.py b/src/datahandlers/__init__.py deleted file mode 100644 index df404f7..0000000 --- a/src/datahandlers/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -__all__ = ["USPSDataset0_6"] - -from .usps_0_6 import USPSDataset0_6 diff --git a/src/datahandlers/usps_0_6.py b/src/datahandlers/usps_0_6.py deleted file mode 100644 index 29fe0da..0000000 --- a/src/datahandlers/usps_0_6.py +++ /dev/null @@ -1,117 +0,0 @@ -""" -Dataset class for USPS dataset with labels 0-6. - -This module contains the Dataset class for the USPS dataset with labels 0-6. -""" - -from pathlib import Path - -import h5py as h5 -import numpy as np -from torch.utils.data import Dataset - - -class USPSDataset0_6(Dataset): - """ - Dataset class for USPS dataset with labels 0-6. - - Args - ---- - path : pathlib.Path - Path to the USPS dataset file. - mode : str - Mode of the dataset. Must be either 'train' or 'test'. - transform : callable, optional - A function/transform that takes in a sample and returns a transformed version. - - Attributes - ---------- - path : pathlib.Path - Path to the USPS dataset file. - mode : str - Mode of the dataset. - transform : callable - A function/transform that takes in a sample and returns a transformed version. - idx : numpy.ndarray - Indices of samples with labels 0-6. - - Methods - ------- - _index() - Get indices of samples with labels 0-6. - _load_data(idx) - Load data and target label from the dataset. - __len__() - Get the number of samples in the dataset. - __getitem__(idx) - Get a sample from the dataset. - - Examples - -------- - >>> from src.datahandlers import USPSDataset0_6 - >>> dataset = USPSDataset0_6(path="data/usps.h5", mode="train") - >>> len(dataset) - 5460 - >>> data, target = dataset[0] - >>> data.shape - (16, 16) - >>> target - 6 - """ - - def __init__(self, path: Path, mode: str = "train", transform=None): - super().__init__() - self.path = path - self.mode = mode - self.transform = transform - - if self.mode not in ["train", "test"]: - raise ValueError("Invalid mode. Must be either 'train' or 'test'") - - self.idx = self._index() - - def _index(self): - with h5.File(self.path, "r") as f: - labels = f[self.mode]["target"][:] - - # Get indices of samples with labels 0-6 - mask = labels <= 6 - idx = np.where(mask)[0] - - return idx - - def _load_data(self, idx): - with h5.File(self.path, "r") as f: - data = f[self.mode]["data"][idx] - label = f[self.mode]["target"][idx] - - return data, label - - def __len__(self): - return len(self.idx) - - def __getitem__(self, idx): - data, target = self._load_data(self.idx[idx]) - - data = data.reshape(16, 16) - - if self.transform: - data = self.transform(data) - - return data, target - - -def test_uspsdataset0_6(): - import pytest - - datapath = Path("data/USPS/usps.h5") - - dataset = USPSDataset0_6(path=datapath, mode="train") - assert len(dataset) == 5460 - data, target = dataset[0] - assert data.shape == (16, 16) - assert target == 6 - - # Test for an invalid mode - with pytest.raises(ValueError): - USPSDataset0_6(path=datapath, mode="inference") From 7579ba34768cf6855de1e76b033ade38fd3fe928 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 08:42:41 +0100 Subject: [PATCH 11/30] USPS dataloader for 0-6 digits --- utils/dataloaders/usps_0_6.py | 117 ++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 utils/dataloaders/usps_0_6.py diff --git a/utils/dataloaders/usps_0_6.py b/utils/dataloaders/usps_0_6.py new file mode 100644 index 0000000..29fe0da --- /dev/null +++ b/utils/dataloaders/usps_0_6.py @@ -0,0 +1,117 @@ +""" +Dataset class for USPS dataset with labels 0-6. + +This module contains the Dataset class for the USPS dataset with labels 0-6. +""" + +from pathlib import Path + +import h5py as h5 +import numpy as np +from torch.utils.data import Dataset + + +class USPSDataset0_6(Dataset): + """ + Dataset class for USPS dataset with labels 0-6. + + Args + ---- + path : pathlib.Path + Path to the USPS dataset file. + mode : str + Mode of the dataset. Must be either 'train' or 'test'. + transform : callable, optional + A function/transform that takes in a sample and returns a transformed version. + + Attributes + ---------- + path : pathlib.Path + Path to the USPS dataset file. + mode : str + Mode of the dataset. + transform : callable + A function/transform that takes in a sample and returns a transformed version. + idx : numpy.ndarray + Indices of samples with labels 0-6. + + Methods + ------- + _index() + Get indices of samples with labels 0-6. + _load_data(idx) + Load data and target label from the dataset. + __len__() + Get the number of samples in the dataset. + __getitem__(idx) + Get a sample from the dataset. + + Examples + -------- + >>> from src.datahandlers import USPSDataset0_6 + >>> dataset = USPSDataset0_6(path="data/usps.h5", mode="train") + >>> len(dataset) + 5460 + >>> data, target = dataset[0] + >>> data.shape + (16, 16) + >>> target + 6 + """ + + def __init__(self, path: Path, mode: str = "train", transform=None): + super().__init__() + self.path = path + self.mode = mode + self.transform = transform + + if self.mode not in ["train", "test"]: + raise ValueError("Invalid mode. Must be either 'train' or 'test'") + + self.idx = self._index() + + def _index(self): + with h5.File(self.path, "r") as f: + labels = f[self.mode]["target"][:] + + # Get indices of samples with labels 0-6 + mask = labels <= 6 + idx = np.where(mask)[0] + + return idx + + def _load_data(self, idx): + with h5.File(self.path, "r") as f: + data = f[self.mode]["data"][idx] + label = f[self.mode]["target"][idx] + + return data, label + + def __len__(self): + return len(self.idx) + + def __getitem__(self, idx): + data, target = self._load_data(self.idx[idx]) + + data = data.reshape(16, 16) + + if self.transform: + data = self.transform(data) + + return data, target + + +def test_uspsdataset0_6(): + import pytest + + datapath = Path("data/USPS/usps.h5") + + dataset = USPSDataset0_6(path=datapath, mode="train") + assert len(dataset) == 5460 + data, target = dataset[0] + assert data.shape == (16, 16) + assert target == 6 + + # Test for an invalid mode + with pytest.raises(ValueError): + USPSDataset0_6(path=datapath, mode="inference") From 6dfd94dd2aa70c8e482710641facfe12f1ad7826 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 08:42:56 +0100 Subject: [PATCH 12/30] Make dataloaders module --- utils/dataloaders/__init__.py | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 utils/dataloaders/__init__.py diff --git a/utils/dataloaders/__init__.py b/utils/dataloaders/__init__.py new file mode 100644 index 0000000..df404f7 --- /dev/null +++ b/utils/dataloaders/__init__.py @@ -0,0 +1,3 @@ +__all__ = ["USPSDataset0_6"] + +from .usps_0_6 import USPSDataset0_6 From 6ad365c280481afc1a46a0322db3a511abcfbe24 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 08:43:12 +0100 Subject: [PATCH 13/30] Add usps dataloader as alternative --- utils/load_data.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/utils/load_data.py b/utils/load_data.py index 272383c..9905f89 100644 --- a/utils/load_data.py +++ b/utils/load_data.py @@ -1,7 +1,10 @@ +from dataloaders import USPS_0_6 from torch.utils.data import Dataset def load_data(dataset: str) -> Dataset: - raise ValueError( - f"Dataset: {dataset} not implemented. \nCheck the documentation for implemented metrics, or check your spelling" - ) + match dataset.lower(): + case "usps_0-6": + return USPS_0_6 + case _: + raise ValueError(f"Dataset: {dataset} not implemented.") From 1947e82dbfc07ff3e37438a099bc6ff3df913e4a Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 08:43:41 +0100 Subject: [PATCH 14/30] Format using ruff --- main.py | 148 ++++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 102 insertions(+), 46 deletions(-) diff --git a/main.py b/main.py index 74e6bce..92b9efa 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,12 @@ -import torch as th -import torch.nn as nn -from torch.utils.data import DataLoader import argparse -import wandb + import numpy as np -from utils import MetricWrapper, load_model, load_data, createfolders +import torch as th +import torch.nn as nn +import wandb +from torch.utils.data import DataLoader + +from utils import MetricWrapper, createfolders, load_data, load_model def main(): @@ -25,44 +27,100 @@ def main(): description='', epilog='', ) - #Structuture related values - parser.add_argument('--datafolder', type=str, default='Data/', help='Path to where data will be saved during training.') - parser.add_argument('--resultfolder', type=str, default='Results/', help='Path to where results will be saved during evaluation.') - parser.add_argument('--modelfolder', type=str, default='Experiments/', help='Path to where model weights will be saved at the end of training.') - parser.add_argument('--savemodel', type=bool, default=False, help='Whether model should be saved or not.') - - parser.add_argument('--download-data', type=bool, default=False, help='Whether the data should be downloaded or not. Might cause code to start a bit slowly.') - - #Data/Model specific values - parser.add_argument('--modelname', type=str, default='MagnusModel', - choices = ['MagnusModel'], help="Model which to be trained on") - parser.add_argument('--dataset', type=str, default='svhn', - choices=['svhn'], help='Which dataset to train the model on.') - - parser.add_argument("--metric", type=str, default="entropy", choices=['entropy', 'f1', 'recall', 'precision', 'accuracy'], nargs="+", help='Which metric to use for evaluation') + # Structuture related values + parser.add_argument( + "--datafolder", + type=str, + default="Data/", + help="Path to where data will be saved during training.", + ) + parser.add_argument( + "--resultfolder", + type=str, + default="Results/", + help="Path to where results will be saved during evaluation.", + ) + parser.add_argument( + "--modelfolder", + type=str, + default="Experiments/", + help="Path to where model weights will be saved at the end of training.", + ) + parser.add_argument( + "--savemodel", + type=bool, + default=False, + help="Whether model should be saved or not.", + ) + + parser.add_argument( + "--download-data", + type=bool, + default=False, + help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.", + ) + + # Data/Model specific values + parser.add_argument( + "--modelname", + type=str, + default="MagnusModel", + choices=["MagnusModel"], + help="Model which to be trained on", + ) + parser.add_argument( + "--dataset", + type=str, + default="svhn", + choices=["svhn"], + help="Which dataset to train the model on.", + ) + + parser.add_argument( + "--metric", + type=str, + default="entropy", + choices=["entropy", "f1", "recall", "precision", "accuracy"], + nargs="+", + help="Which metric to use for evaluation", + ) + + # Training specific values + parser.add_argument( + "--epoch", + type=int, + default=20, + help="Amount of training epochs the model will do.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=0.001, + help="Learning rate parameter for model training.", + ) + parser.add_argument( + "--batchsize", + type=int, + default=64, + help="Amount of training images loaded in one go", + ) - #Training specific values - parser.add_argument('--epoch', type=int, default=20, help='Amount of training epochs the model will do.') - parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate parameter for model training.') - parser.add_argument('--batchsize', type=int, default=64, help='Amount of training images loaded in one go') - args = parser.parse_args() - createfolders(args) - + device = 'cuda' if th.cuda.is_available() else 'cpu' - - #load model + + # load model model = load_model() model.to(device) - + metrics = MetricWrapper(*args.metric) - - #Dataset + + # Dataset traindata = load_data(args.dataset) validata = load_data(args.dataset) - + trainloader = DataLoader(traindata, batch_size=args.batchsize, shuffle=True, @@ -72,33 +130,32 @@ def main(): batch_size=args.batchsize, shuffle=False, pin_memory=True) - + criterion = nn.CrossEntropyLoss() - optimizer = th.optim.Adam(model.parameters(), lr = args.learning_rate) - - + optimizer = th.optim.Adam(model.parameters(), lr=args.learning_rate) + wandb.init(project='', tags=[]) wandb.watch(model) - + for epoch in range(args.epoch): - - #Training loop start + + # Training loop start trainingloss = [] model.train() for x, y in traindata: x, y = x.to(device), y.to(device) pred = model.forward(x) - + loss = criterion(y, pred) loss.backward() - + optimizer.step() optimizer.zero_grad(set_to_none=True) trainingloss.append(loss.item()) - + evalloss = [] - #Eval loop start + # Eval loop start model.eval() with th.no_grad(): for x, y in valiloader: @@ -106,13 +163,12 @@ def main(): pred = model.forward(x) loss = criterion(y, pred) evalloss.append(loss.item()) - + wandb.log({ 'Epoch': epoch, 'Train loss': np.mean(trainingloss), 'Evaluation Loss': np.mean(evalloss) }) - if __name__ == '__main__': From 8b358bf355368e981d3e7fe31055c622b8e80551 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Thu, 30 Jan 2025 16:22:21 +0100 Subject: [PATCH 15/30] Use recommended python gitignore template Template can be found [here](https://github.com/github/gitignore/blob/297239c101dcfdfae7e75757ed17ed993df0b4eb/Python.gitignore) --- .gitignore | 175 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 175 insertions(+) diff --git a/.gitignore b/.gitignore index df0fcc4..29fa5e6 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,178 @@ Results/ Experiments/ _build/ bin/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc From eae20cc942adc61eba07eaa80bb4e74bdf6f9a54 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 10:50:56 +0100 Subject: [PATCH 16/30] add __all__ to __init__.py --- utils/__init__.py | 2 ++ utils/metrics/__init__.py | 2 ++ utils/models/__init__.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/utils/__init__.py b/utils/__init__.py index ee1a0ca..050fe5c 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,3 +1,5 @@ +__all__ = ['createfolders', 'load_data', 'load_model', 'MetricWrapper'] + from .createfolders import createfolders from .load_data import load_data from .load_metric import MetricWrapper diff --git a/utils/metrics/__init__.py b/utils/metrics/__init__.py index f010ece..094ff91 100644 --- a/utils/metrics/__init__.py +++ b/utils/metrics/__init__.py @@ -1 +1,3 @@ +__all__ = ['EntropyPrediction'] + from .EntropyPred import EntropyPrediction diff --git a/utils/models/__init__.py b/utils/models/__init__.py index 649519c..15131af 100644 --- a/utils/models/__init__.py +++ b/utils/models/__init__.py @@ -1 +1,3 @@ +__all__ = ['MagnusModel'] + from .magnus_model import MagnusModel From 781362e02b04a8214b54b977acfe7642dd216f45 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 10:51:24 +0100 Subject: [PATCH 17/30] fix relative imports --- utils/load_metric.py | 3 ++- utils/load_model.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/utils/load_metric.py b/utils/load_metric.py index 489131f..f166c25 100644 --- a/utils/load_metric.py +++ b/utils/load_metric.py @@ -2,7 +2,8 @@ import numpy as np import torch.nn as nn -from metrics import EntropyPrediction + +from .metrics import EntropyPrediction class MetricWrapper(nn.Module): diff --git a/utils/load_model.py b/utils/load_model.py index 1f04fb4..db242a0 100644 --- a/utils/load_model.py +++ b/utils/load_model.py @@ -1,5 +1,6 @@ import torch.nn as nn -from models import MagnusModel + +from .models import MagnusModel def load_model(modelname: str) -> nn.Module: From 8ef502fe6ab69d9aee73ee4e88776cfd2a57b8c4 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 10:53:11 +0100 Subject: [PATCH 18/30] Modernize to use pathlib instead of os.path --- main.py | 15 +++++++------- utils/createfolders.py | 47 +++++++++++++++++++++--------------------- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/main.py b/main.py index 92b9efa..e1e8774 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import argparse +from pathlib import Path import numpy as np import torch as th @@ -30,20 +31,20 @@ def main(): # Structuture related values parser.add_argument( "--datafolder", - type=str, - default="Data/", + type=Path, + default="Data", help="Path to where data will be saved during training.", ) parser.add_argument( "--resultfolder", - type=str, - default="Results/", + type=Path, + default="Results", help="Path to where results will be saved during evaluation.", ) parser.add_argument( "--modelfolder", - type=str, - default="Experiments/", + type=Path, + default="Experiments", help="Path to where model weights will be saved at the end of training.", ) parser.add_argument( @@ -107,7 +108,7 @@ def main(): args = parser.parse_args() - createfolders(args) + createfolders(args.datafolder, args.resultfolder, args.modelfolder) device = 'cuda' if th.cuda.is_available() else 'cpu' diff --git a/utils/createfolders.py b/utils/createfolders.py index fbdaabb..b73bc9f 100644 --- a/utils/createfolders.py +++ b/utils/createfolders.py @@ -1,9 +1,9 @@ import argparse -import os +from pathlib import Path from tempfile import TemporaryDirectory -def createfolders(args) -> None: +def createfolders(*dirs: Path) -> None: """ Creates folders for storing data, results, model weights. @@ -14,43 +14,44 @@ def createfolders(args) -> None: """ - if not os.path.exists(args.datafolder): - os.makedirs(args.datafolder) - print(f"Created a folder at {args.datafolder}") - - if not os.path.exists(args.resultfolder): - os.makedirs(args.resultfolder) - print(f"Created a folder at {args.resultfolder}") - - if not os.path.exists(args.modelfolder): - os.makedirs(args.modelfolder) - print(f"Created a folder at {args.modelfolder}") + for dir in dirs: + dir.mkdir(parents=True, exist_ok=True) def test_createfolders(): - with TemporaryDirectory(dir="tmp/") as temp_dir: + with TemporaryDirectory() as temp_dir: + temp_dir = Path(temp_dir) + parser = argparse.ArgumentParser() + # Structuture related values parser.add_argument( "--datafolder", - type=str, - default=os.path.join(temp_dir, "Data/"), + type=Path, + default=temp_dir / "Data", help="Path to where data will be saved during training.", ) parser.add_argument( "--resultfolder", - type=str, - default=os.path.join(temp_dir, "Results/"), + type=Path, + default=temp_dir / "Results", help="Path to where results will be saved during evaluation.", ) parser.add_argument( "--modelfolder", - type=str, - default=os.path.join(temp_dir, "Experiments/"), + type=Path, + default=temp_dir / "Experiments", help="Path to where model weights will be saved at the end of training.", ) - args = parser.parse_args() - createfolders(args) + args = parser.parse_args([ + "--datafolder", temp_dir / "Data", + "--resultfolder", temp_dir / "Results", + "--modelfolder", temp_dir / "Experiments" + ]) + + createfolders(args.datafolder, args.resultfolder, args.modelfolder) - return + assert (temp_dir / "Data").exists() + assert (temp_dir / "Results").exists() + assert (temp_dir / "Experiments").exists() From faac193c2c460a3c6c624c8bdae7be7106c1f5a2 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 10:54:44 +0100 Subject: [PATCH 19/30] load_data now gives arguments to the datasets --- main.py | 17 +++++++++++++---- utils/dataloaders/usps_0_6.py | 26 +++++++++++++++++--------- utils/load_data.py | 7 ++++--- 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/main.py b/main.py index e1e8774..3783e74 100644 --- a/main.py +++ b/main.py @@ -73,7 +73,7 @@ def main(): "--dataset", type=str, default="svhn", - choices=["svhn"], + choices=["svhn", "usps_0-6"], help="Which dataset to train the model on.", ) @@ -119,8 +119,17 @@ def main(): metrics = MetricWrapper(*args.metric) # Dataset - traindata = load_data(args.dataset) - validata = load_data(args.dataset) + traindata = load_data( + args.dataset, + train=True, + data_path=args.datafolder, + download=args.download_data, + ) + validata = load_data( + args.dataset, + train=False, + data_path=args.datafolder, + ) trainloader = DataLoader(traindata, batch_size=args.batchsize, @@ -144,7 +153,7 @@ def main(): # Training loop start trainingloss = [] model.train() - for x, y in traindata: + for x, y in trainloader: x, y = x.to(device), y.to(device) pred = model.forward(x) diff --git a/utils/dataloaders/usps_0_6.py b/utils/dataloaders/usps_0_6.py index 29fe0da..8d5f496 100644 --- a/utils/dataloaders/usps_0_6.py +++ b/utils/dataloaders/usps_0_6.py @@ -17,19 +17,21 @@ class USPSDataset0_6(Dataset): Args ---- - path : pathlib.Path + data_path : pathlib.Path Path to the USPS dataset file. - mode : str - Mode of the dataset. Must be either 'train' or 'test'. + train : bool, optional + Mode of the dataset. transform : callable, optional A function/transform that takes in a sample and returns a transformed version. + download : bool, optional + Whether to download the Dataset. Attributes ---------- path : pathlib.Path Path to the USPS dataset file. mode : str - Mode of the dataset. + Mode of the dataset, either train or test. transform : callable A function/transform that takes in a sample and returns a transformed version. idx : numpy.ndarray @@ -59,15 +61,21 @@ class USPSDataset0_6(Dataset): 6 """ - def __init__(self, path: Path, mode: str = "train", transform=None): + def __init__( + self, + data_path: Path, + train: bool = False, + transform=None, + download: bool = False, + ): super().__init__() - self.path = path - self.mode = mode + self.path = list(data_path.glob("*.h5"))[0] self.transform = transform - if self.mode not in ["train", "test"]: - raise ValueError("Invalid mode. Must be either 'train' or 'test'") + if download: + raise NotImplementedError("Download functionality not implemented.") + self.mode = "train" if train else "test" self.idx = self._index() def _index(self): diff --git a/utils/load_data.py b/utils/load_data.py index 9905f89..ac1bcfd 100644 --- a/utils/load_data.py +++ b/utils/load_data.py @@ -1,10 +1,11 @@ -from dataloaders import USPS_0_6 from torch.utils.data import Dataset +from .dataloaders import USPSDataset0_6 -def load_data(dataset: str) -> Dataset: + +def load_data(dataset: str, *args, **kwargs) -> Dataset: match dataset.lower(): case "usps_0-6": - return USPS_0_6 + return USPSDataset0_6(*args, **kwargs) case _: raise ValueError(f"Dataset: {dataset} not implemented.") From f7c2058ad83672d5d57bb6e229d2a45277745f29 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 10:55:25 +0100 Subject: [PATCH 20/30] Fix bug where string was treated as a list in argparse due to nargs="+" --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index 3783e74..fa3dcac 100644 --- a/main.py +++ b/main.py @@ -80,7 +80,7 @@ def main(): parser.add_argument( "--metric", type=str, - default="entropy", + default=["entropy"], choices=["entropy", "f1", "recall", "precision", "accuracy"], nargs="+", help="Which metric to use for evaluation", From d045a2a5bbb7803821ab2a9d69fa873bf5f0c3bd Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 10:57:09 +0100 Subject: [PATCH 21/30] Add option for setting device to mps (for mac) and a dry_run parameter - The mps option is necessary to accelerate gpu ops for mac - --dry_run now checks that models/datasets/metrics are loaded before starting training --- main.py | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/main.py b/main.py index fa3dcac..6fab214 100644 --- a/main.py +++ b/main.py @@ -11,22 +11,22 @@ def main(): - ''' - + """ + Parameters ---------- - + Returns ------- - + Raises ------ - - ''' + + """ parser = argparse.ArgumentParser( - prog='', - description='', - epilog='', + prog="", + description="", + epilog="", ) # Structuture related values parser.add_argument( @@ -105,15 +105,27 @@ def main(): default=64, help="Amount of training images loaded in one go", ) + parser.add_argument( + "--device", + type=str, + default="cuda", + choices=["cuda", "cpu", "mps"], + help="Which device to run the training on.", + ) + parser.add_argument( + "--dry_run", + action="store_true", + help="If true, the code will not run the training loop.", + ) args = parser.parse_args() createfolders(args.datafolder, args.resultfolder, args.modelfolder) - device = 'cuda' if th.cuda.is_available() else 'cpu' + device = args.device # load model - model = load_model() + model = load_model(args.modelname) model.to(device) metrics = MetricWrapper(*args.metric) @@ -144,6 +156,11 @@ def main(): criterion = nn.CrossEntropyLoss() optimizer = th.optim.Adam(model.parameters(), lr=args.learning_rate) + # This allows us to load all the components without running the training loop + if args.dry_run: + print("Dry run completed") + exit(0) + wandb.init(project='', tags=[]) wandb.watch(model) From e5aafb00390d8a26411538daa450c3a897903cf3 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 31 Jan 2025 10:00:31 +0000 Subject: [PATCH 22/30] Auto-format: Applied ruff format and isort --- utils/__init__.py | 2 +- utils/createfolders.py | 15 ++++++++++----- utils/metrics/__init__.py | 2 +- utils/models/__init__.py | 2 +- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/utils/__init__.py b/utils/__init__.py index 050fe5c..6ea6cde 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,4 +1,4 @@ -__all__ = ['createfolders', 'load_data', 'load_model', 'MetricWrapper'] +__all__ = ["createfolders", "load_data", "load_model", "MetricWrapper"] from .createfolders import createfolders from .load_data import load_data diff --git a/utils/createfolders.py b/utils/createfolders.py index b73bc9f..cdc3d4b 100644 --- a/utils/createfolders.py +++ b/utils/createfolders.py @@ -44,11 +44,16 @@ def test_createfolders(): help="Path to where model weights will be saved at the end of training.", ) - args = parser.parse_args([ - "--datafolder", temp_dir / "Data", - "--resultfolder", temp_dir / "Results", - "--modelfolder", temp_dir / "Experiments" - ]) + args = parser.parse_args( + [ + "--datafolder", + temp_dir / "Data", + "--resultfolder", + temp_dir / "Results", + "--modelfolder", + temp_dir / "Experiments", + ] + ) createfolders(args.datafolder, args.resultfolder, args.modelfolder) diff --git a/utils/metrics/__init__.py b/utils/metrics/__init__.py index 094ff91..6e79fdc 100644 --- a/utils/metrics/__init__.py +++ b/utils/metrics/__init__.py @@ -1,3 +1,3 @@ -__all__ = ['EntropyPrediction'] +__all__ = ["EntropyPrediction"] from .EntropyPred import EntropyPrediction diff --git a/utils/models/__init__.py b/utils/models/__init__.py index 15131af..ca05548 100644 --- a/utils/models/__init__.py +++ b/utils/models/__init__.py @@ -1,3 +1,3 @@ -__all__ = ['MagnusModel'] +__all__ = ["MagnusModel"] from .magnus_model import MagnusModel From 4f981ecedf4eef0197d41b336a00b0e0b57b92a1 Mon Sep 17 00:00:00 2001 From: Solveig Date: Fri, 31 Jan 2025 12:32:20 +0100 Subject: [PATCH 23/30] Created the folder for our tests --- tests/test_createfolders.py | 0 tests/test_dataloaders.py | 0 tests/test_metrics.py | 0 tests/test_models.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/test_createfolders.py create mode 100644 tests/test_dataloaders.py create mode 100644 tests/test_metrics.py create mode 100644 tests/test_models.py diff --git a/tests/test_createfolders.py b/tests/test_createfolders.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..e69de29 From afeae2a3d9cea4c6517fe1035057c355a2c17964 Mon Sep 17 00:00:00 2001 From: Solveig Thrun <144994301+sot176@users.noreply.github.com> Date: Fri, 31 Jan 2025 12:34:27 +0100 Subject: [PATCH 24/30] Delete .idea directory --- .idea/Collaborative-Coding-Exam.iml | 8 -- .idea/inspectionProfiles/Project_Default.xml | 55 ------------- .../inspectionProfiles/profiles_settings.xml | 6 -- .idea/misc.xml | 4 - .idea/modules.xml | 8 -- .idea/vcs.xml | 6 -- .idea/workspace.xml | 81 ------------------- 7 files changed, 168 deletions(-) delete mode 100644 .idea/Collaborative-Coding-Exam.iml delete mode 100644 .idea/inspectionProfiles/Project_Default.xml delete mode 100644 .idea/inspectionProfiles/profiles_settings.xml delete mode 100644 .idea/misc.xml delete mode 100644 .idea/modules.xml delete mode 100644 .idea/vcs.xml delete mode 100644 .idea/workspace.xml diff --git a/.idea/Collaborative-Coding-Exam.iml b/.idea/Collaborative-Coding-Exam.iml deleted file mode 100644 index d0876a7..0000000 --- a/.idea/Collaborative-Coding-Exam.iml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml deleted file mode 100644 index 457d578..0000000 --- a/.idea/inspectionProfiles/Project_Default.xml +++ /dev/null @@ -1,55 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml deleted file mode 100644 index 105ce2d..0000000 --- a/.idea/inspectionProfiles/profiles_settings.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml deleted file mode 100644 index d806dc0..0000000 --- a/.idea/misc.xml +++ /dev/null @@ -1,4 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml deleted file mode 100644 index 56260d0..0000000 --- a/.idea/modules.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index 94a25f7..0000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/workspace.xml b/.idea/workspace.xml deleted file mode 100644 index 67ac41f..0000000 --- a/.idea/workspace.xml +++ /dev/null @@ -1,81 +0,0 @@ - - - - - - - - - - - - - - - - { - "keyToString": { - "RunOnceActivity.OpenProjectViewOnStart": "true", - "RunOnceActivity.ShowReadmeOnStart": "true", - "last_opened_file_path": "C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git/Collaborative-Coding-Exam", - "settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" - } -} - - - - - - - - - - - - - - - - - - - - - 1738244511415 - - - - \ No newline at end of file From 40bb5c03612dbe60cebf0f00d9ebb8f51ef33146 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 13:36:08 +0100 Subject: [PATCH 25/30] Add ChristianModel: 2 layer CNN w/maxpooling --- main.py | 2 +- utils/load_model.py | 19 ++++--- utils/models/__init__.py | 3 +- utils/models/christian_model.py | 92 +++++++++++++++++++++++++++++++++ 4 files changed, 106 insertions(+), 10 deletions(-) create mode 100644 utils/models/christian_model.py diff --git a/main.py b/main.py index 6fab214..e3c7c45 100644 --- a/main.py +++ b/main.py @@ -66,7 +66,7 @@ def main(): "--modelname", type=str, default="MagnusModel", - choices=["MagnusModel"], + choices=["MagnusModel", "ChristianModel"], help="Model which to be trained on", ) parser.add_argument( diff --git a/utils/load_model.py b/utils/load_model.py index db242a0..7e55699 100644 --- a/utils/load_model.py +++ b/utils/load_model.py @@ -1,12 +1,15 @@ import torch.nn as nn -from .models import MagnusModel +from .models import ChristianModel, MagnusModel -def load_model(modelname: str) -> nn.Module: - if modelname == "MagnusModel": - return MagnusModel() - else: - raise ValueError( - f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling" - ) +def load_model(modelname: str, *args, **kwargs) -> nn.Module: + match modelname.lower(): + case "magnusmodel": + return MagnusModel(*args, **kwargs) + case "christianmodel": + return ChristianModel(*args, **kwargs) + case _: + raise ValueError( + f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling" + ) diff --git a/utils/models/__init__.py b/utils/models/__init__.py index ca05548..7cbae91 100644 --- a/utils/models/__init__.py +++ b/utils/models/__init__.py @@ -1,3 +1,4 @@ -__all__ = ["MagnusModel"] +__all__ = ["MagnusModel", "ChristianModel"] +from .christian_model import ChristianModel from .magnus_model import MagnusModel diff --git a/utils/models/christian_model.py b/utils/models/christian_model.py new file mode 100644 index 0000000..9bdd2da --- /dev/null +++ b/utils/models/christian_model.py @@ -0,0 +1,92 @@ +import pytest +import torch +import torch.nn as nn + + +class CNNBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + padding=1, + ) + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.maxpool(x) + x = self.relu(x) + return x + + +class ChristianModel(nn.Module): + """Simple CNN model for image classification. + + Args + ---- + in_channels : int + Number of input channels. + num_classes : int + Number of classes in the dataset. + + Processing Images + ----------------- + Input: (N, C, H, W) + N: Batch size + C: Number of input channels + H: Height of the input image + W: Width of the input image + + Example: + For grayscale images, C = 1. + + Input Image Shape: (5, 1, 16, 16) + CNN1 Output Shape: (5, 50, 8, 8) + CNN2 Output Shape: (5, 100, 4, 4) + FC Output Shape: (5, num_classes) + """ + def __init__(self, in_channels, num_classes): + super().__init__() + + self.cnn1 = CNNBlock(in_channels, 50) + self.cnn2 = CNNBlock(50, 100) + + self.fc1 = nn.Linear(100 * 4 * 4, num_classes) + self.softmax = nn.Softmax(dim=1) + + def forward(self, x): + x = self.cnn1(x) + x = self.cnn2(x) + + x = x.view(x.size(0), -1) + x = self.fc1(x) + x = self.softmax(x) + + return x + + +@pytest.mark.parametrize("in_channels, num_classes", [(1, 6), (3, 6)]) +def test_christian_model(in_channels, num_classes): + n, c, h, w = 5, in_channels, 16, 16 + + model = ChristianModel(c, num_classes) + + x = torch.randn(n, c, h, w) + y = model(x) + + assert y.shape == (n, num_classes), f"Shape: {y.shape}" + assert y.sum(dim=1).allclose(torch.ones(n), atol=1e-5), f"Softmax output should sum to 1, but got: {y.sum()}" + + +if __name__ == "__main__": + + model = ChristianModel(3, 7) + + x = torch.randn(3, 3, 16, 16) + y = model(x) + + print(y) From 4159b78ef4e3eb4b2540ff1783597ab9e3272bbe Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 14:04:48 +0100 Subject: [PATCH 26/30] Instead of formatting changes, fail the test if code needs formatting There were some bugs with the autoformatter, so instead now the user have to manually format their code before committing their changes. If not, the github action will fail. --- .github/workflows/format.yml | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 748ebe1..72993d8 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -2,13 +2,9 @@ name: Format on: push: - branches: - - main paths: - 'utils/**' pull_request: - branches: - - main paths: - 'utils/**' @@ -30,18 +26,10 @@ jobs: run: | pip install ruff isort - - name: Run Ruff formatter + - name: Run Ruff check run: | - ruff format utils/ + ruff check utils/ - - name: Run isort + - name: Run isort check run: | - isort utils/ - - - name: Commit and push changes - run: | - git config --global user.name "github-actions[bot]" - git config --global user.email "github-actions[bot]@users.noreply.github.com" - git add utils/ - git commit -m "Auto-format: Applied ruff format and isort" || exit 0 - git push + isort --check-only utils/ From fc787c2678bd01b64703d1d92f510f400109024b Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 14:08:29 +0100 Subject: [PATCH 27/30] finds number of channels based on dataset. Adds num_classes to dataset --- main.py | 20 +++++++++++++++----- utils/dataloaders/usps_0_6.py | 3 +++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 6fab214..e6330ec 100644 --- a/main.py +++ b/main.py @@ -108,7 +108,7 @@ def main(): parser.add_argument( "--device", type=str, - default="cuda", + default="cpu", choices=["cuda", "cpu", "mps"], help="Which device to run the training on.", ) @@ -124,10 +124,6 @@ def main(): device = args.device - # load model - model = load_model(args.modelname) - model.to(device) - metrics = MetricWrapper(*args.metric) # Dataset @@ -143,6 +139,20 @@ def main(): data_path=args.datafolder, ) + # Find number of channels in the dataset + if len(traindata[0][0].shape) == 2: + channels = 1 + else: + channels = traindata[0][0].shape[0] + + # load model + model = load_model( + args.modelname, + in_channels=channels, + num_classes=traindata.num_classes, + ) + model.to(device) + trainloader = DataLoader(traindata, batch_size=args.batchsize, shuffle=True, diff --git a/utils/dataloaders/usps_0_6.py b/utils/dataloaders/usps_0_6.py index 8d5f496..7a2608f 100644 --- a/utils/dataloaders/usps_0_6.py +++ b/utils/dataloaders/usps_0_6.py @@ -36,6 +36,8 @@ class USPSDataset0_6(Dataset): A function/transform that takes in a sample and returns a transformed version. idx : numpy.ndarray Indices of samples with labels 0-6. + num_classes : int + Number of classes in the dataset Methods ------- @@ -71,6 +73,7 @@ def __init__( super().__init__() self.path = list(data_path.glob("*.h5"))[0] self.transform = transform + self.num_classes = 7 if download: raise NotImplementedError("Download functionality not implemented.") From cd1e0866228930deffd3293f7c68eb7132df0ee9 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 14:12:26 +0100 Subject: [PATCH 28/30] Add Recall metric --- utils/metrics/__init__.py | 3 +- utils/metrics/recall.py | 62 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 utils/metrics/recall.py diff --git a/utils/metrics/__init__.py b/utils/metrics/__init__.py index 6e79fdc..3afeee5 100644 --- a/utils/metrics/__init__.py +++ b/utils/metrics/__init__.py @@ -1,3 +1,4 @@ -__all__ = ["EntropyPrediction"] +__all__ = ["EntropyPrediction", "Recall"] from .EntropyPred import EntropyPrediction +from .recall import Recall diff --git a/utils/metrics/recall.py b/utils/metrics/recall.py new file mode 100644 index 0000000..4aaae43 --- /dev/null +++ b/utils/metrics/recall.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn + + +def one_hot_encode(y_true, num_classes): + """One-hot encode the target tensor. + + Args + ---- + y_true : torch.Tensor + Target tensor. + num_classes : int + Number of classes in the dataset. + + Returns + ------- + torch.Tensor + One-hot encoded tensor. + """ + + y_onehot = torch.zeros(y_true.size(0), num_classes) + y_onehot.scatter_(1, y_true.unsqueeze(1), 1) + return y_onehot + + +class Recall(nn.Module): + def __init__(self, num_classes): + super().__init__() + + self.num_classes = num_classes + + def forward(self, y_true, y_pred): + true_onehot = one_hot_encode(y_true, self.num_classes) + pred_onehot = one_hot_encode(y_pred, self.num_classes) + + true_positives = (true_onehot * pred_onehot).sum() + + false_negatives = torch.sum(~pred_onehot[true_onehot.bool()].bool()) + + recall = true_positives / (true_positives + false_negatives) + + return recall + + +def test_recall(): + recall = Recall(7) + + y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6]) + y_pred = torch.tensor([2, 1, 2, 1, 4, 5, 6]) + + recall_score = recall(y_true, y_pred) + + assert recall_score.allclose(torch.tensor(0.7143), atol=1e-5), f"Recall Score: {recall_score.item()}" + + +def test_one_hot_encode(): + num_classes = 7 + + y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6]) + y_onehot = one_hot_encode(y_true, num_classes) + + assert y_onehot.shape == (7, 7), f"Shape: {y_onehot.shape}" From efe689465b203c21738e50d318929f9e35e8aec3 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 14:38:01 +0100 Subject: [PATCH 29/30] Fix bug where labels werent put on the device --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index 1b6bb96..fe563f5 100644 --- a/main.py +++ b/main.py @@ -196,7 +196,7 @@ def main(): model.eval() with th.no_grad(): for x, y in valiloader: - x = x.to(device) + x, y = x.to(device), y.to(device) pred = model.forward(x) loss = criterion(y, pred) evalloss.append(loss.item()) From 68b5616279997b4ed33f16a8a092e2f22a275e8a Mon Sep 17 00:00:00 2001 From: salomaestro Date: Fri, 31 Jan 2025 14:38:32 +0100 Subject: [PATCH 30/30] Onehot encode labels in dataset --- utils/dataloaders/usps_0_6.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/utils/dataloaders/usps_0_6.py b/utils/dataloaders/usps_0_6.py index 7a2608f..4e68191 100644 --- a/utils/dataloaders/usps_0_6.py +++ b/utils/dataloaders/usps_0_6.py @@ -106,6 +106,12 @@ def __getitem__(self, idx): data = data.reshape(16, 16) + # one hot encode the target + target = np.eye(self.num_classes, dtype=np.float32)[target] + + # Add channel dimension + data = np.expand_dims(data, axis=0) + if self.transform: data = self.transform(data)