diff --git a/CollaborativeCoding/dataloaders/download.py b/CollaborativeCoding/dataloaders/download.py index 4baefe2..08e7b27 100644 --- a/CollaborativeCoding/dataloaders/download.py +++ b/CollaborativeCoding/dataloaders/download.py @@ -130,7 +130,7 @@ def already_downloaded(path): url_test, _, test_md5 = USPS_SOURCE["test"] # Using temporary directory ensures temporary files are deleted after use - with TemporaryDirectory() as tmp_dir: + with TemporaryDirectory(dir=data_dir) as tmp_dir: train_path = Path(tmp_dir) / "train" test_path = Path(tmp_dir) / "test" diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index b911524..a8dbce2 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -25,10 +25,13 @@ ], ) def test_load_data(data_name, expected): + data_dir = Path("Data") + data_dir.mkdir(exist_ok=True) + dataset, _, _ = load_data( data_name, train=False, - data_dir=Path("Data"), + data_dir=data_dir, transform=transforms.ToTensor(), )