diff --git a/main.py b/main.py index 9582e9d..8b4fd93 100644 --- a/main.py +++ b/main.py @@ -1,13 +1,11 @@ -from pathlib import Path - import numpy as np import torch as th import torch.nn as nn +import wandb from torch.utils.data import DataLoader from torchvision import transforms from tqdm import tqdm -import wandb from utils import MetricWrapper, createfolders, get_args, load_data, load_model @@ -32,42 +30,20 @@ def main(): device = args.device if args.dataset.lower() in ["usps_0-6", "uspsh5_7_9"]: - augmentations = transforms.Compose( + transform = transforms.Compose( [ transforms.Resize((16, 16)), transforms.ToTensor(), ] ) else: - augmentations = transforms.Compose([transforms.ToTensor()]) + transform = transforms.Compose([transforms.ToTensor()]) - # Dataset - assert ( - args.validation_split_percentage < 1.0 and args.validation_split_percentage > 0 - ), "Validation split should be in interval (0,1)" - traindata = load_data( - args.dataset, - split="train", - split_percentage=args.validation_split_percentage, - data_path=args.datafolder, - download=args.download_data, - transform=augmentations, - ) - validata = load_data( - args.dataset, - split="validation", - split_percentage=args.validation_split_percentage, - data_path=args.datafolder, - download=args.download_data, - transform=augmentations, - ) - testdata = load_data( + traindata, validata, testdata = load_data( args.dataset, - split="test", - split_percentage=args.validation_split_percentage, - data_path=args.datafolder, - download=args.download_data, - transform=augmentations, + data_dir=args.datafolder, + transform=transform, + val_size=args.val_size, ) metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes) diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index 9f58ae4..32634d6 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -17,18 +17,25 @@ def test_uspsdataset0_6(): # Create a h5 file with h5py.File(tf, "w") as f: + targets = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0]) + indices = np.arange(len(targets)) # Populate the file with data f["train/data"] = np.random.rand(10, 16 * 16) - f["train/target"] = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0]) + f["train/target"] = targets trans = transforms.Compose( [ - transforms.Resize((16, 16)), # At least for USPS + transforms.Resize((16, 16)), transforms.ToTensor(), ] ) - dataset = USPSDataset0_6(data_path=tempdir, train=True, transform=trans) + dataset = USPSDataset0_6( + data_path=tempdir, + sample_ids=indices, + train=True, + transform=trans, + ) assert len(dataset) == 10 data, target = dataset[0] assert data.shape == (1, 16, 16) - assert all(target == np.array([0, 0, 0, 0, 0, 0, 1])) + assert target == 6 diff --git a/tests/test_models.py b/tests/test_models.py index efc5412..4dd3fa8 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -32,4 +32,3 @@ def test_jan_model(image_shape, num_classes): y = model(x) assert y.shape == (n, num_classes), f"Shape: {y.shape}" - diff --git a/utils/arg_parser.py b/utils/arg_parser.py index 5ced4d0..4d001d3 100644 --- a/utils/arg_parser.py +++ b/utils/arg_parser.py @@ -33,12 +33,6 @@ def get_args(): help="Whether model should be saved or not.", ) - parser.add_argument( - "--download-data", - action="store_true", - help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.", - ) - # Data/Model specific values parser.add_argument( "--modelname", @@ -55,7 +49,7 @@ def get_args(): help="Which dataset to train the model on.", ) parser.add_argument( - "--validation_split_percentage", + "--val_size", type=float, default=0.2, help="Percentage of training dataset to be used as validation dataset - must be within (0,1).", diff --git a/utils/dataloaders/__init__.py b/utils/dataloaders/__init__.py index 1f506e6..0d0bcb6 100644 --- a/utils/dataloaders/__init__.py +++ b/utils/dataloaders/__init__.py @@ -1,5 +1,11 @@ -__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset", "MNISTDataset0_3"] +__all__ = [ + "USPSDataset0_6", + "USPSH5_Digit_7_9_Dataset", + "MNISTDataset0_3", + "Downloader", +] +from .download import Downloader from .mnist_0_3 import MNISTDataset0_3 from .usps_0_6 import USPSDataset0_6 from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset diff --git a/utils/dataloaders/datasources.py b/utils/dataloaders/datasources.py index f0d2e01..936d32e 100644 --- a/utils/dataloaders/datasources.py +++ b/utils/dataloaders/datasources.py @@ -17,3 +17,26 @@ "8ea070ee2aca1ac39742fdd1ef5ed118", ], } + +MNIST_SOURCE = { + "train_images": [ + "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", + "train-images-idx3-ubyte", + None, + ], + "train_labels": [ + "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", + "train-labels-idx1-ubyte", + None, + ], + "test_images": [ + "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", + "t10k-images-idx3-ubyte", + None, + ], + "test_labels": [ + "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", + "t10k-labels-idx1-ubyte", + None, + ], +} diff --git a/utils/dataloaders/download.py b/utils/dataloaders/download.py new file mode 100644 index 0000000..9f667a3 --- /dev/null +++ b/utils/dataloaders/download.py @@ -0,0 +1,183 @@ +import bz2 +import gzip +import hashlib +import os +from pathlib import Path +from tempfile import TemporaryDirectory +from urllib.request import urlretrieve + +import h5py as h5 +import numpy as np + +from .datasources import MNIST_SOURCE, USPS_SOURCE + + +class Downloader: + """ + Class to download and load the USPS dataset. + + Methods + ------- + mnist(data_dir: Path) -> tuple[np.ndarray, np.ndarray] + Download the MNIST dataset and save it as an HDF5 file to `data_dir`. + svhn(data_dir: Path) -> tuple[np.ndarray, np.ndarray] + Download the SVHN dataset and save it as an HDF5 file to `data_dir`. + usps(data_dir: Path) -> tuple[np.ndarray, np.ndarray] + Download the USPS dataset and save it as an HDF5 file to `data_dir`. + + Raises + ------ + NotImplementedError + If the download method is not implemented for the dataset. + + Examples + -------- + >>> from pathlib import Path + >>> from utils import Downloader + >>> dir = Path('tmp') + >>> dir.mkdir(exist_ok=True) + >>> train, test = Downloader().usps(dir) + """ + + def mnist(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]: + def _chech_is_downloaded(path: Path) -> bool: + path = path / "MNIST" + if path.exists(): + required_files = [MNIST_SOURCE[key][1] for key in MNIST_SOURCE.keys()] + if all([(path / file).exists() for file in required_files]): + print("MNIST Dataset already downloaded.") + return True + else: + return False + else: + path.mkdir(parents=True, exist_ok=True) + return False + + def _download_data(path: Path) -> None: + urls = {key: MNIST_SOURCE[key][0] for key in MNIST_SOURCE.keys()} + + for name, url in urls.items(): + file_path = os.path.join(path, url.split("/")[-1]) + if not os.path.exists( + file_path.replace(".gz", "") + ): # Avoid re-downloading + urlretrieve(url, file_path) + with gzip.open(file_path, "rb") as f_in: + with open(file_path.replace(".gz", ""), "wb") as f_out: + f_out.write(f_in.read()) + os.remove(file_path) # Remove compressed file + + def _get_labels(path: Path) -> np.ndarray: + with open(path, "rb") as f: + data = np.frombuffer(f.read(), dtype=np.uint8, offset=8) + return data + + if not _chech_is_downloaded(data_dir): + _download_data(data_dir) + + train_labels_path = data_dir / "MNIST" / MNIST_SOURCE["train_labels"][1] + test_labels_path = data_dir / "MNIST" / MNIST_SOURCE["test_labels"][1] + + train_labels = _get_labels(train_labels_path) + test_labels = _get_labels(test_labels_path) + + return train_labels, test_labels + + def svhn(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]: + raise NotImplementedError("SVHN download not implemented yet") + + def usps(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]: + """ + Download the USPS dataset and save it as an HDF5 file to `data_dir/usps.h5`. + """ + + def already_downloaded(path): + if not path.exists() or not path.is_file(): + return False + + with h5.File(path, "r") as f: + return "train" in f and "test" in f + + filename = data_dir / "usps.h5" + + if already_downloaded(filename): + with h5.File(filename, "r") as f: + return f["train"]["target"][:], f["test"]["target"][:] + + url_train, _, train_md5 = USPS_SOURCE["train"] + url_test, _, test_md5 = USPS_SOURCE["test"] + + # Using temporary directory ensures temporary files are deleted after use + with TemporaryDirectory() as tmp_dir: + train_path = Path(tmp_dir) / "train" + test_path = Path(tmp_dir) / "test" + + # Download the dataset and report the progress + urlretrieve(url_train, train_path, reporthook=self.__reporthook) + self.__check_integrity(train_path, train_md5) + train_targets = self.__extract_usps(train_path, filename, "train") + + urlretrieve(url_test, test_path, reporthook=self.__reporthook) + self.__check_integrity(test_path, test_md5) + test_targets = self.__extract_usps(test_path, filename, "test") + + return train_targets, test_targets + + def __extract_usps(self, src: Path, dest: Path, mode: str): + # Load the dataset and save it as an HDF5 file + with bz2.open(src) as fp: + raw = [line.decode().split() for line in fp.readlines()] + + tmp = [[x.split(":")[-1] for x in data[1:]] for data in raw] + + imgs = np.asarray(tmp, dtype=np.float32) + imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8) + + targets = [int(d[0]) - 1 for d in raw] + + with h5.File(dest, "a") as f: + f.create_dataset(f"{mode}/data", data=imgs, dtype=np.float32) + f.create_dataset(f"{mode}/target", data=targets, dtype=np.int32) + + return targets + + @staticmethod + def __reporthook(blocknum, blocksize, totalsize): + """ + Use this function to report download progress + for the urllib.request.urlretrieve function. + """ + + denom = 1024 * 1024 + readsofar = blocknum * blocksize + + if totalsize > 0: + percent = readsofar * 1e2 / totalsize + s = f"\r{int(percent):^3}% {readsofar / denom:.2f} of {totalsize / denom:.2f} MB" + print(s, end="", flush=True) + if readsofar >= totalsize: + print() + + @staticmethod + def __check_integrity(filepath, checksum): + """Check the integrity of the USPS dataset file. + + Args + ---- + filepath : pathlib.Path + Path to the USPS dataset file. + checksum : str + MD5 checksum of the dataset file. + + Returns + ------- + bool + True if the checksum of the file matches the expected checksum, False otherwise + """ + + file_hash = hashlib.md5(filepath.read_bytes()).hexdigest() + + if not checksum == file_hash: + raise ValueError( + f"File integrity check failed. Expected {checksum}, got {file_hash}" + ) diff --git a/utils/dataloaders/mnist_0_3.py b/utils/dataloaders/mnist_0_3.py index 1f8124d..52a5a28 100644 --- a/utils/dataloaders/mnist_0_3.py +++ b/utils/dataloaders/mnist_0_3.py @@ -1,154 +1,77 @@ -import gzip -import os -import urllib.request from pathlib import Path import numpy as np -import torch -from torch.utils.data import Dataset, random_split +from torch.utils.data import Dataset + +from .datasources import MNIST_SOURCE class MNISTDataset0_3(Dataset): """ - A custom dataset class for loading MNIST data, specifically for digits 0 through 3. - + A custom Dataset class for loading a subset of the MNIST dataset containing digits 0 to 3. Parameters ---------- data_path : Path - The root directory where the MNIST data is stored or will be downloaded. + The root directory where the MNIST data is stored. + sample_ids : list + A list of indices specifying which samples to load. train : bool, optional - If True, loads the training data, otherwise loads the test data. Default is False. + If True, load training data, otherwise load test data. Default is False. transform : callable, optional - A function/transform that takes in an image and returns a transformed version. Default is None. - download : bool, optional - If True, downloads the dataset if it is not already present in the specified data_path. Default is False. - + A function/transform to apply to the images. Default is None. Attributes ---------- data_path : Path The root directory where the MNIST data is stored. mnist_path : Path - The directory where the MNIST data files are stored. + The directory where the MNIST dataset is located within the root directory. + idx : list + A list of indices specifying which samples to load. train : bool - Indicates whether the training data or test data is being used. + Indicates whether to load training data or test data. transform : callable - A function/transform that takes in an image and returns a transformed version. - download : bool - Indicates whether the dataset should be downloaded if not present. + A function/transform to apply to the images. + num_classes : int + The number of classes in the dataset (0 to 3). images_path : Path - The path to the image file (training or test) based on the `train` flag. + The path to the image file (train or test) based on the `train` flag. labels_path : Path - The path to the label file (training or test) based on the `train` flag. - idx : numpy.ndarray - Indices of the labels that are less than 4. + The path to the label file (train or test) based on the `train` flag. length : int The number of samples in the dataset. - Methods ------- - _parse_labels(train) - Parses the labels from the label file. - _chech_is_downloaded() - Checks if the dataset is already downloaded. - _download_data() - Downloads and extracts the MNIST dataset. __len__() Returns the number of samples in the dataset. __getitem__(index) - Returns the image and label at the specified index. + Retrieves the image and label at the specified index. """ def __init__( self, - split: str, - split_percentage: float, data_path: Path, - download: bool = False, + sample_ids: list, + train: bool = False, transform=None, ): super().__init__() self.data_path = data_path self.mnist_path = self.data_path / "MNIST" - self.split = split - self.split_percentage = split_percentage + self.idx = sample_ids + self.train = train self.transform = transform - self.download = download self.num_classes = 4 - if self.split == "test": - train = False # used to decide whether to load training or test dataset - else: - train = True - - if not self.download and not self._chech_is_downloaded(): - raise ValueError( - "Data not found. Set --download-data=True to download the data." - ) - if self.download and not self._chech_is_downloaded(): - self._download_data() - self.images_path = self.mnist_path / ( - "train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte" + MNIST_SOURCE["train_images"][1] if train else MNIST_SOURCE["test_images"][1] ) self.labels_path = self.mnist_path / ( - "train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte" + MNIST_SOURCE["train_labels"][1] if train else MNIST_SOURCE["test_labels"][1] ) - labels = self._parse_labels() - - self.idx = np.where(labels < 4)[0] - - if self.split != "test": - generator1 = torch.Generator().manual_seed(42) - tr, val = random_split( - self.idx, - [1 - self.split_percentage, self.split_percentage], - generator=generator1, - ) - self.idx = tr if self.split == "train" else val - self.length = len(self.idx) - def _parse_labels(self): - with open(self.labels_path, "rb") as f: - data = np.frombuffer(f.read(), dtype=np.uint8, offset=8) - return data - - def _chech_is_downloaded(self): - if self.mnist_path.exists(): - required_files = [ - "train-images-idx3-ubyte", - "train-labels-idx1-ubyte", - "t10k-images-idx3-ubyte", - "t10k-labels-idx1-ubyte", - ] - if all([(self.mnist_path / file).exists() for file in required_files]): - print("MNIST Dataset already downloaded.") - return True - else: - return False - else: - self.mnist_path.mkdir(parents=True, exist_ok=True) - return False - - def _download_data(self): - urls = { - "train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", - "train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", - "test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", - "test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", - } - - for name, url in urls.items(): - file_path = os.path.join(self.mnist_path, url.split("/")[-1]) - if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading - urllib.request.urlretrieve(url, file_path) - with gzip.open(file_path, "rb") as f_in: - with open(file_path.replace(".gz", ""), "wb") as f_out: - f_out.write(f_in.read()) - os.remove(file_path) # Remove compressed file - def __len__(self): return self.length diff --git a/utils/dataloaders/usps_0_6.py b/utils/dataloaders/usps_0_6.py index 3673fa9..70286dc 100644 --- a/utils/dataloaders/usps_0_6.py +++ b/utils/dataloaders/usps_0_6.py @@ -4,19 +4,12 @@ This module contains the Dataset class for the USPS dataset with labels 0-6. """ -import bz2 -import hashlib from pathlib import Path -from tempfile import TemporaryDirectory -from urllib.request import urlretrieve import h5py as h5 import numpy as np from PIL import Image from torch.utils.data import Dataset -from torchvision import transforms - -from .datasources import USPS_SOURCE class USPSDataset0_6(Dataset): @@ -87,9 +80,9 @@ class USPSDataset0_6(Dataset): def __init__( self, data_path: Path, + sample_ids: list, train: bool = False, transform=None, - download: bool = False, ): super().__init__() @@ -97,168 +90,21 @@ def __init__( self.filepath = path / self.filename self.transform = transform self.mode = "train" if train else "test" + self.sample_ids = sample_ids - # Download the dataset if it does not exist in a temporary directory - # to automatically clean up the downloaded file - if download and not self._dataset_ok(): - url, _, checksum = USPS_SOURCE[self.mode] - - print(f"Downloading USPS dataset ({self.mode})...") - self.download(url, self.filepath, checksum, self.mode) - - self.idx = self._index() - - def _dataset_ok(self): - """Check if the dataset file exists and contains the required datasets.""" - - if not self.filepath.exists(): - print(f"Dataset file {self.filepath} does not exist.") - return False - - with h5.File(self.filepath, "r") as f: - for mode in ["train", "test"]: - if mode not in f: - print( - f"Dataset file {self.filepath} is missing the {mode} dataset." - ) - return False - - return True - - def download(self, url, filepath, checksum, mode): - """Download the USPS dataset, and save it as an HDF5 file. - - Args - ---- - url : str - URL to download the dataset from. - filepath : pathlib.Path - Path to save the downloaded dataset. - checksum : str - MD5 checksum of the downloaded file. - mode : str - Mode of the dataset, either train or test. - - Raises - ------ - ValueError - If the checksum of the downloaded file does not match the expected checksum. - """ - - def reporthook(blocknum, blocksize, totalsize): - """Report download progress.""" - denom = 1024 * 1024 - readsofar = blocknum * blocksize - if totalsize > 0: - percent = readsofar * 1e2 / totalsize - s = f"\r{int(percent):^3}% {readsofar / denom:.2f} of {totalsize / denom:.2f} MB" - print(s, end="", flush=True) - if readsofar >= totalsize: - print() - - # Download the dataset to a temporary file - with TemporaryDirectory() as tmpdir: - tmpdir = Path(tmpdir) - tmpfile = tmpdir / "usps.bz2" - urlretrieve( - url, - tmpfile, - reporthook=reporthook, - ) - - # For fun we can check the integrity of the downloaded file - if not self.check_integrity(tmpfile, checksum): - errmsg = ( - "The checksum of the downloaded file does " - "not match the expected checksum." - ) - raise ValueError(errmsg) - - # Load the dataset and save it as an HDF5 file - with bz2.open(tmpfile) as fp: - raw = [line.decode().split() for line in fp.readlines()] - - tmp = [[x.split(":")[-1] for x in data[1:]] for data in raw] - - imgs = np.asarray(tmp, dtype=np.float32) - imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8) - - targets = [int(d[0]) - 1 for d in raw] - - with h5.File(self.filepath, "a") as f: - f.create_dataset(f"{mode}/data", data=imgs, dtype=np.float32) - f.create_dataset(f"{mode}/target", data=targets, dtype=np.int32) - - @staticmethod - def check_integrity(filepath, checksum): - """Check the integrity of the USPS dataset file. - - Args - ---- - filepath : pathlib.Path - Path to the USPS dataset file. - checksum : str - MD5 checksum of the dataset file. - - Returns - ------- - bool - True if the checksum of the file matches the expected checksum, False otherwise - """ - - file_hash = hashlib.md5(filepath.read_bytes()).hexdigest() - - return checksum == file_hash - - def _index(self): - with h5.File(self.filepath, "r") as f: - labels = f[self.mode]["target"][:] - - # Get indices of samples with labels 0-6 - mask = labels <= 6 - idx = np.where(mask)[0] + def __len__(self): + return len(self.sample_ids) - return idx + def __getitem__(self, id): + index = self.sample_ids[id] - def _load_data(self, idx): with h5.File(self.filepath, "r") as f: - data = f[self.mode]["data"][idx].astype(np.uint8) - label = f[self.mode]["target"][idx] + data = f[self.mode]["data"][index].astype(np.uint8) + label = f[self.mode]["target"][index] - return data, label - - def __len__(self): - return len(self.idx) - - def __getitem__(self, idx): - data, target = self._load_data(self.idx[idx]) data = Image.fromarray(data, mode="L") - # one hot encode the target - target = np.eye(self.num_classes, dtype=np.float32)[target] - if self.transform: data = self.transform(data) - return data, target - - -if __name__ == "__main__": - # Example usage: - transform = transforms.Compose( - [ - transforms.Resize((16, 16)), - transforms.ToTensor(), - ] - ) - - dataset = USPSDataset0_6( - data_path="data", - train=True, - download=False, - transform=transform, - ) - print(len(dataset)) - data, target = dataset[0] - print(data.shape) - print(target) + return data, label diff --git a/utils/load_data.py b/utils/load_data.py index 9060013..1c3923d 100644 --- a/utils/load_data.py +++ b/utils/load_data.py @@ -1,11 +1,21 @@ -from torch.utils.data import Dataset +import numpy as np +from torch.utils.data import random_split -from .dataloaders import MNISTDataset0_3, USPSDataset0_6, USPSH5_Digit_7_9_Dataset +from .dataloaders import ( + Downloader, + MNISTDataset0_3, + USPSDataset0_6, + USPSH5_Digit_7_9_Dataset, +) -def load_data(dataset: str, *args, **kwargs) -> Dataset: +def filter_labels(samples: list, wanted_labels: list) -> list: + return list(filter(lambda x: x in wanted_labels, samples)) + + +def load_data(dataset: str, *args, **kwargs) -> tuple: """ - Load the dataset based on the dataset name. + load the dataset based on the dataset name. Args ---- @@ -18,8 +28,8 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset: Returns ------- - dataset : torch.utils.data.Dataset - Dataset object. + tuple + Tuple of train, validation and test datasets. Raises ------ @@ -28,17 +38,61 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset: Examples -------- - >>> from utils import load_data - >>> dataset = load_data("usps_0-6", data_path="data", train=True, download=True) - >>> len(dataset) - 5460 + >>> from utils import setup_data + >>> train, val, test = setup_data("usps_0-6", data_path="data", train=True, download=True) + >>> len(train), len(val), len(test) + (4914, 546, 1782) """ + downloader = Downloader() + data_dir = kwargs.get("data_dir") + transform = kwargs.get("transform") match dataset.lower(): case "usps_0-6": - return USPSDataset0_6(*args, **kwargs) - case "mnist_0-3": - return MNISTDataset0_3(*args, **kwargs) + dataset = USPSDataset0_6 + train_labels, test_labels = downloader.usps(data_dir=data_dir) + labels = np.arange(7) case "usps_7-9": - return USPSH5_Digit_7_9_Dataset(*args, **kwargs) + dataset = USPSH5_Digit_7_9_Dataset + train_labels, test_labels = downloader.usps(data_dir=data_dir) + labels = np.arange(7, 10) + case "mnist_0-3": + dataset = MNISTDataset0_3 + train_labels, test_labels = downloader.mnist(data_dir=data_dir) + labels = np.arange(4) case _: raise NotImplementedError(f"Dataset: {dataset} not implemented.") + + val_size = kwargs.get("val_size", 0.2) + + # Get the indices of the samples + train_indices = np.arange(len(train_labels)) + test_indices = np.arange(len(test_labels)) + + # Filter the labels to only get indices of the wanted labels + train_samples = train_indices[np.isin(train_labels, labels)] + test_samples = test_indices[np.isin(test_labels, labels)] + + train_samples, val_samples = random_split(train_samples, [1 - val_size, val_size]) + + train = dataset( + data_path=data_dir, + sample_ids=train_samples, + train=True, + transform=transform, + ) + + val = dataset( + data_path=data_dir, + sample_ids=val_samples, + train=True, + transform=transform, + ) + + test = dataset( + data_path=data_dir, + sample_ids=test_samples, + train=False, + transform=transform, + ) + + return train, val, test