diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 0000000..72993d8 --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,35 @@ +name: Format + +on: + push: + paths: + - 'utils/**' + pull_request: + paths: + - 'utils/**' + +jobs: + format: + name: Run Ruff and isort + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.x' + + - name: Install dependencies + run: | + pip install ruff isort + + - name: Run Ruff check + run: | + ruff check utils/ + + - name: Run isort check + run: | + isort --check-only utils/ diff --git a/environment.yml b/environment.yml index 132b539..f1df2f5 100644 --- a/environment.yml +++ b/environment.yml @@ -9,6 +9,14 @@ dependencies: - sphinx-autobuild - sphinx-rtd-theme - pip + - h5py + - black + - isort + - jupyterlab + - numpy + - pandas - pytest + - ruff + - scalene prefix: /opt/miniconda3/envs/cc-exam diff --git a/main.py b/main.py index 68fbb93..fe563f5 100644 --- a/main.py +++ b/main.py @@ -1,78 +1,158 @@ -import torch as th -import torch.nn as nn -from torch.utils.data import DataLoader import argparse -import wandb +from pathlib import Path + import numpy as np -from utils import MetricWrapper, load_model, load_data, createfolders +import torch as th +import torch.nn as nn +import wandb +from torch.utils.data import DataLoader + +from utils import MetricWrapper, createfolders, load_data, load_model def main(): - ''' - + """ + Parameters ---------- - + Returns ------- - + Raises ------ - - ''' + + """ parser = argparse.ArgumentParser( - prog='', - description='', - epilog='', - ) - #Structuture related values - parser.add_argument('--datafolder', type=str, default='Data/', help='Path to where data will be saved during training.') - parser.add_argument('--resultfolder', type=str, default='Results/', help='Path to where results will be saved during evaluation.') - parser.add_argument('--modelfolder', type=str, default='Experiments/', help='Path to where model weights will be saved at the end of training.') - parser.add_argument('--savemodel', type=bool, default=False, help='Whether model should be saved or not.') - - parser.add_argument('--download-data', type=bool, default=False, help='Whether the data should be downloaded or not. Might cause code to start a bit slowly.') - - #Data/Model specific values - parser.add_argument('--modelname', type=str, default='MagnusModel', - choices = ['MagnusModel'], help="Model which to be trained on") - parser.add_argument('--dataset', type=str, default='svhn', - choices=['svhn'], help='Which dataset to train the model on.') - - parser.add_argument('--EntropyPrediction', type=bool, default=True, help='Include the Entropy Prediction metric in evaluation') - parser.add_argument('--F1Score', type=bool, default=True, help='Include the F1Score metric in evaluation') - parser.add_argument('--Recall', type=bool, default=True, help='Include the Recall metric in evaluation') - parser.add_argument('--Precision', type=bool, default=True, help='Include the Precision metric in evaluation') - parser.add_argument('--Accuracy', type=bool, default=True, help='Include the Accuracy metric in evaluation') - - #Training specific values - parser.add_argument('--epoch', type=int, default=20, help='Amount of training epochs the model will do.') - parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate parameter for model training.') - parser.add_argument('--batchsize', type=int, default=64, help='Amount of training images loaded in one go') - + prog="", + description="", + epilog="", + ) + # Structuture related values + parser.add_argument( + "--datafolder", + type=Path, + default="Data", + help="Path to where data will be saved during training.", + ) + parser.add_argument( + "--resultfolder", + type=Path, + default="Results", + help="Path to where results will be saved during evaluation.", + ) + parser.add_argument( + "--modelfolder", + type=Path, + default="Experiments", + help="Path to where model weights will be saved at the end of training.", + ) + parser.add_argument( + "--savemodel", + type=bool, + default=False, + help="Whether model should be saved or not.", + ) + + parser.add_argument( + "--download-data", + type=bool, + default=False, + help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.", + ) + + # Data/Model specific values + parser.add_argument( + "--modelname", + type=str, + default="MagnusModel", + choices=["MagnusModel", "ChristianModel"], + help="Model which to be trained on", + ) + parser.add_argument( + "--dataset", + type=str, + default="svhn", + choices=["svhn", "usps_0-6"], + help="Which dataset to train the model on.", + ) + + parser.add_argument( + "--metric", + type=str, + default=["entropy"], + choices=["entropy", "f1", "recall", "precision", "accuracy"], + nargs="+", + help="Which metric to use for evaluation", + ) + + # Training specific values + parser.add_argument( + "--epoch", + type=int, + default=20, + help="Amount of training epochs the model will do.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=0.001, + help="Learning rate parameter for model training.", + ) + parser.add_argument( + "--batchsize", + type=int, + default=64, + help="Amount of training images loaded in one go", + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + choices=["cuda", "cpu", "mps"], + help="Which device to run the training on.", + ) + parser.add_argument( + "--dry_run", + action="store_true", + help="If true, the code will not run the training loop.", + ) + args = parser.parse_args() - - - createfolders(args) - - device = 'cuda' if th.cuda.is_available() else 'cpu' - - #load model - model = load_model() + + createfolders(args.datafolder, args.resultfolder, args.modelfolder) + + device = args.device + + metrics = MetricWrapper(*args.metric) + + # Dataset + traindata = load_data( + args.dataset, + train=True, + data_path=args.datafolder, + download=args.download_data, + ) + validata = load_data( + args.dataset, + train=False, + data_path=args.datafolder, + ) + + # Find number of channels in the dataset + if len(traindata[0][0].shape) == 2: + channels = 1 + else: + channels = traindata[0][0].shape[0] + + # load model + model = load_model( + args.modelname, + in_channels=channels, + num_classes=traindata.num_classes, + ) model.to(device) - - metrics = MetricWrapper( - EntropyPred = args.EntropyPrediction, - F1Score = args.F1Score, - Recall = args.Recall, - Precision = args.Precision, - Accuracy = args.Accuracy - ) - - #Dataset - traindata = load_data(args.dataset) - validata = load_data(args.dataset) - + trainloader = DataLoader(traindata, batch_size=args.batchsize, shuffle=True, @@ -82,48 +162,51 @@ def main(): batch_size=args.batchsize, shuffle=False, pin_memory=True) - + criterion = nn.CrossEntropyLoss() - optimizer = th.optim.Adam(model.parameters(), lr = args.learning_rate) - - + optimizer = th.optim.Adam(model.parameters(), lr=args.learning_rate) + + # This allows us to load all the components without running the training loop + if args.dry_run: + print("Dry run completed") + exit(0) + wandb.init(project='', tags=[]) wandb.watch(model) - + for epoch in range(args.epoch): - - #Training loop start + + # Training loop start trainingloss = [] model.train() - for x, y in traindata: + for x, y in trainloader: x, y = x.to(device), y.to(device) pred = model.forward(x) - + loss = criterion(y, pred) loss.backward() - + optimizer.step() optimizer.zero_grad(set_to_none=True) trainingloss.append(loss.item()) - + evalloss = [] - #Eval loop start + # Eval loop start model.eval() with th.no_grad(): for x, y in valiloader: - x = x.to(device) + x, y = x.to(device), y.to(device) pred = model.forward(x) loss = criterion(y, pred) evalloss.append(loss.item()) - + wandb.log({ 'Epoch': epoch, 'Train loss': np.mean(trainingloss), 'Evaluation Loss': np.mean(evalloss) }) - if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/test.ipynb b/test.ipynb deleted file mode 100644 index 0b50544..0000000 --- a/test.ipynb +++ /dev/null @@ -1,76 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "import argparse" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "usage: [-h] [--datafolder DATAFOLDER]\n", - ": error: unrecognized arguments: --f=/home/magnus/.local/share/jupyter/runtime/kernel-v3fc3d3b04bd8d83becf1be5eacf19e7bf46887012.json\n" - ] - }, - { - "ename": "SystemExit", - "evalue": "2", - "output_type": "error", - "traceback": [ - "An exception has occurred, use %tb to see the full traceback.\n", - "\u001b[0;31mSystemExit\u001b[0m\u001b[0;31m:\u001b[0m 2\n" - ] - } - ], - "source": [ - "parser = argparse.ArgumentParser(\n", - " prog='',\n", - " description='',\n", - " epilog='',\n", - " )\n", - "parser.add_argument('--datafolder', type=str, default='Data/', help='Path to where data will be saved during training.')\n", - "args = parser.parse_args()\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(args)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "env", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/tests/test_createfolders.py b/tests/test_createfolders.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/__init__.py b/utils/__init__.py index f1c5c8b..6ea6cde 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,4 +1,6 @@ +__all__ = ["createfolders", "load_data", "load_model", "MetricWrapper"] + +from .createfolders import createfolders +from .load_data import load_data from .load_metric import MetricWrapper from .load_model import load_model -from .load_data import load_data -from .createfolders import createfolders \ No newline at end of file diff --git a/utils/createfolders.py b/utils/createfolders.py index f8de995..cdc3d4b 100644 --- a/utils/createfolders.py +++ b/utils/createfolders.py @@ -1,41 +1,62 @@ -import os -from tempfile import TemporaryDirectory import argparse +from pathlib import Path +from tempfile import TemporaryDirectory + + +def createfolders(*dirs: Path) -> None: + """ + Creates folders for storing data, results, model weights. -def createfolders(args) -> None: - ''' - Creates folders for storing data, results, model weights. - Parameters ---------- args ArgParse object containing string paths to be created - - ''' - - if not os.path.exists(args.datafolder): - os.makedirs(args.datafolder) - print(f'Created a folder at {args.datafolder}') - - if not os.path.exists(args.resultfolder): - os.makedirs(args.resultfolder) - print(f'Created a folder at {args.resultfolder}') - - if not os.path.exists(args.modelfolder): - os.makedirs(args.modelfolder) - print(f'Created a folder at {args.modelfolder}') + """ + + for dir in dirs: + dir.mkdir(parents=True, exist_ok=True) def test_createfolders(): - with TemporaryDirectory(dir = 'tmp/') as temp_dir: + with TemporaryDirectory() as temp_dir: + temp_dir = Path(temp_dir) + parser = argparse.ArgumentParser() - #Structuture related values - parser.add_argument('--datafolder', type=str, default=os.path.join(temp_dir, 'Data/'), help='Path to where data will be saved during training.') - parser.add_argument('--resultfolder', type=str, default=os.path.join(temp_dir, 'Results/'), help='Path to where results will be saved during evaluation.') - parser.add_argument('--modelfolder', type=str, default=os.path.join(temp_dir, 'Experiments/'), help='Path to where model weights will be saved at the end of training.') - - args = parser.parse_args() - createfolders(args) - - return \ No newline at end of file + + # Structuture related values + parser.add_argument( + "--datafolder", + type=Path, + default=temp_dir / "Data", + help="Path to where data will be saved during training.", + ) + parser.add_argument( + "--resultfolder", + type=Path, + default=temp_dir / "Results", + help="Path to where results will be saved during evaluation.", + ) + parser.add_argument( + "--modelfolder", + type=Path, + default=temp_dir / "Experiments", + help="Path to where model weights will be saved at the end of training.", + ) + + args = parser.parse_args( + [ + "--datafolder", + temp_dir / "Data", + "--resultfolder", + temp_dir / "Results", + "--modelfolder", + temp_dir / "Experiments", + ] + ) + + createfolders(args.datafolder, args.resultfolder, args.modelfolder) + + assert (temp_dir / "Data").exists() + assert (temp_dir / "Results").exists() + assert (temp_dir / "Experiments").exists() diff --git a/utils/dataloaders/__init__.py b/utils/dataloaders/__init__.py new file mode 100644 index 0000000..df404f7 --- /dev/null +++ b/utils/dataloaders/__init__.py @@ -0,0 +1,3 @@ +__all__ = ["USPSDataset0_6"] + +from .usps_0_6 import USPSDataset0_6 diff --git a/utils/dataloaders/svhn.py b/utils/dataloaders/svhn.py index 3a38f28..be9d09d 100644 --- a/utils/dataloaders/svhn.py +++ b/utils/dataloaders/svhn.py @@ -1,11 +1,12 @@ from torch.utils.data import Dataset + class SVHN(Dataset): def __init__(self): super().__init__() - + def __len__(self): - return - + return + def __getitem__(self, index): - return \ No newline at end of file + return diff --git a/utils/dataloaders/usps_0_6.py b/utils/dataloaders/usps_0_6.py new file mode 100644 index 0000000..4e68191 --- /dev/null +++ b/utils/dataloaders/usps_0_6.py @@ -0,0 +1,134 @@ +""" +Dataset class for USPS dataset with labels 0-6. + +This module contains the Dataset class for the USPS dataset with labels 0-6. +""" + +from pathlib import Path + +import h5py as h5 +import numpy as np +from torch.utils.data import Dataset + + +class USPSDataset0_6(Dataset): + """ + Dataset class for USPS dataset with labels 0-6. + + Args + ---- + data_path : pathlib.Path + Path to the USPS dataset file. + train : bool, optional + Mode of the dataset. + transform : callable, optional + A function/transform that takes in a sample and returns a transformed version. + download : bool, optional + Whether to download the Dataset. + + Attributes + ---------- + path : pathlib.Path + Path to the USPS dataset file. + mode : str + Mode of the dataset, either train or test. + transform : callable + A function/transform that takes in a sample and returns a transformed version. + idx : numpy.ndarray + Indices of samples with labels 0-6. + num_classes : int + Number of classes in the dataset + + Methods + ------- + _index() + Get indices of samples with labels 0-6. + _load_data(idx) + Load data and target label from the dataset. + __len__() + Get the number of samples in the dataset. + __getitem__(idx) + Get a sample from the dataset. + + Examples + -------- + >>> from src.datahandlers import USPSDataset0_6 + >>> dataset = USPSDataset0_6(path="data/usps.h5", mode="train") + >>> len(dataset) + 5460 + >>> data, target = dataset[0] + >>> data.shape + (16, 16) + >>> target + 6 + """ + + def __init__( + self, + data_path: Path, + train: bool = False, + transform=None, + download: bool = False, + ): + super().__init__() + self.path = list(data_path.glob("*.h5"))[0] + self.transform = transform + self.num_classes = 7 + + if download: + raise NotImplementedError("Download functionality not implemented.") + + self.mode = "train" if train else "test" + self.idx = self._index() + + def _index(self): + with h5.File(self.path, "r") as f: + labels = f[self.mode]["target"][:] + + # Get indices of samples with labels 0-6 + mask = labels <= 6 + idx = np.where(mask)[0] + + return idx + + def _load_data(self, idx): + with h5.File(self.path, "r") as f: + data = f[self.mode]["data"][idx] + label = f[self.mode]["target"][idx] + + return data, label + + def __len__(self): + return len(self.idx) + + def __getitem__(self, idx): + data, target = self._load_data(self.idx[idx]) + + data = data.reshape(16, 16) + + # one hot encode the target + target = np.eye(self.num_classes, dtype=np.float32)[target] + + # Add channel dimension + data = np.expand_dims(data, axis=0) + + if self.transform: + data = self.transform(data) + + return data, target + + +def test_uspsdataset0_6(): + import pytest + + datapath = Path("data/USPS/usps.h5") + + dataset = USPSDataset0_6(path=datapath, mode="train") + assert len(dataset) == 5460 + data, target = dataset[0] + assert data.shape == (16, 16) + assert target == 6 + + # Test for an invalid mode + with pytest.raises(ValueError): + USPSDataset0_6(path=datapath, mode="inference") diff --git a/utils/dataloaders/uspsh5_7_9.py b/utils/dataloaders/uspsh5_7_9.py new file mode 100644 index 0000000..a343554 --- /dev/null +++ b/utils/dataloaders/uspsh5_7_9.py @@ -0,0 +1,116 @@ +from torch.utils.data import Dataset +import numpy as np +import h5py +from torchvision import transforms +from PIL import Image +import torch + + +class USPSH5_Digit_7_9_Dataset(Dataset): + """ + Custom USPS dataset class that loads images with digits 7-9 from an .h5 file. + + Parameters + ---------- + h5_path : str + Path to the USPS `.h5` file. + + transform : callable, optional, default=None + A transform function to apply on images. If None, no transformation is applied. + + Attributes + ---------- + images : numpy.ndarray + The filtered images corresponding to digits 7-9. + + labels : numpy.ndarray + The filtered labels corresponding to digits 7-9. + + transform : callable, optional + A transform function to apply to the images. + """ + + def __init__(self, h5_path, mode, transform=None): + super().__init__() + """ + Initializes the USPS dataset by loading images and labels from the given `.h5` file. + + Parameters + ---------- + h5_path : str + Path to the USPS `.h5` file. + + transform : callable, optional, default=None + A transform function to apply on images. + """ + + self.transform = transform + self.mode = mode + self.h5_path = h5_path + # Load the dataset from the HDF5 file + with h5py.File(self.h5_path, "r") as hf: + images = hf[self.mode]["data"][:] + labels = hf[self.mode]["target"][:] + + # Filter only digits 7, 8, and 9 + mask = np.isin(labels, [7, 8, 9]) + self.images = images[mask] + self.labels = labels[mask] + + def __len__(self): + """ + Returns the total number of samples in the dataset. + + Returns + ------- + int + The number of images in the dataset. + """ + return len(self.images) + + def __getitem__(self, id): + """ + Returns a sample from the dataset given an index. + + Parameters + ---------- + idx : int + The index of the sample to retrieve. + + Returns + ------- + tuple + - image (PIL Image): The image at the specified index. + - label (int): The label corresponding to the image. + """ + # Convert to PIL Image (USPS images are typically grayscale 16x16) + image = Image.fromarray(self.images[id].astype(np.uint8), mode="L") + label = int(self.labels[id]) # Convert label to integer + + if self.transform: + image = self.transform(image) + + return image, label + + +def main(): + # Example Usage: + transform = transforms.Compose([ + transforms.Resize((16, 16)), # Ensure images are 16x16 + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1] + ]) + + # Load the dataset + dataset = USPSH5_Digit_7_9_Dataset(h5_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git/usps.h5", mode="train", transform=transform) + data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True) + batch = next(iter(data_loader)) # grab a batch from the dataloader + img, label = batch + print(img.shape) + print(label.shape) + + # Check dataset size + print(f"Dataset size: {len(dataset)}") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/utils/load_data.py b/utils/load_data.py index b664764..ac1bcfd 100644 --- a/utils/load_data.py +++ b/utils/load_data.py @@ -1,5 +1,11 @@ from torch.utils.data import Dataset -def load_data(dataset:str) -> Dataset: - - raise ValueError(f'Dataset: {dataset} not implemented. \nCheck the documentation for implemented metrics, or check your spelling') \ No newline at end of file +from .dataloaders import USPSDataset0_6 + + +def load_data(dataset: str, *args, **kwargs) -> Dataset: + match dataset.lower(): + case "usps_0-6": + return USPSDataset0_6(*args, **kwargs) + case _: + raise ValueError(f"Dataset: {dataset} not implemented.") diff --git a/utils/load_metric.py b/utils/load_metric.py index cdce0b7..f166c25 100644 --- a/utils/load_metric.py +++ b/utils/load_metric.py @@ -1,47 +1,59 @@ -import copy -import numpy as np -import torch.nn as nn -from metrics import EntropyPrediction +import copy + +import numpy as np +import torch.nn as nn + +from .metrics import EntropyPrediction class MetricWrapper(nn.Module): - def __init__(self, - EntropyPred:bool = True, - F1Score:bool = True, - Recall:bool = True, - Precision:bool = True, - Accuracy:bool = True): + def __init__(self, *metrics): super().__init__() self.metrics = {} - - if EntropyPred: - self.metrics['Entropy of Predictions'] = EntropyPrediction() - - if F1Score: - self.metrics['F1 Score'] = None - - if Recall: - self.metrics['Recall'] = None - - if Precision: - self.metrics['Precision'] = None - - if Accuracy: - self.metrics['Accuracy'] = None - + + for metric in metrics: + self.metrics[metric] = self._get_metric(metric) + self.tmp_scores = copy.deepcopy(self.metrics) for key in self.tmp_scores: self.tmp_scores[key] = [] + def _get_metric(self, key): + """ + Get the metric function based on the key + + Args + ---- + key (str): metric name + + Returns + ------- + metric (callable): metric function + """ + + match key.lower(): + case "entropy": + return EntropyPrediction() + case "f1": + raise NotImplementedError("F1 score not implemented yet") + case "recall": + raise NotImplementedError("Recall score not implemented yet") + case "precision": + raise NotImplementedError("Precision score not implemented yet") + case "accuracy": + raise NotImplementedError("Accuracy score not implemented yet") + case _: + raise ValueError(f"Metric {key} not supported") + def __call__(self, y_true, y_pred): for key in self.metrics: self.tmp_scores[key].append(self.metrics[key](y_true, y_pred)) - + def __getmetrics__(self): return_metrics = {} for key in self.metrics: return_metrics[key] = np.mean(self.tmp_scores[key]) - + return return_metrics def __resetvalues__(self): diff --git a/utils/load_model.py b/utils/load_model.py index 93d0491..7e55699 100644 --- a/utils/load_model.py +++ b/utils/load_model.py @@ -1,9 +1,15 @@ -import torch.nn as nn -from models import MagnusModel +import torch.nn as nn -def load_model(modelname:str) -> nn.Module: - - if modelname == 'MagnusModel': - return MagnusModel() - else: - raise ValueError(f'Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling') \ No newline at end of file +from .models import ChristianModel, MagnusModel + + +def load_model(modelname: str, *args, **kwargs) -> nn.Module: + match modelname.lower(): + case "magnusmodel": + return MagnusModel(*args, **kwargs) + case "christianmodel": + return ChristianModel(*args, **kwargs) + case _: + raise ValueError( + f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling" + ) diff --git a/utils/metrics/EntropyPred.py b/utils/metrics/EntropyPred.py index 08e4406..97058e7 100644 --- a/utils/metrics/EntropyPred.py +++ b/utils/metrics/EntropyPred.py @@ -1,10 +1,9 @@ -import torch.nn as nn +import torch.nn as nn class EntropyPrediction(nn.Module): def __init__(self): super().__init__() - + def __call__(self, y_true, y_false): - - return \ No newline at end of file + return diff --git a/utils/metrics/F1.py b/utils/metrics/F1.py new file mode 100644 index 0000000..16c87f8 --- /dev/null +++ b/utils/metrics/F1.py @@ -0,0 +1,96 @@ +import torch.nn as nn +import torch + + +class F1Score(nn.Module): + """ + F1 Score implementation with direct averaging inside the compute method. + + Parameters + ---------- + num_classes : int + Number of classes. + + Attributes + ---------- + num_classes : int + The number of classes. + + tp : torch.Tensor + Tensor for True Positives (TP) for each class. + + fp : torch.Tensor + Tensor for False Positives (FP) for each class. + + fn : torch.Tensor + Tensor for False Negatives (FN) for each class. + """ + def __init__(self, num_classes): + """ + Initializes the F1Score object, setting up the necessary state variables. + + Parameters + ---------- + num_classes : int + The number of classes in the classification task. + + """ + + super().__init__() + + self.num_classes = num_classes + + # Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN) + self.tp = torch.zeros(num_classes) + self.fp = torch.zeros(num_classes) + self.fn = torch.zeros(num_classes) + + def update(self, preds, target): + """ + Update the variables with predictions and true labels. + + Parameters + ---------- + preds : torch.Tensor + Predicted logits (shape: [batch_size, num_classes]). + + target : torch.Tensor + True labels (shape: [batch_size]). + """ + preds = torch.argmax(preds, dim=1) + + # Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class + for i in range(self.num_classes): + self.tp[i] += torch.sum((preds == i) & (target == i)).float() + self.fp[i] += torch.sum((preds == i) & (target != i)).float() + self.fn[i] += torch.sum((preds != i) & (target == i)).float() + + def compute(self): + """ + Compute the F1 score. + + Returns + ------- + torch.Tensor + The computed F1 score. + """ + + # Compute F1 score based on the specified averaging method + f1_score = 2 * torch.sum(self.tp) / (2 * torch.sum(self.tp) + torch.sum(self.fp) + torch.sum(self.fn)) + + return f1_score + + +def test_f1score(): + f1_metric = F1Score(num_classes=3) + preds = torch.tensor([[0.8, 0.1, 0.1], + [0.2, 0.7, 0.1], + [0.2, 0.3, 0.5], + [0.1, 0.2, 0.7]]) + + target = torch.tensor([0, 1, 0, 2]) + + f1_metric.update(preds, target) + assert f1_metric.tp.sum().item() > 0, "Expected some true positives." + assert f1_metric.fp.sum().item() > 0, "Expected some false positives." + assert f1_metric.fn.sum().item() > 0, "Expected some false negatives." diff --git a/utils/metrics/__init__.py b/utils/metrics/__init__.py index dd0baac..3afeee5 100644 --- a/utils/metrics/__init__.py +++ b/utils/metrics/__init__.py @@ -1 +1,4 @@ -from .EntropyPred import EntropyPrediction \ No newline at end of file +__all__ = ["EntropyPrediction", "Recall"] + +from .EntropyPred import EntropyPrediction +from .recall import Recall diff --git a/utils/metrics/recall.py b/utils/metrics/recall.py new file mode 100644 index 0000000..4aaae43 --- /dev/null +++ b/utils/metrics/recall.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn + + +def one_hot_encode(y_true, num_classes): + """One-hot encode the target tensor. + + Args + ---- + y_true : torch.Tensor + Target tensor. + num_classes : int + Number of classes in the dataset. + + Returns + ------- + torch.Tensor + One-hot encoded tensor. + """ + + y_onehot = torch.zeros(y_true.size(0), num_classes) + y_onehot.scatter_(1, y_true.unsqueeze(1), 1) + return y_onehot + + +class Recall(nn.Module): + def __init__(self, num_classes): + super().__init__() + + self.num_classes = num_classes + + def forward(self, y_true, y_pred): + true_onehot = one_hot_encode(y_true, self.num_classes) + pred_onehot = one_hot_encode(y_pred, self.num_classes) + + true_positives = (true_onehot * pred_onehot).sum() + + false_negatives = torch.sum(~pred_onehot[true_onehot.bool()].bool()) + + recall = true_positives / (true_positives + false_negatives) + + return recall + + +def test_recall(): + recall = Recall(7) + + y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6]) + y_pred = torch.tensor([2, 1, 2, 1, 4, 5, 6]) + + recall_score = recall(y_true, y_pred) + + assert recall_score.allclose(torch.tensor(0.7143), atol=1e-5), f"Recall Score: {recall_score.item()}" + + +def test_one_hot_encode(): + num_classes = 7 + + y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6]) + y_onehot = one_hot_encode(y_true, num_classes) + + assert y_onehot.shape == (7, 7), f"Shape: {y_onehot.shape}" diff --git a/utils/models/__init__.py b/utils/models/__init__.py index d58bcdb..7cbae91 100644 --- a/utils/models/__init__.py +++ b/utils/models/__init__.py @@ -1 +1,4 @@ -from .magnus_model import MagnusModel \ No newline at end of file +__all__ = ["MagnusModel", "ChristianModel"] + +from .christian_model import ChristianModel +from .magnus_model import MagnusModel diff --git a/utils/models/christian_model.py b/utils/models/christian_model.py new file mode 100644 index 0000000..9bdd2da --- /dev/null +++ b/utils/models/christian_model.py @@ -0,0 +1,92 @@ +import pytest +import torch +import torch.nn as nn + + +class CNNBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + padding=1, + ) + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.maxpool(x) + x = self.relu(x) + return x + + +class ChristianModel(nn.Module): + """Simple CNN model for image classification. + + Args + ---- + in_channels : int + Number of input channels. + num_classes : int + Number of classes in the dataset. + + Processing Images + ----------------- + Input: (N, C, H, W) + N: Batch size + C: Number of input channels + H: Height of the input image + W: Width of the input image + + Example: + For grayscale images, C = 1. + + Input Image Shape: (5, 1, 16, 16) + CNN1 Output Shape: (5, 50, 8, 8) + CNN2 Output Shape: (5, 100, 4, 4) + FC Output Shape: (5, num_classes) + """ + def __init__(self, in_channels, num_classes): + super().__init__() + + self.cnn1 = CNNBlock(in_channels, 50) + self.cnn2 = CNNBlock(50, 100) + + self.fc1 = nn.Linear(100 * 4 * 4, num_classes) + self.softmax = nn.Softmax(dim=1) + + def forward(self, x): + x = self.cnn1(x) + x = self.cnn2(x) + + x = x.view(x.size(0), -1) + x = self.fc1(x) + x = self.softmax(x) + + return x + + +@pytest.mark.parametrize("in_channels, num_classes", [(1, 6), (3, 6)]) +def test_christian_model(in_channels, num_classes): + n, c, h, w = 5, in_channels, 16, 16 + + model = ChristianModel(c, num_classes) + + x = torch.randn(n, c, h, w) + y = model(x) + + assert y.shape == (n, num_classes), f"Shape: {y.shape}" + assert y.sum(dim=1).allclose(torch.ones(n), atol=1e-5), f"Softmax output should sum to 1, but got: {y.sum()}" + + +if __name__ == "__main__": + + model = ChristianModel(3, 7) + + x = torch.randn(3, 3, 16, 16) + y = model(x) + + print(y) diff --git a/utils/models/magnus_model.py b/utils/models/magnus_model.py index cbc1d77..6117a94 100644 --- a/utils/models/magnus_model.py +++ b/utils/models/magnus_model.py @@ -1,8 +1,9 @@ -import torch.nn as nn +import torch.nn as nn + class MagnusModel(nn.Module): def __init__(self): super().__init__() - + def forward(self, x): - return \ No newline at end of file + return