Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 52 additions & 16 deletions CollaborativeCoding/dataloaders/uspsh5_7_9.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,45 @@

class USPSH5_Digit_7_9_Dataset(Dataset):
"""
Custom USPS dataset class that loads images with digits 7-9 from an .h5 file.
This class loads a subset of the USPS dataset, specifically images of digits 7, 8, and 9, from an HDF5 file.
It allows for applying transformations to the images and provides methods to retrieve images and their corresponding labels.

Parameters
----------
h5_path : str
Path to the USPS `.h5` file.
data_path : str or Path
Path to the directory containing the USPS `.h5` file. This file should contain the data in the "train" or "test" group.

sample_ids : list of int
A list of sample indices to be used from the dataset. This allows for filtering or selecting a subset of the full dataset.

train : bool, optional, default=False
If `True`, the dataset is loaded in training mode (using the "train" group). If `False`, the dataset is loaded in test mode (using the "test" group).

transform : callable, optional, default=None
A transform function to apply on images. If None, no transformation is applied.
A transformation function to apply to each image. If `None`, no transformation is applied. Typically used for data augmentation or normalization.

nr_channels : int, optional, default=1
The number of channels in the image. USPS images are typically grayscale, so this should generally be set to 1. This parameter allows for potential future flexibility.

Attributes
----------
images : numpy.ndarray
The filtered images corresponding to digits 7-9.
Array of images corresponding to digits 7, 8, and 9 from the USPS dataset. The images are loaded from the HDF5 file and filtered based on the labels.

labels : numpy.ndarray
The filtered labels corresponding to digits 7-9.
Array of labels corresponding to the images. Only labels of digits 7, 8, and 9 are retained, and they are mapped to 0, 1, and 2 for classification tasks.

transform : callable, optional
A transform function to apply to the images.
A transformation function to apply to the images. This is passed as an argument during initialization.

label_shift : function
A function to shift the labels for classification purposes. It maps the original labels (7, 8, 9) to (0, 1, 2).

label_restore : function
A function to restore the original labels (7, 8, 9) from the shifted labels (0, 1, 2).

num_classes : int
The number of unique labels in the dataset, which is 3 (for digits 7, 8, and 9).
"""

