|
26 | 26 | def main(): |
27 | 27 | global args, best_result, output_directory, train_csv, test_csv |
28 | 28 |
|
29 | | - # create results folder, if not already exists |
30 | | - output_directory = utils.get_output_directory(args) |
31 | | - if not os.path.exists(output_directory): |
32 | | - os.makedirs(output_directory) |
33 | | - train_csv = os.path.join(output_directory, 'train.csv') |
34 | | - test_csv = os.path.join(output_directory, 'test.csv') |
35 | | - best_txt = os.path.join(output_directory, 'best.txt') |
36 | | - |
37 | | - # define loss function (criterion) and optimizer |
38 | | - if args.criterion == 'l2': |
39 | | - criterion = criteria.MaskedMSELoss().cuda() |
40 | | - elif args.criterion == 'l1': |
41 | | - criterion = criteria.MaskedL1Loss().cuda() |
42 | | - |
43 | | - # sparsifier is a class for generating random sparse depth input from the ground truth |
44 | | - sparsifier = None |
45 | | - max_depth = args.max_depth if args.max_depth >= 0.0 else np.inf |
46 | | - if args.sparsifier == UniformSampling.name: |
47 | | - sparsifier = UniformSampling(num_samples=args.num_samples, max_depth=max_depth) |
48 | | - elif args.sparsifier == SimulatedStereo.name: |
49 | | - sparsifier = SimulatedStereo(num_samples=args.num_samples, max_depth=max_depth) |
50 | | - |
51 | | - # Data loading code |
52 | | - print("=> creating data loaders ...") |
53 | | - traindir = os.path.join('data', args.data, 'train') |
54 | | - valdir = os.path.join('data', args.data, 'val') |
55 | | - |
56 | | - if args.data == 'nyudepthv2': |
57 | | - from dataloaders.nyu_dataloader import NYUDataset |
58 | | - train_dataset = NYUDataset(traindir, type='train', |
59 | | - modality=args.modality, sparsifier=sparsifier) |
60 | | - val_dataset = NYUDataset(valdir, type='val', |
61 | | - modality=args.modality, sparsifier=sparsifier) |
62 | | - |
63 | | - elif args.data == 'kitti': |
64 | | - from dataloaders.kitti_dataloader import KITTIDataset |
65 | | - train_dataset = KITTIDataset(traindir, type='train', |
66 | | - modality=args.modality, sparsifier=sparsifier) |
67 | | - val_dataset = KITTIDataset(valdir, type='val', |
68 | | - modality=args.modality, sparsifier=sparsifier) |
69 | | - |
70 | | - else: |
71 | | - raise RuntimeError('Dataset not found.' + |
72 | | - 'The dataset must be either of nyudepthv2 or kitti.') |
73 | | - |
74 | | - train_loader = torch.utils.data.DataLoader( |
75 | | - train_dataset, batch_size=args.batch_size, shuffle=True, |
76 | | - num_workers=args.workers, pin_memory=True, sampler=None, |
77 | | - worker_init_fn=lambda work_id:np.random.seed(work_id)) |
78 | | - # worker_init_fn ensures different sampling patterns for each data loading thread |
79 | | - |
80 | | - # set batch size to be 1 for validation |
81 | | - val_loader = torch.utils.data.DataLoader(val_dataset, |
82 | | - batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True) |
83 | | - print("=> data loaders created.") |
84 | | - |
85 | 29 | # evaluation mode |
| 30 | + start_epoch = 0 |
86 | 31 | if args.evaluate: |
87 | | - best_model_filename = os.path.join(output_directory, 'model_best.pth.tar') |
88 | | - assert os.path.isfile(best_model_filename), \ |
89 | | - "=> no best model found at '{}'".format(best_model_filename) |
90 | | - print("=> loading best model '{}'".format(best_model_filename)) |
91 | | - checkpoint = torch.load(best_model_filename) |
92 | | - args.start_epoch = checkpoint['epoch'] |
| 32 | + assert os.path.isfile(args.evaluate), \ |
| 33 | + "=> no best model found at '{}'".format(args.evaluate) |
| 34 | + print("=> loading best model '{}'".format(args.evaluate)) |
| 35 | + checkpoint = torch.load(args.evaluate) |
| 36 | + args = checkpoint['args'] |
| 37 | + args.evaluate = True |
| 38 | + start_epoch = checkpoint['epoch'] + 1 |
93 | 39 | best_result = checkpoint['best_result'] |
94 | 40 | model = checkpoint['model'] |
95 | 41 | print("=> loaded best model (epoch {})".format(checkpoint['epoch'])) |
96 | | - validate(val_loader, model, checkpoint['epoch'], write_to_file=False) |
97 | | - return |
98 | 42 |
|
99 | 43 | # optionally resume from a checkpoint |
100 | 44 | elif args.resume: |
101 | 45 | assert os.path.isfile(args.resume), \ |
102 | 46 | "=> no checkpoint found at '{}'".format(args.resume) |
103 | 47 | print("=> loading checkpoint '{}'".format(args.resume)) |
104 | 48 | checkpoint = torch.load(args.resume) |
105 | | - args.start_epoch = checkpoint['epoch']+1 |
| 49 | + args = checkpoint['args'] |
| 50 | + start_epoch = checkpoint['epoch'] + 1 |
106 | 51 | best_result = checkpoint['best_result'] |
107 | 52 | model = checkpoint['model'] |
108 | 53 | optimizer = checkpoint['optimizer'] |
| 54 | + output_directory, _ = os.path.split(args.resume) |
109 | 55 | print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch'])) |
110 | 56 |
|
111 | 57 | # create new model |
@@ -137,7 +83,70 @@ def main(): |
137 | 83 | # print(model) |
138 | 84 | print("=> model transferred to GPU.") |
139 | 85 |
|
140 | | - for epoch in range(args.start_epoch, args.epochs): |
| 86 | + # define loss function (criterion) and optimizer |
| 87 | + if args.criterion == 'l2': |
| 88 | + criterion = criteria.MaskedMSELoss().cuda() |
| 89 | + elif args.criterion == 'l1': |
| 90 | + criterion = criteria.MaskedL1Loss().cuda() |
| 91 | + |
| 92 | + # sparsifier is a class for generating random sparse depth input from the ground truth |
| 93 | + sparsifier = None |
| 94 | + max_depth = args.max_depth if args.max_depth >= 0.0 else np.inf |
| 95 | + if args.sparsifier == UniformSampling.name: |
| 96 | + sparsifier = UniformSampling(num_samples=args.num_samples, max_depth=max_depth) |
| 97 | + elif args.sparsifier == SimulatedStereo.name: |
| 98 | + sparsifier = SimulatedStereo(num_samples=args.num_samples, max_depth=max_depth) |
| 99 | + |
| 100 | + # Data loading code |
| 101 | + print("=> creating data loaders ...") |
| 102 | + traindir = os.path.join('data', args.data, 'train') |
| 103 | + valdir = os.path.join('data', args.data, 'val') |
| 104 | + |
| 105 | + if args.data == 'nyudepthv2': |
| 106 | + from dataloaders.nyu_dataloader import NYUDataset |
| 107 | + if not args.evaluate: |
| 108 | + train_dataset = NYUDataset(traindir, type='train', |
| 109 | + modality=args.modality, sparsifier=sparsifier) |
| 110 | + val_dataset = NYUDataset(valdir, type='val', |
| 111 | + modality=args.modality, sparsifier=sparsifier) |
| 112 | + |
| 113 | + elif args.data == 'kitti': |
| 114 | + from dataloaders.kitti_dataloader import KITTIDataset |
| 115 | + if not args.evaluate: |
| 116 | + train_dataset = KITTIDataset(traindir, type='train', |
| 117 | + modality=args.modality, sparsifier=sparsifier) |
| 118 | + val_dataset = KITTIDataset(valdir, type='val', |
| 119 | + modality=args.modality, sparsifier=sparsifier) |
| 120 | + |
| 121 | + else: |
| 122 | + raise RuntimeError('Dataset not found.' + |
| 123 | + 'The dataset must be either of nyudepthv2 or kitti.') |
| 124 | + |
| 125 | + # set batch size to be 1 for validation |
| 126 | + val_loader = torch.utils.data.DataLoader(val_dataset, |
| 127 | + batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True) |
| 128 | + print("=> data loaders created.") |
| 129 | + |
| 130 | + if args.evaluate: |
| 131 | + validate(val_loader, model, checkpoint['epoch'], write_to_file=False) |
| 132 | + return |
| 133 | + |
| 134 | + # put construction of train loader here, for those who are interested in testing only |
| 135 | + train_loader = torch.utils.data.DataLoader( |
| 136 | + train_dataset, batch_size=args.batch_size, shuffle=True, |
| 137 | + num_workers=args.workers, pin_memory=True, sampler=None, |
| 138 | + worker_init_fn=lambda work_id:np.random.seed(work_id)) |
| 139 | + # worker_init_fn ensures different sampling patterns for each data loading thread |
| 140 | + |
| 141 | + # create results folder, if not already exists |
| 142 | + output_directory = utils.get_output_directory(args) |
| 143 | + if not os.path.exists(output_directory): |
| 144 | + os.makedirs(output_directory) |
| 145 | + train_csv = os.path.join(output_directory, 'train.csv') |
| 146 | + test_csv = os.path.join(output_directory, 'test.csv') |
| 147 | + best_txt = os.path.join(output_directory, 'best.txt') |
| 148 | + |
| 149 | + for epoch in range(start_epoch, args.epochs): |
141 | 150 | utils.adjust_learning_rate(optimizer, epoch, args.lr) |
142 | 151 | train(train_loader, model, criterion, optimizer, epoch) # train for one epoch |
143 | 152 | result, img_merge = validate(val_loader, model, epoch) # evaluate on validation set |
|
0 commit comments