Skip to content

Commit 531e835

Browse files
author
wangyang59
committed
changes to demo/gan following lzhao4ever comments
1 parent 9a02bd4 commit 531e835

File tree

3 files changed

+18
-11
lines changed

3 files changed

+18
-11
lines changed

demo/gan/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
output/
2+
uniform_params/
3+
cifar_params/
4+
mnist_params/
25
*.png
36
.pydevproject
47
.project

demo/gan/gan_conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
is_generator = mode == "generator"
2525
is_discriminator = mode == "discriminator"
2626

27+
# The network structure below follows the ref https://arxiv.org/abs/1406.2661
28+
# Here we used two hidden layers and batch_norm
29+
2730
print('mode=%s' % mode)
2831
# the dim of the noise (z) as the input of the generator network
2932
noise_dim = 10

demo/gan/gan_trainer.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,8 @@ def load_mnist_data(imageFile):
9090
data = numpy.zeros((n, 28*28), dtype = "float32")
9191

9292
for i in range(n):
93-
pixels = []
94-
for j in range(28 * 28):
95-
pixels.append(float(ord(f.read(1))) / 255.0 * 2.0 - 1.0)
96-
data[i, :] = pixels
93+
pixels = numpy.fromfile(f, 'ubyte', count=28*28)
94+
data[i, :] = pixels / 255.0 * 2.0 - 1.0
9795

9896
f.close()
9997
return data
@@ -129,7 +127,7 @@ def merge(images, size):
129127
((images[idx, :].reshape((h, w, c), order="F").transpose(1, 0, 2) + 1.0) / 2.0 * 255.0)
130128
return img.astype('uint8')
131129

132-
def saveImages(images, path):
130+
def save_images(images, path):
133131
merged_img = merge(images, [8, 8])
134132
if merged_img.shape[2] == 1:
135133
im = Image.fromarray(numpy.squeeze(merged_img)).convert('RGB')
@@ -207,9 +205,15 @@ def main():
207205
useGpu = args.useGpu
208206
assert dataSource in ["mnist", "cifar", "uniform"]
209207
assert useGpu in ["0", "1"]
210-
208+
209+
if not os.path.exists("./%s_samples/" % dataSource):
210+
os.makedirs("./%s_samples/" % dataSource)
211+
212+
if not os.path.exists("./%s_params/" % dataSource):
213+
os.makedirs("./%s_params/" % dataSource)
214+
211215
api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100',
212-
'--gpu_id=' + args.gpuId)
216+
'--gpu_id=' + args.gpuId, '--save_dir=' + "./%s_params/" % dataSource)
213217

214218
if dataSource == "uniform":
215219
conf = "gan_conf.py"
@@ -231,9 +235,6 @@ def main():
231235
else:
232236
data_np = load_uniform_data()
233237

234-
if not os.path.exists("./%s_samples/" % dataSource):
235-
os.makedirs("./%s_samples/" % dataSource)
236-
237238
# this create a gradient machine for discriminator
238239
dis_training_machine = api.GradientMachine.createFromConfigProto(
239240
dis_conf.model_config)
@@ -321,7 +322,7 @@ def main():
321322
if dataSource == "uniform":
322323
plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass))
323324
else:
324-
saveImages(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass))
325+
save_images(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass))
325326
dis_trainer.finishTrain()
326327
gen_trainer.finishTrain()
327328

0 commit comments

Comments
 (0)