From 8b22c30a22f29caffb1919447ed19f59c06e34bf Mon Sep 17 00:00:00 2001 From: Jan Zavadil Date: Thu, 13 Feb 2025 17:49:04 +0100 Subject: [PATCH 1/7] adjusted accuracy --- CollaborativeCoding/metrics/accuracy.py | 85 +++++++++++++++++-------- 1 file changed, 57 insertions(+), 28 deletions(-) diff --git a/CollaborativeCoding/metrics/accuracy.py b/CollaborativeCoding/metrics/accuracy.py index 5123d36..066e73d 100644 --- a/CollaborativeCoding/metrics/accuracy.py +++ b/CollaborativeCoding/metrics/accuracy.py @@ -7,6 +7,8 @@ def __init__(self, num_classes, macro_averaging=False): super().__init__() self.num_classes = num_classes self.macro_averaging = macro_averaging + self.y_true = [] + self.y_pred = [] def forward(self, y_true, y_pred): """ @@ -26,12 +28,10 @@ def forward(self, y_true, y_pred): """ if y_pred.dim() > 1: y_pred = y_pred.argmax(dim=1) - if self.macro_averaging: - return self._macro_acc(y_true, y_pred) - else: - return self._micro_acc(y_true, y_pred) + self.y_true.append(y_true) + self.y_pred.append(y_pred) - def _macro_acc(self, y_true, y_pred): + def _macro_acc(self): """ Compute the macro-average accuracy. @@ -47,7 +47,7 @@ def _macro_acc(self, y_true, y_pred): float Macro-average accuracy score. """ - y_true, y_pred = y_true.flatten(), y_pred.flatten() # Ensure 1D shape + y_true, y_pred = self.y_true.flatten(), self.y_pred.flatten() # Ensure 1D shape classes = torch.unique(y_true) # Find unique class labels acc_per_class = [] @@ -60,7 +60,7 @@ def _macro_acc(self, y_true, y_pred): macro_acc = torch.stack(acc_per_class).mean().item() # Average across classes return macro_acc - def _micro_acc(self, y_true, y_pred): + def _micro_acc(self): """ Compute the micro-average accuracy. @@ -76,27 +76,56 @@ def _micro_acc(self, y_true, y_pred): float Micro-average accuracy score. """ - return (y_true == y_pred).float().mean().item() + print(self.y_true, self.y_pred) + return (self.y_true == self.y_pred).float().mean().item() + + def __returnmetric__(self): + print(self.y_true, self.y_pred) + print(self.y_true == [], self.y_pred == []) + print(len(self.y_true), len(self.y_pred)) + print(type(self.y_true), type(self.y_pred)) + if self.y_true == [] or self.y_pred == []: + return 0.0 + if isinstance(self.y_true,list): + if len(self.y_true) == 1: + self.y_true = self.y_true[0] + self.y_pred = self.y_pred[0] + else: + self.y_true = torch.cat(self.y_true) + self.y_pred = torch.cat(self.y_pred) + return self._micro_acc() if not self.macro_averaging else self._macro_acc() + + def __resetmetric__(self): + self.y_true = [] + self.y_pred = [] + return None if __name__ == "__main__": - accuracy = Accuracy(5) - macro_accuracy = Accuracy(5, macro_averaging=True) - - y_true = torch.tensor([0, 3, 2, 3, 4]) - y_pred = torch.tensor([0, 1, 2, 3, 4]) - print(accuracy(y_true, y_pred)) - print(macro_accuracy(y_true, y_pred)) - - y_true = torch.tensor([0, 3, 2, 3, 4]) - y_onehot_pred = torch.tensor( - [ - [1, 0, 0, 0, 0], - [0, 1, 0, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 0, 1, 0], - [0, 0, 0, 0, 1], - ] - ) - print(accuracy(y_true, y_onehot_pred)) - print(macro_accuracy(y_true, y_onehot_pred)) + # Test the accuracy metric + y_true = torch.tensor([0, 1, 2, 3, 4, 5]) + y_pred = torch.tensor([0, 1, 2, 3, 4, 5]) + accuracy = Accuracy(num_classes=6, macro_averaging=False) + accuracy(y_true, y_pred) + print(accuracy.__returnmetric__()) # 1.0 + accuracy.__resetmetric__() + print(accuracy.__returnmetric__()) # 0.0 + y_pred = torch.tensor([0, 1, 2, 3, 4, 4]) + accuracy(y_true, y_pred) + print(accuracy.__returnmetric__()) # 0.8333333134651184 + accuracy.__resetmetric__() + print(accuracy.__returnmetric__()) # 0.0 + accuracy.macro_averaging = True + accuracy(y_true, y_pred) + y_true_1 = torch.tensor([0, 1, 2, 3, 4, 5]) + y_pred_1 = torch.tensor([0, 1, 2, 3, 4, 4]) + accuracy(y_true_1, y_pred_1) + print(accuracy.__returnmetric__()) # 0.9166666865348816 + #accuracy.__resetmetric__() + #accuracy(y_true, y_pred) + #accuracy(y_true_1, y_pred_1) + accuracy.macro_averaging = False + print(accuracy.__returnmetric__()) # 0.8333333134651184 + accuracy.__resetmetric__() + print(accuracy.__returnmetric__()) # 0.0 + print(accuracy.__resetmetric__()) # None From 57b8ebdcd83a169e46e3ce49278e7f703b2e3be5 Mon Sep 17 00:00:00 2001 From: Jan Zavadil Date: Thu, 13 Feb 2025 18:49:34 +0100 Subject: [PATCH 2/7] added nr_channels - issue #78 --- .gitignore | 3 +++ CollaborativeCoding/dataloaders/mnist_0_3.py | 1 + 2 files changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index 3afaf91..c4425b0 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,9 @@ wandb/* wandb_api.py doc/autoapi +*.DS_Store + + #Magnus specific job* env2/* diff --git a/CollaborativeCoding/dataloaders/mnist_0_3.py b/CollaborativeCoding/dataloaders/mnist_0_3.py index 52a5a28..57cfe35 100644 --- a/CollaborativeCoding/dataloaders/mnist_0_3.py +++ b/CollaborativeCoding/dataloaders/mnist_0_3.py @@ -53,6 +53,7 @@ def __init__( sample_ids: list, train: bool = False, transform=None, + nr_channels: int = 1, ): super().__init__() From 787a3420284530975f89c69f86a2a7f569d600b8 Mon Sep 17 00:00:00 2001 From: Jan Zavadil Date: Thu, 13 Feb 2025 19:00:22 +0100 Subject: [PATCH 3/7] fixed wrong naming in accuracy --- CollaborativeCoding/metrics/accuracy.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/CollaborativeCoding/metrics/accuracy.py b/CollaborativeCoding/metrics/accuracy.py index 066e73d..8ef3f94 100644 --- a/CollaborativeCoding/metrics/accuracy.py +++ b/CollaborativeCoding/metrics/accuracy.py @@ -76,14 +76,9 @@ def _micro_acc(self): float Micro-average accuracy score. """ - print(self.y_true, self.y_pred) return (self.y_true == self.y_pred).float().mean().item() def __returnmetric__(self): - print(self.y_true, self.y_pred) - print(self.y_true == [], self.y_pred == []) - print(len(self.y_true), len(self.y_pred)) - print(type(self.y_true), type(self.y_pred)) if self.y_true == [] or self.y_pred == []: return 0.0 if isinstance(self.y_true,list): @@ -95,7 +90,7 @@ def __returnmetric__(self): self.y_pred = torch.cat(self.y_pred) return self._micro_acc() if not self.macro_averaging else self._macro_acc() - def __resetmetric__(self): + def __reset__(self): self.y_true = [] self.y_pred = [] return None @@ -121,9 +116,6 @@ def __resetmetric__(self): y_pred_1 = torch.tensor([0, 1, 2, 3, 4, 4]) accuracy(y_true_1, y_pred_1) print(accuracy.__returnmetric__()) # 0.9166666865348816 - #accuracy.__resetmetric__() - #accuracy(y_true, y_pred) - #accuracy(y_true_1, y_pred_1) accuracy.macro_averaging = False print(accuracy.__returnmetric__()) # 0.8333333134651184 accuracy.__resetmetric__() From 24f920df885ff4ed5abf8534e1de4934eb5ed512 Mon Sep 17 00:00:00 2001 From: Jan Zavadil Date: Thu, 13 Feb 2025 20:23:21 +0100 Subject: [PATCH 4/7] adjusted accuracy test and test_metric_wrapper to work with new method names --- CollaborativeCoding/metrics/accuracy.py | 28 ++---------------- tests/test_metrics.py | 39 ++++++++++++++++--------- 2 files changed, 28 insertions(+), 39 deletions(-) diff --git a/CollaborativeCoding/metrics/accuracy.py b/CollaborativeCoding/metrics/accuracy.py index 8ef3f94..ed51399 100644 --- a/CollaborativeCoding/metrics/accuracy.py +++ b/CollaborativeCoding/metrics/accuracy.py @@ -1,5 +1,6 @@ import torch from torch import nn +import numpy as np class Accuracy(nn.Module): @@ -80,7 +81,7 @@ def _micro_acc(self): def __returnmetric__(self): if self.y_true == [] or self.y_pred == []: - return 0.0 + return np.nan if isinstance(self.y_true,list): if len(self.y_true) == 1: self.y_true = self.y_true[0] @@ -96,28 +97,3 @@ def __reset__(self): return None -if __name__ == "__main__": - # Test the accuracy metric - y_true = torch.tensor([0, 1, 2, 3, 4, 5]) - y_pred = torch.tensor([0, 1, 2, 3, 4, 5]) - accuracy = Accuracy(num_classes=6, macro_averaging=False) - accuracy(y_true, y_pred) - print(accuracy.__returnmetric__()) # 1.0 - accuracy.__resetmetric__() - print(accuracy.__returnmetric__()) # 0.0 - y_pred = torch.tensor([0, 1, 2, 3, 4, 4]) - accuracy(y_true, y_pred) - print(accuracy.__returnmetric__()) # 0.8333333134651184 - accuracy.__resetmetric__() - print(accuracy.__returnmetric__()) # 0.0 - accuracy.macro_averaging = True - accuracy(y_true, y_pred) - y_true_1 = torch.tensor([0, 1, 2, 3, 4, 5]) - y_pred_1 = torch.tensor([0, 1, 2, 3, 4, 4]) - accuracy(y_true_1, y_pred_1) - print(accuracy.__returnmetric__()) # 0.9166666865348816 - accuracy.macro_averaging = False - print(accuracy.__returnmetric__()) # 0.8333333134651184 - accuracy.__resetmetric__() - print(accuracy.__returnmetric__()) # 0.0 - print(accuracy.__resetmetric__()) # None diff --git a/tests/test_metrics.py b/tests/test_metrics.py index b747a1c..f5d01f3 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -3,7 +3,7 @@ import pytest from CollaborativeCoding.load_metric import MetricWrapper -from CollaborativeCoding.metrics import Accuracy, F1Score, Precision, Recall +from CollaborativeCoding.metrics import Accuracy, F1Score, Precision, Recall, EntropyPrediction @pytest.mark.parametrize( @@ -34,9 +34,9 @@ def test_metric_wrapper(metric, num_classes, macro_averaging): ) metrics(y_true, logits) - score = metrics.accumulate() - metrics.reset() - empty_score = metrics.accumulate() + score = metrics.__getmetrics__() + metrics.__resetmetrics__() + empty_score = metrics.__getmetrics__() assert isinstance(score, dict), "Expected a dictionary output." assert metric in score, f"Expected {metric} metric in the output." @@ -129,17 +129,30 @@ def test_precision(): def test_accuracy(): import torch + import numpy as np - accuracy = Accuracy(num_classes=5) - - y_true = torch.tensor([0, 3, 2, 3, 4]) - y_pred = torch.tensor([0, 1, 2, 3, 4]) - - accuracy_score = accuracy(y_true, y_pred) + # Test the accuracy metric + y_true = torch.tensor([0, 1, 2, 3, 4, 5]) + y_pred = torch.tensor([0, 1, 2, 3, 4, 5]) + accuracy = Accuracy(num_classes=6, macro_averaging=False) + accuracy(y_true, y_pred) + assert accuracy.__returnmetric__() == 1.0, "Expected accuracy to be 1.0" + accuracy.__reset__() + assert accuracy.__returnmetric__() is np.nan, "Expected accuracy to be 0.0" + y_pred = torch.tensor([0, 1, 2, 3, 4, 4]) + accuracy(y_true, y_pred) + assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, "Expected accuracy to be 0.8333333134651184" + accuracy.__reset__() + accuracy.macro_averaging = True + accuracy(y_true, y_pred) + y_true_1 = torch.tensor([0, 1, 2, 3, 4, 5]) + y_pred_1 = torch.tensor([0, 1, 2, 3, 4, 4]) + accuracy(y_true_1, y_pred_1) + assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, "Expected accuracy to be 0.8333333134651186" + accuracy.macro_averaging = False + assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, "Expected accuracy to be 0.8333333134651184" + accuracy.__reset__() - assert torch.abs(torch.tensor(accuracy_score - 0.8)) < 1e-5, ( - f"Accuracy Score: {accuracy_score.item()}" - ) def test_entropypred(): From 4e8c4e6ce190be2a03bfbfebf0f3a5ca9ff72a18 Mon Sep 17 00:00:00 2001 From: Jan Zavadil Date: Thu, 13 Feb 2025 20:34:40 +0100 Subject: [PATCH 5/7] teeny adjustments --- main.py | 5 +- pyproject.toml | 1 + tests/test_models.py | 2 +- uv.lock | 160 ++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 156 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index ba37479..52e1dc5 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,7 @@ from torch.utils.data import DataLoader from torchvision import transforms from tqdm import tqdm +#from wandb_api import WANDB_API from CollaborativeCoding import ( MetricWrapper, @@ -121,11 +122,11 @@ def main(): train_metrics(y, logits) break - print(train_metrics.accumulate()) + print(train_metrics.__getmetrics__()) print("Dry run completed successfully.") exit() - # wandb.login(key=WANDB_API) +# wandb.login(key=WANDB_API) wandb.init( entity="ColabCode", project=args.run_name, diff --git a/pyproject.toml b/pyproject.toml index 93ccc07..1cb1da6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "torch>=2.6.0", "torchvision>=0.21.0", "tqdm>=4.67.1", + "wandb>=0.19.6", ] [tool.isort] profile = "black" diff --git a/tests/test_models.py b/tests/test_models.py index 0af2717..f61315e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,7 +1,7 @@ import pytest import torch -from CollaborativeCoding.models import ChristianModel, JanModel, MagnusModel +from CollaborativeCoding.models import ChristianModel, JanModel, MagnusModel, SolveigModel @pytest.mark.parametrize( diff --git a/uv.lock b/uv.lock index 6a20243..4d42378 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,9 @@ version = 1 requires-python = ">=3.12" +resolution-markers = [ + "sys_platform == 'linux'", + "sys_platform != 'linux'", +] [[package]] name = "alabaster" @@ -308,6 +312,7 @@ dependencies = [ { name = "torch" }, { name = "torchvision" }, { name = "tqdm" }, + { name = "wandb" }, ] [package.metadata] @@ -330,6 +335,7 @@ requires-dist = [ { name = "torch", specifier = ">=2.6.0" }, { name = "torchvision", specifier = ">=0.21.0" }, { name = "tqdm", specifier = ">=4.67.1" }, + { name = "wandb", specifier = ">=0.19.6" }, ] [[package]] @@ -388,6 +394,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604 }, ] +[[package]] +name = "docker-pycreds" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c5/e6/d1f6c00b7221e2d7c4b470132c931325c8b22c51ca62417e300f5ce16009/docker-pycreds-0.4.0.tar.gz", hash = "sha256:6ce3270bcaf404cc4c3e27e4b6c70d3521deae82fb508767870fdbf772d584d4", size = 8754 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f5/e8/f6bd1eee09314e7e6dee49cbe2c5e22314ccdb38db16c9fc72d2fa80d054/docker_pycreds-0.4.0-py2.py3-none-any.whl", hash = "sha256:7266112468627868005106ec19cd0d722702d2b7d5912a28e19b826c3d37af49", size = 8982 }, +] + [[package]] name = "docutils" version = "0.21.2" @@ -442,6 +460,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e2/94/758680531a00d06e471ef649e4ec2ed6bf185356a7f9fbfbb7368a40bd49/fsspec-2025.2.0-py3-none-any.whl", hash = "sha256:9de2ad9ce1f85e1931858535bc882543171d197001a0a5eb2ddc04f1781ab95b", size = 184484 }, ] +[[package]] +name = "gitdb" +version = "4.0.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "smmap" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/94/63b0fc47eb32792c7ba1fe1b694daec9a63620db1e313033d18140c2320a/gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571", size = 394684 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf", size = 62794 }, +] + +[[package]] +name = "gitpython" +version = "3.1.44" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "gitdb" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c0/89/37df0b71473153574a5cdef8f242de422a0f5d26d7a9e231e6f169b4ad14/gitpython-3.1.44.tar.gz", hash = "sha256:c87e30b26253bf5418b01b0660f818967f3c503193838337fe5e573331249269", size = 214196 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/9a/4114a9057db2f1462d5c8f8390ab7383925fe1ac012eaa42402ad65c2963/GitPython-3.1.44-py3-none-any.whl", hash = "sha256:9e0e10cda9bed1ee64bc9a6de50e7e38a9c9943241cd7f585f6df3ed28011110", size = 207599 }, +] + [[package]] name = "h11" version = "0.14.0" @@ -760,7 +802,7 @@ dependencies = [ { name = "overrides" }, { name = "packaging" }, { name = "prometheus-client" }, - { name = "pywinpty", marker = "os_name == 'nt'" }, + { name = "pywinpty", marker = "os_name == 'nt' and sys_platform != 'linux'" }, { name = "pyzmq" }, { name = "send2trash" }, { name = "terminado" }, @@ -778,7 +820,7 @@ name = "jupyter-server-terminals" version = "0.5.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pywinpty", marker = "os_name == 'nt'" }, + { name = "pywinpty", marker = "os_name == 'nt' and sys_platform != 'linux'" }, { name = "terminado" }, ] sdist = { url = "https://files.pythonhosted.org/packages/fc/d5/562469734f476159e99a55426d697cbf8e7eb5efe89fb0e0b4f83a3d3459/jupyter_server_terminals-0.5.3.tar.gz", hash = "sha256:5ae0295167220e9ace0edcfdb212afd2b01ee8d179fe6f23c899590e9b8a5269", size = 31430 } @@ -1095,7 +1137,7 @@ name = "nvidia-cudnn-cu12" version = "9.1.0.70" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, @@ -1106,7 +1148,7 @@ name = "nvidia-cufft-cu12" version = "11.2.1.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 }, @@ -1125,9 +1167,9 @@ name = "nvidia-cusolver-cu12" version = "11.6.1.9" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, - { name = "nvidia-cusparse-cu12" }, - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 }, @@ -1138,7 +1180,7 @@ name = "nvidia-cusparse-cu12" version = "12.3.1.170" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 }, @@ -1362,6 +1404,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e4/ea/d836f008d33151c7a1f62caf3d8dd782e4d15f6a43897f64480c2b8de2ad/prompt_toolkit-3.0.50-py3-none-any.whl", hash = "sha256:9b6427eb19e479d98acff65196a307c555eb567989e6d88ebbb1b509d9779198", size = 387816 }, ] +[[package]] +name = "protobuf" +version = "5.29.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f7/d1/e0a911544ca9993e0f17ce6d3cc0932752356c1b0a834397f28e63479344/protobuf-5.29.3.tar.gz", hash = "sha256:5da0f41edaf117bde316404bad1a486cb4ededf8e4a54891296f648e8e076620", size = 424945 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/7a/1e38f3cafa022f477ca0f57a1f49962f21ad25850c3ca0acd3b9d0091518/protobuf-5.29.3-cp310-abi3-win32.whl", hash = "sha256:3ea51771449e1035f26069c4c7fd51fba990d07bc55ba80701c78f886bf9c888", size = 422708 }, + { url = "https://files.pythonhosted.org/packages/61/fa/aae8e10512b83de633f2646506a6d835b151edf4b30d18d73afd01447253/protobuf-5.29.3-cp310-abi3-win_amd64.whl", hash = "sha256:a4fa6f80816a9a0678429e84973f2f98cbc218cca434abe8db2ad0bffc98503a", size = 434508 }, + { url = "https://files.pythonhosted.org/packages/dd/04/3eaedc2ba17a088961d0e3bd396eac764450f431621b58a04ce898acd126/protobuf-5.29.3-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a8434404bbf139aa9e1300dbf989667a83d42ddda9153d8ab76e0d5dcaca484e", size = 417825 }, + { url = "https://files.pythonhosted.org/packages/4f/06/7c467744d23c3979ce250397e26d8ad8eeb2bea7b18ca12ad58313c1b8d5/protobuf-5.29.3-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:daaf63f70f25e8689c072cfad4334ca0ac1d1e05a92fc15c54eb9cf23c3efd84", size = 319573 }, + { url = "https://files.pythonhosted.org/packages/a8/45/2ebbde52ad2be18d3675b6bee50e68cd73c9e0654de77d595540b5129df8/protobuf-5.29.3-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:c027e08a08be10b67c06bf2370b99c811c466398c357e615ca88c91c07f0910f", size = 319672 }, + { url = "https://files.pythonhosted.org/packages/fd/b2/ab07b09e0f6d143dfb839693aa05765257bceaa13d03bf1a696b78323e7a/protobuf-5.29.3-py3-none-any.whl", hash = "sha256:0a18ed4a24198528f2333802eb075e59dea9d679ab7a6c5efb017a59004d849f", size = 172550 }, +] + [[package]] name = "psutil" version = "6.1.1" @@ -1837,6 +1893,51 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/40/b0/4562db6223154aa4e22f939003cb92514c79f3d4dccca3444253fd17f902/Send2Trash-1.8.3-py3-none-any.whl", hash = "sha256:0c31227e0bd08961c7665474a3d1ef7193929fedda4233843689baa056be46c9", size = 18072 }, ] +[[package]] +name = "sentry-sdk" +version = "2.21.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/08/63/3f0e88709cf4af992e2813c27d8ba628a891db0805e3fcc6dc834e142c5b/sentry_sdk-2.21.0.tar.gz", hash = "sha256:a6d38e0fb35edda191acf80b188ec713c863aaa5ad8d5798decb8671d02077b6", size = 301965 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/18/7587660cb5e4d07134913d8e74137efcd4903fda873bf612c30eb34c7ab4/sentry_sdk-2.21.0-py2.py3-none-any.whl", hash = "sha256:7623cfa9e2c8150948a81ca253b8e2bfe4ce0b96ab12f8cd78e3ac9c490fd92f", size = 324096 }, +] + +[[package]] +name = "setproctitle" +version = "1.3.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ae/4e/b09341b19b9ceb8b4c67298ab4a08ef7a4abdd3016c7bb152e9b6379031d/setproctitle-1.3.4.tar.gz", hash = "sha256:3b40d32a3e1f04e94231ed6dfee0da9e43b4f9c6b5450d53e6dd7754c34e0c50", size = 26456 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/1f/02fb3c6038c819d86765316d2a911281fc56c7dd3a9355dceb3f26a5bf7b/setproctitle-1.3.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:d06990dcfcd41bb3543c18dd25c8476fbfe1f236757f42fef560f6aa03ac8dfc", size = 16842 }, + { url = "https://files.pythonhosted.org/packages/b8/0c/d69e1f91c8f3d3aa74394e9e6ebb818f7d323e2d138ce1127e9462d09ebc/setproctitle-1.3.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:317218c9d8b17a010ab2d2f0851e8ef584077a38b1ba2b7c55c9e44e79a61e73", size = 11614 }, + { url = "https://files.pythonhosted.org/packages/86/ed/8031871d275302054b2f1b94b7cf5e850212cc412fe968f0979e64c1b838/setproctitle-1.3.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb5fefb53b9d9f334a5d9ec518a36b92a10b936011ac8a6b6dffd60135f16459", size = 31840 }, + { url = "https://files.pythonhosted.org/packages/45/b7/04f5d221cbdcff35d6cdf74e2a852e69dc8d8e746eb1b314be6b57b79c41/setproctitle-1.3.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0855006261635e8669646c7c304b494b6df0a194d2626683520103153ad63cc9", size = 33271 }, + { url = "https://files.pythonhosted.org/packages/25/b2/8dff0d2a72076e5535f117f33458d520538b5a0900b90a9f59a278f0d3f6/setproctitle-1.3.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1a88e466fcaee659679c1d64dcb2eddbcb4bfadffeb68ba834d9c173a25b6184", size = 30509 }, + { url = "https://files.pythonhosted.org/packages/4b/cf/4f19cdc7fdff3eaeb3064ce6eeb27c63081dba3123fbf904ac6bf0de440c/setproctitle-1.3.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f963b6ed8ba33eda374a98d979e8a0eaf21f891b6e334701693a2c9510613c4c", size = 31543 }, + { url = "https://files.pythonhosted.org/packages/9b/a7/5f9c3c70dc5573f660f978fb3bb4847cd26ede95a5fc294d3f1cf6779800/setproctitle-1.3.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:122c2e05697fa91f5d23f00bbe98a9da1bd457b32529192e934095fadb0853f1", size = 31268 }, + { url = "https://files.pythonhosted.org/packages/26/ab/bbde90ea0ed6a062ef94fe1c609b68077f7eb586133a62fa62d0c8dd9f8c/setproctitle-1.3.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:1bba0a866f5895d5b769d8c36b161271c7fd407e5065862ab80ff91c29fbe554", size = 30232 }, + { url = "https://files.pythonhosted.org/packages/36/0e/817be9934eda4cf63c96c694c3383cb0d2e5d019a2871af7dbd2202f7a58/setproctitle-1.3.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:97f1f861998e326e640708488c442519ad69046374b2c3fe9bcc9869b387f23c", size = 32739 }, + { url = "https://files.pythonhosted.org/packages/b0/76/9b4877850c9c5f41c4bacae441285dead7c192bebf4fcbf3b3eb0e8033cc/setproctitle-1.3.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:726aee40357d4bdb70115442cb85ccc8e8bc554fc0bbbaa3a57cbe81df42287d", size = 30778 }, + { url = "https://files.pythonhosted.org/packages/b2/fa/bbc7ab32f253b9700ac20d78ba0d5fbdc4ea5789d33e1adb236cdf20b23a/setproctitle-1.3.4-cp312-cp312-win32.whl", hash = "sha256:04d6ba8b816dbb0bfd62000b0c3e583160893e6e8c4233e1dca1a9ae4d95d924", size = 11355 }, + { url = "https://files.pythonhosted.org/packages/44/5c/6e6665b5fd800206a9e537ab0d2630d7b9b31b4697d931ed468837cc9cf5/setproctitle-1.3.4-cp312-cp312-win_amd64.whl", hash = "sha256:9c76e43cb351ba8887371240b599925cdf3ecececc5dfb7125c71678e7722c55", size = 12069 }, + { url = "https://files.pythonhosted.org/packages/d4/01/51d07ab1dbec8885ebad419d254c06b9e28f4363c163b737a89995a52b75/setproctitle-1.3.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:d6e3b177e634aa6bbbfbf66d097b6d1cdb80fc60e912c7d8bace2e45699c07dd", size = 16831 }, + { url = "https://files.pythonhosted.org/packages/30/03/deff7089b525c0d8ec047e06661d2be67c87685a99be6a6aed2890b81c8f/setproctitle-1.3.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6b17655a5f245b416e127e02087ea6347a48821cc4626bc0fd57101bfcd88afc", size = 11607 }, + { url = "https://files.pythonhosted.org/packages/ea/be/cb2950b3f6ba460f530bda2c713828236c75d982d0aa0f62b33429a9b4d0/setproctitle-1.3.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fa5057a86df920faab8ee83960b724bace01a3231eb8e3f2c93d78283504d598", size = 31881 }, + { url = "https://files.pythonhosted.org/packages/5c/b4/1f0dba7525a2fbefd08d4086e7e998d9c7581153807fb6b3083d06e0b8e2/setproctitle-1.3.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:149fdfb8a26a555780c4ce53c92e6d3c990ef7b30f90a675eca02e83c6d5f76d", size = 33290 }, + { url = "https://files.pythonhosted.org/packages/2d/a8/07a160f9dcd1a7b1cad39ce6cbaf4425837502b0592a400c38cb21f0f247/setproctitle-1.3.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ded03546938a987f463c68ab98d683af87a83db7ac8093bbc179e77680be5ba2", size = 30489 }, + { url = "https://files.pythonhosted.org/packages/83/0c/3d972d9ea4165961a9764df5324d42bf2d059cb8a6ef516c67f068ed4d92/setproctitle-1.3.4-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ab9f5b7f2bbc1754bc6292d9a7312071058e5a891b0391e6d13b226133f36aa", size = 31576 }, + { url = "https://files.pythonhosted.org/packages/7a/c0/c12bdc2c91009defdd1b207ff156ccd691f5b9a6a0aae1ed9126d4ff9a0c/setproctitle-1.3.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0b19813c852566fa031902124336fa1f080c51e262fc90266a8c3d65ca47b74c", size = 31273 }, + { url = "https://files.pythonhosted.org/packages/4f/83/8d704bee57990b27537adf7c97540f32226ffa3922fb26bdd459da8a4470/setproctitle-1.3.4-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:db78b645dc63c0ccffca367a498f3b13492fb106a2243a1e998303ba79c996e2", size = 30236 }, + { url = "https://files.pythonhosted.org/packages/d8/42/94e31d1f515f831e1ae43f2405794257eb940a7972b2fbb6283790db2958/setproctitle-1.3.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:b669aaac70bd9f03c070270b953f78d9ee56c4af6f0ff9f9cd3e6d1878c10b40", size = 32766 }, + { url = "https://files.pythonhosted.org/packages/83/53/01746ed8fb75239a001ee89d5eb8ad5a3022df510572d1cf60dd04567e13/setproctitle-1.3.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6dc3d656702791565994e64035a208be56b065675a5bc87b644c657d6d9e2232", size = 30812 }, + { url = "https://files.pythonhosted.org/packages/5f/ea/3ce61e70a6b898e95c0a1e393964c829103dc4ad4b0732cd70c8fc13e54c/setproctitle-1.3.4-cp313-cp313-win32.whl", hash = "sha256:091f682809a4d12291cf0205517619d2e7014986b7b00ebecfde3d76f8ae5a8f", size = 11349 }, + { url = "https://files.pythonhosted.org/packages/e7/1a/8149da1c19db6bd57164d62b1d91c188e7d77e695947cf1ac327c8aea513/setproctitle-1.3.4-cp313-cp313-win_amd64.whl", hash = "sha256:adcd6ba863a315702184d92d3d3bbff290514f24a14695d310f02ae5e28bd1f7", size = 12062 }, +] + [[package]] name = "setuptools" version = "75.8.0" @@ -1855,6 +1956,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050 }, ] +[[package]] +name = "smmap" +version = "5.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/44/cd/a040c4b3119bbe532e5b0732286f805445375489fceaec1f48306068ee3b/smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5", size = 22329 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/be/d09147ad1ec7934636ad912901c5fd7667e1c858e19d355237db0d0cd5e4/smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e", size = 24303 }, +] + [[package]] name = "sniffio" version = "1.3.1" @@ -2065,7 +2175,7 @@ version = "0.18.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "ptyprocess", marker = "os_name != 'nt'" }, - { name = "pywinpty", marker = "os_name == 'nt'" }, + { name = "pywinpty", marker = "os_name == 'nt' and sys_platform != 'linux'" }, { name = "tornado" }, ] sdist = { url = "https://files.pythonhosted.org/packages/8a/11/965c6fd8e5cc254f1fe142d547387da17a8ebfd75a3455f637c663fb38a0/terminado-0.18.1.tar.gz", hash = "sha256:de09f2c4b85de4765f7714688fff57d3e75bad1f909b589fde880460c753fd2e", size = 32701 } @@ -2258,6 +2368,38 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/61/14/33a3a1352cfa71812a3a21e8c9bfb83f60b0011f5e36f2b1399d51928209/uvicorn-0.34.0-py3-none-any.whl", hash = "sha256:023dc038422502fa28a09c7a30bf2b6991512da7dcdb8fd35fe57cfc154126f4", size = 62315 }, ] +[[package]] +name = "wandb" +version = "0.19.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "docker-pycreds" }, + { name = "gitpython" }, + { name = "platformdirs" }, + { name = "protobuf" }, + { name = "psutil" }, + { name = "pydantic" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "sentry-sdk" }, + { name = "setproctitle" }, + { name = "setuptools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/41/a2/63fbebc6ed670a7d834ca76552b8c6382211874b23ee8a718ba26a342a4a/wandb-0.19.6.tar.gz", hash = "sha256:4661856ee070fe8a123caece5b372d495d3cf9f58176a8f981bd716830eefc49", size = 39203528 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bd/4f/5b77e20f10e643404df871557610a6618383e036de65e9c34b3a8354f2ac/wandb-0.19.6-py3-none-any.whl", hash = "sha256:0b174b5f190999a8238961c63c134622bf2173147a1301ea298a9ec58abbd7d4", size = 6387720 }, + { url = "https://files.pythonhosted.org/packages/25/aa/824a171586f3fa1549f9f946d32187362c8d06ff67540d9f1be694ee9094/wandb-0.19.6-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:ad2887dd916207ead5a9f36e4aebc1b6624265f29033e4e883bb6fbd5b674080", size = 20776552 }, + { url = "https://files.pythonhosted.org/packages/ad/3b/222e2a27ee3df3a973d8f165fa47f3e3bb25dc6d9ac1d3ec79b083c5ee09/wandb-0.19.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:ca90dd5519de1a48963536f02d6e14c150475807173b7af1d8ebe3e2f9e3afba", size = 19933524 }, + { url = "https://files.pythonhosted.org/packages/65/76/1d69145ac3c9c6b63545e684c39b95711c3632c34d452626fd831227089d/wandb-0.19.6-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:3cb10bd1e1c0b568464a017c88eb95e0c8c3e9c1283d9ad4ee717c8977d491c1", size = 20791479 }, + { url = "https://files.pythonhosted.org/packages/88/96/4411c4aa29cfb0bc8e310480181d79779b423231420bbcf5e61ff8c44ff7/wandb-0.19.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0fe6e7bedd396b2b5f92c7fab3d364f7e0e8cb9f645d0f0c27ba7be94e720931", size = 19539263 }, + { url = "https://files.pythonhosted.org/packages/bc/89/2e414951d35e55caf6d8ac5758a82c61c1b8330f77852fbc733c833196eb/wandb-0.19.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd9ae9a7f08e4d3972ba341c42af787e951689e0d1a76c111aa66d09bcdadafd", size = 20861187 }, + { url = "https://files.pythonhosted.org/packages/3a/5e/7517c9fa9aa0075160c04e467f6d0e5d1b9bb6b91c4ffd6dd6fa23dd3dd0/wandb-0.19.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:ff0973ca26cd06bc5451ae7ba469ad98f74024f5678dfa0d6dc78ca36eb950b6", size = 19549095 }, + { url = "https://files.pythonhosted.org/packages/bd/be/ef3c78ab14a631558f639ab3a8379efee6f7d529e3bbf9efb0e17472495b/wandb-0.19.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:2e8dc997eb3ae5f22f5a1c3d4f3b30c28398dda45b9dbada9ff20b8d3984d3e2", size = 20938943 }, + { url = "https://files.pythonhosted.org/packages/b6/43/2f9c71a1fe77a97e9d32b4828f1dd685ac545442f8dfbf703eac8128056f/wandb-0.19.6-py3-none-win32.whl", hash = "sha256:c0127d99e98202dc2471d44b920129c2c9242fb3a6b52a7aa8bbf9ffa35173e7", size = 20230403 }, + { url = "https://files.pythonhosted.org/packages/fd/b2/a9ffa91c43dbe2a6687467f3aa196947b7532592879738665be5c0db17c3/wandb-0.19.6-py3-none-win_amd64.whl", hash = "sha256:8688a4f724d37a90075312e8dccffd948adbe8b6bcb82f9d2b38b764b53269fb", size = 20230407 }, +] + [[package]] name = "watchfiles" version = "1.0.4" From 87753d53c82b903ebec6ae0b8b39d85fe5c1c228 Mon Sep 17 00:00:00 2001 From: Jan Zavadil Date: Thu, 13 Feb 2025 20:39:12 +0100 Subject: [PATCH 6/7] ruffed --- CollaborativeCoding/metrics/accuracy.py | 10 ++++------ main.py | 8 +++++--- tests/test_metrics.py | 23 +++++++++++++++++------ tests/test_models.py | 7 ++++++- 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/CollaborativeCoding/metrics/accuracy.py b/CollaborativeCoding/metrics/accuracy.py index ed51399..2b23dff 100644 --- a/CollaborativeCoding/metrics/accuracy.py +++ b/CollaborativeCoding/metrics/accuracy.py @@ -1,6 +1,6 @@ +import numpy as np import torch from torch import nn -import numpy as np class Accuracy(nn.Module): @@ -78,11 +78,11 @@ def _micro_acc(self): Micro-average accuracy score. """ return (self.y_true == self.y_pred).float().mean().item() - + def __returnmetric__(self): if self.y_true == [] or self.y_pred == []: return np.nan - if isinstance(self.y_true,list): + if isinstance(self.y_true, list): if len(self.y_true) == 1: self.y_true = self.y_true[0] self.y_pred = self.y_pred[0] @@ -90,10 +90,8 @@ def __returnmetric__(self): self.y_true = torch.cat(self.y_true) self.y_pred = torch.cat(self.y_pred) return self._micro_acc() if not self.macro_averaging else self._macro_acc() - + def __reset__(self): self.y_true = [] self.y_pred = [] return None - - diff --git a/main.py b/main.py index 52e1dc5..190baaf 100644 --- a/main.py +++ b/main.py @@ -1,12 +1,11 @@ import numpy as np import torch as th import torch.nn as nn -import wandb from torch.utils.data import DataLoader from torchvision import transforms from tqdm import tqdm -#from wandb_api import WANDB_API +import wandb from CollaborativeCoding import ( MetricWrapper, createfolders, @@ -15,6 +14,9 @@ load_model, ) +# from wandb_api import WANDB_API + + def main(): """ @@ -126,7 +128,7 @@ def main(): print("Dry run completed successfully.") exit() -# wandb.login(key=WANDB_API) + # wandb.login(key=WANDB_API) wandb.init( entity="ColabCode", project=args.run_name, diff --git a/tests/test_metrics.py b/tests/test_metrics.py index f5d01f3..bb580c3 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -3,7 +3,13 @@ import pytest from CollaborativeCoding.load_metric import MetricWrapper -from CollaborativeCoding.metrics import Accuracy, F1Score, Precision, Recall, EntropyPrediction +from CollaborativeCoding.metrics import ( + Accuracy, + EntropyPrediction, + F1Score, + Precision, + Recall, +) @pytest.mark.parametrize( @@ -128,8 +134,8 @@ def test_precision(): def test_accuracy(): - import torch import numpy as np + import torch # Test the accuracy metric y_true = torch.tensor([0, 1, 2, 3, 4, 5]) @@ -141,20 +147,25 @@ def test_accuracy(): assert accuracy.__returnmetric__() is np.nan, "Expected accuracy to be 0.0" y_pred = torch.tensor([0, 1, 2, 3, 4, 4]) accuracy(y_true, y_pred) - assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, "Expected accuracy to be 0.8333333134651184" + assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, ( + "Expected accuracy to be 0.8333333134651184" + ) accuracy.__reset__() accuracy.macro_averaging = True accuracy(y_true, y_pred) y_true_1 = torch.tensor([0, 1, 2, 3, 4, 5]) y_pred_1 = torch.tensor([0, 1, 2, 3, 4, 4]) accuracy(y_true_1, y_pred_1) - assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, "Expected accuracy to be 0.8333333134651186" + assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, ( + "Expected accuracy to be 0.8333333134651186" + ) accuracy.macro_averaging = False - assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, "Expected accuracy to be 0.8333333134651184" + assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, ( + "Expected accuracy to be 0.8333333134651184" + ) accuracy.__reset__() - def test_entropypred(): import torch as th diff --git a/tests/test_models.py b/tests/test_models.py index f61315e..6a1837d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,7 +1,12 @@ import pytest import torch -from CollaborativeCoding.models import ChristianModel, JanModel, MagnusModel, SolveigModel +from CollaborativeCoding.models import ( + ChristianModel, + JanModel, + MagnusModel, + SolveigModel, +) @pytest.mark.parametrize( From 3a46ddcc55d6331a477c69ebaed94a4bcdb2546e Mon Sep 17 00:00:00 2001 From: Jan Zavadil Date: Fri, 14 Feb 2025 20:00:42 +0100 Subject: [PATCH 7/7] test_dataloader wasn't passing the data_path correctly --- tests/test_dataloaders.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index b0fa78d..dbdb14a 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -26,8 +26,7 @@ def test_load_data(data_name, expected): dataset = load_data( data_name, - data_path=Path("data"), - download=True, + data_dir=Path("data"), transform=transforms.ToTensor(), ) assert isinstance(dataset, expected)