diff --git a/environment.yml b/environment.yml index c003c73..9a977a3 100644 --- a/environment.yml +++ b/environment.yml @@ -9,7 +9,8 @@ dependencies: - sphinx-autobuild - sphinx-rtd-theme - pip - - h5py + - h5py==3.12.1 + - hdf5==1.14.4 - black - isort - jupyterlab @@ -20,6 +21,7 @@ dependencies: - scalene - tqdm - scipy + - wandb - pip: - torch - torchvision diff --git a/main.py b/main.py index d9f233f..65ecc86 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,7 @@ import wandb from utils import MetricWrapper, createfolders, get_args, load_data, load_model +from wandb_api import WANDB_API def main(): @@ -29,33 +30,38 @@ def main(): device = args.device - if args.dataset.lower() in ["usps_0-6", "uspsh5_7_9"]: - augmentations = transforms.Compose( + if args.dataset.lower() in ["usps_0-6", "usps_7-9"]: + transform = transforms.Compose( [ transforms.Resize((16, 16)), transforms.ToTensor(), ] ) else: - augmentations = transforms.Compose([transforms.ToTensor()]) + transform = transforms.Compose([transforms.ToTensor()]) - # Dataset - traindata = load_data( + traindata, validata, testdata = load_data( args.dataset, - train=True, - data_path=args.datafolder, - download=args.download_data, - transform=augmentations, - ) - validata = load_data( - args.dataset, - train=False, - data_path=args.datafolder, - download=args.download_data, - transform=augmentations, + data_dir=args.datafolder, + transform=transform, + val_size=args.val_size, ) - metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes) + train_metrics = MetricWrapper( + *args.metric, + num_classes=traindata.num_classes, + macro_averaging=args.macro_averaging, + ) + val_metrics = MetricWrapper( + *args.metric, + num_classes=traindata.num_classes, + macro_averaging=args.macro_averaging, + ) + test_metrics = MetricWrapper( + *args.metric, + num_classes=traindata.num_classes, + macro_averaging=args.macro_averaging, + ) # Find the shape of the data, if is 2D, add a channel dimension data_shape = traindata[0][0].shape @@ -80,6 +86,9 @@ def main(): valiloader = DataLoader( validata, batch_size=args.batchsize, shuffle=False, pin_memory=True ) + testloader = DataLoader( + testdata, batch_size=args.batchsize, shuffle=False, pin_memory=True + ) criterion = nn.CrossEntropyLoss() optimizer = th.optim.Adam(model.parameters(), lr=args.learning_rate) @@ -104,22 +113,22 @@ def main(): optimizer.step() optimizer.zero_grad(set_to_none=True) - metrics(y, logits) + train_metrics(y, logits) break - print(metrics.accumulate()) + print(train_metrics.accumulate()) print("Dry run completed successfully.") exit() # wandb.login(key=WANDB_API) wandb.init( - entity="ColabCode-org", - # entity="FYS-8805 Exam", - project="Test", - tags=[args.modelname, args.dataset] - ) + entity="ColabCode", + # entity="FYS-8805 Exam", + project="Jan", + tags=[args.modelname, args.dataset], + ) wandb.watch(model) - exit() + for epoch in range(args.epoch): # Training loop start trainingloss = [] @@ -135,33 +144,49 @@ def main(): optimizer.zero_grad(set_to_none=True) trainingloss.append(loss.item()) - metrics(y, logits) - - wandb.log(metrics.accumulate(str_prefix="Train ")) - metrics.reset() + train_metrics(y, logits) - evalloss = [] - # Eval loop start + valloss = [] + # Validation loop start model.eval() with th.no_grad(): for x, y in tqdm(valiloader, desc="Validation"): x, y = x.to(device), y.to(device) logits = model.forward(x) loss = criterion(logits, y) - evalloss.append(loss.item()) - - metrics(y, logits) + valloss.append(loss.item()) - wandb.log(metrics.accumulate(str_prefix="Evaluation ")) - metrics.reset() + val_metrics(y, logits) wandb.log( { "Epoch": epoch, "Train loss": np.mean(trainingloss), - "Evaluation Loss": np.mean(evalloss), + "Validation loss": np.mean(valloss), } + | train_metrics.accumulate(str_prefix="Train ") + | val_metrics.accumulate(str_prefix="Validation ") ) + train_metrics.reset() + val_metrics.reset() + + testloss = [] + model.eval() + with th.no_grad(): + for x, y in tqdm(testloader, desc="Testing"): + x, y = x.to(device), y.to(device) + logits = model.forward(x) + loss = criterion(logits, y) + testloss.append(loss.item()) + + preds = th.argmax(logits, dim=1) + test_metrics(y, preds) + + wandb.log( + {"Epoch": 1, "Test loss": np.mean(testloss)} + | test_metrics.accumulate(str_prefix="Test ") + ) + test_metrics.reset() if __name__ == "__main__": diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index 9f58ae4..32634d6 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -17,18 +17,25 @@ def test_uspsdataset0_6(): # Create a h5 file with h5py.File(tf, "w") as f: + targets = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0]) + indices = np.arange(len(targets)) # Populate the file with data f["train/data"] = np.random.rand(10, 16 * 16) - f["train/target"] = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0]) + f["train/target"] = targets trans = transforms.Compose( [ - transforms.Resize((16, 16)), # At least for USPS + transforms.Resize((16, 16)), transforms.ToTensor(), ] ) - dataset = USPSDataset0_6(data_path=tempdir, train=True, transform=trans) + dataset = USPSDataset0_6( + data_path=tempdir, + sample_ids=indices, + train=True, + transform=trans, + ) assert len(dataset) == 10 data, target = dataset[0] assert data.shape == (1, 16, 16) - assert all(target == np.array([0, 0, 0, 0, 0, 0, 1])) + assert target == 6 diff --git a/tests/test_metrics.py b/tests/test_metrics.py index d6da0ab..97d651a 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -26,7 +26,7 @@ def test_f1score(): target = torch.tensor([0, 1, 0, 2]) - f1_metric.update(preds, target) + f1_metric(preds, target) assert f1_metric.tp.sum().item() > 0, "Expected some true positives." assert f1_metric.fp.sum().item() > 0, "Expected some false positives." assert f1_metric.fn.sum().item() > 0, "Expected some false negatives." diff --git a/tests/test_models.py b/tests/test_models.py index efc5412..4dd3fa8 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -32,4 +32,3 @@ def test_jan_model(image_shape, num_classes): y = model(x) assert y.shape == (n, num_classes), f"Shape: {y.shape}" - diff --git a/utils/arg_parser.py b/utils/arg_parser.py index 079eae4..618e8d2 100644 --- a/utils/arg_parser.py +++ b/utils/arg_parser.py @@ -33,13 +33,6 @@ def get_args(): help="Whether model should be saved or not.", ) - parser.add_argument( - "--download-data", - type=bool, - default=False, - help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.", - ) - # Data/Model specific values parser.add_argument( "--modelname", @@ -61,7 +54,12 @@ def get_args(): choices=["svhn", "usps_0-6", "usps_7-9", "mnist_0-3", "mnist_4-9"], help="Which dataset to train the model on.", ) - + parser.add_argument( + "--val_size", + type=float, + default=0.2, + help="Percentage of training dataset to be used as validation dataset - must be within (0,1).", + ) parser.add_argument( "--metric", type=str, @@ -70,20 +68,10 @@ def get_args(): nargs="+", help="Which metric to use for evaluation", ) - - parser.add_argument( - '--imagesize', - type=int, - default=28, - help='Imagesize' - ) - parser.add_argument( - '--nr_channels', - type=int, - default=1, - choices=[1,3], - help='Number of image channels' + "--macro_averaging", + action="store_true", + help="If the flag is included, the metrics will be calculated using macro averaging.", ) # Training specific values @@ -115,7 +103,7 @@ def get_args(): parser.add_argument( "--dry_run", action="store_true", - help="If true, the code will not run the training loop.", + help="If the flag is included, the code will not run the training loop.", ) args = parser.parse_args() diff --git a/utils/dataloaders/__init__.py b/utils/dataloaders/__init__.py index a5688ba..5f14335 100644 --- a/utils/dataloaders/__init__.py +++ b/utils/dataloaders/__init__.py @@ -1,6 +1,13 @@ -__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset", "MNISTDataset0_3", "SVHNDataset"] +__all__ = [ + "USPSDataset0_6", + "USPSH5_Digit_7_9_Dataset", + "MNISTDataset0_3", + "Downloader", + "SVHNDataset", +] +from .download import Downloader from .mnist_0_3 import MNISTDataset0_3 +from .svhn import SVHNDataset from .usps_0_6 import USPSDataset0_6 from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset -from .svhn import SVHNDataset diff --git a/utils/dataloaders/datasources.py b/utils/dataloaders/datasources.py index f0d2e01..936d32e 100644 --- a/utils/dataloaders/datasources.py +++ b/utils/dataloaders/datasources.py @@ -17,3 +17,26 @@ "8ea070ee2aca1ac39742fdd1ef5ed118", ], } + +MNIST_SOURCE = { + "train_images": [ + "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", + "train-images-idx3-ubyte", + None, + ], + "train_labels": [ + "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", + "train-labels-idx1-ubyte", + None, + ], + "test_images": [ + "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", + "t10k-images-idx3-ubyte", + None, + ], + "test_labels": [ + "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", + "t10k-labels-idx1-ubyte", + None, + ], +} diff --git a/utils/dataloaders/download.py b/utils/dataloaders/download.py new file mode 100644 index 0000000..d99eff2 --- /dev/null +++ b/utils/dataloaders/download.py @@ -0,0 +1,183 @@ +import bz2 +import gzip +import hashlib +import os +from pathlib import Path +from tempfile import TemporaryDirectory +from urllib.request import urlretrieve + +import h5py as h5 +import numpy as np + +from .datasources import MNIST_SOURCE, USPS_SOURCE + + +class Downloader: + """ + Class to download and load the USPS dataset. + + Methods + ------- + mnist(data_dir: Path) -> tuple[np.ndarray, np.ndarray] + Download the MNIST dataset and save it as an HDF5 file to `data_dir`. + svhn(data_dir: Path) -> tuple[np.ndarray, np.ndarray] + Download the SVHN dataset and save it as an HDF5 file to `data_dir`. + usps(data_dir: Path) -> tuple[np.ndarray, np.ndarray] + Download the USPS dataset and save it as an HDF5 file to `data_dir`. + + Raises + ------ + NotImplementedError + If the download method is not implemented for the dataset. + + Examples + -------- + >>> from pathlib import Path + >>> from utils import Downloader + >>> dir = Path('tmp') + >>> dir.mkdir(exist_ok=True) + >>> train, test = Downloader().usps(dir) + """ + + def mnist(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]: + def _chech_is_downloaded(path: Path) -> bool: + path = path / "MNIST" + if path.exists(): + required_files = [MNIST_SOURCE[key][1] for key in MNIST_SOURCE.keys()] + if all([(path / file).exists() for file in required_files]): + print("MNIST Dataset already downloaded.") + return True + else: + return False + else: + path.mkdir(parents=True, exist_ok=True) + return False + + def _download_data(path: Path) -> None: + urls = {key: MNIST_SOURCE[key][0] for key in MNIST_SOURCE.keys()} + + for name, url in urls.items(): + file_path = os.path.join(path, url.split("/")[-1]) + if not os.path.exists( + file_path.replace(".gz", "") + ): # Avoid re-downloading + urlretrieve(url, file_path, reporthook=self.__reporthook) + with gzip.open(file_path, "rb") as f_in: + with open(file_path.replace(".gz", ""), "wb") as f_out: + f_out.write(f_in.read()) + os.remove(file_path) # Remove compressed file + + def _get_labels(path: Path) -> np.ndarray: + with open(path, "rb") as f: + data = np.frombuffer(f.read(), dtype=np.uint8, offset=8) + return data + + if not _chech_is_downloaded(data_dir): + _download_data(data_dir) + + train_labels_path = data_dir / "MNIST" / MNIST_SOURCE["train_labels"][1] + test_labels_path = data_dir / "MNIST" / MNIST_SOURCE["test_labels"][1] + + train_labels = _get_labels(train_labels_path) + test_labels = _get_labels(test_labels_path) + + return train_labels, test_labels + + def svhn(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]: + raise NotImplementedError("SVHN download not implemented yet") + + def usps(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]: + """ + Download the USPS dataset and save it as an HDF5 file to `data_dir/usps.h5`. + """ + + def already_downloaded(path): + if not path.exists() or not path.is_file(): + return False + + with h5.File(path, "r") as f: + return "train" in f and "test" in f + + filename = data_dir / "usps.h5" + + if already_downloaded(filename): + with h5.File(filename, "r") as f: + return f["train"]["target"][:], f["test"]["target"][:] + + url_train, _, train_md5 = USPS_SOURCE["train"] + url_test, _, test_md5 = USPS_SOURCE["test"] + + # Using temporary directory ensures temporary files are deleted after use + with TemporaryDirectory() as tmp_dir: + train_path = Path(tmp_dir) / "train" + test_path = Path(tmp_dir) / "test" + + # Download the dataset and report the progress + urlretrieve(url_train, train_path, reporthook=self.__reporthook) + self.__check_integrity(train_path, train_md5) + train_targets = self.__extract_usps(train_path, filename, "train") + + urlretrieve(url_test, test_path, reporthook=self.__reporthook) + self.__check_integrity(test_path, test_md5) + test_targets = self.__extract_usps(test_path, filename, "test") + + return train_targets, test_targets + + def __extract_usps(self, src: Path, dest: Path, mode: str): + # Load the dataset and save it as an HDF5 file + with bz2.open(src) as fp: + raw = [line.decode().split() for line in fp.readlines()] + + tmp = [[x.split(":")[-1] for x in data[1:]] for data in raw] + + imgs = np.asarray(tmp, dtype=np.float32) + imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8) + + targets = [int(d[0]) - 1 for d in raw] + + with h5.File(dest, "a") as f: + f.create_dataset(f"{mode}/data", data=imgs, dtype=np.float32) + f.create_dataset(f"{mode}/target", data=targets, dtype=np.int32) + + return targets + + @staticmethod + def __reporthook(blocknum, blocksize, totalsize): + """ + Use this function to report download progress + for the urllib.request.urlretrieve function. + """ + + denom = 1024 * 1024 + readsofar = blocknum * blocksize + + if totalsize > 0: + percent = readsofar * 1e2 / totalsize + s = f"\r{int(percent):^3}% {readsofar / denom:.2f} of {totalsize / denom:.2f} MB" + print(s, end="", flush=True) + if readsofar >= totalsize: + print() + + @staticmethod + def __check_integrity(filepath, checksum): + """Check the integrity of the USPS dataset file. + + Args + ---- + filepath : pathlib.Path + Path to the USPS dataset file. + checksum : str + MD5 checksum of the dataset file. + + Returns + ------- + bool + True if the checksum of the file matches the expected checksum, False otherwise + """ + + file_hash = hashlib.md5(filepath.read_bytes()).hexdigest() + + if not checksum == file_hash: + raise ValueError( + f"File integrity check failed. Expected {checksum}, got {file_hash}" + ) diff --git a/utils/dataloaders/mnist_0_3.py b/utils/dataloaders/mnist_0_3.py index bad3bd9..52a5a28 100644 --- a/utils/dataloaders/mnist_0_3.py +++ b/utils/dataloaders/mnist_0_3.py @@ -1,137 +1,77 @@ -import gzip -import os -import urllib.request from pathlib import Path import numpy as np from torch.utils.data import Dataset +from .datasources import MNIST_SOURCE + class MNISTDataset0_3(Dataset): """ - A custom dataset class for loading MNIST data, specifically for digits 0 through 3. - + A custom Dataset class for loading a subset of the MNIST dataset containing digits 0 to 3. Parameters ---------- data_path : Path - The root directory where the MNIST data is stored or will be downloaded. + The root directory where the MNIST data is stored. + sample_ids : list + A list of indices specifying which samples to load. train : bool, optional - If True, loads the training data, otherwise loads the test data. Default is False. + If True, load training data, otherwise load test data. Default is False. transform : callable, optional - A function/transform that takes in an image and returns a transformed version. Default is None. - download : bool, optional - If True, downloads the dataset if it is not already present in the specified data_path. Default is False. - + A function/transform to apply to the images. Default is None. Attributes ---------- data_path : Path The root directory where the MNIST data is stored. mnist_path : Path - The directory where the MNIST data files are stored. + The directory where the MNIST dataset is located within the root directory. + idx : list + A list of indices specifying which samples to load. train : bool - Indicates whether the training data or test data is being used. + Indicates whether to load training data or test data. transform : callable - A function/transform that takes in an image and returns a transformed version. - download : bool - Indicates whether the dataset should be downloaded if not present. + A function/transform to apply to the images. + num_classes : int + The number of classes in the dataset (0 to 3). images_path : Path - The path to the image file (training or test) based on the `train` flag. + The path to the image file (train or test) based on the `train` flag. labels_path : Path - The path to the label file (training or test) based on the `train` flag. - idx : numpy.ndarray - Indices of the labels that are less than 4. + The path to the label file (train or test) based on the `train` flag. length : int The number of samples in the dataset. - Methods ------- - _parse_labels(train) - Parses the labels from the label file. - _chech_is_downloaded() - Checks if the dataset is already downloaded. - _download_data() - Downloads and extracts the MNIST dataset. __len__() Returns the number of samples in the dataset. __getitem__(index) - Returns the image and label at the specified index. + Retrieves the image and label at the specified index. """ def __init__( self, data_path: Path, + sample_ids: list, train: bool = False, transform=None, - download: bool = False, ): super().__init__() self.data_path = data_path self.mnist_path = self.data_path / "MNIST" + self.idx = sample_ids self.train = train self.transform = transform - self.download = download self.num_classes = 4 - if not self.download and not self._chech_is_downloaded(): - raise ValueError( - "Data not found. Set --download-data=True to download the data." - ) - if self.download and not self._chech_is_downloaded(): - self._download_data() - self.images_path = self.mnist_path / ( - "train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte" + MNIST_SOURCE["train_images"][1] if train else MNIST_SOURCE["test_images"][1] ) self.labels_path = self.mnist_path / ( - "train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte" + MNIST_SOURCE["train_labels"][1] if train else MNIST_SOURCE["test_labels"][1] ) - labels = self._parse_labels(train=self.train) - - self.idx = np.where(labels < 4)[0] - self.length = len(self.idx) - def _parse_labels(self, train): - with open(self.labels_path, "rb") as f: - data = np.frombuffer(f.read(), dtype=np.uint8, offset=8) - return data - - def _chech_is_downloaded(self): - if self.mnist_path.exists(): - required_files = [ - "train-images-idx3-ubyte", - "train-labels-idx1-ubyte", - "t10k-images-idx3-ubyte", - "t10k-labels-idx1-ubyte", - ] - if all([(self.mnist_path / file).exists() for file in required_files]): - print("MNIST Dataset already downloaded.") - return True - else: - return False - else: - self.mnist_path.mkdir(parents=True, exist_ok=True) - return False - - def _download_data(self): - urls = { - "train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", - "train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", - "test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", - "test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", - } - - for name, url in urls.items(): - file_path = os.path.join(self.mnist_path, url.split("/")[-1]) - if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading - urllib.request.urlretrieve(url, file_path) - with gzip.open(file_path, "rb") as f_in: - with open(file_path.replace(".gz", ""), "wb") as f_out: - f_out.write(f_in.read()) - os.remove(file_path) # Remove compressed file - def __len__(self): return self.length diff --git a/utils/dataloaders/svhn.py b/utils/dataloaders/svhn.py index e71de73..f0bd18c 100644 --- a/utils/dataloaders/svhn.py +++ b/utils/dataloaders/svhn.py @@ -1,4 +1,5 @@ import os + import numpy as np from scipy.io import loadmat from torch.utils.data import Dataset @@ -7,13 +8,13 @@ class SVHNDataset(Dataset): def __init__( - self, - data_path: str, + self, + data_path: str, train: bool, - transform=None, - download:bool=True, - nr_channels=3 - ): + transform=None, + download: bool = True, + nr_channels=3, + ): """ Initializes the SVHNDataset object. Args: @@ -26,8 +27,8 @@ def __init__( """ super().__init__() # assert split == "train" or split == "test" - self.split = 'train' if train else 'test' - + self.split = "train" if train else "test" + if download: self._download_data(data_path) @@ -37,7 +38,7 @@ def __init__( self.images = data["X"].transpose(3, 1, 0, 2) self.labels = data["y"].flatten() self.labels[self.labels == 10] = 0 - + self.nr_channels = nr_channels self.transforms = transform @@ -49,7 +50,7 @@ def _download_data(self, path: str): split (str): The dataset split to download, either 'train' or 'test'. """ print(f"Downloading SVHN data into {path}") - + SVHN(path, split=self.split, download=True) def __len__(self): @@ -72,7 +73,7 @@ def __getitem__(self, index): if self.nr_channels == 1: img = np.mean(img, axis=2, keepdims=True) - + if self.transforms is not None: img = self.transforms(img) diff --git a/utils/dataloaders/usps_0_6.py b/utils/dataloaders/usps_0_6.py index 3673fa9..70286dc 100644 --- a/utils/dataloaders/usps_0_6.py +++ b/utils/dataloaders/usps_0_6.py @@ -4,19 +4,12 @@ This module contains the Dataset class for the USPS dataset with labels 0-6. """ -import bz2 -import hashlib from pathlib import Path -from tempfile import TemporaryDirectory -from urllib.request import urlretrieve import h5py as h5 import numpy as np from PIL import Image from torch.utils.data import Dataset -from torchvision import transforms - -from .datasources import USPS_SOURCE class USPSDataset0_6(Dataset): @@ -87,9 +80,9 @@ class USPSDataset0_6(Dataset): def __init__( self, data_path: Path, + sample_ids: list, train: bool = False, transform=None, - download: bool = False, ): super().__init__() @@ -97,168 +90,21 @@ def __init__( self.filepath = path / self.filename self.transform = transform self.mode = "train" if train else "test" + self.sample_ids = sample_ids - # Download the dataset if it does not exist in a temporary directory - # to automatically clean up the downloaded file - if download and not self._dataset_ok(): - url, _, checksum = USPS_SOURCE[self.mode] - - print(f"Downloading USPS dataset ({self.mode})...") - self.download(url, self.filepath, checksum, self.mode) - - self.idx = self._index() - - def _dataset_ok(self): - """Check if the dataset file exists and contains the required datasets.""" - - if not self.filepath.exists(): - print(f"Dataset file {self.filepath} does not exist.") - return False - - with h5.File(self.filepath, "r") as f: - for mode in ["train", "test"]: - if mode not in f: - print( - f"Dataset file {self.filepath} is missing the {mode} dataset." - ) - return False - - return True - - def download(self, url, filepath, checksum, mode): - """Download the USPS dataset, and save it as an HDF5 file. - - Args - ---- - url : str - URL to download the dataset from. - filepath : pathlib.Path - Path to save the downloaded dataset. - checksum : str - MD5 checksum of the downloaded file. - mode : str - Mode of the dataset, either train or test. - - Raises - ------ - ValueError - If the checksum of the downloaded file does not match the expected checksum. - """ - - def reporthook(blocknum, blocksize, totalsize): - """Report download progress.""" - denom = 1024 * 1024 - readsofar = blocknum * blocksize - if totalsize > 0: - percent = readsofar * 1e2 / totalsize - s = f"\r{int(percent):^3}% {readsofar / denom:.2f} of {totalsize / denom:.2f} MB" - print(s, end="", flush=True) - if readsofar >= totalsize: - print() - - # Download the dataset to a temporary file - with TemporaryDirectory() as tmpdir: - tmpdir = Path(tmpdir) - tmpfile = tmpdir / "usps.bz2" - urlretrieve( - url, - tmpfile, - reporthook=reporthook, - ) - - # For fun we can check the integrity of the downloaded file - if not self.check_integrity(tmpfile, checksum): - errmsg = ( - "The checksum of the downloaded file does " - "not match the expected checksum." - ) - raise ValueError(errmsg) - - # Load the dataset and save it as an HDF5 file - with bz2.open(tmpfile) as fp: - raw = [line.decode().split() for line in fp.readlines()] - - tmp = [[x.split(":")[-1] for x in data[1:]] for data in raw] - - imgs = np.asarray(tmp, dtype=np.float32) - imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8) - - targets = [int(d[0]) - 1 for d in raw] - - with h5.File(self.filepath, "a") as f: - f.create_dataset(f"{mode}/data", data=imgs, dtype=np.float32) - f.create_dataset(f"{mode}/target", data=targets, dtype=np.int32) - - @staticmethod - def check_integrity(filepath, checksum): - """Check the integrity of the USPS dataset file. - - Args - ---- - filepath : pathlib.Path - Path to the USPS dataset file. - checksum : str - MD5 checksum of the dataset file. - - Returns - ------- - bool - True if the checksum of the file matches the expected checksum, False otherwise - """ - - file_hash = hashlib.md5(filepath.read_bytes()).hexdigest() - - return checksum == file_hash - - def _index(self): - with h5.File(self.filepath, "r") as f: - labels = f[self.mode]["target"][:] - - # Get indices of samples with labels 0-6 - mask = labels <= 6 - idx = np.where(mask)[0] + def __len__(self): + return len(self.sample_ids) - return idx + def __getitem__(self, id): + index = self.sample_ids[id] - def _load_data(self, idx): with h5.File(self.filepath, "r") as f: - data = f[self.mode]["data"][idx].astype(np.uint8) - label = f[self.mode]["target"][idx] + data = f[self.mode]["data"][index].astype(np.uint8) + label = f[self.mode]["target"][index] - return data, label - - def __len__(self): - return len(self.idx) - - def __getitem__(self, idx): - data, target = self._load_data(self.idx[idx]) data = Image.fromarray(data, mode="L") - # one hot encode the target - target = np.eye(self.num_classes, dtype=np.float32)[target] - if self.transform: data = self.transform(data) - return data, target - - -if __name__ == "__main__": - # Example usage: - transform = transforms.Compose( - [ - transforms.Resize((16, 16)), - transforms.ToTensor(), - ] - ) - - dataset = USPSDataset0_6( - data_path="data", - train=True, - download=False, - transform=transform, - ) - print(len(dataset)) - data, target = dataset[0] - print(data.shape) - print(target) + return data, label diff --git a/utils/load_data.py b/utils/load_data.py index 8ebbec5..a3bfe42 100644 --- a/utils/load_data.py +++ b/utils/load_data.py @@ -1,11 +1,22 @@ -from torch.utils.data import Dataset +import numpy as np +from torch.utils.data import random_split -from .dataloaders import MNISTDataset0_3, USPSDataset0_6, USPSH5_Digit_7_9_Dataset, SVHNDataset +from .dataloaders import ( + Downloader, + MNISTDataset0_3, + SVHNDataset, + USPSDataset0_6, + USPSH5_Digit_7_9_Dataset, +) -def load_data(dataset: str, *args, **kwargs) -> Dataset: +def filter_labels(samples: list, wanted_labels: list) -> list: + return list(filter(lambda x: x in wanted_labels, samples)) + + +def load_data(dataset: str, *args, **kwargs) -> tuple: """ - Load the dataset based on the dataset name. + load the dataset based on the dataset name. Args ---- @@ -18,8 +29,8 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset: Returns ------- - dataset : torch.utils.data.Dataset - Dataset object. + tuple + Tuple of train, validation and test datasets. Raises ------ @@ -28,21 +39,67 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset: Examples -------- - >>> from utils import load_data - >>> dataset = load_data("usps_0-6", data_path="data", train=True, download=True) - >>> len(dataset) - 5460 + >>> from utils import setup_data + >>> train, val, test = setup_data("usps_0-6", data_path="data", train=True, download=True) + >>> len(train), len(val), len(test) + (4914, 546, 1782) """ + downloader = Downloader() + data_dir = kwargs.get("data_dir") + transform = kwargs.get("transform") match dataset.lower(): case "usps_0-6": - return USPSDataset0_6(*args, **kwargs) - case "mnist_0-3": - return MNISTDataset0_3(*args, **kwargs) + dataset = USPSDataset0_6 + train_labels, test_labels = downloader.usps(data_dir=data_dir) + labels = np.arange(7) case "usps_7-9": - return USPSH5_Digit_7_9_Dataset(*args, **kwargs) + dataset = USPSH5_Digit_7_9_Dataset + train_labels, test_labels = downloader.usps(data_dir=data_dir) + labels = np.arange(7, 10) + case "mnist_0-3": + dataset = MNISTDataset0_3 + train_labels, test_labels = downloader.mnist(data_dir=data_dir) + labels = np.arange(4) case "svhn": - return SVHNDataset(*args, **kwargs) + dataset = SVHNDataset + train_labels, test_labels = downloader.svhn(data_dir=data_dir) + labels = np.arange(10) case "mnist_4-9": raise NotImplementedError("MNIST 4-9 dataset not yet implemented.") case _: raise NotImplementedError(f"Dataset: {dataset} not implemented.") + + val_size = kwargs.get("val_size", 0.2) + + # Get the indices of the samples + train_indices = np.arange(len(train_labels)) + test_indices = np.arange(len(test_labels)) + + # Filter the labels to only get indices of the wanted labels + train_samples = train_indices[np.isin(train_labels, labels)] + test_samples = test_indices[np.isin(test_labels, labels)] + + train_samples, val_samples = random_split(train_samples, [1 - val_size, val_size]) + + train = dataset( + data_path=data_dir, + sample_ids=train_samples, + train=True, + transform=transform, + ) + + val = dataset( + data_path=data_dir, + sample_ids=val_samples, + train=True, + transform=transform, + ) + + test = dataset( + data_path=data_dir, + sample_ids=test_samples, + train=False, + transform=transform, + ) + + return train, val, test diff --git a/utils/load_metric.py b/utils/load_metric.py index a321845..1698b66 100644 --- a/utils/load_metric.py +++ b/utils/load_metric.py @@ -45,10 +45,11 @@ class MetricWrapper(nn.Module): {'entropy': [], 'f1': [], 'precision': []} """ - def __init__(self, *metrics, num_classes): + def __init__(self, *metrics, num_classes, macro_averaging=False): super().__init__() self.metrics = {} self.num_classes = num_classes + self.macro_averaging = macro_averaging for metric in metrics: self.metrics[metric] = self._get_metric(metric) @@ -74,13 +75,13 @@ def _get_metric(self, key): case "entropy": return EntropyPrediction(num_classes=self.num_classes) case "f1": - return F1Score(num_classes=self.num_classes) + return F1Score(num_classes=self.num_classes, macro_averaging=self.macro_averaging) case "recall": - return Recall(num_classes=self.num_classes) + return Recall(num_classes=self.num_classes, macro_averaging=self.macro_averaging) case "precision": - return Precision(num_classes=self.num_classes) + return Precision(num_classes=self.num_classes, macro_averaging=self.macro_averaging) case "accuracy": - return Accuracy(num_classes=self.num_classes) + return Accuracy(num_classes=self.num_classes, macro_averaging=self.macro_averaging) case _: raise ValueError(f"Metric {key} not supported") diff --git a/utils/metrics/F1.py b/utils/metrics/F1.py index 1e0e795..0c7a5e2 100644 --- a/utils/metrics/F1.py +++ b/utils/metrics/F1.py @@ -4,29 +4,39 @@ class F1Score(nn.Module): """ - F1 Score implementation with direct averaging inside the compute method. + F1 Score implementation with support for both macro and micro averaging. + + This class computes the F1 score during training using either macro or micro averaging. + The F1 score is calculated based on the true positives (TP), false positives (FP), + and false negatives (FN) for each class. Parameters ---------- num_classes : int - Number of classes. + The number of classes in the classification task. + + macro_averaging : bool, optional, default=False + If True, computes the macro-averaged F1 score. If False, computes the micro-averaged F1 score. Attributes ---------- num_classes : int - The number of classes. + The number of classes in the classification task. tp : torch.Tensor - Tensor for True Positives (TP) for each class. + Tensor storing the count of True Positives (TP) for each class. fp : torch.Tensor - Tensor for False Positives (FP) for each class. + Tensor storing the count of False Positives (FP) for each class. fn : torch.Tensor - Tensor for False Negatives (FN) for each class. + Tensor storing the count of False Negatives (FN) for each class. + + macro_averaging : bool + A flag indicating whether to compute the macro-averaged F1 score or not. """ - def __init__(self, num_classes): + def __init__(self, num_classes, macro_averaging=False): """ Initializes the F1Score object, setting up the necessary state variables. @@ -35,28 +45,81 @@ def __init__(self, num_classes): num_classes : int The number of classes in the classification task. + macro_averaging : bool, optional, default=False + If True, computes the macro-averaged F1 score. If False, computes the micro-averaged F1 score. """ - super().__init__() self.num_classes = num_classes + self.macro_averaging = macro_averaging - # Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN) + # Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN) self.tp = torch.zeros(num_classes) self.fp = torch.zeros(num_classes) self.fn = torch.zeros(num_classes) - def update(self, preds, target): + def _micro_F1(self): + """ + Compute the Micro F1 score by aggregating TP, FP, and FN across all classes. + + Micro F1 score is calculated globally by considering all predictions together, regardless of class. + + Returns + ------- + torch.Tensor + The micro-averaged F1 score. """ - Update the variables with predictions and true labels. + tp = torch.sum(self.tp) + fp = torch.sum(self.fp) + fn = torch.sum(self.fn) + + precision = tp / (tp + fp + 1e-8) # Avoid division by zero + recall = tp / (tp + fn + 1e-8) # Avoid division by zero + + f1 = 2 * precision * recall / (precision + recall + 1e-8) # Avoid division by zero + return f1 + + def _macro_F1(self): + """ + Compute the Macro F1 score by calculating the F1 score per class and averaging. + + Macro F1 score is calculated as the average of per-class F1 scores. This approach treats all classes equally, + regardless of their frequency. + + Returns + ------- + torch.Tensor + The macro-averaged F1 score. + """ + precision_per_class = self.tp / (self.tp + self.fp + 1e-8) # Avoid division by zero + recall_per_class = self.tp / (self.tp + self.fn + 1e-8) # Avoid division by zero + f1_per_class = 2 * precision_per_class * recall_per_class / ( + precision_per_class + recall_per_class + 1e-8) # Avoid division by zero + + # Take the average of F1 scores across all classes + f1_score = torch.mean(f1_per_class) + return f1_score + + def forward(self, preds, target): + """ + Update the True Positives, False Positives, and False Negatives, and compute the F1 score. + + This method computes the F1 score based on the predictions and true labels. It can compute either the + macro-averaged or micro-averaged F1 score, depending on the `macro_averaging` flag. Parameters ---------- preds : torch.Tensor - Predicted logits (shape: [batch_size, num_classes]). + Predicted logits or class indices (shape: [batch_size, num_classes]). + These logits are typically the output of a softmax or sigmoid activation. target : torch.Tensor - True labels (shape: [batch_size]). + True labels (shape: [batch_size]), where each element is an integer representing the true class. + + Returns + ------- + torch.Tensor + The computed F1 score (either micro or macro, based on `macro_averaging`). """ preds = torch.argmax(preds, dim=1) @@ -66,21 +129,11 @@ def update(self, preds, target): self.fp[i] += torch.sum((preds == i) & (target != i)).float() self.fn[i] += torch.sum((preds != i) & (target == i)).float() - def compute(self): - """ - Compute the F1 score. + if self.macro_averaging: + # Calculate Macro F1 score + f1_score = self._macro_F1() + else: + # Calculate Micro F1 score + f1_score = self._micro_F1() - Returns - ------- - torch.Tensor - The computed F1 score. - """ - - # Compute F1 score based on the specified averaging method - f1_score = ( - 2 - * torch.sum(self.tp) - / (2 * torch.sum(self.tp) + torch.sum(self.fp) + torch.sum(self.fn)) - ) - - return f1_score + return f1_score \ No newline at end of file diff --git a/utils/metrics/accuracy.py b/utils/metrics/accuracy.py index 4d1cdd1..22a1283 100644 --- a/utils/metrics/accuracy.py +++ b/utils/metrics/accuracy.py @@ -3,10 +3,11 @@ class Accuracy(nn.Module): - def __init__(self, num_classes): + def __init__(self, num_classes, macro_averaging=False): super().__init__() self.num_classes = num_classes - + self.macro_averaging = macro_averaging + def forward(self, y_true, y_pred): """ Compute the accuracy of the model. @@ -23,12 +24,71 @@ def forward(self, y_true, y_pred): float Accuracy score. """ + if y_pred.dim() > 1: + y_pred = y_pred.argmax(dim=1) + if self.macro_averaging: + return self._macro_acc(y_true, y_pred) + else: + return self._micro_acc(y_true, y_pred) + + def _macro_acc(self, y_true, y_pred): + """ + Compute the macro-average accuracy. + + Parameters + ---------- + y_true : torch.Tensor + True labels. + y_pred : torch.Tensor + Predicted labels. + + Returns + ------- + float + Macro-average accuracy score. + """ + y_true, y_pred = y_true.flatten(), y_pred.flatten() # Ensure 1D shape + + classes = torch.unique(y_true) # Find unique class labels + acc_per_class = [] + + for c in classes: + mask = (y_true == c) # Mask for class c + acc = (y_pred[mask] == y_true[mask]).float().mean() # Accuracy for class c + acc_per_class.append(acc) + + macro_acc = torch.stack(acc_per_class).mean().item() # Average across classes + return macro_acc + + def _micro_acc(self, y_true, y_pred): + """ + Compute the micro-average accuracy. + + Parameters + ---------- + y_true : torch.Tensor + True labels. + y_pred : torch.Tensor + Predicted labels. + + Returns + ------- + float + Micro-average accuracy score. + """ return (y_true == y_pred).float().mean().item() if __name__ == "__main__": + accuracy = Accuracy(5) + macro_accuracy = Accuracy(5, macro_averaging=True) + y_true = torch.tensor([0, 3, 2, 3, 4]) y_pred = torch.tensor([0, 1, 2, 3, 4]) - - accuracy = Accuracy() print(accuracy(y_true, y_pred)) + print(macro_accuracy(y_true, y_pred)) + + y_true = torch.tensor([0, 3, 2, 3, 4]) + y_onehot_pred = torch.tensor([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]]) + print(accuracy(y_true, y_onehot_pred)) + print(macro_accuracy(y_true, y_onehot_pred)) diff --git a/utils/models/magnus_model.py b/utils/models/magnus_model.py index c80fae0..48386ce 100644 --- a/utils/models/magnus_model.py +++ b/utils/models/magnus_model.py @@ -22,18 +22,18 @@ def __init__(self, image_shape: int, num_classes: int, imagechannels: int): self.image_shape = image_shape self.imagechannels = imagechannels - self.layer1 = nn.Sequential(*([ - nn.Linear(self.imagechannels * self.imagesize * self.imagesize, 133), - nn.ReLU(), - ])) - self.layer2 = nn.Sequential(*([ - nn.Linear(133, 133), - nn.ReLU() - ])) - self.layer3 = nn.Sequential(*([ - nn.Linear(133, num_classes), - nn.ReLU() - ])) + self.layer1 = nn.Sequential( + *( + [ + nn.Linear( + self.imagechannels * self.imagesize * self.imagesize, 133 + ), + nn.ReLU(), + ] + ) + ) + self.layer2 = nn.Sequential(*([nn.Linear(133, 133), nn.ReLU()])) + self.layer3 = nn.Sequential(*([nn.Linear(133, num_classes), nn.ReLU()])) def forward(self, x): """