def __init__(
Expand All @@ -36,14 +55,25 @@ def __init__(
super().__init__()
"""
Initializes the USPS dataset by loading images and labels from the given `.h5` file.


The dataset is filtered to only include images of digits 7, 8, and 9, which are mapped to labels 0, 1, and 2 respectively for classification purposes.

Parameters
----------
h5_path : str
Path to the USPS `.h5` file.

data_path : str or Path
Path to the directory containing the USPS `.h5` file.

sample_ids : list of int
List of sample indices to load from the dataset.

train : bool, optional, default=False
If `True`, loads the training data from the HDF5 file. If `False`, loads the test data.

transform : callable, optional, default=None
A transform function to apply on images.
A function to apply transformations to the images. If None, no transformation is applied.

nr_channels : int, optional, default=1
The number of channels in the image. Defaults to 1 for grayscale images.
"""
self.filename = "usps.h5"
path = data_path if isinstance(data_path, Path) else Path(data_path)
Expand Down Expand Up @@ -72,27 +102,33 @@ def __len__(self):
"""
Returns the total number of samples in the dataset.

This method is required for PyTorch's Dataset class, as it allows PyTorch to determine the size of the dataset.

Returns
-------
int
The number of images in the dataset.
The number of images in the dataset (after filtering for digits 7, 8, and 9).
"""

return len(self.images)

def __getitem__(self, id):
"""
Returns a sample from the dataset given an index.

This method is required for PyTorch's Dataset class, as it allows indexing into the dataset to retrieve specific samples.

Parameters
----------
idx : int
The index of the sample to retrieve.
The index of the sample to retrieve from the dataset.

Returns
-------
tuple
- image (PIL Image): The image at the specified index.
- label (int): The label corresponding to the image.
A tuple containing:
- image (PIL Image): The image at the specified index.
- label (int): The label corresponding to the image, shifted to be in the range [0, 2] for classification.
"""
# Convert to PIL Image (USPS images are typically grayscale 16x16)
image = Image.fromarray(self.images[id].astype(np.uint8), mode="L")
Expand Down
76 changes: 63 additions & 13 deletions CollaborativeCoding/metrics/F1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,67 @@

class F1Score(nn.Module):
"""
F1 Score implementation with support for both macro and micro averaging.
This class computes the F1 score during training using either macro or micro averaging.
Computes the F1 score for classification tasks with support for both macro and micro averaging.

This class allows you to compute the F1 score during training or evaluation. You can select between two methods of averaging:
- **Micro Averaging**: Computes the F1 score globally, treating each individual prediction as equally important.
- **Macro Averaging**: Computes the F1 score for each class individually and then averages the scores.

Parameters
----------
num_classes : int
The number of classes in the classification task.

macro_averaging : bool, default=False
If True, computes the macro-averaged F1 score. If False, computes the micro-averaged F1 score.
macro_averaging : bool, optional, default=False
If True, computes the macro-averaged F1 score. If False, computes the micro-averaged F1 score. Default is micro averaging.

Attributes
----------
num_classes : int
The number of classes in the classification task.

macro_averaging : bool
A flag to determine whether to compute the macro-averaged or micro-averaged F1 score.

y_true : list
A list to store true labels for the current batch.

y_pred : list
A list to store predicted labels for the current batch.

Methods
-------
forward(target, preds)
Stores predictions and true labels for computing the F1 score during training or evaluation.

compute_f1()
Computes and returns the F1 score based on the stored predictions and true labels.

_micro_F1(target, preds)
Computes the micro-averaged F1 score based on the global true positive, false positive, and false negative counts.

_macro_F1(target, preds)
Computes the macro-averaged F1 score by calculating the F1 score per class and then averaging across all classes.

__returnmetric__()
Computes and returns the F1 score (Micro or Macro) as specified.

__reset__()
Resets the stored predictions and true labels, preparing for the next batch or epoch.
"""

def __init__(self, num_classes, macro_averaging=False):
"""
Initializes the F1Score object with the number of classes and averaging mode.

Parameters
----------
num_classes : int
The number of classes in the classification task.

macro_averaging : bool, optional, default=False
If True, compute the macro-averaged F1 score. If False, compute the micro-averaged F1 score.
"""
super().__init__()
self.num_classes = num_classes
self.macro_averaging = macro_averaging
Expand All @@ -25,14 +74,15 @@ def __init__(self, num_classes, macro_averaging=False):

def forward(self, target, preds):
"""
Stores predictions and targets for computing the F1 score.
Stores the true labels and predictions to compute the F1 score.

Parameters
----------
preds : torch.Tensor
Predicted logits (shape: [batch_size, num_classes]).
target : torch.Tensor
True labels (shape: [batch_size]).

preds : torch.Tensor
Predicted logits (shape: [batch_size, num_classes]).
"""
preds = torch.argmax(preds, dim=-1) # Convert logits to class indices
self.y_true.append(target.detach())
Expand All @@ -47,7 +97,7 @@ def compute_f1(self):
Returns
-------
torch.Tensor
The computed F1 score.
The computed F1 score. Returns NaN if no predictions or targets are available.
"""
if not self.y_true or not self.y_pred: # Check if empty
return torch.tensor(np.nan)
Expand All @@ -63,7 +113,7 @@ def compute_f1(self):
)

def _micro_F1(self, target, preds):
"""Computes Micro F1 Score (global TP, FP, FN)."""
"""Computes the Micro-averaged F1 score (global TP, FP, FN)."""
tp = torch.sum(preds == target).float()
fp = torch.sum(preds != target).float()
fn = fp # Since all errors are either FP or FN
Expand All @@ -75,7 +125,7 @@ def _micro_F1(self, target, preds):
return f1

def _macro_F1(self, target, preds):
"""Computes Macro F1 Score in a vectorized way (no loops)."""
"""Computes the Macro-averaged F1 score."""
num_classes = self.num_classes
target = target.long() # Ensure target is a LongTensor
preds = preds.long()
Expand All @@ -100,12 +150,12 @@ def _macro_F1(self, target, preds):

def __returnmetric__(self):
"""
Computes and returns the F1 score (Micro or Macro).
Computes and returns the F1 score (Micro or Macro) based on the stored predictions and targets.

Returns
-------
torch.Tensor
The computed F1 score.
The computed F1 score. Returns NaN if no predictions or targets are available.
"""
if not self.y_true or not self.y_pred: # Check if empty
return torch.tensor(np.nan)
Expand All @@ -121,6 +171,6 @@ def __returnmetric__(self):
)

def __reset__(self):
"""Resets stored predictions and targets."""
"""Resets the stored predictions and targets for the next batch or epoch."""
self.y_true = []
self.y_pred = []
Loading