diff --git a/CollaborativeCoding/dataloaders/uspsh5_7_9.py b/CollaborativeCoding/dataloaders/uspsh5_7_9.py index 35167a4..6808ad3 100644 --- a/CollaborativeCoding/dataloaders/uspsh5_7_9.py +++ b/CollaborativeCoding/dataloaders/uspsh5_7_9.py @@ -32,7 +32,7 @@ class USPSH5_Digit_7_9_Dataset(Dataset): A transform function to apply to the images. """ - def __init__(self, data_path, train=False, transform=None): + def __init__(self, data_path, sample_ids, train=False, transform=None, nr_channels=1): super().__init__() """ Initializes the USPS dataset by loading images and labels from the given `.h5` file. @@ -51,6 +51,8 @@ def __init__(self, data_path, train=False, transform=None): self.transform = transform self.mode = "train" if train else "test" self.h5_path = data_path / self.filename + self.sample_ids = sample_ids + self.nr_channels = nr_channels # Load the dataset from the HDF5 file with h5py.File(self.filepath, "r") as hf: @@ -107,10 +109,10 @@ def main(): transforms.Normalize((0.5,), (0.5,)), # Normalize to [-1, 1] ] ) - + indices = np.array([7, 8, 9]) # Load the dataset dataset = USPSH5_Digit_7_9_Dataset( - data_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git", + data_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git", sample_ids=indices, train=False, transform=transform, ) diff --git a/CollaborativeCoding/models/solveig_model.py b/CollaborativeCoding/models/solveig_model.py index 21dec4e..442ab0e 100644 --- a/CollaborativeCoding/models/solveig_model.py +++ b/CollaborativeCoding/models/solveig_model.py @@ -2,6 +2,37 @@ import torch.nn as nn +def find_fc_input_shape(image_shape, model): + """ + Find the shape of the input to the fully connected layer after passing through the convolutional layers. + + Code inspired by @Seilmast (https://github.com/SFI-Visual-Intelligence/Collaborative-Coding-Exam/issues/67#issuecomment-2651212254) + + Args + ---- + image_shape : tuple(int, int, int) + Shape of the input image (C, H, W), where C is the number of channels, + H is the height, and W is the width of the image. + model : nn.Module + The CNN model containing the convolutional layers, whose output size is used to + determine the number of input features for the fully connected layer. + + Returns + ------- + int + The number of elements in the input to the fully connected layer. + """ + + dummy_img = torch.randn(1, *image_shape) + with torch.no_grad(): + x = model.conv_block1(dummy_img) + x = model.conv_block2(x) + x = model.conv_block3(x) + x = torch.flatten(x, 1) + + return x.size(1) + + class SolveigModel(nn.Module): """ A Convolutional Neural Network model for classification. @@ -49,9 +80,19 @@ def __init__(self, image_shape, num_classes): nn.ReLU(), ) - self.fc1 = nn.Linear(100 * 8 * 8, num_classes) + fc_input_size = find_fc_input_shape(image_shape, self) + + self.fc1 = nn.Linear(fc_input_size, num_classes) def forward(self, x): + """ + Defines the forward pass. + Args: + x (torch.Tensor): A four-dimensional tensor with shape + (Batch Size, Channels, Image Height, Image Width). + Returns: + torch.Tensor: The output tensor containing class logits for each input sample. + """ x = self.conv_block1(x) x = self.conv_block2(x) x = self.conv_block3(x) @@ -63,7 +104,7 @@ def forward(self, x): if __name__ == "__main__": - x = torch.randn(1, 3, 16, 16) + x = torch.randn(1, 3, 28, 28) model = SolveigModel(x.shape[1:], 3)