From 18083b3f74fb6445fbae14fd394464e57971a4f2 Mon Sep 17 00:00:00 2001 From: Solveig Date: Sun, 23 Feb 2025 14:29:11 +0100 Subject: [PATCH] fixed remapping of the labels for usps 7-9 --- CollaborativeCoding/dataloaders/uspsh5_7_9.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CollaborativeCoding/dataloaders/uspsh5_7_9.py b/CollaborativeCoding/dataloaders/uspsh5_7_9.py index 3a933db..5dce83e 100644 --- a/CollaborativeCoding/dataloaders/uspsh5_7_9.py +++ b/CollaborativeCoding/dataloaders/uspsh5_7_9.py @@ -65,6 +65,9 @@ def __init__( mask = np.isin(labels, [7, 8, 9]) self.images = images[mask] self.labels = labels[mask] + # map labels from (7,9) to (0,2) for CE loss + self.label_shift = lambda x: x - 7 + self.label_restore = lambda x: x + 7 def __len__(self): """ @@ -95,7 +98,7 @@ def __getitem__(self, id): # Convert to PIL Image (USPS images are typically grayscale 16x16) image = Image.fromarray(self.images[id].astype(np.uint8), mode="L") label = int(self.labels[id]) # Convert label to integer - + label = self.label_shift(label) if self.transform: image = self.transform(image)