diff --git a/CollaborativeCoding/dataloaders/uspsh5_7_9.py b/CollaborativeCoding/dataloaders/uspsh5_7_9.py index 3a933db..5dce83e 100644 --- a/CollaborativeCoding/dataloaders/uspsh5_7_9.py +++ b/CollaborativeCoding/dataloaders/uspsh5_7_9.py @@ -65,6 +65,9 @@ def __init__( mask = np.isin(labels, [7, 8, 9]) self.images = images[mask] self.labels = labels[mask] + # map labels from (7,9) to (0,2) for CE loss + self.label_shift = lambda x: x - 7 + self.label_restore = lambda x: x + 7 def __len__(self): """ @@ -95,7 +98,7 @@ def __getitem__(self, id): # Convert to PIL Image (USPS images are typically grayscale 16x16) image = Image.fromarray(self.images[id].astype(np.uint8), mode="L") label = int(self.labels[id]) # Convert label to integer - + label = self.label_shift(label) if self.transform: image = self.transform(image)