diff --git a/.gitignore b/.gitignore index 58b934b..3afaf91 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ _build/ bin/* wandb/* wandb_api.py +doc/autoapi #Magnus specific job* diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index 32634d6..a0c1bf6 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -1,8 +1,41 @@ -from utils.dataloaders.usps_0_6 import USPSDataset0_6 +from pathlib import Path + +import numpy as np +import pytest +import torch +from PIL import Image +from torchvision import transforms + +from utils.dataloaders import MNISTDataset0_3, USPSDataset0_6, USPSH5_Digit_7_9_Dataset +from utils.load_data import load_data + + +@pytest.mark.parametrize( + "data_name, expected", + [ + ("usps_0-6", USPSDataset0_6), + ("usps_7-9", USPSH5_Digit_7_9_Dataset), + ("mnist_0-3", MNISTDataset0_3), + # TODO: Add more datasets here + ], +) +def test_load_data(data_name, expected): + dataset = load_data( + data_name, + data_path=Path("data"), + download=True, + transform=transforms.ToTensor(), + ) + assert isinstance(dataset, expected) + assert len(dataset) > 0 + assert isinstance(dataset[0], tuple) + assert isinstance(dataset[0][0], torch.Tensor) + assert isinstance( + dataset[0][1], (int, torch.Tensor, np.ndarray) + ) # Should probably restrict this to only int or one-hot encoded tensor or array for consistency. def test_uspsdataset0_6(): - from pathlib import Path from tempfile import TemporaryDirectory import h5py diff --git a/utils/dataloaders/uspsh5_7_9.py b/utils/dataloaders/uspsh5_7_9.py index 4d63255..35167a4 100644 --- a/utils/dataloaders/uspsh5_7_9.py +++ b/utils/dataloaders/uspsh5_7_9.py @@ -50,6 +50,8 @@ def __init__(self, data_path, train=False, transform=None): self.filepath = path / self.filename self.transform = transform self.mode = "train" if train else "test" + self.h5_path = data_path / self.filename + # Load the dataset from the HDF5 file with h5py.File(self.filepath, "r") as hf: images = hf[self.mode]["data"][:]