diff --git a/CollaborativeCoding/dataloaders/mnist_0_3.py b/CollaborativeCoding/dataloaders/mnist_0_3.py index 401ea0d..7fefa28 100644 --- a/CollaborativeCoding/dataloaders/mnist_0_3.py +++ b/CollaborativeCoding/dataloaders/mnist_0_3.py @@ -93,5 +93,4 @@ def __getitem__(self, index): if self.transform: image = self.transform(image) - return image, label diff --git a/CollaborativeCoding/dataloaders/svhn.py b/CollaborativeCoding/dataloaders/svhn.py index 4d039ac..5db829b 100644 --- a/CollaborativeCoding/dataloaders/svhn.py +++ b/CollaborativeCoding/dataloaders/svhn.py @@ -38,6 +38,11 @@ def __init__( self.nr_channels = nr_channels self.transforms = transform + if not os.path.exists( + os.path.join(self.data_path, f"svhn_{self.split}data.h5") + ): + self._download_data(self.data_path) + assert os.path.exists( os.path.join(self.data_path, f"svhn_{self.split}data.h5") ), f"File svhn_{self.split}data.h5 does not exists. Run download=True" @@ -97,4 +102,4 @@ def __getitem__(self, index): if self.transforms is not None: img = self.transforms(img) - return img, lab + return img, int(lab) diff --git a/CollaborativeCoding/load_data.py b/CollaborativeCoding/load_data.py index 6cafc3c..80fe43c 100644 --- a/CollaborativeCoding/load_data.py +++ b/CollaborativeCoding/load_data.py @@ -79,8 +79,8 @@ def load_data(dataset: str, *args, **kwargs) -> tuple: test_indices = np.arange(len(test_labels)) # Filter the labels to only get indices of the wanted labels - train_samples = train_indices[np.isin(train_labels, labels)] - test_samples = test_indices[np.isin(test_labels, labels)] + train_samples = train_indices[np.isin(train_labels, labels).flatten()] + test_samples = test_indices[np.isin(test_labels, labels).flatten()] train_samples, val_samples = random_split(train_samples, [1 - val_size, val_size])