Generative adversarial network and its conditional case
本来想讨论InfoGAN的,先留个坑吧。今天先讨论标准的GAN,即生成对抗网络。GAN最先由Ian Goodfellow在2014年提出,跟Variational Auto-Encoder (VAE)的时间差不多,二者都是非常好的生成网络,在无监督学习中发挥了重要的作用。
Basic theory
生成对抗网络属于一个minmax game,其目标是Learn a generator, whose distribution matches the real data distribution ,即实现一个生成器用于从随机分布 (一般为高斯) 的噪声中生成与目标样本类似的模拟$x_g$。

为了衡量与的相似性,设计一个称为Discriminator的对抗网络,该网络的输入可以是真实的,也可以是有generator生成的伪造的样本。我们要求能够分辨出输入给它的样本的真实性。因此discriminator的输出应该是样本真实的概率,
可以看出,当与接近时,,即discriminator无法判断其输入来自真实样本还是伪造的样本。因此与构成了一对相互对抗的网络,即生成对抗。相应的目标函数为,
参考Goodfellow的论文,该目标函数的优化求解分为两步,
Step1
- 随机生成minibatch个z noise作为generator的输入,得到对应的输出;
- 随机选取minibatch个real data样本;
- 计算;
- 将沿梯度方向传递给Discriminator的参数,进行参数学习。
Step2
- 随机生成minibatch个z noise作为generator的输入,得到对应的输出;
- 计算;
- 将沿梯度方向传递给Generator的参数,进行参数学习。
交替重复以上两步,直到和收敛,即 After several steps of training, if G and D have enough capacity, they will reach a point at which both cannot improve.
Conditional case
以上是最原始的GAN,是一种无监督的网络。那么,以MNIST手写体数据库为例,如果我们想得到一个能够生成特定数字的生成器,应该如何做? 这一问题可以理解为GAN的有监督学习,即conditional GAN。这里插一句,InfoGAN可以在无标签的情况下通过将$z$分解为noise+latent两部分,无监督地学到具有样本的context semantic representations.绝对秒杀原始的GAN。。。

条件GAN的思路是什么呢?如上图所示,从左往右分别是原始的GAN、条件GAN两个网络。conditional GAN在generator和discriminator的输入部分均添加了一个新的变量$y$,即样本的标签,作为一种固定的约束,指导网络学习到样本内部的区别。但是网络的目标函数和训练过程的变化很小。 这个思路,和conditional variational auto-encoder非常相似。更新后的目标函数如下所示,
GAN的TensorFlow实现
首先吐个槽,GAN的调参真的不是一般的复杂,我发现Generator和Discriminator互搏的时候,经常训着训着,loss就偏了,最后的输出也很诡异。然后是激活函数的选择、dropout的keep_prob,以及z的长度,batch的大小都会有影响,我提供了GAN和Conditional GAN的notebook,感兴趣可以自己去试试看。。。下面我分步骤介绍一下Conditional GAN的实现。
Initialization
123456batch_size = 64z_len = 100z_noise = tf.placeholder(dtype=tf.float32, shape=[None, z_len], name='z_noise')y = tf.placeholder(dtype=tf.float32, shape=[None, 10], name='y')x_data = tf.placeholder(dtype=tf.float32, shape=[None, 784], name='x_data')keep_prob = tf.placeholder(dtype=tf.float32, shape=(), name='keep_prob')Generator
12345678910def generator(z_noise, y, keep_prob, namescope='generator'):"""The generator"""with tf.name_scope(namescope):net = tf.concat([z_noise, y], axis=1)net = tf.layers.dense(net, units=150, activation=tf.nn.relu, name='g_fc1')net = tf.nn.dropout(net, keep_prob=keep_prob)net = tf.layers.dense(net, units=300, activation=tf.nn.relu, name='g_fc2')net = tf.nn.dropout(net, keep_prob=keep_prob)net = tf.layers.dense(net, units=784, activation=tf.nn.sigmoid, name='g_fc3')return netDiscriminator
12345678910def discriminator(d_in, y, z_len, keep_prob, namescope='discriminator', reuse=True):"""The discriminator"""with tf.name_scope(namescope):net = tf.concat([d_in, y], axis=1)net = tf.layers.dense(net, units=300, activation=tf.nn.relu, name='d_fc1', reuse=reuse)net = tf.nn.dropout(net, keep_prob=keep_prob)net = tf.layers.dense(net, units=150, activation=tf.nn.relu, name='d_fc2', reuse=reuse)net = tf.nn.dropout(net, keep_prob=keep_prob)net = tf.layers.dense(net, units=1, activation=tf.nn.sigmoid, name='d_fc4', reuse=reuse)return netNetwork instance
1234# generate the networkx_g = generator(z_noise, y, keep_prob)d_g = discriminator(x_g, y, z_len, keep_prob, reuse=False)d_data = discriminator(x_data, y, z_len, keep_prob)Loss and optimizer
12345678# get variablesvarlist = tf.trainable_variables() # 查看待训练的参数,为了获取G和D两个网络的参数列表# The objectivewith tf.name_scope("loss"):loss_d = - (tf.reduce_mean(tf.log(1e-8 + d_data)) + tf.reduce_mean(tf.log(1e-8 + 1 - d_g)))loss_g = - tf.reduce_mean(tf.log(1e-8 + d_g))train_op_g = tf.train.AdamOptimizer(0.0001).minimize(loss_g, var_list=varlist[0:6])train_op_d = tf.train.AdamOptimizer(0.0001).minimize(loss_d, var_list=varlist[6:])
Conditional GAN在MNIST上的测试结果
针对MNIST手写体数据库,实现了一个可以根据标签生成指定数字的Conditional GAN,网络的配置如下表,参考了这篇博客。
Subnet | Layer | Nodes | Activation | Dropout |
---|---|---|---|---|
Generator | input [z, y] | 32+10 | —- | —- |
Generator | FC | 150 | relu | T |
Generator | FC | 300 | relu | T |
Generator | FC | 784 | sigmoid | F |
Discriminator | input | 784+10 | —- | —- |
Discriminator | FC | 300 | relu | T |
Discriminator | FC | 150 | relu | T |
Discriminator | FC | 1 | sigmoid | F |
下面贴一下实验结果 (生成的手写体图像,每行对应一个数字),可以看出,随着迭代次数的增加,生成的数字越来约清晰,且准确性在提升。
