Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ env2/*
ruffian.sh
localtest.sh

# Johanthings
formatting.x

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
12 changes: 11 additions & 1 deletion CollaborativeCoding/dataloaders/mnist_4_9.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@ class MNISTDataset4_9(Dataset):
Whether to train the model or not, by default False
"""

def __init__(self, data_path: Path, sample_ids: np.ndarray, train: bool = False):
def __init__(
self,
data_path: Path,
sample_ids: np.ndarray,
train: bool = False,
transform=None,
nr_channels: int = 1,
):
super.__init__()
self.data_path = data_path
self.mnist_path = self.data_path / "MNIST"
Expand Down Expand Up @@ -52,4 +59,7 @@ def __getitem__(self, idx):

image = np.expand_dims(image, axis=0) # Channel

if self.transform:
image = self.transform(image)

return image, label
5 changes: 4 additions & 1 deletion CollaborativeCoding/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .dataloaders import (
Downloader,
MNISTDataset0_3,
MNISTDataset4_9,
SVHNDataset,
USPSDataset0_6,
USPSH5_Digit_7_9_Dataset,
Expand Down Expand Up @@ -65,7 +66,9 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
train_labels, test_labels = downloader.svhn(data_dir=data_dir)
labels = np.arange(10)
case "mnist_4-9":
raise NotImplementedError("MNIST 4-9 dataset not yet implemented.")
dataset = MNISTDataset4_9
train_labels, test_labels = downloader.mnist(data_dir=data_dir)
labels = np.arange(4, 10)
case _:
raise NotImplementedError(f"Dataset: {dataset} not implemented.")

Expand Down
10 changes: 5 additions & 5 deletions CollaborativeCoding/load_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ class MetricWrapper(nn.Module):
-------
__call__(y_true, y_pred)
Computes the specified metrics on the provided true and predicted labels.
__getmetrics__(str_prefix: str = None)
getmetrics(str_prefix: str = None)
Retrieves the computed metrics, optionally prefixed with a string.
reset()
resetmetric()
Resets the state of all metric computations.
Examples
--------
Expand All @@ -36,10 +36,10 @@ class MetricWrapper(nn.Module):
>>> y_true = [0, 1, 0, 1]
>>> y_pred = [0, 1, 1, 0]
>>> metrics(y_true, y_pred)
>>> metrics.__getmetrics__()
>>> metrics.getmetrics()
{'entropy': 0.6931471805599453, 'f1': 0.5, 'precision': 0.5}
>>> metrics.reset()
>>> metrics.__getmetrics__()
>>> metrics.resetmetric()
>>> metrics.getmetrics()
{'entropy': [], 'f1': [], 'precision': []}
"""

Expand Down
38 changes: 32 additions & 6 deletions CollaborativeCoding/metrics/precision.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import torch
import torch.nn as nn

Expand All @@ -18,6 +19,8 @@ def __init__(self, num_classes: int, macro_averaging: bool = False):

self.num_classes = num_classes
self.macro_averaging = macro_averaging
self.y_true = []
self.y_pred = []

def forward(self, y_true: torch.tensor, logits: torch.tensor) -> torch.tensor:
"""Compute precision of model
Expand All @@ -35,11 +38,10 @@ def forward(self, y_true: torch.tensor, logits: torch.tensor) -> torch.tensor:
Precision score
"""
y_pred = logits.argmax(dim=-1)
return (
self._macro_avg_precision(y_true, y_pred)
if self.macro_averaging
else self._micro_avg_precision(y_true, y_pred)
)

# Append to the class-global values
self.y_true.append(y_true)
self.y_pred.append(y_pred)

def _micro_avg_precision(
self, y_true: torch.tensor, y_pred: torch.tensor
Expand All @@ -58,7 +60,6 @@ def _micro_avg_precision(
torch.tensor
Micro-averaged precision
"""
print(y_true.shape)
true_oh = torch.zeros(y_true.size(0), self.num_classes).scatter_(
1, y_true.unsqueeze(1), 1
)
Expand Down Expand Up @@ -98,6 +99,31 @@ def _macro_avg_precision(

return torch.nanmean(tp / (tp + fp))

def __returnmetric__(self):
if self.y_true == [] and self.y_pred == []:
return np.nan
elif self.y_true == [] or self.y_pred == []:
raise ValueError("y_true or y_pred is empty.")
self.y_true = torch.cat(self.y_true)
self.y_pred = torch.cat(self.y_pred)

return (
self._macro_avg_precision(self.y_true, self.y_pred)
if self.macro_averaging
else self._micro_avg_precision(self.y_true, self.y_pred)
)

def __reset__(self):
"""Resets the class-global lists of true and predicted values to empty lists.

Returns
-------
None
Returns None
"""
self.y_true = self.y_pred = []
return None


if __name__ == "__main__":
print(
Expand Down
18 changes: 10 additions & 8 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ def test_f1score():


def test_precision():
from random import randint

import numpy as np
import torch
from sklearn.metrics import precision_score
Expand All @@ -100,9 +98,13 @@ def test_precision():
precision_micro = Precision(num_classes=C)
precision_macro = Precision(num_classes=C, macro_averaging=True)

# find scores
micro_precision_score = precision_micro(y_true, logits)
macro_precision_score = precision_macro(y_true, logits)
# run metric object
precision_micro(y_true, logits)
precision_macro(y_true, logits)

# get metric scores
micro_precision_score = precision_micro.__returnmetric__()
macro_precision_score = precision_macro.__returnmetric__()

# check output to be tensor
assert isinstance(micro_precision_score, torch.Tensor), "Tensor output is expected."
Expand All @@ -113,12 +115,12 @@ def test_precision():
assert macro_precision_score.item() >= 0, "Expected non-negative value"

# find predictions
y_pred = logits.argmax(dim=-1, keepdims=True)
y_pred = logits.argmax(dim=-1)

# check dimension
assert y_true.shape == torch.Size([N, 1]) or torch.Size([N])
assert y_true.shape == torch.Size([N])
assert logits.shape == torch.Size([N, C])
assert y_pred.shape == torch.Size([N, 1]) or torch.Size([N])
assert y_pred.shape == torch.Size([N])

# find true values with scikit learn
scikit_macro_precision = precision_score(y_true, y_pred, average="macro")
Expand Down
Loading