Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 7 additions & 31 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from pathlib import Path

import numpy as np
import torch as th
import torch.nn as nn
import wandb
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

import wandb
from utils import MetricWrapper, createfolders, get_args, load_data, load_model


Expand All @@ -32,42 +30,20 @@ def main():
device = args.device

if args.dataset.lower() in ["usps_0-6", "uspsh5_7_9"]:
augmentations = transforms.Compose(
transform = transforms.Compose(
[
transforms.Resize((16, 16)),
transforms.ToTensor(),
]
)
else:
augmentations = transforms.Compose([transforms.ToTensor()])
transform = transforms.Compose([transforms.ToTensor()])

# Dataset
assert (
args.validation_split_percentage < 1.0 and args.validation_split_percentage > 0
), "Validation split should be in interval (0,1)"
traindata = load_data(
args.dataset,
split="train",
split_percentage=args.validation_split_percentage,
data_path=args.datafolder,
download=args.download_data,
transform=augmentations,
)
validata = load_data(
args.dataset,
split="validation",
split_percentage=args.validation_split_percentage,
data_path=args.datafolder,
download=args.download_data,
transform=augmentations,
)
testdata = load_data(
traindata, validata, testdata = load_data(
args.dataset,
split="test",
split_percentage=args.validation_split_percentage,
data_path=args.datafolder,
download=args.download_data,
transform=augmentations,
data_dir=args.datafolder,
transform=transform,
val_size=args.val_size,
)

metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes)
Expand Down
15 changes: 11 additions & 4 deletions tests/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,25 @@ def test_uspsdataset0_6():

# Create a h5 file
with h5py.File(tf, "w") as f:
targets = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
indices = np.arange(len(targets))
# Populate the file with data
f["train/data"] = np.random.rand(10, 16 * 16)
f["train/target"] = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
f["train/target"] = targets

trans = transforms.Compose(
[
transforms.Resize((16, 16)), # At least for USPS
transforms.Resize((16, 16)),
transforms.ToTensor(),
]
)
dataset = USPSDataset0_6(data_path=tempdir, train=True, transform=trans)
dataset = USPSDataset0_6(
data_path=tempdir,
sample_ids=indices,
train=True,
transform=trans,
)
assert len(dataset) == 10
data, target = dataset[0]
assert data.shape == (1, 16, 16)
assert all(target == np.array([0, 0, 0, 0, 0, 0, 1]))
assert target == 6
1 change: 0 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,3 @@ def test_jan_model(image_shape, num_classes):
y = model(x)

assert y.shape == (n, num_classes), f"Shape: {y.shape}"

8 changes: 1 addition & 7 deletions utils/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,6 @@ def get_args():
help="Whether model should be saved or not.",
)

parser.add_argument(
"--download-data",
action="store_true",
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
)

# Data/Model specific values
parser.add_argument(
"--modelname",
Expand All @@ -55,7 +49,7 @@ def get_args():
help="Which dataset to train the model on.",
)
parser.add_argument(
"--validation_split_percentage",
"--val_size",
type=float,
default=0.2,
help="Percentage of training dataset to be used as validation dataset - must be within (0,1).",
Expand Down
8 changes: 7 additions & 1 deletion utils/dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset", "MNISTDataset0_3"]
__all__ = [
"USPSDataset0_6",
"USPSH5_Digit_7_9_Dataset",
"MNISTDataset0_3",
"Downloader",
]

from .download import Downloader
from .mnist_0_3 import MNISTDataset0_3
from .usps_0_6 import USPSDataset0_6
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset
23 changes: 23 additions & 0 deletions utils/dataloaders/datasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,26 @@
"8ea070ee2aca1ac39742fdd1ef5ed118",
],
}

MNIST_SOURCE = {
"train_images": [
"https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz",
"train-images-idx3-ubyte",
None,
],
"train_labels": [
"https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz",
"train-labels-idx1-ubyte",
None,
],
"test_images": [
"https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz",
"t10k-images-idx3-ubyte",
None,
],
"test_labels": [
"https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz",
"t10k-labels-idx1-ubyte",
None,
],
}
183 changes: 183 additions & 0 deletions utils/dataloaders/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import bz2
import gzip
import hashlib
import os
from pathlib import Path
from tempfile import TemporaryDirectory
from urllib.request import urlretrieve

import h5py as h5
import numpy as np

from .datasources import MNIST_SOURCE, USPS_SOURCE


class Downloader:
"""
Class to download and load the USPS dataset.

Methods
-------
mnist(data_dir: Path) -> tuple[np.ndarray, np.ndarray]
Download the MNIST dataset and save it as an HDF5 file to `data_dir`.
svhn(data_dir: Path) -> tuple[np.ndarray, np.ndarray]
Download the SVHN dataset and save it as an HDF5 file to `data_dir`.
usps(data_dir: Path) -> tuple[np.ndarray, np.ndarray]
Download the USPS dataset and save it as an HDF5 file to `data_dir`.

Raises
------
NotImplementedError
If the download method is not implemented for the dataset.

Examples
--------
>>> from pathlib import Path
>>> from utils import Downloader
>>> dir = Path('tmp')
>>> dir.mkdir(exist_ok=True)
>>> train, test = Downloader().usps(dir)
"""

