Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
9baa17e
load_data - changed to accomodate train/val/test split, added test loop
hzavadil98 Feb 7, 2025
f840af4
ruffed and isorted
hzavadil98 Feb 7, 2025
efc78f3
fix in mnist0-3
hzavadil98 Feb 7, 2025
54f3883
Added micro/macro averaging option to MetricsWrapper and as commandli…
hzavadil98 Feb 7, 2025
a2606e1
Make separate downloader class that handles everything related to dow…
c-salomonsen Feb 8, 2025
a58e495
downloader handles wheter to download data or not, so remove option
c-salomonsen Feb 8, 2025
a9e2cad
Remove downloading logic from USPS dataset
c-salomonsen Feb 8, 2025
34539b3
`load_data` now splits the data, downloads data and returns all splits
c-salomonsen Feb 8, 2025
6c6f7b5
Made a whoopsie
c-salomonsen Feb 8, 2025
20faa24
Add the size thing
c-salomonsen Feb 8, 2025
d7526bf
Actually send the indices, not labels to datasets
c-salomonsen Feb 8, 2025
bd35ae6
Format
c-salomonsen Feb 8, 2025
0f32064
More formatting
c-salomonsen Feb 8, 2025
ad15940
Adjust test to comply with new functionality
c-salomonsen Feb 8, 2025
177258b
added micro/macro to F1
sot176 Feb 10, 2025
15c99ea
added MNIST downloader, adjusted minor thinks for the code to run
hzavadil98 Feb 10, 2025
601caca
ruffed, isorted
hzavadil98 Feb 10, 2025
f2e14c4
Merge pull request #60 from SFI-Visual-Intelligence/christian/train-v…
hzavadil98 Feb 10, 2025
ba2212e
Merge branch 'main' into Jan-dataloader
hzavadil98 Feb 10, 2025
b7bffa3
ruffisorted :'(
hzavadil98 Feb 10, 2025
4071181
Merge branch 'Jan-dataloader' into Jan-metrics
hzavadil98 Feb 10, 2025
27f120c
Merge pull request #63 from SFI-Visual-Intelligence/Jan-metrics
hzavadil98 Feb 10, 2025
5d8309b
preparing for overall test
hzavadil98 Feb 10, 2025
0d5fc20
Merge branch 'Jan-dataloader' of https://github.com/SFI-Visual-Intell…
hzavadil98 Feb 10, 2025
19a6ea1
hopefully fixed f1 test
hzavadil98 Feb 10, 2025
68f4736
Try to fix conda error: "undefined symbol: H5Pset_fapl_ros3" by setti…
c-salomonsen Feb 10, 2025
fbe9f4f
Found that mamba was unable to use installed python and h5py with the…
c-salomonsen Feb 10, 2025
a7d51c4
Add percentage downloaded progress thing to mnist (since the reportho…
c-salomonsen Feb 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ dependencies:
- sphinx-autobuild
- sphinx-rtd-theme
- pip
- h5py
- h5py==3.12.1
- hdf5==1.14.4
- black
- isort
- jupyterlab
Expand All @@ -20,6 +21,7 @@ dependencies:
- scalene
- tqdm
- scipy
- wandb
- pip:
- torch
- torchvision
Expand Down
99 changes: 62 additions & 37 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import wandb
from utils import MetricWrapper, createfolders, get_args, load_data, load_model
from wandb_api import WANDB_API


def main():
Expand All @@ -29,33 +30,38 @@ def main():

device = args.device

if args.dataset.lower() in ["usps_0-6", "uspsh5_7_9"]:
augmentations = transforms.Compose(
if args.dataset.lower() in ["usps_0-6", "usps_7-9"]:
transform = transforms.Compose(
[
transforms.Resize((16, 16)),
transforms.ToTensor(),
]
)
else:
augmentations = transforms.Compose([transforms.ToTensor()])
transform = transforms.Compose([transforms.ToTensor()])

# Dataset
traindata = load_data(
traindata, validata, testdata = load_data(
args.dataset,
train=True,
data_path=args.datafolder,
download=args.download_data,
transform=augmentations,
)
validata = load_data(
args.dataset,
train=False,
data_path=args.datafolder,
download=args.download_data,
transform=augmentations,
data_dir=args.datafolder,
transform=transform,
val_size=args.val_size,
)

metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes)
train_metrics = MetricWrapper(
*args.metric,
num_classes=traindata.num_classes,
macro_averaging=args.macro_averaging,
)
val_metrics = MetricWrapper(
*args.metric,
num_classes=traindata.num_classes,
macro_averaging=args.macro_averaging,
)
test_metrics = MetricWrapper(
*args.metric,
num_classes=traindata.num_classes,
macro_averaging=args.macro_averaging,
)

# Find the shape of the data, if is 2D, add a channel dimension
data_shape = traindata[0][0].shape
Expand All @@ -80,6 +86,9 @@ def main():
valiloader = DataLoader(
validata, batch_size=args.batchsize, shuffle=False, pin_memory=True
)
testloader = DataLoader(
testdata, batch_size=args.batchsize, shuffle=False, pin_memory=True
)

criterion = nn.CrossEntropyLoss()
optimizer = th.optim.Adam(model.parameters(), lr=args.learning_rate)
Expand All @@ -104,22 +113,22 @@ def main():
optimizer.step()
optimizer.zero_grad(set_to_none=True)

metrics(y, logits)
train_metrics(y, logits)

break
print(metrics.accumulate())
print(train_metrics.accumulate())
print("Dry run completed successfully.")
exit()

# wandb.login(key=WANDB_API)
wandb.init(
entity="ColabCode-org",
# entity="FYS-8805 Exam",
project="Test",
tags=[args.modelname, args.dataset]
)
entity="ColabCode",
# entity="FYS-8805 Exam",
project="Jan",
tags=[args.modelname, args.dataset],
)
wandb.watch(model)
exit()

for epoch in range(args.epoch):
# Training loop start
trainingloss = []
Expand All @@ -135,33 +144,49 @@ def main():
optimizer.zero_grad(set_to_none=True)
trainingloss.append(loss.item())

metrics(y, logits)

wandb.log(metrics.accumulate(str_prefix="Train "))
metrics.reset()
train_metrics(y, logits)

evalloss = []
# Eval loop start
valloss = []
# Validation loop start
model.eval()
with th.no_grad():
for x, y in tqdm(valiloader, desc="Validation"):
x, y = x.to(device), y.to(device)
logits = model.forward(x)
loss = criterion(logits, y)
evalloss.append(loss.item())

metrics(y, logits)
valloss.append(loss.item())

wandb.log(metrics.accumulate(str_prefix="Evaluation "))
metrics.reset()
val_metrics(y, logits)

wandb.log(
{
"Epoch": epoch,
"Train loss": np.mean(trainingloss),
"Evaluation Loss": np.mean(evalloss),
"Validation loss": np.mean(valloss),
}
| train_metrics.accumulate(str_prefix="Train ")
| val_metrics.accumulate(str_prefix="Validation ")
)
train_metrics.reset()
val_metrics.reset()

testloss = []
model.eval()
with th.no_grad():
for x, y in tqdm(testloader, desc="Testing"):
x, y = x.to(device), y.to(device)
logits = model.forward(x)
loss = criterion(logits, y)
testloss.append(loss.item())

preds = th.argmax(logits, dim=1)
test_metrics(y, preds)

wandb.log(
{"Epoch": 1, "Test loss": np.mean(testloss)}
| test_metrics.accumulate(str_prefix="Test ")
)
test_metrics.reset()


if __name__ == "__main__":
Expand Down
15 changes: 11 additions & 4 deletions tests/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,25 @@ def test_uspsdataset0_6():

# Create a h5 file
with h5py.File(tf, "w") as f:
targets = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
indices = np.arange(len(targets))
# Populate the file with data
f["train/data"] = np.random.rand(10, 16 * 16)
f["train/target"] = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
f["train/target"] = targets

trans = transforms.Compose(
[
transforms.Resize((16, 16)), # At least for USPS
transforms.Resize((16, 16)),
transforms.ToTensor(),
]
)
dataset = USPSDataset0_6(data_path=tempdir, train=True, transform=trans)
dataset = USPSDataset0_6(
data_path=tempdir,
sample_ids=indices,
train=True,
transform=trans,
)
assert len(dataset) == 10
data, target = dataset[0]
assert data.shape == (1, 16, 16)
assert all(target == np.array([0, 0, 0, 0, 0, 0, 1]))
assert target == 6
2 changes: 1 addition & 1 deletion tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_f1score():

target = torch.tensor([0, 1, 0, 2])

f1_metric.update(preds, target)
f1_metric(preds, target)
assert f1_metric.tp.sum().item() > 0, "Expected some true positives."
assert f1_metric.fp.sum().item() > 0, "Expected some false positives."
assert f1_metric.fn.sum().item() > 0, "Expected some false negatives."
Expand Down
1 change: 0 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,3 @@ def test_jan_model(image_shape, num_classes):
y = model(x)

assert y.shape == (n, num_classes), f"Shape: {y.shape}"

32 changes: 10 additions & 22 deletions utils/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,6 @@ def get_args():
help="Whether model should be saved or not.",
)

