从损失函数角度分析GAN和WGAN等经典的网络模型;分析GAN评价的两个指标Inception Score 和Frechet Inception Distance(FID)。

损失函数角度分析GAN模型

对抗生成网络是尝试使用无监督的方式学习原数据的概率分布。在GAN的训练过程中,使用损失函数反映真实图像分布和生成图像分布之间的差异。本文主要介绍两种GAN损失函数:

  • 最大最小化损失函数(minmax loss)
  • Wasserstein损失函数(Wasserstein loss)

除了本文介绍的两个损失函数,还有许多其他不同的损失函数。

GAN的训练过程中有两个损失函数:一个作为网络$G$的损失函数,一个作为网络$D$的损失函数。问题是,如果使用两个损失函数反映真实图像分布和生成图像分布之间的差异呢?

(1) 最大最小化损失函数 $$ \begin{equation}\label{eq1} \min_G \max_D V(D,G) =E_{x \sim p_{data}(x)}[ \log D(x)] + E_{z \sim p_z (z)}[log (1- D(G(z)))] \tag{1} \end{equation} $$

其中,$D(x)$ 是$D$网络对真实样本$x$ 预测为真的概率估计, $E_x$是所有的样本的期望值,$G(z)$ 是网络$G$的生成结果, $D(G(z))$是$D$网络对生成样本判别为真的概率估计, $E_z$ 是对所有生成样本判别为真的概率估计。

$E_{x \sim p_{data}(x)}[ \log D(x)]$ 表示网络$D$判别真实的样本是真,$D(G(z))$表示网络$D$判别网路$G$生成的样本为真,$E_{z \sim p_z (z)}[log (1- D(G(z)))]$表示网络$D$判别网络$G$生成的样本为假。在网络$D$训练的过程中,需要最大化两个子式子;在网络$G$训练的过程中是需要最小化右边的式子,即$ E_{z \sim p_z (z)}[log (1- D(G(z)))]$。上述的计算都需要进行log运算。$D$网络承担的是分类器的功能,网络的输出是在$[0, 1]$之间,一般来说以0.5为界限,如果大于0.5,那么$D$网络认为该样本是真实的样本,否则是虚假的样本。

该损失函数来源于交叉熵。如果两个分布$p$ 和$q$,那么两者之间的交叉熵为: $$ H(p, q) =-\sum_{i} p_{i} \log q_{i} $$ 其中$p$ 和$q$分别是真实和预测的分布。两者都是离散的分布,如果是连续的话,那么应该使用积分符号而非求和符合。最初GAN的设计是一个二分类,判别图像是真实的$1$还是虚假的$0$。GAN的训练目标可以表示为以下形式:

$$ H((x_{1}, y_{1}), D)=-y_{1} \log D(x_{1})-(1-y_{1}) \log (1-D(x_{1})) $$

对于真实样本$x$,网络$D$可以映射到 $[0,1]$区间,对于生成样本的判别,需要转化一下 $1-D(z)$,所以输出就约束于$[0, 1]$。上述是一个样本,然后将所有的样本加起来求期望得

$$ H((x _{i}, y _{i}) _{i=1}^{N}, D)=-\sum _{i=1}^{N} y _{i} \log D(x _{i})-\sum _{i=1}^{N}(1-y _{i}) \log (1-D(x _{i})) $$

输入样本到$D$网络中的样本一般是来自真实数据集,一半是来自$G$网络,并且把求和符号转换成期望,写成如下的形式:

$$ H((x _{i}, y _{i}) _{i=1}^{\infty}, D)=-\frac{1}{2} E _{x \sim p _{\text {data }}}[\log D(x)]-\frac{1}{2} E _{z}[\log (1-D(G(z)))] $$ 是不是和公式$(1)$很类似。

(2)修正后的最大最小化损失函数

最初的论文发现,对抗生成网路训练的开始几个阶段,最大最小化损失函数容易停止不动,这可能是$D$网络判别太容易。所以论文中修改原来损失函数中的$G$网络为最大化$ \log D(G(z))$。

(3)Wasserstein损失函数

在Wasserstein GAN(WGAN)中网络$D$不再承担分类样本的功能。这里的$D$只要能使得真实的样本的输出大于$G$生成样本就可以。所以这里的额$D$网路更像是一种"critic"的角色而不是一个"discriminator"的角色。损失函数就变得异常简单

Critic Loss: $D(x) - D(G(z))$

在训练过程中,$D$网络去最大化上述式子。也就是说,$D$网络努力最大化真实样本输出和生成样本输出之间的差距。

Generator Loss: $D(G(z))$

训练过程中,$G$网络最大化上述式子。也就是说,$G$网络的输出要被$D$网络认为是“真”。$D$网络的输出不必约束于$[0, 1]$之间。

两个评价指标

两个指标的计算都是基于Inception network预训练模型,然后得到某一层的信息,最后进行计算。Inception Score是基于最后一层softmax层,Frechet Inception Distance是基于中间的某个特征层(v3 pool3)得到的特征向量。

(1)Inception Score

在最后softmax层中计算两个分布之间的距离 $$ \begin{equation} IS=exp[E_xD_{KL}(p(y|x)||p(y))] \end{equation} $$

原作者认为$p(y|x)$ 是图像真实度的衡量,$p(y)$是由前者积分得到, 是图像多样性的衡量。所以IS是计算的$p(y|x)$和$p(y)$两个分布之间的KL 散度。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
import math
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.models import inception_v3
net = inception_v3(pretrained=True).cuda()

def inception_score(images, batch_size=5):
    scores = []
    for i in range(int(math.ceil(float(len(images)) / float(batch_size)))):
        batch = Variable(torch.cat(images[i * batch_size: (i + 1) * batch_size], 0))
        s, _ = net(batch)  # skipping aux logits
        scores.append(s)
    p_yx = F.softmax(torch.cat(scores, 0), 1)
    p_y = p_yx.mean(0).unsqueeze(0).expand(p_yx.size(0), -1)
    KL_d = p_yx * (torch.log(p_yx) - torch.log(p_y))
    final_score = KL_d.mean()
    return final_score

使用Inception Score进行评价的时候,只需要输入一个生成图像路径即可。

(2)Frechet Inception Distance

在特征维度计算两个分布的距离 $$ \begin{equation} {d}^{2}((m, C),(m_{w}, C_{w}))=|m-m_{2}|+{Tr}(C+C_{w}-2(C C_{w})^{1 / 2}) \end{equation} $$

其中$m$ 和$m_w$表示真实图像和生成图像在特征空间的均值,$C$和$C_w$表示真实图像和生成图像在特征空间向量的协方差矩阵中的协方差(在二维中就是方差)。$Tr$表示矩阵中的迹(主对角线元素之和)。

基于numpy的代码实现。两个正太分布之间的距离。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
# calculate frechet inception distance
def calculate_fid(act1, act2):
	# calculate mean and covariance statistics
	mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)
	mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)
	ssdiff = numpy.sum((mu1 - mu2)**2.0)
	covmean = sqrtm(sigma1.dot(sigma2))
	if iscomplexobj(covmean):
		covmean = covmean.real
	fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
	return fid

使用FID进行GAN的评价时候,需要给定两个路径分别指向真实的图像和生成的图像。

参考文献

Understanding Generative Adversarial Networks Loss Functions