本来想讨论InfoGAN的,先留个坑吧。今天先讨论标准的GAN,即生成对抗网络。GAN最先由Ian Goodfellow在2014年提出,跟Variational Auto-Encoder (VAE)的时间差不多,二者都是非常好的生成网络,在无监督学习中发挥了重要的作用。

Basic theory

生成对抗网络属于一个minmax game,其目标是Learn a generator, whose distribution matches the real data distribution ,即实现一个生成器用于从随机分布 (一般为高斯) 的噪声中生成与目标样本类似的模拟$x_g$。

为了衡量的相似性,设计一个称为Discriminator的对抗网络,该网络的输入可以是真实的,也可以是有generator生成的伪造的样本。我们要求能够分辨出输入给它的样本的真实性。因此discriminator的输出应该是样本真实的概率,

可以看出,当接近时,,即discriminator无法判断其输入来自真实样本还是伪造的样本。因此构成了一对相互对抗的网络,即生成对抗。相应的目标函数为,

参考Goodfellow的论文,该目标函数的优化求解分为两步,

Step1

  • 随机生成minibatch个z noise作为generator的输入,得到对应的输出
  • 随机选取minibatch个real data样本
  • 计算
  • 沿梯度方向传递给Discriminator的参数,进行参数学习。

Step2

  • 随机生成minibatch个z noise作为generator的输入,得到对应的输出
  • 计算
  • 沿梯度方向传递给Generator的参数,进行参数学习。

交替重复以上两步,直到收敛,即 After several steps of training, if G and D have enough capacity, they will reach a point at which both cannot improve.

Conditional case

以上是最原始的GAN,是一种无监督的网络。那么,以MNIST手写体数据库为例,如果我们想得到一个能够生成特定数字的生成器,应该如何做? 这一问题可以理解为GAN的有监督学习,即conditional GAN。这里插一句,InfoGAN可以在无标签的情况下通过将$z$分解为noise+latent两部分,无监督地学到具有样本的context semantic representations.绝对秒杀原始的GAN。。。

条件GAN的思路是什么呢?如上图所示,从左往右分别是原始的GAN、条件GAN两个网络。conditional GAN在generator和discriminator的输入部分均添加了一个新的变量$y$,即样本的标签,作为一种固定的约束,指导网络学习到样本内部的区别。但是网络的目标函数和训练过程的变化很小。 这个思路,和conditional variational auto-encoder非常相似。更新后的目标函数如下所示,

GAN的TensorFlow实现

首先吐个槽,GAN的调参真的不是一般的复杂,我发现Generator和Discriminator互搏的时候,经常训着训着,loss就偏了,最后的输出也很诡异。然后是激活函数的选择、dropout的keep_prob,以及z的长度,batch的大小都会有影响,我提供了GANConditional GAN的notebook,感兴趣可以自己去试试看。。。下面我分步骤介绍一下Conditional GAN的实现。

  • Initialization

    1
    2
    3
    4
    5
    6
    batch_size = 64
    z_len = 100
    z_noise = tf.placeholder(dtype=tf.float32, shape=[None, z_len], name='z_noise')
    y = tf.placeholder(dtype=tf.float32, shape=[None, 10], name='y')
    x_data = tf.placeholder(dtype=tf.float32, shape=[None, 784], name='x_data')
    keep_prob = tf.placeholder(dtype=tf.float32, shape=(), name='keep_prob')
  • Generator

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    def generator(z_noise, y, keep_prob, namescope='generator'):
    """The generator"""
    with tf.name_scope(namescope):
    net = tf.concat([z_noise, y], axis=1)
    net = tf.layers.dense(net, units=150, activation=tf.nn.relu, name='g_fc1')
    net = tf.nn.dropout(net, keep_prob=keep_prob)
    net = tf.layers.dense(net, units=300, activation=tf.nn.relu, name='g_fc2')
    net = tf.nn.dropout(net, keep_prob=keep_prob)
    net = tf.layers.dense(net, units=784, activation=tf.nn.sigmoid, name='g_fc3')
    return net
  • Discriminator

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    def discriminator(d_in, y, z_len, keep_prob, namescope='discriminator', reuse=True):
    """The discriminator"""
    with tf.name_scope(namescope):
    net = tf.concat([d_in, y], axis=1)
    net = tf.layers.dense(net, units=300, activation=tf.nn.relu, name='d_fc1', reuse=reuse)
    net = tf.nn.dropout(net, keep_prob=keep_prob)
    net = tf.layers.dense(net, units=150, activation=tf.nn.relu, name='d_fc2', reuse=reuse)
    net = tf.nn.dropout(net, keep_prob=keep_prob)
    net = tf.layers.dense(net, units=1, activation=tf.nn.sigmoid, name='d_fc4', reuse=reuse)
    return net
  • Network instance

    1
    2
    3
    4
    # generate the network
    x_g = generator(z_noise, y, keep_prob)
    d_g = discriminator(x_g, y, z_len, keep_prob, reuse=False)
    d_data = discriminator(x_data, y, z_len, keep_prob)
  • Loss and optimizer

    1
    2
    3
    4
    5
    6
    7
    8
    # get variables
    varlist = tf.trainable_variables() # 查看待训练的参数,为了获取G和D两个网络的参数列表
    # The objective
    with tf.name_scope("loss"):
    loss_d = - (tf.reduce_mean(tf.log(1e-8 + d_data)) + tf.reduce_mean(tf.log(1e-8 + 1 - d_g)))
    loss_g = - tf.reduce_mean(tf.log(1e-8 + d_g))
    train_op_g = tf.train.AdamOptimizer(0.0001).minimize(loss_g, var_list=varlist[0:6])
    train_op_d = tf.train.AdamOptimizer(0.0001).minimize(loss_d, var_list=varlist[6:])

Conditional GAN在MNIST上的测试结果

针对MNIST手写体数据库,实现了一个可以根据标签生成指定数字的Conditional GAN,网络的配置如下表,参考了这篇博客

Subnet Layer Nodes Activation Dropout
Generator input [z, y] 32+10 —- —-
Generator FC 150 relu T
Generator FC 300 relu T
Generator FC 784 sigmoid F
Discriminator input 784+10 —- —-
Discriminator FC 300 relu T
Discriminator FC 150 relu T
Discriminator FC 1 sigmoid F

下面贴一下实验结果 (生成的手写体图像,每行对应一个数字),可以看出,随着迭代次数的增加,生成的数字越来约清晰,且准确性在提升。

References