diff --git a/main.py b/main.py index fe563f5..ff7664f 100644 --- a/main.py +++ b/main.py @@ -66,14 +66,14 @@ def main(): "--modelname", type=str, default="MagnusModel", - choices=["MagnusModel", "ChristianModel"], + choices=["MagnusModel", "ChristianModel", "SolveigModel"], help="Model which to be trained on", ) parser.add_argument( "--dataset", type=str, default="svhn", - choices=["svhn", "usps_0-6"], + choices=["svhn", "usps_0-6", "uspsh5_7_9"], help="Which dataset to train the model on.", ) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index c25d861..1650e01 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,4 +1,4 @@ -from utils.metrics import Recall +from utils.metrics import Recall, F1Score def test_recall(): @@ -14,3 +14,19 @@ def test_recall(): assert recall_score.allclose(torch.tensor(0.7143), atol=1e-5), ( f"Recall Score: {recall_score.item()}" ) + + +def test_f1score(): + import torch + + f1_metric = F1Score(num_classes=3) + preds = torch.tensor( + [[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.3, 0.5], [0.1, 0.2, 0.7]] + ) + + target = torch.tensor([0, 1, 0, 2]) + + f1_metric.update(preds, target) + assert f1_metric.tp.sum().item() > 0, "Expected some true positives." + assert f1_metric.fp.sum().item() > 0, "Expected some false positives." + assert f1_metric.fn.sum().item() > 0, "Expected some false negatives." diff --git a/utils/dataloaders/__init__.py b/utils/dataloaders/__init__.py index df404f7..bb97adc 100644 --- a/utils/dataloaders/__init__.py +++ b/utils/dataloaders/__init__.py @@ -1,3 +1,4 @@ -__all__ = ["USPSDataset0_6"] +__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset"] from .usps_0_6 import USPSDataset0_6 +from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset \ No newline at end of file diff --git a/utils/load_data.py b/utils/load_data.py index ac1bcfd..f54e94a 100644 --- a/utils/load_data.py +++ b/utils/load_data.py @@ -1,11 +1,13 @@ from torch.utils.data import Dataset -from .dataloaders import USPSDataset0_6 +from .dataloaders import USPSDataset0_6, USPSH5_Digit_7_9_Dataset def load_data(dataset: str, *args, **kwargs) -> Dataset: match dataset.lower(): case "usps_0-6": return USPSDataset0_6(*args, **kwargs) + case "usps_7-9": + return USPSH5_Digit_7_9_Dataset(*args, **kwargs) case _: raise ValueError(f"Dataset: {dataset} not implemented.") diff --git a/utils/load_metric.py b/utils/load_metric.py index f166c25..cc687c3 100644 --- a/utils/load_metric.py +++ b/utils/load_metric.py @@ -3,7 +3,7 @@ import numpy as np import torch.nn as nn -from .metrics import EntropyPrediction +from .metrics import EntropyPrediction, F1Score class MetricWrapper(nn.Module): @@ -35,7 +35,7 @@ def _get_metric(self, key): case "entropy": return EntropyPrediction() case "f1": - raise NotImplementedError("F1 score not implemented yet") + raise F1Score() case "recall": raise NotImplementedError("Recall score not implemented yet") case "precision": diff --git a/utils/load_model.py b/utils/load_model.py index 7e55699..21d9d03 100644 --- a/utils/load_model.py +++ b/utils/load_model.py @@ -1,6 +1,6 @@ import torch.nn as nn -from .models import ChristianModel, MagnusModel +from .models import ChristianModel, MagnusModel, SolveigModel def load_model(modelname: str, *args, **kwargs) -> nn.Module: @@ -9,6 +9,8 @@ def load_model(modelname: str, *args, **kwargs) -> nn.Module: return MagnusModel(*args, **kwargs) case "christianmodel": return ChristianModel(*args, **kwargs) + case "solveigmodel": + return SolveigModel(*args, **kwargs) case _: raise ValueError( f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling" diff --git a/utils/metrics/F1.py b/utils/metrics/F1.py index d13bddb..36e5e34 100644 --- a/utils/metrics/F1.py +++ b/utils/metrics/F1.py @@ -85,16 +85,3 @@ def compute(self): return f1_score - -def test_f1score(): - f1_metric = F1Score(num_classes=3) - preds = torch.tensor( - [[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.3, 0.5], [0.1, 0.2, 0.7]] - ) - - target = torch.tensor([0, 1, 0, 2]) - - f1_metric.update(preds, target) - assert f1_metric.tp.sum().item() > 0, "Expected some true positives." - assert f1_metric.fp.sum().item() > 0, "Expected some false positives." - assert f1_metric.fn.sum().item() > 0, "Expected some false negatives." diff --git a/utils/metrics/__init__.py b/utils/metrics/__init__.py index 3afeee5..f623943 100644 --- a/utils/metrics/__init__.py +++ b/utils/metrics/__init__.py @@ -1,4 +1,5 @@ -__all__ = ["EntropyPrediction", "Recall"] +__all__ = ["EntropyPrediction", "Recall", "F1Score"] from .EntropyPred import EntropyPrediction +from .F1 import F1Score from .recall import Recall diff --git a/utils/models/__init__.py b/utils/models/__init__.py index 7cbae91..8573991 100644 --- a/utils/models/__init__.py +++ b/utils/models/__init__.py @@ -1,4 +1,5 @@ -__all__ = ["MagnusModel", "ChristianModel"] +__all__ = ["MagnusModel", "ChristianModel", "SolveigModel"] from .christian_model import ChristianModel from .magnus_model import MagnusModel +from .solveig_model import SolveigModel diff --git a/utils/models/solveig_model.py b/utils/models/solveig_model.py new file mode 100644 index 0000000..c16dbaf --- /dev/null +++ b/utils/models/solveig_model.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn + + +class SolveigModel(nn.Module): + """ + A Convolutional Neural Network model for classification. + + Args + ---- + image_shape : tuple(int, int, int) + Shape of the input image (C, H, W). + num_classes : int + Number of classes in the dataset. + + Attributes: + ----------- + conv_block1 : nn.Sequential + First convolutional block containing a convolutional layer, ReLU activation, and max-pooling. + conv_block2 : nn.Sequential + Second convolutional block containing a convolutional layer and ReLU activation. + conv_block3 : nn.Sequential + Third convolutional block containing a convolutional layer and ReLU activation. + fc1 : nn.Linear + Fully connected layer that outputs the final classification scores. + """ + + def __init__(self, image_shape, num_classes): + super().__init__() + + C, *_ = image_shape + + # Define the first convolutional block (conv + relu + maxpool) + self.conv_block1 = nn.Sequential( + nn.Conv2d(in_channels=C, out_channels=25, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2) + ) + + # Define the second convolutional block (conv + relu) + self.conv_block2 = nn.Sequential( + nn.Conv2d(in_channels=25, out_channels=50, kernel_size=3, padding=1), + nn.ReLU() + ) + + # Define the third convolutional block (conv + relu) + self.conv_block3 = nn.Sequential( + nn.Conv2d(in_channels=50, out_channels=100, kernel_size=3, padding=1), + nn.ReLU() + ) + + self.fc1 = nn.Linear(100 * 8 * 8, num_classes) + + def forward(self, x): + x = self.conv_block1(x) + x = self.conv_block2(x) + x = self.conv_block3(x) + x = torch.flatten(x, 1) + + x = self.fc1(x) + x = nn.Softmax(x) + + return x + + +if __name__ == "__main__": + + x = torch.randn(1,3, 16, 16) + + model = SolveigModel(x.shape[1:], 3) + + y = model(x) + + print(y)