一文通俗了解对抗生成网络(GAN)核心思想

下面我们来介绍一下最基本的gan的训练算法,不够严谨,但是容易接受。

首先,跟任何网络训练一样,我们需要初始化生成器G,和判别器D的参数

形式化公式就是如下:

1、然后在每一轮中,首先固定住G,训练D,具体怎么训练呢?

我们任意选取一些向量,送给G,同时从database中挑选出一些数据,使得判别器学会从database挑选出真实的图片打分高,任意选取向量从G中生成的图片,打分低。这样就是在训练判别器:

形式化公式如下:

稍微解释一下图片中的公式,训练判别器就是希望它对于真实的图片打分高,生成的图片打分低。而公式中是最大化那个式子,分解来看完全对应文字的解释:

对应真正的图片打分高,也就是最大化如下公式,公式如下:

2、第二步是固定判别器D,训练生成器G,我们还是任意给定一些向量,这些向量送给G,生成一些图片,然后喂进判别器进行判别。

首先我们的目的是使得生成器能够生成非常真实的图片,对于真正真实的图片来说,判别器的打分是高的,那么也就是说,我们需要训练生成器,使得通过生成器生成出来的图片让判别器打分高,尽可能的迷惑判别器,这样通过生成器生成出来的图片就是接近“真实的”,土话说就是跟真的好像啊。

形式化公式如下:

公式解释为:最大化使得判别器对于生成器生成的图片打分。

这里需要注意的是训练生成器的时候,一定要固定判别器的参数,因为在实际实现中,生成器和判别器会构成一个大网络,如果不固定判别器的参数去训练生成器的话。

因为目标是使得最后的得分高,网络这个时候仅仅更新最后一层的参数就能让最后的输出标量非常大,很显然这不是我们希望的,如果固定了判别器后面几个layer,训练前面生成器的参数就能正常学习。

实例