Skip to content

Commit c23442f

Browse files
author
Fangchang Ma
committed
fixed bugs with resuming and evaluation
1 parent 8d750e2 commit c23442f

File tree

3 files changed

+83
-77
lines changed

3 files changed

+83
-77
lines changed

README.md

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@ This repo can be used for training and testing of
1414

1515
The original Torch implementation of the paper can be found [here](https://github.com/fangchangma/sparse-to-dense).
1616

17-
## Thanks
18-
Thanks to [Tim](https://github.com/timethy) and [Akari](https://github.com/AkariAsai) for their contributions.
19-
2017
## Contents
2118
0. [Requirements](#requirements)
2219
0. [Training](#training)
@@ -54,13 +51,15 @@ For instance, run the following command to train a network with ResNet50 as the
5451
python3 main.py -a resnet50 -d deconv3 -m rgbd -s 100 -data nyudepthv2
5552
```
5653

57-
Training results will be saved under the `results` folder.
58-
54+
Training results will be saved under the `results` folder. To resume a previous training, run
55+
```bash
56+
python3 main.py --resume [path_to_previous_model]
57+
```
5958

6059
## Testing
61-
To test the performance of a trained model, simply run main.py with the `-e` option, along with other model options. For instance,
60+
To test the performance of a trained model without training, simply run main.py with the `-e` option. For instance,
6261
```bash
63-
python3 main.py -e -a resnet50 -d deconv3 -m rgbd -s 100 -data nyudepthv2
62+
python3 main.py --evaluate [path_to_trained_model]
6463
```
6564

6665
## Trained Models
@@ -124,4 +123,4 @@ If you use our code or method in your work, please consider citing the following
124123
year={2018}
125124
}
126125

127-
Please direct any questions to [Fangchang Ma](http://www.mit.edu/~fcma) at fcma@mit.edu.
126+
Please create a new issue for code-related questions. Pull requests are welcome.

main.py

Lines changed: 75 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -26,86 +26,32 @@
2626
def main():
2727
global args, best_result, output_directory, train_csv, test_csv
2828

29-
# create results folder, if not already exists
30-
output_directory = utils.get_output_directory(args)
31-
if not os.path.exists(output_directory):
32-
os.makedirs(output_directory)
33-
train_csv = os.path.join(output_directory, 'train.csv')
34-
test_csv = os.path.join(output_directory, 'test.csv')
35-
best_txt = os.path.join(output_directory, 'best.txt')
36-
37-
# define loss function (criterion) and optimizer
38-
if args.criterion == 'l2':
39-
criterion = criteria.MaskedMSELoss().cuda()
40-
elif args.criterion == 'l1':
41-
criterion = criteria.MaskedL1Loss().cuda()
42-
43-
# sparsifier is a class for generating random sparse depth input from the ground truth
44-
sparsifier = None
45-
max_depth = args.max_depth if args.max_depth >= 0.0 else np.inf
46-
if args.sparsifier == UniformSampling.name:
47-
sparsifier = UniformSampling(num_samples=args.num_samples, max_depth=max_depth)
48-
elif args.sparsifier == SimulatedStereo.name:
49-
sparsifier = SimulatedStereo(num_samples=args.num_samples, max_depth=max_depth)
50-
51-
# Data loading code
52-
print("=> creating data loaders ...")
53-
traindir = os.path.join('data', args.data, 'train')
54-
valdir = os.path.join('data', args.data, 'val')
55-
56-
if args.data == 'nyudepthv2':
57-
from dataloaders.nyu_dataloader import NYUDataset
58-
train_dataset = NYUDataset(traindir, type='train',
59-
modality=args.modality, sparsifier=sparsifier)
60-
val_dataset = NYUDataset(valdir, type='val',
61-
modality=args.modality, sparsifier=sparsifier)
62-
63-
elif args.data == 'kitti':
64-
from dataloaders.kitti_dataloader import KITTIDataset
65-
train_dataset = KITTIDataset(traindir, type='train',
66-
modality=args.modality, sparsifier=sparsifier)
67-
val_dataset = KITTIDataset(valdir, type='val',
68-
modality=args.modality, sparsifier=sparsifier)
69-
70-
else:
71-
raise RuntimeError('Dataset not found.' +
72-
'The dataset must be either of nyudepthv2 or kitti.')
73-
74-
train_loader = torch.utils.data.DataLoader(
75-
train_dataset, batch_size=args.batch_size, shuffle=True,
76-
num_workers=args.workers, pin_memory=True, sampler=None,
77-
worker_init_fn=lambda work_id:np.random.seed(work_id))
78-
# worker_init_fn ensures different sampling patterns for each data loading thread
79-
80-
# set batch size to be 1 for validation
81-
val_loader = torch.utils.data.DataLoader(val_dataset,
82-
batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)
83-
print("=> data loaders created.")
84-
8529
# evaluation mode
30+
start_epoch = 0
8631
if args.evaluate:
87-
best_model_filename = os.path.join(output_directory, 'model_best.pth.tar')
88-
assert os.path.isfile(best_model_filename), \
89-
"=> no best model found at '{}'".format(best_model_filename)
90-
print("=> loading best model '{}'".format(best_model_filename))
91-
checkpoint = torch.load(best_model_filename)
92-
args.start_epoch = checkpoint['epoch']
32+
assert os.path.isfile(args.evaluate), \
33+
"=> no best model found at '{}'".format(args.evaluate)
34+
print("=> loading best model '{}'".format(args.evaluate))
35+
checkpoint = torch.load(args.evaluate)
36+
args = checkpoint['args']
37+
args.evaluate = True
38+
start_epoch = checkpoint['epoch'] + 1
9339
best_result = checkpoint['best_result']
9440
model = checkpoint['model']
9541
print("=> loaded best model (epoch {})".format(checkpoint['epoch']))
96-
validate(val_loader, model, checkpoint['epoch'], write_to_file=False)
97-
return
9842

9943
# optionally resume from a checkpoint
10044
elif args.resume:
10145
assert os.path.isfile(args.resume), \
10246
"=> no checkpoint found at '{}'".format(args.resume)
10347
print("=> loading checkpoint '{}'".format(args.resume))
10448
checkpoint = torch.load(args.resume)
105-
args.start_epoch = checkpoint['epoch']+1
49+
args = checkpoint['args']
50+
start_epoch = checkpoint['epoch'] + 1
10651
best_result = checkpoint['best_result']
10752
model = checkpoint['model']
10853
optimizer = checkpoint['optimizer']
54+
output_directory, _ = os.path.split(args.resume)
10955
print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
11056

11157
# create new model
@@ -137,7 +83,70 @@ def main():
13783
# print(model)
13884
print("=> model transferred to GPU.")
13985

140-
for epoch in range(args.start_epoch, args.epochs):
86+
# define loss function (criterion) and optimizer
87+
if args.criterion == 'l2':
88+
criterion = criteria.MaskedMSELoss().cuda()
89+
elif args.criterion == 'l1':
90+
criterion = criteria.MaskedL1Loss().cuda()
91+
92+
# sparsifier is a class for generating random sparse depth input from the ground truth
93+
sparsifier = None
94+
max_depth = args.max_depth if args.max_depth >= 0.0 else np.inf
95+
if args.sparsifier == UniformSampling.name:
96+
sparsifier = UniformSampling(num_samples=args.num_samples, max_depth=max_depth)
97+
elif args.sparsifier == SimulatedStereo.name:
98+
sparsifier = SimulatedStereo(num_samples=args.num_samples, max_depth=max_depth)
99+
100+
# Data loading code
101+
print("=> creating data loaders ...")
102+
traindir = os.path.join('data', args.data, 'train')
103+
valdir = os.path.join('data', args.data, 'val')
104+
105+
if args.data == 'nyudepthv2':
106+
from dataloaders.nyu_dataloader import NYUDataset
107+
if not args.evaluate:
108+
train_dataset = NYUDataset(traindir, type='train',
109+
modality=args.modality, sparsifier=sparsifier)
110+
val_dataset = NYUDataset(valdir, type='val',
111+
modality=args.modality, sparsifier=sparsifier)
112+
113+
elif args.data == 'kitti':
114+
from dataloaders.kitti_dataloader import KITTIDataset
115+
if not args.evaluate:
116+
train_dataset = KITTIDataset(traindir, type='train',
117+
modality=args.modality, sparsifier=sparsifier)
118+
val_dataset = KITTIDataset(valdir, type='val',
119+
modality=args.modality, sparsifier=sparsifier)
120+
121+
else:
122+
raise RuntimeError('Dataset not found.' +
123+
'The dataset must be either of nyudepthv2 or kitti.')
124+
125+
# set batch size to be 1 for validation
126+
val_loader = torch.utils.data.DataLoader(val_dataset,
127+
batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)
128+
print("=> data loaders created.")
129+
130+
if args.evaluate:
131+
validate(val_loader, model, checkpoint['epoch'], write_to_file=False)
132+
return
133+
134+
# put construction of train loader here, for those who are interested in testing only
135+
train_loader = torch.utils.data.DataLoader(
136+
train_dataset, batch_size=args.batch_size, shuffle=True,
137+
num_workers=args.workers, pin_memory=True, sampler=None,
138+
worker_init_fn=lambda work_id:np.random.seed(work_id))
139+
# worker_init_fn ensures different sampling patterns for each data loading thread
140+
141+
# create results folder, if not already exists
142+
output_directory = utils.get_output_directory(args)
143+
if not os.path.exists(output_directory):
144+
os.makedirs(output_directory)
145+
train_csv = os.path.join(output_directory, 'train.csv')
146+
test_csv = os.path.join(output_directory, 'test.csv')
147+
best_txt = os.path.join(output_directory, 'best.txt')
148+
149+
for epoch in range(start_epoch, args.epochs):
141150
utils.adjust_learning_rate(optimizer, epoch, args.lr)
142151
train(train_loader, model, criterion, optimizer, epoch) # train for one epoch
143152
result, img_merge = validate(val_loader, model, epoch) # evaluate on validation set

utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@ def parse_command():
3939
help='number of data loading workers (default: 10)')
4040
parser.add_argument('--epochs', default=15, type=int, metavar='N',
4141
help='number of total epochs to run (default: 15)')
42-
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
43-
help='manual epoch number (useful on restarts)')
4442
parser.add_argument('-c', '--criterion', metavar='LOSS', default='l1', choices=loss_names,
4543
help='loss function: ' + ' | '.join(loss_names) + ' (default: l1)')
4644
parser.add_argument('-b', '--batch-size', default=8, type=int, help='mini-batch size (default: 8)')
@@ -54,7 +52,7 @@ def parse_command():
5452
metavar='N', help='print frequency (default: 10)')
5553
parser.add_argument('--resume', default='', type=str, metavar='PATH',
5654
help='path to latest checkpoint (default: none)')
57-
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
55+
parser.add_argument('-e', '--evaluate', dest='evaluate', type=str, default='',
5856
help='evaluate model on validation set')
5957
parser.add_argument('--no-pretrain', dest='pretrained', action='store_false',
6058
help='not to use ImageNet pre-trained weights')

0 commit comments

Comments
 (0)