diff --git a/.gitignore b/.gitignore index dce22ec..d68c5ec 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,8 @@ local* # Johanthings formatting.x +testrun.x +storage/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/CollaborativeCoding/dataloaders/mnist_4_9.py b/CollaborativeCoding/dataloaders/mnist_4_9.py index cc6ff0e..55d2453 100644 --- a/CollaborativeCoding/dataloaders/mnist_4_9.py +++ b/CollaborativeCoding/dataloaders/mnist_4_9.py @@ -43,7 +43,12 @@ def __init__( self.labels_path = self.mnist_path / ( MNIST_SOURCE["train_labels"][1] if train else MNIST_SOURCE["test_labels"][1] ) - + + # Functions to map the labels from (4,9) -> (0,5) for CrossEntropyLoss to work properly. + self.label_shift = lambda x: x-4 + self.label_restore = lambda x: x+4 + + def __len__(self): return len(self.samples) @@ -66,4 +71,4 @@ def __getitem__(self, idx): if self.transform: image = self.transform(image) - return image, label + return image, self.label_shift(label) diff --git a/main.py b/main.py index 06f7277..7225388 100644 --- a/main.py +++ b/main.py @@ -139,12 +139,12 @@ def main(): for epoch in range(args.epoch): # Training loop start + print(f"Epoch: {epoch+1}/{args.epoch}") trainingloss = [] model.train() for x, y in tqdm(trainloader, desc="Training"): x, y = x.to(device), y.to(device) logits = model.forward(x) - loss = criterion(logits, y) loss.backward() @@ -172,8 +172,8 @@ def main(): "Train loss": np.mean(trainingloss), "Validation loss": np.mean(valloss), } - | train_metrics.getmetric(str_prefix="Train ") - | val_metrics.getmetric(str_prefix="Validation ") + | train_metrics.getmetrics(str_prefix="Train ") + | val_metrics.getmetrics(str_prefix="Validation ") ) train_metrics.resetmetric() val_metrics.resetmetric() @@ -187,12 +187,11 @@ def main(): loss = criterion(logits, y) testloss.append(loss.item()) - preds = th.argmax(logits, dim=1) - test_metrics(y, preds) + test_metrics(y, logits) wandb.log( {"Epoch": 1, "Test loss": np.mean(testloss)} - | test_metrics.getmetric(str_prefix="Test ") + | test_metrics.getmetrics(str_prefix="Test ") ) test_metrics.resetmetric()