From 744699fbb6212bba5d80344cd1f39ea69e48d594 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Thu, 6 Feb 2025 16:11:26 +0100 Subject: [PATCH 1/4] To not track .rst files generated from sphinx-autoapi --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index b564848..4bc8724 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ _build/ bin/ wandb/ wandb_api.py +doc/autoapi # Byte-compiled / optimized / DLL files __pycache__/ From a4214d26fc99daaefc072c7031f6597cde06d55b Mon Sep 17 00:00:00 2001 From: salomaestro Date: Thu, 6 Feb 2025 16:12:02 +0100 Subject: [PATCH 2/4] Had to modify to fit in the overall format --- utils/dataloaders/uspsh5_7_9.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/utils/dataloaders/uspsh5_7_9.py b/utils/dataloaders/uspsh5_7_9.py index 98cbd03..adeabc6 100644 --- a/utils/dataloaders/uspsh5_7_9.py +++ b/utils/dataloaders/uspsh5_7_9.py @@ -30,7 +30,9 @@ class USPSH5_Digit_7_9_Dataset(Dataset): A transform function to apply to the images. """ - def __init__(self, h5_path, mode, transform=None): + filename = "usps.h5" + + def __init__(self, data_path, train=False, transform=None, download=False): super().__init__() """ Initializes the USPS dataset by loading images and labels from the given `.h5` file. @@ -45,8 +47,8 @@ def __init__(self, h5_path, mode, transform=None): """ self.transform = transform - self.mode = mode - self.h5_path = h5_path + self.mode = "train" if train else "test" + self.h5_path = data_path / self.filename # Load the dataset from the HDF5 file with h5py.File(self.h5_path, "r") as hf: images = hf[self.mode]["data"][:] From 1add669d260b58a4a025aa3a4c3030b6ea7db5fa Mon Sep 17 00:00:00 2001 From: salomaestro Date: Thu, 6 Feb 2025 16:12:20 +0100 Subject: [PATCH 3/4] Create new test that verifies basic functionality of all datasets --- tests/test_dataloaders.py | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index 9f58ae4..527890c 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -1,8 +1,41 @@ -from utils.dataloaders.usps_0_6 import USPSDataset0_6 +from pathlib import Path + +import numpy as np +import pytest +import torch +from PIL import Image +from torchvision import transforms + +from utils.dataloaders import MNISTDataset0_3, USPSDataset0_6, USPSH5_Digit_7_9_Dataset +from utils.load_data import load_data + + +@pytest.mark.parametrize( + "data_name, expected", + [ + ("usps_0-6", USPSDataset0_6), + ("usps_7-9", USPSH5_Digit_7_9_Dataset), + ("mnist_0-3", MNISTDataset0_3), + # TODO: Add more datasets here + ], +) +def test_load_data(data_name, expected): + dataset = load_data( + data_name, + data_path=Path("data"), + download=True, + transform=transforms.ToTensor(), + ) + assert isinstance(dataset, expected) + assert len(dataset) > 0 + assert isinstance(dataset[0], tuple) + assert isinstance(dataset[0][0], torch.Tensor) + assert isinstance( + dataset[0][1], (int, torch.Tensor, np.ndarray) + ) # Should probably restrict this to only int or one-hot encoded tensor or array for consistency. def test_uspsdataset0_6(): - from pathlib import Path from tempfile import TemporaryDirectory import h5py From 6be010a51b632392e729f643b48faf98b288cde1 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Thu, 13 Feb 2025 13:13:50 +0100 Subject: [PATCH 4/4] Remove double function definition --- utils/dataloaders/uspsh5_7_9.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/utils/dataloaders/uspsh5_7_9.py b/utils/dataloaders/uspsh5_7_9.py index 1f745a5..35167a4 100644 --- a/utils/dataloaders/uspsh5_7_9.py +++ b/utils/dataloaders/uspsh5_7_9.py @@ -32,10 +32,6 @@ class USPSH5_Digit_7_9_Dataset(Dataset): A transform function to apply to the images. """ - - - def __init__(self, data_path, train=False, transform=None, download=False): - def __init__(self, data_path, train=False, transform=None): super().__init__() """