Skip to content

Commit ea5594d

Browse files
author
wangyang59
committed
modification of gan tutorial following luotao01 comments
1 parent 0186ede commit ea5594d

File tree

1 file changed

+17
-13
lines changed

1 file changed

+17
-13
lines changed

doc/tutorials/gan/index_en.md

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
# Generative Adversarial Networks (GAN)
22

3-
This demo implements GAN training described in the original GAN paper (https://arxiv.org/abs/1406.2661) and DCGAN (https://arxiv.org/abs/1511.06434).
3+
This demo implements GAN training described in the original [GAN paper](https://arxiv.org/abs/1406.2661) and deep convolutional generative adversarial networks [DCGAN paper](https://arxiv.org/abs/1511.06434).
44

55
The high-level structure of GAN is shown in Figure. 1 below. It is composed of two major parts: a generator and a discriminator, both of which are based on neural networks. The generator takes in some kind of noise with a known distribution and transforms it into an image. The discriminator takes in an image and determines whether it is artificially generated by the generator or a real image. So the generator and the discriminator are in a competitive game in which generator is trying to generate image to look as real as possible to fool the discriminator, while the discriminator is trying to distinguish between real and fake images.
66

77
<center>![](./gan.png)</center>
8-
<center>Figure 1. GAN-Model-Structure Source: ishmaelbelghazi.github.io/ALI/</center>
8+
<center>Figure 1. GAN-Model-Structure [Source](https://ishmaelbelghazi.github.io/ALI/)</center>
99

1010
The generator and discriminator take turn to be trained using SGD. The objective function of the generator is for its generated images being classified as real by the discriminator, and the objective function of the discriminator is to correctly classify real and fake images. When the GAN model is trained to converge to the equilibrium state, the generator will transform the given noise distribution to the distribution of real images, and the discriminator will not be able to distinguish between real and fake images at all.
1111

1212
## Implementation of GAN Model Structure
1313
Since GAN model involves multiple neural networks, it requires to use paddle python API. So the code walk-through below can also partially serve as an introduction to the usage of Paddle Python API.
1414

15-
There are three networks defined in gan_conf.py, namely **generator_training**, **discriminator_training** and **generator**. The relationship to the model structure we defined above is that **discriminator_training** is the discriminator, **generator** is the generator, and the **generator_training** combined the generator and discriminator since training generator would require the discriminator to provide loss function. This relationship is described in the following code
15+
There are three networks defined in gan_conf.py, namely **generator_training**, **discriminator_training** and **generator**. The relationship to the model structure we defined above is that **discriminator_training** is the discriminator, **generator** is the generator, and the **generator_training** combined the generator and discriminator since training generator would require the discriminator to provide loss function. This relationship is described in the following code:
1616
```python
1717
if is_generator_training:
1818
noise = data_layer(name="noise", size=noise_dim)
@@ -34,7 +34,7 @@ if is_generator:
3434
outputs(generator(noise))
3535
```
3636

37-
In order to train the networks defined in gan_conf.py, one first needs to initialize a Paddle environment, parse the config, create GradientMachine from the config and create trainer from GradientMachine as done in the code chunk below.
37+
In order to train the networks defined in gan_conf.py, one first needs to initialize a Paddle environment, parse the config, create GradientMachine from the config and create trainer from GradientMachine as done in the code chunk below:
3838
```python
3939
import py_paddle.swig_paddle as api
4040
# init paddle environment
@@ -60,7 +60,7 @@ dis_trainer = api.Trainer.create(dis_conf, dis_training_machine)
6060
gen_trainer = api.Trainer.create(gen_conf, gen_training_machine)
6161
```
6262

63-
In order to balance the strength between generator and discriminator, we schedule to train whichever one is performing worse by comparing their loss function value. The loss function value can be calculated by a forward pass through the GradientMachine
63+
In order to balance the strength between generator and discriminator, we schedule to train whichever one is performing worse by comparing their loss function value. The loss function value can be calculated by a forward pass through the GradientMachine.
6464
```python
6565
def get_training_loss(training_machine, inputs):
6666
outputs = api.Arguments.createArguments(0)
@@ -69,7 +69,7 @@ def get_training_loss(training_machine, inputs):
6969
return numpy.mean(loss)
7070
```
7171

72-
After training one network, one needs to sync the new parameters to the other networks. The code below demonstrates one example of such use case.
72+
After training one network, one needs to sync the new parameters to the other networks. The code below demonstrates one example of such use case:
7373
```python
7474
# Train the gen_training
7575
gen_trainer.trainOneDataBatch(batch_size, data_batch_gen)
@@ -84,13 +84,13 @@ copy_shared_parameters(gen_training_machine, generator_machine)
8484
## A Toy Example
8585
With the infrastructure explained above, we can now walk you through a toy example of generating two dimensional uniform distribution using 10 dimensional Gaussian noise.
8686

87-
The Gaussian noises are generated using the code below
87+
The Gaussian noises are generated using the code below:
8888
```python
8989
def get_noise(batch_size, noise_dim):
9090
return numpy.random.normal(size=(batch_size, noise_dim)).astype('float32')
9191
```
9292

93-
The real samples (2-D uniform) are generated using the code below
93+
The real samples (2-D uniform) are generated using the code below:
9494
```python
9595
# synthesize 2-D uniform data in gan_trainer.py:114
9696
def load_uniform_data():
@@ -106,12 +106,16 @@ $python gan_trainer.py -d uniform --useGpu 1
106106
```
107107
The generated samples can be found in ./uniform_samples/ and one example is shown below as Figure 2. One can see that it roughly recovers the 2D uniform distribution.
108108

109-
<center>![](./uniform_sample.png)</center>
110-
<center>Figure 2. Uniform Sample</center>
109+
<p align="center">
110+
<img src="./uniform_sample.png" width="256" height="256">
111+
</p>
112+
<p align="center">
113+
Figure 2. Uniform Sample
114+
</p>
111115

112116
## MNIST Example
113117
### Data preparation
114-
To download the MNIST data, one can use the following commands.
118+
To download the MNIST data, one can use the following commands:
115119
```bash
116120
$cd data/
117121
$./get_mnist_data.sh
@@ -121,10 +125,10 @@ $./get_mnist_data.sh
121125
Following the DC-Gan paper (https://arxiv.org/abs/1511.06434), we use convolution/convolution-transpose layer in the discriminator/generator network to better deal with images. The details of the network structures are defined in gan_conf_image.py.
122126

123127
### Training the model
124-
To train the GAN model on mnist data, one can use the following command
128+
To train the GAN model on mnist data, one can use the following command:
125129
```bash
126130
$python gan_trainer.py -d mnist --useGpu 1
127131
```
128132
The generated sample images can be found at ./mnist_samples/ and one example is shown below as Figure 3.
129133
<center>![](./mnist_sample.png)</center>
130-
<center>Figure 2. MNIST Sample</center>
134+
<center>Figure 3. MNIST Sample</center>

0 commit comments

Comments
 (0)