Skip to content

Commit c3ebff5

Browse files
author
wangyang59
committed
modified demo/gan following emailxuwei comments
1 parent 5aa5979 commit c3ebff5

File tree

3 files changed

+29
-28
lines changed

3 files changed

+29
-28
lines changed

demo/gan/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ The general training procedures are implemented in gan_trainer.py. The neural ne
77
In order to run the model, first download the corresponding data by running the shell script in ./data.
88
Then you can run the command below. The flag -d specifies the training data (cifar, mnist or uniform) and flag --useGpu specifies whether to use gpu for training (0 is cpu, 1 is gpu).
99

10-
$python gan_trainer.py -d cifar --useGpu 1
10+
$python gan_trainer.py -d cifar --use_gpu 1
1111

1212
The generated images will be stored in ./cifar_samples/
1313
The corresponding models will be stored in ./cifar_params/

demo/gan/gan_trainer.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def plot2DScatter(data, outputfile):
3131
'''
3232
x = data[:, 0]
3333
y = data[:, 1]
34-
print "The mean vector is %s" % numpy.mean(data, 0)
35-
print "The std vector is %s" % numpy.std(data, 0)
34+
logger.info("The mean vector is %s" % numpy.mean(data, 0))
35+
logger.info("The std vector is %s" % numpy.std(data, 0))
3636

3737
heatmap, xedges, yedges = numpy.histogram2d(x, y, bins=50)
3838
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
@@ -192,42 +192,42 @@ def get_layer_size(model_conf, layer_name):
192192

193193
def main():
194194
parser = argparse.ArgumentParser()
195-
parser.add_argument("-d", "--dataSource", help="mnist or cifar or uniform")
196-
parser.add_argument("--useGpu", default="1",
195+
parser.add_argument("-d", "--data_source", help="mnist or cifar or uniform")
196+
parser.add_argument("--use_gpu", default="1",
197197
help="1 means use gpu for training")
198-
parser.add_argument("--gpuId", default="0",
198+
parser.add_argument("--gpu_id", default="0",
199199
help="the gpu_id parameter")
200200
args = parser.parse_args()
201-
dataSource = args.dataSource
202-
useGpu = args.useGpu
203-
assert dataSource in ["mnist", "cifar", "uniform"]
204-
assert useGpu in ["0", "1"]
201+
data_source = args.data_source
202+
use_gpu = args.use_gpu
203+
assert data_source in ["mnist", "cifar", "uniform"]
204+
assert use_gpu in ["0", "1"]
205205

206-
if not os.path.exists("./%s_samples/" % dataSource):
207-
os.makedirs("./%s_samples/" % dataSource)
206+
if not os.path.exists("./%s_samples/" % data_source):
207+
os.makedirs("./%s_samples/" % data_source)
208208

209-
if not os.path.exists("./%s_params/" % dataSource):
210-
os.makedirs("./%s_params/" % dataSource)
209+
if not os.path.exists("./%s_params/" % data_source):
210+
os.makedirs("./%s_params/" % data_source)
211211

212-
api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100',
213-
'--gpu_id=' + args.gpuId, '--save_dir=' + "./%s_params/" % dataSource)
212+
api.initPaddle('--use_gpu=' + use_gpu, '--dot_period=10', '--log_period=100',
213+
'--gpu_id=' + args.gpu_id, '--save_dir=' + "./%s_params/" % data_source)
214214

215-
if dataSource == "uniform":
215+
if data_source == "uniform":
216216
conf = "gan_conf.py"
217217
num_iter = 10000
218218
else:
219219
conf = "gan_conf_image.py"
220220
num_iter = 1000
221221

222-
gen_conf = parse_config(conf, "mode=generator_training,data=" + dataSource)
223-
dis_conf = parse_config(conf, "mode=discriminator_training,data=" + dataSource)
224-
generator_conf = parse_config(conf, "mode=generator,data=" + dataSource)
222+
gen_conf = parse_config(conf, "mode=generator_training,data=" + data_source)
223+
dis_conf = parse_config(conf, "mode=discriminator_training,data=" + data_source)
224+
generator_conf = parse_config(conf, "mode=generator,data=" + data_source)
225225
batch_size = dis_conf.opt_config.batch_size
226226
noise_dim = get_layer_size(gen_conf.model_config, "noise")
227227

228-
if dataSource == "mnist":
228+
if data_source == "mnist":
229229
data_np = load_mnist_data("./data/mnist_data/train-images-idx3-ubyte")
230-
elif dataSource == "cifar":
230+
elif data_source == "cifar":
231231
data_np = load_cifar_data("./data/cifar-10-batches-py/")
232232
else:
233233
data_np = load_uniform_data()
@@ -308,18 +308,20 @@ def main():
308308
else:
309309
curr_train = "gen"
310310
curr_strike = 1
311-
gen_trainer.trainOneDataBatch(batch_size, data_batch_gen)
311+
gen_trainer.trainOneDataBatch(batch_size, data_batch_gen)
312+
# TODO: add API for paddle to allow true parameter sharing between different GradientMachines
313+
# so that we do not need to copy shared parameters.
312314
copy_shared_parameters(gen_training_machine, dis_training_machine)
313315
copy_shared_parameters(gen_training_machine, generator_machine)
314316

315317
dis_trainer.finishTrainPass()
316318
gen_trainer.finishTrainPass()
317319
# At the end of each pass, save the generated samples/images
318320
fake_samples = get_fake_samples(generator_machine, batch_size, noise)
319-
if dataSource == "uniform":
320-
plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass))
321+
if data_source == "uniform":
322+
plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" % (data_source, train_pass))
321323
else:
322-
save_images(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass))
324+
save_images(fake_samples, "./%s_samples/train_pass%s.png" % (data_source, train_pass))
323325
dis_trainer.finishTrain()
324326
gen_trainer.finishTrain()
325327

paddle/gserver/tests/test_BatchNorm.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ P_DECLARE_double(checkgrad_eps);
3333
P_DECLARE_bool(thread_local_rand_use_global_seed);
3434
P_DECLARE_bool(prev_batch_state);
3535

36-
// Test that the convTrans forward is the same as conv backward
36+
// Test that the batchNormLayer can be followed by a ConvLayer
3737
TEST(Layer, batchNorm) {
3838
FLAGS_use_gpu = false;
3939
TestConfig configBN;
@@ -104,7 +104,6 @@ TEST(Layer, batchNorm) {
104104
LayerPtr convLayer;
105105
initTestLayer(config, &layerMap, &parameters2, &convLayer);
106106

107-
// Set convLayer outputGrad as convTransLayer input value
108107
bnLayer->forward(PASS_GC);
109108
convLayer->forward(PASS_GC);
110109

0 commit comments

Comments
 (0)