From a2606e129da49c3cdae329c0854d95e8b7282133 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Sat, 8 Feb 2025 19:45:53 +0100 Subject: [PATCH 01/12] Make separate downloader class that handles everything related to downloading --- utils/dataloaders/__init__.py | 8 +- utils/dataloaders/download.py | 141 ++++++++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+), 1 deletion(-) create mode 100644 utils/dataloaders/download.py diff --git a/utils/dataloaders/__init__.py b/utils/dataloaders/__init__.py index 1f506e6..0d0bcb6 100644 --- a/utils/dataloaders/__init__.py +++ b/utils/dataloaders/__init__.py @@ -1,5 +1,11 @@ -__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset", "MNISTDataset0_3"] +__all__ = [ + "USPSDataset0_6", + "USPSH5_Digit_7_9_Dataset", + "MNISTDataset0_3", + "Downloader", +] +from .download import Downloader from .mnist_0_3 import MNISTDataset0_3 from .usps_0_6 import USPSDataset0_6 from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset diff --git a/utils/dataloaders/download.py b/utils/dataloaders/download.py new file mode 100644 index 0000000..7a7fa13 --- /dev/null +++ b/utils/dataloaders/download.py @@ -0,0 +1,141 @@ +import bz2 +import hashlib +from pathlib import Path +from tempfile import TemporaryDirectory +from urllib.request import urlretrieve + +import h5py as h5 +import numpy as np +from PIL import Image + +from .datasources import USPS_SOURCE + + +class Downloader: + """ + Class to download and load the USPS dataset. + + Methods + ------- + mnist(data_dir: Path) -> tuple[np.ndarray, np.ndarray] + Download the MNIST dataset and save it as an HDF5 file to `data_dir`. + svhn(data_dir: Path) -> tuple[np.ndarray, np.ndarray] + Download the SVHN dataset and save it as an HDF5 file to `data_dir`. + usps(data_dir: Path) -> tuple[np.ndarray, np.ndarray] + Download the USPS dataset and save it as an HDF5 file to `data_dir`. + + Raises + ------ + NotImplementedError + If the download method is not implemented for the dataset. + + Examples + -------- + >>> from pathlib import Path + >>> from utils import Downloader + >>> dir = Path('tmp') + >>> dir.mkdir(exist_ok=True) + >>> train, test = Downloader().usps(dir) + """ + + def mnist(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]: + raise NotImplementedError("MNIST download not implemented yet") + + def svhn(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]: + raise NotImplementedError("SVHN download not implemented yet") + + def usps(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]: + """ + Download the USPS dataset and save it as an HDF5 file to `data_dir/usps.h5`. + """ + + def already_downloaded(path): + if not path.exists() or not path.is_file(): + return False + + with h5.File(path, "r") as f: + return "train" in f and "test" in f + + filename = data_dir / "usps.h5" + + if already_downloaded(filename): + with h5.File(filename, "r") as f: + return f["train"]["target"][:], f["test"]["target"][:] + + url_train, _, train_md5 = USPS_SOURCE["train"] + url_test, _, test_md5 = USPS_SOURCE["test"] + + # Using temporary directory ensures temporary files are deleted after use + with TemporaryDirectory() as tmp_dir: + train_path = Path(tmp_dir) / "train" + test_path = Path(tmp_dir) / "test" + + # Download the dataset and report the progress + urlretrieve(url_train, train_path, reporthook=self.__reporthook) + self.__check_integrity(train_path, train_md5) + train_targets = self.__extract_usps(train_path, filename, "train") + + urlretrieve(url_test, test_path, reporthook=self.__reporthook) + self.__check_integrity(test_path, test_md5) + test_targets = self.__extract_usps(test_path, filename, "test") + + return train_targets, test_targets + + def __extract_usps(self, src: Path, dest: Path, mode: str): + # Load the dataset and save it as an HDF5 file + with bz2.open(src) as fp: + raw = [line.decode().split() for line in fp.readlines()] + + tmp = [[x.split(":")[-1] for x in data[1:]] for data in raw] + + imgs = np.asarray(tmp, dtype=np.float32) + imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8) + + targets = [int(d[0]) - 1 for d in raw] + + with h5.File(dest, "a") as f: + f.create_dataset(f"{mode}/data", data=imgs, dtype=np.float32) + f.create_dataset(f"{mode}/target", data=targets, dtype=np.int32) + + return targets + + @staticmethod + def __reporthook(blocknum, blocksize, totalsize): + """ + Use this function to report download progress + for the urllib.request.urlretrieve function. + """ + + denom = 1024 * 1024 + readsofar = blocknum * blocksize + + if totalsize > 0: + percent = readsofar * 1e2 / totalsize + s = f"\r{int(percent):^3}% {readsofar / denom:.2f} of {totalsize / denom:.2f} MB" + print(s, end="", flush=True) + if readsofar >= totalsize: + print() + + @staticmethod + def __check_integrity(filepath, checksum): + """Check the integrity of the USPS dataset file. + + Args + ---- + filepath : pathlib.Path + Path to the USPS dataset file. + checksum : str + MD5 checksum of the dataset file. + + Returns + ------- + bool + True if the checksum of the file matches the expected checksum, False otherwise + """ + + file_hash = hashlib.md5(filepath.read_bytes()).hexdigest() + + if not checksum == file_hash: + raise ValueError( + f"File integrity check failed. Expected {checksum}, got {file_hash}" + ) From a58e495119de7e32778a61cbcb029804d0929411 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Sat, 8 Feb 2025 19:51:49 +0100 Subject: [PATCH 02/12] downloader handles wheter to download data or not, so remove option --- utils/arg_parser.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/utils/arg_parser.py b/utils/arg_parser.py index 5ced4d0..4d001d3 100644 --- a/utils/arg_parser.py +++ b/utils/arg_parser.py @@ -33,12 +33,6 @@ def get_args(): help="Whether model should be saved or not.", ) - parser.add_argument( - "--download-data", - action="store_true", - help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.", - ) - # Data/Model specific values parser.add_argument( "--modelname", @@ -55,7 +49,7 @@ def get_args(): help="Which dataset to train the model on.", ) parser.add_argument( - "--validation_split_percentage", + "--val_size", type=float, default=0.2, help="Percentage of training dataset to be used as validation dataset - must be within (0,1).", From a9e2cade27ee7d378e3cc31310829bc9e90b820b Mon Sep 17 00:00:00 2001 From: salomaestro Date: Sat, 8 Feb 2025 19:53:01 +0100 Subject: [PATCH 03/12] Remove downloading logic from USPS dataset --- utils/dataloaders/usps_0_6.py | 171 ++-------------------------------- 1 file changed, 9 insertions(+), 162 deletions(-) diff --git a/utils/dataloaders/usps_0_6.py b/utils/dataloaders/usps_0_6.py index 3673fa9..85b3114 100644 --- a/utils/dataloaders/usps_0_6.py +++ b/utils/dataloaders/usps_0_6.py @@ -4,11 +4,7 @@ This module contains the Dataset class for the USPS dataset with labels 0-6. """ -import bz2 -import hashlib from pathlib import Path -from tempfile import TemporaryDirectory -from urllib.request import urlretrieve import h5py as h5 import numpy as np @@ -16,8 +12,6 @@ from torch.utils.data import Dataset from torchvision import transforms -from .datasources import USPS_SOURCE - class USPSDataset0_6(Dataset): """ @@ -87,9 +81,9 @@ class USPSDataset0_6(Dataset): def __init__( self, data_path: Path, + sample_ids: list, train: bool = False, transform=None, - download: bool = False, ): super().__init__() @@ -97,168 +91,21 @@ def __init__( self.filepath = path / self.filename self.transform = transform self.mode = "train" if train else "test" + self.sample_ids = sample_ids - # Download the dataset if it does not exist in a temporary directory - # to automatically clean up the downloaded file - if download and not self._dataset_ok(): - url, _, checksum = USPS_SOURCE[self.mode] - - print(f"Downloading USPS dataset ({self.mode})...") - self.download(url, self.filepath, checksum, self.mode) - - self.idx = self._index() - - def _dataset_ok(self): - """Check if the dataset file exists and contains the required datasets.""" - - if not self.filepath.exists(): - print(f"Dataset file {self.filepath} does not exist.") - return False - - with h5.File(self.filepath, "r") as f: - for mode in ["train", "test"]: - if mode not in f: - print( - f"Dataset file {self.filepath} is missing the {mode} dataset." - ) - return False - - return True - - def download(self, url, filepath, checksum, mode): - """Download the USPS dataset, and save it as an HDF5 file. - - Args - ---- - url : str - URL to download the dataset from. - filepath : pathlib.Path - Path to save the downloaded dataset. - checksum : str - MD5 checksum of the downloaded file. - mode : str - Mode of the dataset, either train or test. - - Raises - ------ - ValueError - If the checksum of the downloaded file does not match the expected checksum. - """ - - def reporthook(blocknum, blocksize, totalsize): - """Report download progress.""" - denom = 1024 * 1024 - readsofar = blocknum * blocksize - if totalsize > 0: - percent = readsofar * 1e2 / totalsize - s = f"\r{int(percent):^3}% {readsofar / denom:.2f} of {totalsize / denom:.2f} MB" - print(s, end="", flush=True) - if readsofar >= totalsize: - print() - - # Download the dataset to a temporary file - with TemporaryDirectory() as tmpdir: - tmpdir = Path(tmpdir) - tmpfile = tmpdir / "usps.bz2" - urlretrieve( - url, - tmpfile, - reporthook=reporthook, - ) - - # For fun we can check the integrity of the downloaded file - if not self.check_integrity(tmpfile, checksum): - errmsg = ( - "The checksum of the downloaded file does " - "not match the expected checksum." - ) - raise ValueError(errmsg) - - # Load the dataset and save it as an HDF5 file - with bz2.open(tmpfile) as fp: - raw = [line.decode().split() for line in fp.readlines()] - - tmp = [[x.split(":")[-1] for x in data[1:]] for data in raw] - - imgs = np.asarray(tmp, dtype=np.float32) - imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8) - - targets = [int(d[0]) - 1 for d in raw] - - with h5.File(self.filepath, "a") as f: - f.create_dataset(f"{mode}/data", data=imgs, dtype=np.float32) - f.create_dataset(f"{mode}/target", data=targets, dtype=np.int32) - - @staticmethod - def check_integrity(filepath, checksum): - """Check the integrity of the USPS dataset file. - - Args - ---- - filepath : pathlib.Path - Path to the USPS dataset file. - checksum : str - MD5 checksum of the dataset file. - - Returns - ------- - bool - True if the checksum of the file matches the expected checksum, False otherwise - """ - - file_hash = hashlib.md5(filepath.read_bytes()).hexdigest() - - return checksum == file_hash - - def _index(self): - with h5.File(self.filepath, "r") as f: - labels = f[self.mode]["target"][:] - - # Get indices of samples with labels 0-6 - mask = labels <= 6 - idx = np.where(mask)[0] + def __len__(self): + return len(self.sample_ids) - return idx + def __getitem__(self, id): + index = self.sample_ids[id] - def _load_data(self, idx): with h5.File(self.filepath, "r") as f: - data = f[self.mode]["data"][idx].astype(np.uint8) - label = f[self.mode]["target"][idx] + data = f[self.mode]["data"][index].astype(np.uint8) + label = f[self.mode]["target"][index] - return data, label - - def __len__(self): - return len(self.idx) - - def __getitem__(self, idx): - data, target = self._load_data(self.idx[idx]) data = Image.fromarray(data, mode="L") - # one hot encode the target - target = np.eye(self.num_classes, dtype=np.float32)[target] - if self.transform: data = self.transform(data) - return data, target - - -if __name__ == "__main__": - # Example usage: - transform = transforms.Compose( - [ - transforms.Resize((16, 16)), - transforms.ToTensor(), - ] - ) - - dataset = USPSDataset0_6( - data_path="data", - train=True, - download=False, - transform=transform, - ) - print(len(dataset)) - data, target = dataset[0] - print(data.shape) - print(target) + return data, label From 34539b313466ded471d3a55c2a010ed9441cba42 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Sat, 8 Feb 2025 19:53:38 +0100 Subject: [PATCH 04/12] `load_data` now splits the data, downloads data and returns all splits --- main.py | 32 ++++---------------- utils/load_data.py | 74 +++++++++++++++++++++++++++++++++++++--------- 2 files changed, 65 insertions(+), 41 deletions(-) diff --git a/main.py b/main.py index 9582e9d..05483bf 100644 --- a/main.py +++ b/main.py @@ -3,11 +3,11 @@ import numpy as np import torch as th import torch.nn as nn +import wandb from torch.utils.data import DataLoader from torchvision import transforms from tqdm import tqdm -import wandb from utils import MetricWrapper, createfolders, get_args, load_data, load_model @@ -32,42 +32,20 @@ def main(): device = args.device if args.dataset.lower() in ["usps_0-6", "uspsh5_7_9"]: - augmentations = transforms.Compose( + transform = transforms.Compose( [ transforms.Resize((16, 16)), transforms.ToTensor(), ] ) else: - augmentations = transforms.Compose([transforms.ToTensor()]) + transform = transforms.Compose([transforms.ToTensor()]) - # Dataset - assert ( - args.validation_split_percentage < 1.0 and args.validation_split_percentage > 0 - ), "Validation split should be in interval (0,1)" - traindata = load_data( - args.dataset, - split="train", - split_percentage=args.validation_split_percentage, - data_path=args.datafolder, - download=args.download_data, - transform=augmentations, - ) - validata = load_data( - args.dataset, - split="validation", - split_percentage=args.validation_split_percentage, - data_path=args.datafolder, - download=args.download_data, - transform=augmentations, - ) - testdata = load_data( + traindata, validata, testdata = load_data( args.dataset, - split="test", - split_percentage=args.validation_split_percentage, data_path=args.datafolder, + transform=transform, download=args.download_data, - transform=augmentations, ) metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes) diff --git a/utils/load_data.py b/utils/load_data.py index 9060013..bf49ad6 100644 --- a/utils/load_data.py +++ b/utils/load_data.py @@ -1,11 +1,20 @@ -from torch.utils.data import Dataset +from torch.utils.data import Dataset, random_split -from .dataloaders import MNISTDataset0_3, USPSDataset0_6, USPSH5_Digit_7_9_Dataset +from .dataloaders import ( + Downloader, + MNISTDataset0_3, + USPSDataset0_6, + USPSH5_Digit_7_9_Dataset, +) -def load_data(dataset: str, *args, **kwargs) -> Dataset: +def filter_labels(samples: list, wanted_labels: list) -> list: + return list(filter(lambda x: x in wanted_labels, samples)) + + +def load_data(dataset: str, *args, **kwargs) -> tuple: """ - Load the dataset based on the dataset name. + load the dataset based on the dataset name. Args ---- @@ -18,8 +27,8 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset: Returns ------- - dataset : torch.utils.data.Dataset - Dataset object. + tuple + Tuple of train, validation and test datasets. Raises ------ @@ -28,17 +37,54 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset: Examples -------- - >>> from utils import load_data - >>> dataset = load_data("usps_0-6", data_path="data", train=True, download=True) - >>> len(dataset) - 5460 + >>> from utils import setup_data + >>> train, val, test = setup_data("usps_0-6", data_path="data", train=True, download=True) + >>> len(train), len(val), len(test) + (4914, 546, 1782) """ + match dataset.lower(): case "usps_0-6": - return USPSDataset0_6(*args, **kwargs) - case "mnist_0-3": - return MNISTDataset0_3(*args, **kwargs) + dataset = USPSDataset0_6 + train_samples, test_samples = Downloader.usps(*args, **kwargs) + labels = range(7) case "usps_7-9": - return USPSH5_Digit_7_9_Dataset(*args, **kwargs) + dataset = USPSH5_Digit_7_9_Dataset + train_samples, test_samples = Downloader.usps(*args, **kwargs) + labels = range(7, 10) + case "mnist_0-3": + dataset = MNISTDataset0_3 + train_samples, test_samples = Downloader.mnist(*args, **kwargs) + labels = range(4) case _: raise NotImplementedError(f"Dataset: {dataset} not implemented.") + + val_size = kwargs.get("val_size", 0.1) + + train_samples = filter_labels(train_samples, labels) + test_samples = filter_labels(test_samples, labels) + + train_samples, val_samples = random_split(train_samples, [1 - val_size, val_size]) + + train = dataset( + *args, + sample_ids=train_samples, + train=True, + **kwargs, + ) + + val = dataset( + *args, + sample_ids=val_samples, + train=True, + **kwargs, + ) + + test = dataset( + *args, + sample_ids=test_samples, + train=False, + **kwargs, + ) + + return train, val, test From 6c6f7b57bf0f56ee8f75222dea11c1f10ccaf81a Mon Sep 17 00:00:00 2001 From: salomaestro Date: Sat, 8 Feb 2025 20:18:45 +0100 Subject: [PATCH 05/12] Made a whoopsie --- utils/dataloaders/download.py | 1 - 1 file changed, 1 deletion(-) diff --git a/utils/dataloaders/download.py b/utils/dataloaders/download.py index 7a7fa13..c99f657 100644 --- a/utils/dataloaders/download.py +++ b/utils/dataloaders/download.py @@ -6,7 +6,6 @@ import h5py as h5 import numpy as np -from PIL import Image from .datasources import USPS_SOURCE From 20faa24e148d4dd76208f17337e3aa550e1eb3ff Mon Sep 17 00:00:00 2001 From: salomaestro Date: Sat, 8 Feb 2025 20:18:54 +0100 Subject: [PATCH 06/12] Add the size thing --- main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/main.py b/main.py index 05483bf..ab78830 100644 --- a/main.py +++ b/main.py @@ -46,6 +46,7 @@ def main(): data_path=args.datafolder, transform=transform, download=args.download_data, + val_size=args.val_size, ) metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes) From d7526bf66243dda5f0325c3650c4c6504366c6a1 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Sat, 8 Feb 2025 20:19:27 +0100 Subject: [PATCH 07/12] Actually send the indices, not labels to datasets --- utils/load_data.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/utils/load_data.py b/utils/load_data.py index bf49ad6..aa968f6 100644 --- a/utils/load_data.py +++ b/utils/load_data.py @@ -1,3 +1,4 @@ +import numpy as np from torch.utils.data import Dataset, random_split from .dataloaders import ( @@ -46,23 +47,28 @@ def load_data(dataset: str, *args, **kwargs) -> tuple: match dataset.lower(): case "usps_0-6": dataset = USPSDataset0_6 - train_samples, test_samples = Downloader.usps(*args, **kwargs) - labels = range(7) + train_labels, test_labels = Downloader.usps(*args, **kwargs) + labels = np.arange(7) case "usps_7-9": dataset = USPSH5_Digit_7_9_Dataset - train_samples, test_samples = Downloader.usps(*args, **kwargs) - labels = range(7, 10) + train_labels, test_labels = Downloader.usps(*args, **kwargs) + labels = np.arange(7, 10) case "mnist_0-3": dataset = MNISTDataset0_3 - train_samples, test_samples = Downloader.mnist(*args, **kwargs) - labels = range(4) + train_labels, test_labels = Downloader.mnist(*args, **kwargs) + labels = np.arange(4) case _: raise NotImplementedError(f"Dataset: {dataset} not implemented.") - val_size = kwargs.get("val_size", 0.1) + val_size = kwargs.get("val_size", 0.2) - train_samples = filter_labels(train_samples, labels) - test_samples = filter_labels(test_samples, labels) + # Get the indices of the samples + train_indices = np.arange(len(train_labels)) + test_indices = np.arange(len(test_labels)) + + # Filter the labels to only get indices of the wanted labels + train_samples = train_indices[np.isin(train_labels, labels)] + test_samples = test_indices[np.isin(test_labels, labels)] train_samples, val_samples = random_split(train_samples, [1 - val_size, val_size]) From bd35ae639283e6f77eeb86d5dfc73b76a5c03f65 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Sat, 8 Feb 2025 20:24:57 +0100 Subject: [PATCH 08/12] Format --- tests/test_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index efc5412..4dd3fa8 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -32,4 +32,3 @@ def test_jan_model(image_shape, num_classes): y = model(x) assert y.shape == (n, num_classes), f"Shape: {y.shape}" - From 0f3206454123a33ed2ac25cb3935d06b7985014f Mon Sep 17 00:00:00 2001 From: salomaestro Date: Sat, 8 Feb 2025 20:27:08 +0100 Subject: [PATCH 09/12] More formatting --- main.py | 2 -- utils/dataloaders/usps_0_6.py | 1 - utils/load_data.py | 2 +- 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/main.py b/main.py index ab78830..69383db 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,3 @@ -from pathlib import Path - import numpy as np import torch as th import torch.nn as nn diff --git a/utils/dataloaders/usps_0_6.py b/utils/dataloaders/usps_0_6.py index 85b3114..70286dc 100644 --- a/utils/dataloaders/usps_0_6.py +++ b/utils/dataloaders/usps_0_6.py @@ -10,7 +10,6 @@ import numpy as np from PIL import Image from torch.utils.data import Dataset -from torchvision import transforms class USPSDataset0_6(Dataset): diff --git a/utils/load_data.py b/utils/load_data.py index aa968f6..d2c4621 100644 --- a/utils/load_data.py +++ b/utils/load_data.py @@ -1,5 +1,5 @@ import numpy as np -from torch.utils.data import Dataset, random_split +from torch.utils.data import random_split from .dataloaders import ( Downloader, From ad159404c94cf52ef8ff2e3108bf54e6f1dc90b9 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Sat, 8 Feb 2025 20:34:40 +0100 Subject: [PATCH 10/12] Adjust test to comply with new functionality --- tests/test_dataloaders.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index 9f58ae4..32634d6 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -17,18 +17,25 @@ def test_uspsdataset0_6(): # Create a h5 file with h5py.File(tf, "w") as f: + targets = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0]) + indices = np.arange(len(targets)) # Populate the file with data f["train/data"] = np.random.rand(10, 16 * 16) - f["train/target"] = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0]) + f["train/target"] = targets trans = transforms.Compose( [ - transforms.Resize((16, 16)), # At least for USPS + transforms.Resize((16, 16)), transforms.ToTensor(), ] ) - dataset = USPSDataset0_6(data_path=tempdir, train=True, transform=trans) + dataset = USPSDataset0_6( + data_path=tempdir, + sample_ids=indices, + train=True, + transform=trans, + ) assert len(dataset) == 10 data, target = dataset[0] assert data.shape == (1, 16, 16) - assert all(target == np.array([0, 0, 0, 0, 0, 0, 1])) + assert target == 6 From 15c99ea46fb2622b1b4b9c9ef480b956fac280a3 Mon Sep 17 00:00:00 2001 From: Jan Zavadil Date: Mon, 10 Feb 2025 13:44:05 +0100 Subject: [PATCH 11/12] added MNIST downloader, adjusted minor thinks for the code to run --- main.py | 3 +- utils/dataloaders/datasources.py | 19 +++++ utils/dataloaders/download.py | 46 ++++++++++- utils/dataloaders/mnist_0_3.py | 132 ++++++------------------------- utils/load_data.py | 22 +++--- 5 files changed, 101 insertions(+), 121 deletions(-) diff --git a/main.py b/main.py index 69383db..8b4fd93 100644 --- a/main.py +++ b/main.py @@ -41,9 +41,8 @@ def main(): traindata, validata, testdata = load_data( args.dataset, - data_path=args.datafolder, + data_dir=args.datafolder, transform=transform, - download=args.download_data, val_size=args.val_size, ) diff --git a/utils/dataloaders/datasources.py b/utils/dataloaders/datasources.py index f0d2e01..9fb8276 100644 --- a/utils/dataloaders/datasources.py +++ b/utils/dataloaders/datasources.py @@ -17,3 +17,22 @@ "8ea070ee2aca1ac39742fdd1ef5ed118", ], } + +MNIST_SOURCE = { + "train_images": ["https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", + "train-images-idx3-ubyte", + None + ], + "train_labels": ["https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", + "train-labels-idx1-ubyte", + None + ], + "test_images": ["https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", + "t10k-images-idx3-ubyte", + None + ], + "test_labels": ["https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", + "t10k-labels-idx1-ubyte", + None + ], +} diff --git a/utils/dataloaders/download.py b/utils/dataloaders/download.py index c99f657..7cbd5db 100644 --- a/utils/dataloaders/download.py +++ b/utils/dataloaders/download.py @@ -1,5 +1,7 @@ import bz2 import hashlib +import os +import gzip from pathlib import Path from tempfile import TemporaryDirectory from urllib.request import urlretrieve @@ -7,7 +9,7 @@ import h5py as h5 import numpy as np -from .datasources import USPS_SOURCE +from .datasources import USPS_SOURCE, MNIST_SOURCE class Downloader: @@ -38,7 +40,47 @@ class Downloader: """ def mnist(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]: - raise NotImplementedError("MNIST download not implemented yet") + def _chech_is_downloaded(path: Path) -> bool: + path = path / "MNIST" + if path.exists(): + required_files = [MNIST_SOURCE[key][1] for key in MNIST_SOURCE.keys()] + if all([(path / file).exists() for file in required_files]): + print("MNIST Dataset already downloaded.") + return True + else: + return False + else: + path.mkdir(parents=True, exist_ok=True) + return False + + def _download_data(path: Path) -> None: + urls = {key: MNIST_SOURCE[key][0] for key in MNIST_SOURCE.keys()} + + for name, url in urls.items(): + file_path = os.path.join(path, url.split("/")[-1]) + if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading + urlretrieve(url, file_path) + with gzip.open(file_path, "rb") as f_in: + with open(file_path.replace(".gz", ""), "wb") as f_out: + f_out.write(f_in.read()) + os.remove(file_path) # Remove compressed file + + def _get_labels(path: Path) -> np.ndarray: + with open(path, "rb") as f: + data = np.frombuffer(f.read(), dtype=np.uint8, offset=8) + return data + + if not _chech_is_downloaded(data_dir): + _download_data(data_dir) + + train_labels_path = data_dir / "MNIST" / MNIST_SOURCE["train_labels"][1] + test_labels_path = data_dir / "MNIST" / MNIST_SOURCE["test_labels"][1] + + train_labels = _get_labels(train_labels_path) + test_labels = _get_labels(test_labels_path) + + return train_labels, test_labels + def svhn(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]: raise NotImplementedError("SVHN download not implemented yet") diff --git a/utils/dataloaders/mnist_0_3.py b/utils/dataloaders/mnist_0_3.py index 1f8124d..fa96960 100644 --- a/utils/dataloaders/mnist_0_3.py +++ b/utils/dataloaders/mnist_0_3.py @@ -1,154 +1,72 @@ -import gzip -import os -import urllib.request from pathlib import Path import numpy as np -import torch -from torch.utils.data import Dataset, random_split +from torch.utils.data import Dataset +from .datasources import MNIST_SOURCE class MNISTDataset0_3(Dataset): """ - A custom dataset class for loading MNIST data, specifically for digits 0 through 3. - + A custom Dataset class for loading a subset of the MNIST dataset containing digits 0 to 3. Parameters ---------- data_path : Path - The root directory where the MNIST data is stored or will be downloaded. + The root directory where the MNIST data is stored. + sample_ids : list + A list of indices specifying which samples to load. train : bool, optional - If True, loads the training data, otherwise loads the test data. Default is False. + If True, load training data, otherwise load test data. Default is False. transform : callable, optional - A function/transform that takes in an image and returns a transformed version. Default is None. - download : bool, optional - If True, downloads the dataset if it is not already present in the specified data_path. Default is False. - + A function/transform to apply to the images. Default is None. Attributes ---------- data_path : Path The root directory where the MNIST data is stored. mnist_path : Path - The directory where the MNIST data files are stored. + The directory where the MNIST dataset is located within the root directory. + idx : list + A list of indices specifying which samples to load. train : bool - Indicates whether the training data or test data is being used. + Indicates whether to load training data or test data. transform : callable - A function/transform that takes in an image and returns a transformed version. - download : bool - Indicates whether the dataset should be downloaded if not present. + A function/transform to apply to the images. + num_classes : int + The number of classes in the dataset (0 to 3). images_path : Path - The path to the image file (training or test) based on the `train` flag. + The path to the image file (train or test) based on the `train` flag. labels_path : Path - The path to the label file (training or test) based on the `train` flag. - idx : numpy.ndarray - Indices of the labels that are less than 4. + The path to the label file (train or test) based on the `train` flag. length : int The number of samples in the dataset. - Methods ------- - _parse_labels(train) - Parses the labels from the label file. - _chech_is_downloaded() - Checks if the dataset is already downloaded. - _download_data() - Downloads and extracts the MNIST dataset. __len__() Returns the number of samples in the dataset. __getitem__(index) - Returns the image and label at the specified index. + Retrieves the image and label at the specified index. """ def __init__( self, - split: str, - split_percentage: float, data_path: Path, - download: bool = False, + sample_ids: list, + train: bool = False, transform=None, ): super().__init__() self.data_path = data_path self.mnist_path = self.data_path / "MNIST" - self.split = split - self.split_percentage = split_percentage + self.idx = sample_ids + self.train = train self.transform = transform - self.download = download self.num_classes = 4 - if self.split == "test": - train = False # used to decide whether to load training or test dataset - else: - train = True - - if not self.download and not self._chech_is_downloaded(): - raise ValueError( - "Data not found. Set --download-data=True to download the data." - ) - if self.download and not self._chech_is_downloaded(): - self._download_data() - - self.images_path = self.mnist_path / ( - "train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte" - ) - self.labels_path = self.mnist_path / ( - "train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte" - ) - - labels = self._parse_labels() - - self.idx = np.where(labels < 4)[0] - - if self.split != "test": - generator1 = torch.Generator().manual_seed(42) - tr, val = random_split( - self.idx, - [1 - self.split_percentage, self.split_percentage], - generator=generator1, - ) - self.idx = tr if self.split == "train" else val + self.images_path = self.mnist_path / (MNIST_SOURCE["train_images"][1] if train else MNIST_SOURCE["test_images"][1]) + self.labels_path = self.mnist_path / (MNIST_SOURCE["train_labels"][1] if train else MNIST_SOURCE["test_labels"][1]) self.length = len(self.idx) - - def _parse_labels(self): - with open(self.labels_path, "rb") as f: - data = np.frombuffer(f.read(), dtype=np.uint8, offset=8) - return data - - def _chech_is_downloaded(self): - if self.mnist_path.exists(): - required_files = [ - "train-images-idx3-ubyte", - "train-labels-idx1-ubyte", - "t10k-images-idx3-ubyte", - "t10k-labels-idx1-ubyte", - ] - if all([(self.mnist_path / file).exists() for file in required_files]): - print("MNIST Dataset already downloaded.") - return True - else: - return False - else: - self.mnist_path.mkdir(parents=True, exist_ok=True) - return False - - def _download_data(self): - urls = { - "train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", - "train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", - "test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", - "test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", - } - - for name, url in urls.items(): - file_path = os.path.join(self.mnist_path, url.split("/")[-1]) - if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading - urllib.request.urlretrieve(url, file_path) - with gzip.open(file_path, "rb") as f_in: - with open(file_path.replace(".gz", ""), "wb") as f_out: - f_out.write(f_in.read()) - os.remove(file_path) # Remove compressed file - + def __len__(self): return self.length diff --git a/utils/load_data.py b/utils/load_data.py index d2c4621..1c3923d 100644 --- a/utils/load_data.py +++ b/utils/load_data.py @@ -43,19 +43,21 @@ def load_data(dataset: str, *args, **kwargs) -> tuple: >>> len(train), len(val), len(test) (4914, 546, 1782) """ - + downloader = Downloader() + data_dir = kwargs.get("data_dir") + transform = kwargs.get("transform") match dataset.lower(): case "usps_0-6": dataset = USPSDataset0_6 - train_labels, test_labels = Downloader.usps(*args, **kwargs) + train_labels, test_labels = downloader.usps(data_dir=data_dir) labels = np.arange(7) case "usps_7-9": dataset = USPSH5_Digit_7_9_Dataset - train_labels, test_labels = Downloader.usps(*args, **kwargs) + train_labels, test_labels = downloader.usps(data_dir=data_dir) labels = np.arange(7, 10) case "mnist_0-3": dataset = MNISTDataset0_3 - train_labels, test_labels = Downloader.mnist(*args, **kwargs) + train_labels, test_labels = downloader.mnist(data_dir=data_dir) labels = np.arange(4) case _: raise NotImplementedError(f"Dataset: {dataset} not implemented.") @@ -73,24 +75,24 @@ def load_data(dataset: str, *args, **kwargs) -> tuple: train_samples, val_samples = random_split(train_samples, [1 - val_size, val_size]) train = dataset( - *args, + data_path=data_dir, sample_ids=train_samples, train=True, - **kwargs, + transform=transform, ) val = dataset( - *args, + data_path=data_dir, sample_ids=val_samples, train=True, - **kwargs, + transform=transform, ) test = dataset( - *args, + data_path=data_dir, sample_ids=test_samples, train=False, - **kwargs, + transform=transform, ) return train, val, test From 601caca8e6fb1962fae31829c0509cadbfd91606 Mon Sep 17 00:00:00 2001 From: Jan Zavadil Date: Mon, 10 Feb 2025 13:46:24 +0100 Subject: [PATCH 12/12] ruffed, isorted --- utils/dataloaders/datasources.py | 28 ++++++++++++++++------------ utils/dataloaders/download.py | 21 +++++++++++---------- utils/dataloaders/mnist_0_3.py | 11 ++++++++--- 3 files changed, 35 insertions(+), 25 deletions(-) diff --git a/utils/dataloaders/datasources.py b/utils/dataloaders/datasources.py index 9fb8276..936d32e 100644 --- a/utils/dataloaders/datasources.py +++ b/utils/dataloaders/datasources.py @@ -19,20 +19,24 @@ } MNIST_SOURCE = { - "train_images": ["https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", - "train-images-idx3-ubyte", - None + "train_images": [ + "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", + "train-images-idx3-ubyte", + None, ], - "train_labels": ["https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", - "train-labels-idx1-ubyte", - None + "train_labels": [ + "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", + "train-labels-idx1-ubyte", + None, ], - "test_images": ["https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", - "t10k-images-idx3-ubyte", - None + "test_images": [ + "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", + "t10k-images-idx3-ubyte", + None, ], - "test_labels": ["https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", - "t10k-labels-idx1-ubyte", - None + "test_labels": [ + "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", + "t10k-labels-idx1-ubyte", + None, ], } diff --git a/utils/dataloaders/download.py b/utils/dataloaders/download.py index 7cbd5db..9f667a3 100644 --- a/utils/dataloaders/download.py +++ b/utils/dataloaders/download.py @@ -1,7 +1,7 @@ import bz2 +import gzip import hashlib import os -import gzip from pathlib import Path from tempfile import TemporaryDirectory from urllib.request import urlretrieve @@ -9,7 +9,7 @@ import h5py as h5 import numpy as np -from .datasources import USPS_SOURCE, MNIST_SOURCE +from .datasources import MNIST_SOURCE, USPS_SOURCE class Downloader: @@ -52,35 +52,36 @@ def _chech_is_downloaded(path: Path) -> bool: else: path.mkdir(parents=True, exist_ok=True) return False - + def _download_data(path: Path) -> None: urls = {key: MNIST_SOURCE[key][0] for key in MNIST_SOURCE.keys()} for name, url in urls.items(): file_path = os.path.join(path, url.split("/")[-1]) - if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading + if not os.path.exists( + file_path.replace(".gz", "") + ): # Avoid re-downloading urlretrieve(url, file_path) with gzip.open(file_path, "rb") as f_in: with open(file_path.replace(".gz", ""), "wb") as f_out: f_out.write(f_in.read()) os.remove(file_path) # Remove compressed file - + def _get_labels(path: Path) -> np.ndarray: with open(path, "rb") as f: data = np.frombuffer(f.read(), dtype=np.uint8, offset=8) return data - + if not _chech_is_downloaded(data_dir): _download_data(data_dir) - + train_labels_path = data_dir / "MNIST" / MNIST_SOURCE["train_labels"][1] test_labels_path = data_dir / "MNIST" / MNIST_SOURCE["test_labels"][1] - + train_labels = _get_labels(train_labels_path) test_labels = _get_labels(test_labels_path) - + return train_labels, test_labels - def svhn(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]: raise NotImplementedError("SVHN download not implemented yet") diff --git a/utils/dataloaders/mnist_0_3.py b/utils/dataloaders/mnist_0_3.py index fa96960..52a5a28 100644 --- a/utils/dataloaders/mnist_0_3.py +++ b/utils/dataloaders/mnist_0_3.py @@ -2,6 +2,7 @@ import numpy as np from torch.utils.data import Dataset + from .datasources import MNIST_SOURCE @@ -62,11 +63,15 @@ def __init__( self.transform = transform self.num_classes = 4 - self.images_path = self.mnist_path / (MNIST_SOURCE["train_images"][1] if train else MNIST_SOURCE["test_images"][1]) - self.labels_path = self.mnist_path / (MNIST_SOURCE["train_labels"][1] if train else MNIST_SOURCE["test_labels"][1]) + self.images_path = self.mnist_path / ( + MNIST_SOURCE["train_images"][1] if train else MNIST_SOURCE["test_images"][1] + ) + self.labels_path = self.mnist_path / ( + MNIST_SOURCE["train_labels"][1] if train else MNIST_SOURCE["test_labels"][1] + ) self.length = len(self.idx) - + def __len__(self): return self.length