|
| 1 | +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +from paddle.trainer_config_helpers import * |
| 15 | + |
| 16 | +mode = get_config_arg("mode", str, "generator") |
| 17 | +assert mode in set(["generator", |
| 18 | + "discriminator", |
| 19 | + "generator_training", |
| 20 | + "discriminator_training"]) |
| 21 | + |
| 22 | +is_generator_training = mode == "generator_training" |
| 23 | +is_discriminator_training = mode == "discriminator_training" |
| 24 | +is_generator = mode == "generator" |
| 25 | +is_discriminator = mode == "discriminator" |
| 26 | + |
| 27 | +# The network structure below follows the ref https://arxiv.org/abs/1406.2661 |
| 28 | +# Here we used two hidden layers and batch_norm |
| 29 | + |
| 30 | +print('mode=%s' % mode) |
| 31 | +# the dim of the noise (z) as the input of the generator network |
| 32 | +noise_dim = 10 |
| 33 | +# the dim of the hidden layer |
| 34 | +hidden_dim = 10 |
| 35 | +# the dim of the generated sample |
| 36 | +sample_dim = 2 |
| 37 | + |
| 38 | +settings( |
| 39 | + batch_size=128, |
| 40 | + learning_rate=1e-4, |
| 41 | + learning_method=AdamOptimizer(beta1=0.5) |
| 42 | +) |
| 43 | + |
| 44 | +def discriminator(sample): |
| 45 | + """ |
| 46 | + discriminator ouputs the probablity of a sample is from generator |
| 47 | + or real data. |
| 48 | + The output has two dimenstional: dimension 0 is the probablity |
| 49 | + of the sample is from generator and dimension 1 is the probabblity |
| 50 | + of the sample is from real data. |
| 51 | + """ |
| 52 | + param_attr = ParamAttr(is_static=is_generator_training) |
| 53 | + bias_attr = ParamAttr(is_static=is_generator_training, |
| 54 | + initial_mean=1.0, |
| 55 | + initial_std=0) |
| 56 | + |
| 57 | + hidden = fc_layer(input=sample, name="dis_hidden", size=hidden_dim, |
| 58 | + bias_attr=bias_attr, |
| 59 | + param_attr=param_attr, |
| 60 | + act=ReluActivation()) |
| 61 | + |
| 62 | + hidden2 = fc_layer(input=hidden, name="dis_hidden2", size=hidden_dim, |
| 63 | + bias_attr=bias_attr, |
| 64 | + param_attr=param_attr, |
| 65 | + act=LinearActivation()) |
| 66 | + |
| 67 | + hidden_bn = batch_norm_layer(hidden2, |
| 68 | + act=ReluActivation(), |
| 69 | + name="dis_hidden_bn", |
| 70 | + bias_attr=bias_attr, |
| 71 | + param_attr=ParamAttr(is_static=is_generator_training, |
| 72 | + initial_mean=1.0, |
| 73 | + initial_std=0.02), |
| 74 | + use_global_stats=False) |
| 75 | + |
| 76 | + return fc_layer(input=hidden_bn, name="dis_prob", size=2, |
| 77 | + bias_attr=bias_attr, |
| 78 | + param_attr=param_attr, |
| 79 | + act=SoftmaxActivation()) |
| 80 | + |
| 81 | +def generator(noise): |
| 82 | + """ |
| 83 | + generator generates a sample given noise |
| 84 | + """ |
| 85 | + param_attr = ParamAttr(is_static=is_discriminator_training) |
| 86 | + bias_attr = ParamAttr(is_static=is_discriminator_training, |
| 87 | + initial_mean=1.0, |
| 88 | + initial_std=0) |
| 89 | + |
| 90 | + hidden = fc_layer(input=noise, |
| 91 | + name="gen_layer_hidden", |
| 92 | + size=hidden_dim, |
| 93 | + bias_attr=bias_attr, |
| 94 | + param_attr=param_attr, |
| 95 | + act=ReluActivation()) |
| 96 | + |
| 97 | + hidden2 = fc_layer(input=hidden, name="gen_hidden2", size=hidden_dim, |
| 98 | + bias_attr=bias_attr, |
| 99 | + param_attr=param_attr, |
| 100 | + act=LinearActivation()) |
| 101 | + |
| 102 | + hidden_bn = batch_norm_layer(hidden2, |
| 103 | + act=ReluActivation(), |
| 104 | + name="gen_layer_hidden_bn", |
| 105 | + bias_attr=bias_attr, |
| 106 | + param_attr=ParamAttr(is_static=is_discriminator_training, |
| 107 | + initial_mean=1.0, |
| 108 | + initial_std=0.02), |
| 109 | + use_global_stats=False) |
| 110 | + |
| 111 | + return fc_layer(input=hidden_bn, |
| 112 | + name="gen_layer1", |
| 113 | + size=sample_dim, |
| 114 | + bias_attr=bias_attr, |
| 115 | + param_attr=param_attr, |
| 116 | + act=LinearActivation()) |
| 117 | + |
| 118 | +if is_generator_training: |
| 119 | + noise = data_layer(name="noise", size=noise_dim) |
| 120 | + sample = generator(noise) |
| 121 | + |
| 122 | +if is_discriminator_training: |
| 123 | + sample = data_layer(name="sample", size=sample_dim) |
| 124 | + |
| 125 | +if is_generator_training or is_discriminator_training: |
| 126 | + label = data_layer(name="label", size=1) |
| 127 | + prob = discriminator(sample) |
| 128 | + cost = cross_entropy(input=prob, label=label) |
| 129 | + classification_error_evaluator(input=prob, label=label, name=mode+'_error') |
| 130 | + outputs(cost) |
| 131 | + |
| 132 | +if is_generator: |
| 133 | + noise = data_layer(name="noise", size=noise_dim) |
| 134 | + outputs(generator(noise)) |
0 commit comments