Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
6757135
Added UV as package manager
Johanmkr Feb 6, 2025
9baa17e
load_data - changed to accomodate train/val/test split, added test loop
hzavadil98 Feb 7, 2025
f840af4
ruffed and isorted
hzavadil98 Feb 7, 2025
efc78f3
fix in mnist0-3
hzavadil98 Feb 7, 2025
54f3883
Added micro/macro averaging option to MetricsWrapper and as commandli…
hzavadil98 Feb 7, 2025
a2606e1
Make separate downloader class that handles everything related to dow…
c-salomonsen Feb 8, 2025
a58e495
downloader handles wheter to download data or not, so remove option
c-salomonsen Feb 8, 2025
a9e2cad
Remove downloading logic from USPS dataset
c-salomonsen Feb 8, 2025
34539b3
`load_data` now splits the data, downloads data and returns all splits
c-salomonsen Feb 8, 2025
6c6f7b5
Made a whoopsie
c-salomonsen Feb 8, 2025
20faa24
Add the size thing
c-salomonsen Feb 8, 2025
d7526bf
Actually send the indices, not labels to datasets
c-salomonsen Feb 8, 2025
bd35ae6
Format
c-salomonsen Feb 8, 2025
0f32064
More formatting
c-salomonsen Feb 8, 2025
ad15940
Adjust test to comply with new functionality
c-salomonsen Feb 8, 2025
177258b
added micro/macro to F1
sot176 Feb 10, 2025
15c99ea
added MNIST downloader, adjusted minor thinks for the code to run
hzavadil98 Feb 10, 2025
601caca
ruffed, isorted
hzavadil98 Feb 10, 2025
f2e14c4
Merge pull request #60 from SFI-Visual-Intelligence/christian/train-v…
hzavadil98 Feb 10, 2025
ba2212e
Merge branch 'main' into Jan-dataloader
hzavadil98 Feb 10, 2025
b7bffa3
ruffisorted :'(
hzavadil98 Feb 10, 2025
4071181
Merge branch 'Jan-dataloader' into Jan-metrics
hzavadil98 Feb 10, 2025
27f120c
Merge pull request #63 from SFI-Visual-Intelligence/Jan-metrics
hzavadil98 Feb 10, 2025
5d8309b
preparing for overall test
hzavadil98 Feb 10, 2025
0d5fc20
Merge branch 'Jan-dataloader' of https://github.com/SFI-Visual-Intell…
hzavadil98 Feb 10, 2025
19a6ea1
hopefully fixed f1 test
hzavadil98 Feb 10, 2025
68f4736
Try to fix conda error: "undefined symbol: H5Pset_fapl_ros3" by setti…
c-salomonsen Feb 10, 2025
fbe9f4f
Found that mamba was unable to use installed python and h5py with the…
c-salomonsen Feb 10, 2025
a7d51c4
Add percentage downloaded progress thing to mnist (since the reportho…
c-salomonsen Feb 10, 2025
ff32432
added pyproject and mnist dataloader, although not finished yet
Johanmkr Feb 11, 2025
a412297
Merge commit 'refs/pull/54/head' of github.com:SFI-Visual-Intelligenc…
Johanmkr Feb 11, 2025
4c3dc32
ruffed
Johanmkr Feb 11, 2025
8d6c07a
Merge branch 'johan/dataloader' into johan/micromacro
Johanmkr Feb 11, 2025
9abaf0c
ruffed
Johanmkr Feb 11, 2025
4ad390d
isorted
Johanmkr Feb 11, 2025
a35e6ea
Removed numpy import
Johanmkr Feb 11, 2025
cdd5a4f
Merge branch 'johan/devbranch' into johan/micromacro
Johanmkr Feb 11, 2025
e10cf73
updated dataloader with micro/macro averaging
Johanmkr Feb 11, 2025
a4df0f2
updated precision test to fit new micro_averaging argument in dataloader
Johanmkr Feb 11, 2025
2a85e81
fixed bug in test-file
Johanmkr Feb 11, 2025
daf82d6
Changed order of true/false to match manually calculated precision vals
Johanmkr Feb 11, 2025
22df0a0
Updated dataloader to fit with MNIST 4-9
Johanmkr Feb 11, 2025
5d0d296
made dataloader parsable
Johanmkr Feb 11, 2025
64fac10
ruffedisorted
Johanmkr Feb 11, 2025
bf8a09f
Update recall metric with macro/micro averaging
c-salomonsen Feb 11, 2025
b9dc34e
Update tests for Recall metric
c-salomonsen Feb 11, 2025
2885a30
Took the liberty to change the F1 metric dimension to fit
c-salomonsen Feb 11, 2025
08aa876
Modified the MetricWrappers arguments being passed on
c-salomonsen Feb 11, 2025
bab6aee
Add test for metricwrapper and all metrics
c-salomonsen Feb 11, 2025
0c16ba1
updated johan_model to flatten the input in the forward loop
Johanmkr Feb 12, 2025
ba90d89
Updated precision metric with macro_averaging as argument
Johanmkr Feb 12, 2025
b9b7158
Updated precision metric and test function, need to discuss shape of …
Johanmkr Feb 13, 2025
7c7a80d
updated UV stuff
Johanmkr Feb 13, 2025
e7ba8a8
ruffedisorted
Johanmkr Feb 13, 2025
8922263
added sklearn to conda environment for github tests
Johanmkr Feb 13, 2025
97750d8
Fixed bug in JohanModel
Johanmkr Feb 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
5 changes: 4 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ dependencies:
- sphinx-autobuild
- sphinx-rtd-theme
- pip
- h5py
- h5py==3.12.1
- hdf5==1.14.4
- black
- isort
- jupyterlab
Expand All @@ -20,6 +21,8 @@ dependencies:
- scalene
- tqdm
- scipy
- wandb
- scikit-learn
- pip:
- torch
- torchvision
Expand Down
99 changes: 62 additions & 37 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import wandb
from utils import MetricWrapper, createfolders, get_args, load_data, load_model
from wandb_api import WANDB_API


def main():
Expand All @@ -29,33 +30,38 @@ def main():

device = args.device

if args.dataset.lower() in ["usps_0-6", "uspsh5_7_9"]:
augmentations = transforms.Compose(
if args.dataset.lower() in ["usps_0-6", "usps_7-9"]:
transform = transforms.Compose(
[
transforms.Resize((16, 16)),
transforms.ToTensor(),
]
)
else:
augmentations = transforms.Compose([transforms.ToTensor()])
transform = transforms.Compose([transforms.ToTensor()])

# Dataset
traindata = load_data(
traindata, validata, testdata = load_data(
args.dataset,
train=True,
data_path=args.datafolder,
download=args.download_data,
transform=augmentations,
)
validata = load_data(
args.dataset,
train=False,
data_path=args.datafolder,
download=args.download_data,
transform=augmentations,
data_dir=args.datafolder,
transform=transform,
val_size=args.val_size,
)

metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes)
train_metrics = MetricWrapper(
*args.metric,
num_classes=traindata.num_classes,
macro_averaging=args.macro_averaging,
)
val_metrics = MetricWrapper(
*args.metric,
num_classes=traindata.num_classes,
macro_averaging=args.macro_averaging,
)
test_metrics = MetricWrapper(
*args.metric,
num_classes=traindata.num_classes,
macro_averaging=args.macro_averaging,
)

# Find the shape of the data, if is 2D, add a channel dimension
data_shape = traindata[0][0].shape
Expand All @@ -80,6 +86,9 @@ def main():
valiloader = DataLoader(
validata, batch_size=args.batchsize, shuffle=False, pin_memory=True
)
testloader = DataLoader(
testdata, batch_size=args.batchsize, shuffle=False, pin_memory=True
)

criterion = nn.CrossEntropyLoss()
optimizer = th.optim.Adam(model.parameters(), lr=args.learning_rate)
Expand All @@ -104,22 +113,22 @@ def main():
optimizer.step()
optimizer.zero_grad(set_to_none=True)

metrics(y, logits)
train_metrics(y, logits)

break
print(metrics.accumulate())
print(train_metrics.accumulate())
print("Dry run completed successfully.")
exit()

# wandb.login(key=WANDB_API)
wandb.init(
entity="ColabCode-org",
# entity="FYS-8805 Exam",
project="Test",
tags=[args.modelname, args.dataset]
)
entity="ColabCode",
# entity="FYS-8805 Exam",
project="Jan",
tags=[args.modelname, args.dataset],
)
wandb.watch(model)
exit()

for epoch in range(args.epoch):
# Training loop start
trainingloss = []
Expand All @@ -135,33 +144,49 @@ def main():
optimizer.zero_grad(set_to_none=True)
trainingloss.append(loss.item())

metrics(y, logits)

wandb.log(metrics.accumulate(str_prefix="Train "))
metrics.reset()
train_metrics(y, logits)

evalloss = []
# Eval loop start
valloss = []
# Validation loop start
model.eval()
with th.no_grad():
for x, y in tqdm(valiloader, desc="Validation"):
x, y = x.to(device), y.to(device)
logits = model.forward(x)
loss = criterion(logits, y)
evalloss.append(loss.item())

metrics(y, logits)
valloss.append(loss.item())

wandb.log(metrics.accumulate(str_prefix="Evaluation "))
metrics.reset()
val_metrics(y, logits)

wandb.log(
{
"Epoch": epoch,
"Train loss": np.mean(trainingloss),
"Evaluation Loss": np.mean(evalloss),
"Validation loss": np.mean(valloss),
}
| train_metrics.accumulate(str_prefix="Train ")
| val_metrics.accumulate(str_prefix="Validation ")
)
train_metrics.reset()
val_metrics.reset()

testloss = []
model.eval()
with th.no_grad():
for x, y in tqdm(testloader, desc="Testing"):
x, y = x.to(device), y.to(device)
logits = model.forward(x)
loss = criterion(logits, y)
testloss.append(loss.item())

preds = th.argmax(logits, dim=1)
test_metrics(y, preds)

wandb.log(
{"Epoch": 1, "Test loss": np.mean(testloss)}
| test_metrics.accumulate(str_prefix="Test ")
)
test_metrics.reset()


if __name__ == "__main__":
Expand Down
26 changes: 26 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,29 @@
[project]
name = "collaborative-coding-exam"
version = "0.1.0"
description = "Exam project in the collaborative coding course."
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"black>=25.1.0",
"h5py>=3.12.1",
"isort>=6.0.0",
"jupyterlab>=4.3.5",
"numpy>=2.2.2",
"pandas>=2.2.3",
"pip>=25.0",
"pytest>=8.3.4",
"ruff>=0.9.4",
"scalene>=1.5.51",
"scikit-learn>=1.6.1",
"sphinx>=8.1.3",
"sphinx-autoapi>=3.4.0",
"sphinx-autobuild>=2024.10.3",
"sphinx-rtd-theme>=3.0.2",
"torch>=2.6.0",
"torchvision>=0.21.0",
"tqdm>=4.67.1",
]
[tool.isort]
profile = "black"
line_length = 88
15 changes: 11 additions & 4 deletions tests/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,25 @@ def test_uspsdataset0_6():

# Create a h5 file
with h5py.File(tf, "w") as f:
targets = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
indices = np.arange(len(targets))
# Populate the file with data
f["train/data"] = np.random.rand(10, 16 * 16)
f["train/target"] = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
f["train/target"] = targets

trans = transforms.Compose(
[
transforms.Resize((16, 16)), # At least for USPS
transforms.Resize((16, 16)),
transforms.ToTensor(),
]
)
dataset = USPSDataset0_6(data_path=tempdir, train=True, transform=trans)
dataset = USPSDataset0_6(
data_path=tempdir,
sample_ids=indices,
train=True,
transform=trans,
)
assert len(dataset) == 10
data, target = dataset[0]
assert data.shape == (1, 16, 16)
assert all(target == np.array([0, 0, 0, 0, 0, 0, 1]))
assert target == 6
Loading