def mnist(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
def _chech_is_downloaded(path: Path) -> bool:
path = path / "MNIST"
if path.exists():
required_files = [MNIST_SOURCE[key][1] for key in MNIST_SOURCE.keys()]
if all([(path / file).exists() for file in required_files]):
print("MNIST Dataset already downloaded.")
return True
else:
return False
else:
path.mkdir(parents=True, exist_ok=True)
return False

def _download_data(path: Path) -> None:
urls = {key: MNIST_SOURCE[key][0] for key in MNIST_SOURCE.keys()}

for name, url in urls.items():
file_path = os.path.join(path, url.split("/")[-1])
if not os.path.exists(
file_path.replace(".gz", "")
): # Avoid re-downloading
urlretrieve(url, file_path)
with gzip.open(file_path, "rb") as f_in:
with open(file_path.replace(".gz", ""), "wb") as f_out:
f_out.write(f_in.read())
os.remove(file_path) # Remove compressed file

def _get_labels(path: Path) -> np.ndarray:
with open(path, "rb") as f:
data = np.frombuffer(f.read(), dtype=np.uint8, offset=8)
return data

if not _chech_is_downloaded(data_dir):
_download_data(data_dir)

train_labels_path = data_dir / "MNIST" / MNIST_SOURCE["train_labels"][1]
test_labels_path = data_dir / "MNIST" / MNIST_SOURCE["test_labels"][1]

train_labels = _get_labels(train_labels_path)
test_labels = _get_labels(test_labels_path)

return train_labels, test_labels

def svhn(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
raise NotImplementedError("SVHN download not implemented yet")

def usps(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
"""
Download the USPS dataset and save it as an HDF5 file to `data_dir/usps.h5`.
"""

def already_downloaded(path):
if not path.exists() or not path.is_file():
return False

with h5.File(path, "r") as f:
return "train" in f and "test" in f

filename = data_dir / "usps.h5"

if already_downloaded(filename):
with h5.File(filename, "r") as f:
return f["train"]["target"][:], f["test"]["target"][:]

url_train, _, train_md5 = USPS_SOURCE["train"]
url_test, _, test_md5 = USPS_SOURCE["test"]

# Using temporary directory ensures temporary files are deleted after use
with TemporaryDirectory() as tmp_dir:
train_path = Path(tmp_dir) / "train"
test_path = Path(tmp_dir) / "test"

# Download the dataset and report the progress
urlretrieve(url_train, train_path, reporthook=self.__reporthook)
self.__check_integrity(train_path, train_md5)
train_targets = self.__extract_usps(train_path, filename, "train")

urlretrieve(url_test, test_path, reporthook=self.__reporthook)
self.__check_integrity(test_path, test_md5)
test_targets = self.__extract_usps(test_path, filename, "test")

return train_targets, test_targets

def __extract_usps(self, src: Path, dest: Path, mode: str):
# Load the dataset and save it as an HDF5 file
with bz2.open(src) as fp:
raw = [line.decode().split() for line in fp.readlines()]

tmp = [[x.split(":")[-1] for x in data[1:]] for data in raw]

imgs = np.asarray(tmp, dtype=np.float32)
imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)

targets = [int(d[0]) - 1 for d in raw]

with h5.File(dest, "a") as f:
f.create_dataset(f"{mode}/data", data=imgs, dtype=np.float32)
f.create_dataset(f"{mode}/target", data=targets, dtype=np.int32)

return targets

@staticmethod
def __reporthook(blocknum, blocksize, totalsize):
"""
Use this function to report download progress
for the urllib.request.urlretrieve function.
"""

denom = 1024 * 1024
readsofar = blocknum * blocksize

if totalsize > 0:
percent = readsofar * 1e2 / totalsize
s = f"\r{int(percent):^3}% {readsofar / denom:.2f} of {totalsize / denom:.2f} MB"
print(s, end="", flush=True)
if readsofar >= totalsize:
print()

@staticmethod
def __check_integrity(filepath, checksum):
"""Check the integrity of the USPS dataset file.

Args
----
filepath : pathlib.Path
Path to the USPS dataset file.
checksum : str
MD5 checksum of the dataset file.

Returns
-------
bool
True if the checksum of the file matches the expected checksum, False otherwise
"""

file_hash = hashlib.md5(filepath.read_bytes()).hexdigest()

if not checksum == file_hash:
raise ValueError(
f"File integrity check failed. Expected {checksum}, got {file_hash}"
)
Loading