2323best_result = Result ()
2424best_result .set_to_worst ()
2525
26+ def create_data_loaders (args ):
27+ # Data loading code
28+ print ("=> creating data loaders ..." )
29+ traindir = os .path .join ('data' , args .data , 'train' )
30+ valdir = os .path .join ('data' , args .data , 'val' )
31+ train_loader = None
32+ val_loader = None
33+
34+ # sparsifier is a class for generating random sparse depth input from the ground truth
35+ sparsifier = None
36+ max_depth = args .max_depth if args .max_depth >= 0.0 else np .inf
37+ if args .sparsifier == UniformSampling .name :
38+ sparsifier = UniformSampling (num_samples = args .num_samples , max_depth = max_depth )
39+ elif args .sparsifier == SimulatedStereo .name :
40+ sparsifier = SimulatedStereo (num_samples = args .num_samples , max_depth = max_depth )
41+
42+ if args .data == 'nyudepthv2' :
43+ from dataloaders .nyu_dataloader import NYUDataset
44+ if not args .evaluate :
45+ train_dataset = NYUDataset (traindir , type = 'train' ,
46+ modality = args .modality , sparsifier = sparsifier )
47+ val_dataset = NYUDataset (valdir , type = 'val' ,
48+ modality = args .modality , sparsifier = sparsifier )
49+
50+ elif args .data == 'kitti' :
51+ from dataloaders .kitti_dataloader import KITTIDataset
52+ if not args .evaluate :
53+ train_dataset = KITTIDataset (traindir , type = 'train' ,
54+ modality = args .modality , sparsifier = sparsifier )
55+ val_dataset = KITTIDataset (valdir , type = 'val' ,
56+ modality = args .modality , sparsifier = sparsifier )
57+
58+ else :
59+ raise RuntimeError ('Dataset not found.' +
60+ 'The dataset must be either of nyudepthv2 or kitti.' )
61+
62+ # set batch size to be 1 for validation
63+ val_loader = torch .utils .data .DataLoader (val_dataset ,
64+ batch_size = 1 , shuffle = False , num_workers = args .workers , pin_memory = True )
65+
66+ # put construction of train loader here, for those who are interested in testing only
67+ if not args .evaluate :
68+ train_loader = torch .utils .data .DataLoader (
69+ train_dataset , batch_size = args .batch_size , shuffle = True ,
70+ num_workers = args .workers , pin_memory = True , sampler = None ,
71+ worker_init_fn = lambda work_id :np .random .seed (work_id ))
72+ # worker_init_fn ensures different sampling patterns for each data loading thread
73+
74+ print ("=> data loaders created." )
75+ return train_loader , val_loader
76+
2677def main ():
2778 global args , best_result , output_directory , train_csv , test_csv
2879
@@ -33,12 +84,16 @@ def main():
3384 "=> no best model found at '{}'" .format (args .evaluate )
3485 print ("=> loading best model '{}'" .format (args .evaluate ))
3586 checkpoint = torch .load (args .evaluate )
87+ output_directory = os .path .dirname (args .evaluate )
3688 args = checkpoint ['args' ]
37- args .evaluate = True
3889 start_epoch = checkpoint ['epoch' ] + 1
3990 best_result = checkpoint ['best_result' ]
4091 model = checkpoint ['model' ]
4192 print ("=> loaded best model (epoch {})" .format (checkpoint ['epoch' ]))
93+ _ , val_loader = create_data_loaders (args )
94+ args .evaluate = True
95+ validate (val_loader , model , checkpoint ['epoch' ], write_to_file = False )
96+ return
4297
4398 # optionally resume from a checkpoint
4499 elif args .resume :
@@ -51,93 +106,35 @@ def main():
51106 best_result = checkpoint ['best_result' ]
52107 model = checkpoint ['model' ]
53108 optimizer = checkpoint ['optimizer' ]
54- output_directory , _ = os .path .split ( args .resume )
109+ output_directory = os .path .dirname ( os . path . abspath ( args .resume ) )
55110 print ("=> loaded checkpoint (epoch {})" .format (checkpoint ['epoch' ]))
111+ train_loader , val_loader = create_data_loaders (args )
112+ args .resume = True
56113
57114 # create new model
58115 else :
59- # define model
116+ train_loader , val_loader = create_data_loaders ( args )
60117 print ("=> creating Model ({}-{}) ..." .format (args .arch , args .decoder ))
61118 in_channels = len (args .modality )
62119 if args .arch == 'resnet50' :
63- model = ResNet (layers = 50 , decoder = args .decoder , output_size = train_dataset .output_size ,
120+ model = ResNet (layers = 50 , decoder = args .decoder , output_size = train_loader . dataset .output_size ,
64121 in_channels = in_channels , pretrained = args .pretrained )
65122 elif args .arch == 'resnet18' :
66- model = ResNet (layers = 18 , decoder = args .decoder , output_size = train_dataset .output_size ,
123+ model = ResNet (layers = 18 , decoder = args .decoder , output_size = train_loader . dataset .output_size ,
67124 in_channels = in_channels , pretrained = args .pretrained )
68125 print ("=> model created." )
69-
70126 optimizer = torch .optim .SGD (model .parameters (), args .lr , \
71127 momentum = args .momentum , weight_decay = args .weight_decay )
72128
73- # create new csv files with only header
74- with open (train_csv , 'w' ) as csvfile :
75- writer = csv .DictWriter (csvfile , fieldnames = fieldnames )
76- writer .writeheader ()
77- with open (test_csv , 'w' ) as csvfile :
78- writer = csv .DictWriter (csvfile , fieldnames = fieldnames )
79- writer .writeheader ()
80-
81- # model = torch.nn.DataParallel(model).cuda() # for multi-gpu training
82- model = model .cuda ()
83- # print(model)
84- print ("=> model transferred to GPU." )
129+ # model = torch.nn.DataParallel(model).cuda() # for multi-gpu training
130+ model = model .cuda ()
85131
86132 # define loss function (criterion) and optimizer
87133 if args .criterion == 'l2' :
88134 criterion = criteria .MaskedMSELoss ().cuda ()
89135 elif args .criterion == 'l1' :
90136 criterion = criteria .MaskedL1Loss ().cuda ()
91137
92- # sparsifier is a class for generating random sparse depth input from the ground truth
93- sparsifier = None
94- max_depth = args .max_depth if args .max_depth >= 0.0 else np .inf
95- if args .sparsifier == UniformSampling .name :
96- sparsifier = UniformSampling (num_samples = args .num_samples , max_depth = max_depth )
97- elif args .sparsifier == SimulatedStereo .name :
98- sparsifier = SimulatedStereo (num_samples = args .num_samples , max_depth = max_depth )
99-
100- # Data loading code
101- print ("=> creating data loaders ..." )
102- traindir = os .path .join ('data' , args .data , 'train' )
103- valdir = os .path .join ('data' , args .data , 'val' )
104-
105- if args .data == 'nyudepthv2' :
106- from dataloaders .nyu_dataloader import NYUDataset
107- if not args .evaluate :
108- train_dataset = NYUDataset (traindir , type = 'train' ,
109- modality = args .modality , sparsifier = sparsifier )
110- val_dataset = NYUDataset (valdir , type = 'val' ,
111- modality = args .modality , sparsifier = sparsifier )
112-
113- elif args .data == 'kitti' :
114- from dataloaders .kitti_dataloader import KITTIDataset
115- if not args .evaluate :
116- train_dataset = KITTIDataset (traindir , type = 'train' ,
117- modality = args .modality , sparsifier = sparsifier )
118- val_dataset = KITTIDataset (valdir , type = 'val' ,
119- modality = args .modality , sparsifier = sparsifier )
120-
121- else :
122- raise RuntimeError ('Dataset not found.' +
123- 'The dataset must be either of nyudepthv2 or kitti.' )
124-
125- # set batch size to be 1 for validation
126- val_loader = torch .utils .data .DataLoader (val_dataset ,
127- batch_size = 1 , shuffle = False , num_workers = args .workers , pin_memory = True )
128- print ("=> data loaders created." )
129-
130- if args .evaluate :
131- validate (val_loader , model , checkpoint ['epoch' ], write_to_file = False )
132- return
133-
134- # put construction of train loader here, for those who are interested in testing only
135- train_loader = torch .utils .data .DataLoader (
136- train_dataset , batch_size = args .batch_size , shuffle = True ,
137- num_workers = args .workers , pin_memory = True , sampler = None ,
138- worker_init_fn = lambda work_id :np .random .seed (work_id ))
139- # worker_init_fn ensures different sampling patterns for each data loading thread
140-
141138 # create results folder, if not already exists
142139 output_directory = utils .get_output_directory (args )
143140 if not os .path .exists (output_directory ):
@@ -146,6 +143,15 @@ def main():
146143 test_csv = os .path .join (output_directory , 'test.csv' )
147144 best_txt = os .path .join (output_directory , 'best.txt' )
148145
146+ # create new csv files with only header
147+ if not args .resume :
148+ with open (train_csv , 'w' ) as csvfile :
149+ writer = csv .DictWriter (csvfile , fieldnames = fieldnames )
150+ writer .writeheader ()
151+ with open (test_csv , 'w' ) as csvfile :
152+ writer = csv .DictWriter (csvfile , fieldnames = fieldnames )
153+ writer .writeheader ()
154+
149155 for epoch in range (start_epoch , args .epochs ):
150156 utils .adjust_learning_rate (optimizer , epoch , args .lr )
151157 train (train_loader , model , criterion , optimizer , epoch ) # train for one epoch
0 commit comments