@@ -90,10 +90,8 @@ def load_mnist_data(imageFile):
9090 data = numpy .zeros ((n , 28 * 28 ), dtype = "float32" )
9191
9292 for i in range (n ):
93- pixels = []
94- for j in range (28 * 28 ):
95- pixels .append (float (ord (f .read (1 ))) / 255.0 * 2.0 - 1.0 )
96- data [i , :] = pixels
93+ pixels = numpy .fromfile (f , 'ubyte' , count = 28 * 28 )
94+ data [i , :] = pixels / 255.0 * 2.0 - 1.0
9795
9896 f .close ()
9997 return data
@@ -129,7 +127,7 @@ def merge(images, size):
129127 ((images [idx , :].reshape ((h , w , c ), order = "F" ).transpose (1 , 0 , 2 ) + 1.0 ) / 2.0 * 255.0 )
130128 return img .astype ('uint8' )
131129
132- def saveImages (images , path ):
130+ def save_images (images , path ):
133131 merged_img = merge (images , [8 , 8 ])
134132 if merged_img .shape [2 ] == 1 :
135133 im = Image .fromarray (numpy .squeeze (merged_img )).convert ('RGB' )
@@ -207,9 +205,15 @@ def main():
207205 useGpu = args .useGpu
208206 assert dataSource in ["mnist" , "cifar" , "uniform" ]
209207 assert useGpu in ["0" , "1" ]
210-
208+
209+ if not os .path .exists ("./%s_samples/" % dataSource ):
210+ os .makedirs ("./%s_samples/" % dataSource )
211+
212+ if not os .path .exists ("./%s_params/" % dataSource ):
213+ os .makedirs ("./%s_params/" % dataSource )
214+
211215 api .initPaddle ('--use_gpu=' + useGpu , '--dot_period=10' , '--log_period=100' ,
212- '--gpu_id=' + args .gpuId )
216+ '--gpu_id=' + args .gpuId , '--save_dir=' + "./%s_params/" % dataSource )
213217
214218 if dataSource == "uniform" :
215219 conf = "gan_conf.py"
@@ -231,9 +235,6 @@ def main():
231235 else :
232236 data_np = load_uniform_data ()
233237
234- if not os .path .exists ("./%s_samples/" % dataSource ):
235- os .makedirs ("./%s_samples/" % dataSource )
236-
237238 # this create a gradient machine for discriminator
238239 dis_training_machine = api .GradientMachine .createFromConfigProto (
239240 dis_conf .model_config )
@@ -321,7 +322,7 @@ def main():
321322 if dataSource == "uniform" :
322323 plot2DScatter (fake_samples , "./%s_samples/train_pass%s.png" % (dataSource , train_pass ))
323324 else :
324- saveImages (fake_samples , "./%s_samples/train_pass%s.png" % (dataSource , train_pass ))
325+ save_images (fake_samples , "./%s_samples/train_pass%s.png" % (dataSource , train_pass ))
325326 dis_trainer .finishTrain ()
326327 gen_trainer .finishTrain ()
327328
0 commit comments