From e7b6454f0f9c542f1fb460dcdde97e7a8e80beae Mon Sep 17 00:00:00 2001 From: salomaestro Date: Sun, 23 Feb 2025 18:02:10 +0100 Subject: [PATCH 1/2] Fix bug where usps test weren't able to download usps data --- CollaborativeCoding/dataloaders/download.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CollaborativeCoding/dataloaders/download.py b/CollaborativeCoding/dataloaders/download.py index 4baefe2..08e7b27 100644 --- a/CollaborativeCoding/dataloaders/download.py +++ b/CollaborativeCoding/dataloaders/download.py @@ -130,7 +130,7 @@ def already_downloaded(path): url_test, _, test_md5 = USPS_SOURCE["test"] # Using temporary directory ensures temporary files are deleted after use - with TemporaryDirectory() as tmp_dir: + with TemporaryDirectory(dir=data_dir) as tmp_dir: train_path = Path(tmp_dir) / "train" test_path = Path(tmp_dir) / "test" From e3ba385809a6cf7a51bcff7cdaa1cf187ca96cf1 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Sun, 23 Feb 2025 18:19:55 +0100 Subject: [PATCH 2/2] Ensure "Data" exists before attempting to download datasets --- tests/test_dataloaders.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index b911524..a8dbce2 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -25,10 +25,13 @@ ], ) def test_load_data(data_name, expected): + data_dir = Path("Data") + data_dir.mkdir(exist_ok=True) + dataset, _, _ = load_data( data_name, train=False, - data_dir=Path("Data"), + data_dir=data_dir, transform=transforms.ToTensor(), )