Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ wandb/*
wandb_api.py
doc/autoapi

*.DS_Store


#Magnus specific
job*
env2/*
Expand Down
1 change: 1 addition & 0 deletions CollaborativeCoding/dataloaders/mnist_0_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
sample_ids: list,
train: bool = False,
transform=None,
nr_channels: int = 1,
):
super().__init__()

Expand Down
57 changes: 26 additions & 31 deletions CollaborativeCoding/metrics/accuracy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import torch
from torch import nn

Expand All @@ -7,6 +8,8 @@ def __init__(self, num_classes, macro_averaging=False):
super().__init__()
self.num_classes = num_classes
self.macro_averaging = macro_averaging
self.y_true = []
self.y_pred = []

def forward(self, y_true, y_pred):
"""
Expand All @@ -26,12 +29,10 @@ def forward(self, y_true, y_pred):
"""
if y_pred.dim() > 1:
y_pred = y_pred.argmax(dim=1)
if self.macro_averaging:
return self._macro_acc(y_true, y_pred)
else:
return self._micro_acc(y_true, y_pred)
self.y_true.append(y_true)
self.y_pred.append(y_pred)

def _macro_acc(self, y_true, y_pred):
def _macro_acc(self):
"""
Compute the macro-average accuracy.

Expand All @@ -47,7 +48,7 @@ def _macro_acc(self, y_true, y_pred):
float
Macro-average accuracy score.
"""
y_true, y_pred = y_true.flatten(), y_pred.flatten() # Ensure 1D shape
y_true, y_pred = self.y_true.flatten(), self.y_pred.flatten() # Ensure 1D shape

classes = torch.unique(y_true) # Find unique class labels
acc_per_class = []
Expand All @@ -60,7 +61,7 @@ def _macro_acc(self, y_true, y_pred):
macro_acc = torch.stack(acc_per_class).mean().item() # Average across classes
return macro_acc

def _micro_acc(self, y_true, y_pred):
def _micro_acc(self):
"""
Compute the micro-average accuracy.

Expand All @@ -76,27 +77,21 @@ def _micro_acc(self, y_true, y_pred):
float
Micro-average accuracy score.
"""
return (y_true == y_pred).float().mean().item()


if __name__ == "__main__":
accuracy = Accuracy(5)
macro_accuracy = Accuracy(5, macro_averaging=True)

y_true = torch.tensor([0, 3, 2, 3, 4])
y_pred = torch.tensor([0, 1, 2, 3, 4])
print(accuracy(y_true, y_pred))
print(macro_accuracy(y_true, y_pred))

y_true = torch.tensor([0, 3, 2, 3, 4])
y_onehot_pred = torch.tensor(
[
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
]
)
print(accuracy(y_true, y_onehot_pred))
print(macro_accuracy(y_true, y_onehot_pred))
return (self.y_true == self.y_pred).float().mean().item()

def __returnmetric__(self):
if self.y_true == [] or self.y_pred == []:
return np.nan
if isinstance(self.y_true, list):
if len(self.y_true) == 1:
self.y_true = self.y_true[0]
self.y_pred = self.y_pred[0]
else:
self.y_true = torch.cat(self.y_true)
self.y_pred = torch.cat(self.y_pred)
return self._micro_acc() if not self.macro_averaging else self._macro_acc()

def __reset__(self):
self.y_true = []
self.y_pred = []
return None
7 changes: 5 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import numpy as np
import torch as th
import torch.nn as nn
import wandb
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

import wandb
from CollaborativeCoding import (
MetricWrapper,
createfolders,
Expand All @@ -14,6 +14,9 @@
load_model,
)

# from wandb_api import WANDB_API



def main():
"""
Expand Down Expand Up @@ -126,7 +129,7 @@ def main():
print("Dry run completed successfully.")
exit()

# wandb.login(key=WANDB_API)
# wandb.login(key=WANDB_API)
wandb.init(
entity="ColabCode",
project=args.run_name,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies = [
"torch>=2.6.0",
"torchvision>=0.21.0",
"tqdm>=4.67.1",
"wandb>=0.19.6",
]
[tool.isort]
profile = "black"
Expand Down
3 changes: 1 addition & 2 deletions tests/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
def test_load_data(data_name, expected):
dataset = load_data(
data_name,
data_path=Path("data"),
download=True,
data_dir=Path("data"),
transform=transforms.ToTensor(),
)
assert isinstance(dataset, expected)
Expand Down
36 changes: 27 additions & 9 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,18 +134,36 @@ def test_precision():


def test_accuracy():
import numpy as np
import torch

accuracy = Accuracy(num_classes=5)

y_true = torch.tensor([0, 3, 2, 3, 4])
y_pred = torch.tensor([0, 1, 2, 3, 4])

accuracy_score = accuracy(y_true, y_pred)

assert torch.abs(torch.tensor(accuracy_score - 0.8)) < 1e-5, (
f"Accuracy Score: {accuracy_score.item()}"
# Test the accuracy metric
y_true = torch.tensor([0, 1, 2, 3, 4, 5])
y_pred = torch.tensor([0, 1, 2, 3, 4, 5])
accuracy = Accuracy(num_classes=6, macro_averaging=False)
accuracy(y_true, y_pred)
assert accuracy.__returnmetric__() == 1.0, "Expected accuracy to be 1.0"
accuracy.__reset__()
assert accuracy.__returnmetric__() is np.nan, "Expected accuracy to be 0.0"
y_pred = torch.tensor([0, 1, 2, 3, 4, 4])
accuracy(y_true, y_pred)
assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, (
"Expected accuracy to be 0.8333333134651184"
)
accuracy.__reset__()
accuracy.macro_averaging = True
accuracy(y_true, y_pred)
y_true_1 = torch.tensor([0, 1, 2, 3, 4, 5])
y_pred_1 = torch.tensor([0, 1, 2, 3, 4, 4])
accuracy(y_true_1, y_pred_1)
assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, (
"Expected accuracy to be 0.8333333134651186"
)
accuracy.macro_averaging = False
assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, (
"Expected accuracy to be 0.8333333134651184"
)
accuracy.__reset__()


def test_entropypred():
Expand Down
Loading
Loading