@@ -31,8 +31,8 @@ def plot2DScatter(data, outputfile):
3131 '''
3232 x = data [:, 0 ]
3333 y = data [:, 1 ]
34- print "The mean vector is %s" % numpy .mean (data , 0 )
35- print "The std vector is %s" % numpy .std (data , 0 )
34+ logger . info ( "The mean vector is %s" % numpy .mean (data , 0 ) )
35+ logger . info ( "The std vector is %s" % numpy .std (data , 0 ) )
3636
3737 heatmap , xedges , yedges = numpy .histogram2d (x , y , bins = 50 )
3838 extent = [xedges [0 ], xedges [- 1 ], yedges [0 ], yedges [- 1 ]]
@@ -192,42 +192,42 @@ def get_layer_size(model_conf, layer_name):
192192
193193def main ():
194194 parser = argparse .ArgumentParser ()
195- parser .add_argument ("-d" , "--dataSource " , help = "mnist or cifar or uniform" )
196- parser .add_argument ("--useGpu " , default = "1" ,
195+ parser .add_argument ("-d" , "--data_source " , help = "mnist or cifar or uniform" )
196+ parser .add_argument ("--use_gpu " , default = "1" ,
197197 help = "1 means use gpu for training" )
198- parser .add_argument ("--gpuId " , default = "0" ,
198+ parser .add_argument ("--gpu_id " , default = "0" ,
199199 help = "the gpu_id parameter" )
200200 args = parser .parse_args ()
201- dataSource = args .dataSource
202- useGpu = args .useGpu
203- assert dataSource in ["mnist" , "cifar" , "uniform" ]
204- assert useGpu in ["0" , "1" ]
201+ data_source = args .data_source
202+ use_gpu = args .use_gpu
203+ assert data_source in ["mnist" , "cifar" , "uniform" ]
204+ assert use_gpu in ["0" , "1" ]
205205
206- if not os .path .exists ("./%s_samples/" % dataSource ):
207- os .makedirs ("./%s_samples/" % dataSource )
206+ if not os .path .exists ("./%s_samples/" % data_source ):
207+ os .makedirs ("./%s_samples/" % data_source )
208208
209- if not os .path .exists ("./%s_params/" % dataSource ):
210- os .makedirs ("./%s_params/" % dataSource )
209+ if not os .path .exists ("./%s_params/" % data_source ):
210+ os .makedirs ("./%s_params/" % data_source )
211211
212- api .initPaddle ('--use_gpu=' + useGpu , '--dot_period=10' , '--log_period=100' ,
213- '--gpu_id=' + args .gpuId , '--save_dir=' + "./%s_params/" % dataSource )
212+ api .initPaddle ('--use_gpu=' + use_gpu , '--dot_period=10' , '--log_period=100' ,
213+ '--gpu_id=' + args .gpu_id , '--save_dir=' + "./%s_params/" % data_source )
214214
215- if dataSource == "uniform" :
215+ if data_source == "uniform" :
216216 conf = "gan_conf.py"
217217 num_iter = 10000
218218 else :
219219 conf = "gan_conf_image.py"
220220 num_iter = 1000
221221
222- gen_conf = parse_config (conf , "mode=generator_training,data=" + dataSource )
223- dis_conf = parse_config (conf , "mode=discriminator_training,data=" + dataSource )
224- generator_conf = parse_config (conf , "mode=generator,data=" + dataSource )
222+ gen_conf = parse_config (conf , "mode=generator_training,data=" + data_source )
223+ dis_conf = parse_config (conf , "mode=discriminator_training,data=" + data_source )
224+ generator_conf = parse_config (conf , "mode=generator,data=" + data_source )
225225 batch_size = dis_conf .opt_config .batch_size
226226 noise_dim = get_layer_size (gen_conf .model_config , "noise" )
227227
228- if dataSource == "mnist" :
228+ if data_source == "mnist" :
229229 data_np = load_mnist_data ("./data/mnist_data/train-images-idx3-ubyte" )
230- elif dataSource == "cifar" :
230+ elif data_source == "cifar" :
231231 data_np = load_cifar_data ("./data/cifar-10-batches-py/" )
232232 else :
233233 data_np = load_uniform_data ()
@@ -308,18 +308,20 @@ def main():
308308 else :
309309 curr_train = "gen"
310310 curr_strike = 1
311- gen_trainer .trainOneDataBatch (batch_size , data_batch_gen )
311+ gen_trainer .trainOneDataBatch (batch_size , data_batch_gen )
312+ # TODO: add API for paddle to allow true parameter sharing between different GradientMachines
313+ # so that we do not need to copy shared parameters.
312314 copy_shared_parameters (gen_training_machine , dis_training_machine )
313315 copy_shared_parameters (gen_training_machine , generator_machine )
314316
315317 dis_trainer .finishTrainPass ()
316318 gen_trainer .finishTrainPass ()
317319 # At the end of each pass, save the generated samples/images
318320 fake_samples = get_fake_samples (generator_machine , batch_size , noise )
319- if dataSource == "uniform" :
320- plot2DScatter (fake_samples , "./%s_samples/train_pass%s.png" % (dataSource , train_pass ))
321+ if data_source == "uniform" :
322+ plot2DScatter (fake_samples , "./%s_samples/train_pass%s.png" % (data_source , train_pass ))
321323 else :
322- save_images (fake_samples , "./%s_samples/train_pass%s.png" % (dataSource , train_pass ))
324+ save_images (fake_samples , "./%s_samples/train_pass%s.png" % (data_source , train_pass ))
323325 dis_trainer .finishTrain ()
324326 gen_trainer .finishTrain ()
325327
0 commit comments