parser.add_argument(
"--download-data",
type=bool,
default=False,
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
)

# Data/Model specific values
parser.add_argument(
"--modelname",
Expand All @@ -61,7 +54,12 @@ def get_args():
choices=["svhn", "usps_0-6", "usps_7-9", "mnist_0-3", "mnist_4-9"],
help="Which dataset to train the model on.",
)

parser.add_argument(
"--val_size",
type=float,
default=0.2,
help="Percentage of training dataset to be used as validation dataset - must be within (0,1).",
)
parser.add_argument(
"--metric",
type=str,
Expand All @@ -70,20 +68,10 @@ def get_args():
nargs="+",
help="Which metric to use for evaluation",
)

parser.add_argument(
'--imagesize',
type=int,
default=28,
help='Imagesize'
)

parser.add_argument(
'--nr_channels',
type=int,
default=1,
choices=[1,3],
help='Number of image channels'
"--macro_averaging",
action="store_true",
help="If the flag is included, the metrics will be calculated using macro averaging.",
)

# Training specific values
Expand Down Expand Up @@ -115,7 +103,7 @@ def get_args():
parser.add_argument(
"--dry_run",
action="store_true",
help="If true, the code will not run the training loop.",
help="If the flag is included, the code will not run the training loop.",
)
args = parser.parse_args()

Expand Down
11 changes: 9 additions & 2 deletions utils/dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset", "MNISTDataset0_3", "SVHNDataset"]
__all__ = [
"USPSDataset0_6",
"USPSH5_Digit_7_9_Dataset",
"MNISTDataset0_3",
"Downloader",
"SVHNDataset",
]

from .download import Downloader
from .mnist_0_3 import MNISTDataset0_3
from .svhn import SVHNDataset
from .usps_0_6 import USPSDataset0_6
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset
from .svhn import SVHNDataset
23 changes: 23 additions & 0 deletions utils/dataloaders/datasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,26 @@
"8ea070ee2aca1ac39742fdd1ef5ed118",
],
}

MNIST_SOURCE = {
"train_images": [
"https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz",
"train-images-idx3-ubyte",
None,
],
"train_labels": [
"https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz",
"train-labels-idx1-ubyte",
None,
],
"test_images": [
"https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz",
"t10k-images-idx3-ubyte",
None,
],
"test_labels": [
"https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz",
"t10k-labels-idx1-ubyte",
None,
],
}
Loading
Loading