Skip to content

Commit 5aa5979

Browse files
author
wangyang59
committed
minor changes on demo/gan following lzhao4ever comments
1 parent 531e835 commit 5aa5979

File tree

4 files changed

+8
-10
lines changed

4 files changed

+8
-10
lines changed

demo/gan/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ Then you can run the command below. The flag -d specifies the training data (cif
99

1010
$python gan_trainer.py -d cifar --useGpu 1
1111

12-
The generated images will be stored in ./cifar_samples/
12+
The generated images will be stored in ./cifar_samples/
13+
The corresponding models will be stored in ./cifar_params/

demo/gan/data/get_mnist_data.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env sh
2-
# This scripts downloads the mnist data and unzips it.
2+
# This script downloads the mnist data and unzips it.
33
set -e
44
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
55
rm -rf "$DIR/mnist_data"

demo/gan/gan_conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
settings(
3939
batch_size=128,
4040
learning_rate=1e-4,
41-
learning_method=AdamOptimizer(beta1=0.7)
41+
learning_method=AdamOptimizer(beta1=0.5)
4242
)
4343

4444
def discriminator(sample):

demo/gan/gan_trainer.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,8 @@ def load_mnist_data(imageFile):
8787
else:
8888
n = 10000
8989

90-
data = numpy.zeros((n, 28*28), dtype = "float32")
91-
92-
for i in range(n):
93-
pixels = numpy.fromfile(f, 'ubyte', count=28*28)
94-
data[i, :] = pixels / 255.0 * 2.0 - 1.0
90+
data = numpy.fromfile(f, 'ubyte', count=n*28*28).reshape((n, 28*28))
91+
data = data / 255.0 * 2.0 - 1.0
9592

9693
f.close()
9794
return data
@@ -235,15 +232,15 @@ def main():
235232
else:
236233
data_np = load_uniform_data()
237234

238-
# this create a gradient machine for discriminator
235+
# this creates a gradient machine for discriminator
239236
dis_training_machine = api.GradientMachine.createFromConfigProto(
240237
dis_conf.model_config)
241238
# this create a gradient machine for generator
242239
gen_training_machine = api.GradientMachine.createFromConfigProto(
243240
gen_conf.model_config)
244241

245242
# generator_machine is used to generate data only, which is used for
246-
# training discrinator
243+
# training discriminator
247244
logger.info(str(generator_conf.model_config))
248245
generator_machine = api.GradientMachine.createFromConfigProto(
249246
generator_conf.model_config)

0 commit comments

Comments
 (0)