⚠️⚠️⚠️本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!
最近在更新"CV大模型系列"的过程中,收到了一些朋友的私信,问能不能把"GAN-VAE-扩散模型"这条线再展开来说说 。所以在今天的这篇文章中,我就来和大家一起重新细读这篇CV领域的经典制作 :GAN(Generative Adversarial Nets,生成式对抗网络) ,同时也理一遍它和我们之前讲的扩散模型基石DDPM间的原理差异。
虽然GAN发表的年份较早(2014年),但直到今天,它的核心idea的影响力还是巨大的。它挖下了一个巨大的坑,让人后在这10年里争相围绕着它做改进研究。
讲一个小八卦 ,当年GAN刚发表的时候,曾卷入一场"抄袭"的风波中,投诉方叫Jürgen Schmidhuber(LSTM之父),在2016年的NIPS大会上,他公开谴责GAN和自己1992年提出的PM(Predictability Minimization)模型非常相似。GAN的作者也反驳,声明自己并不知道PM模型,并且相关的解释在邮件里已经说得非常明白了。
在写这篇文章之前,我把PM模型翻出来扫了一遍,个人的想法是 ,PM确实在1992年就创新性地提出了"对抗训练"这种思想,但模型本质还是和GAN有一定差别的。遇上撞idea这种事时,起决定作用的,也许是运气吧(Jürgen有很多超前的思想,但当时的数据和硬件质量,不足以证明他的模型的有效性。甚至是LSTM在1997年刚被提出时,也因为算力问题,并没有受到很强的重视)。
哎呀,扯远了,我们赶紧进入正文吧。
CV大模型系列文章导航(持续更新中):
🌸CV大模型系列之:扩散模型基石DDPM(人人都能看懂的数学原理篇)🌸
🌸CV大模型系列之:扩散模型基石DDPM(源码解读与实操篇)🌸
🌸CV大模型系列之:全面解读VIT,它到底给植树人挖了多少坑🌸
🌸CV大模型系列之:多模态经典之作CLIP,探索图文结合的奥秘🌸
🌸CV大模型系列之:MoCo v1,利用对比学习在CV任务上做无监督训练🌸
🌸CV大模型系列之:DALLE2,OpenAI文生图代表作解读🌸
🌸CV大模型系列之:GAN,博弈论下的一个实例
一、GAN在做一件什么事
在经济学里,我们经常讨论一个名词,叫博弈论(Game Theory) 。假设有一场双人游戏(Two Player's game) ,参赛的双方彼此是竞争或对抗 的角色,拥有不同的目标或利益。双方为了实现自己的目标,就必须要揣测对方可能会采用的所有行为,然后选取对自己最有利的方案,如此交手,最终使得整个游戏系统达到一种均衡。
我们平常和人打牌、打麻将就是博弈论的一个实例;大家耳熟能详的名词"囚徒困境"也是博弈论的一个实例。同样,今天我们要谈的GAN,也是博弈论的一个实例。
我们先来看GAN想做一件什么事 。如果你读过之前讲解扩散模型DDPM的这篇文章,你一定对下面这张图很熟悉。其实GAN和DDPM的目标是一致的 :我想训练一个模型,在训练的时候,喂给它一堆图片,让它去学习这堆图片的分布。这样,在推理阶段,我喂给模型一个随机噪声,它就可以帮我"生成"一张和它吃进去的那些图片风格相似的图出来了。一句话总结:GAN和扩散模型,终极目标都是学习训练数据的分布。
但是GAN学习数据分布的方法和DDPM不同:
-
DDPM是通过重参数等技巧,去学习数据的均值 <math xmlns="http://www.w3.org/1998/Math/MathML"> \m \m </math>\m和方差\sigma^{2,相当于真得是把这个分布具象化地学出来了。
-
GAN则是采用一种更简便的方式,它才不去预先假设好什么数学分布,它也不关心这个分布具象化长什么样子 。它就是粗暴地放一个模型,用某些技巧(即后面我们要说的对抗性),强迫模型能学出长得像真实图片的数据就行(有点半玄学的意味在里面,至于为什么是"半",我们后面讲理论部分的时候会说)。
你可能会想,那乍一看,GAN好像比DDPM方便多了。但也正是因为这种不依赖具象分布的方便性,导致了GAN在训练过程中难收敛的问题,这也是后人对其着重做的改进点之一。
好,明确了GAN的目标后,我们现在来细看,GAN到底是通过何种方法,来强迫模型产出像模像样的图片的?
二、GAN的原理
2.1 GAN架构
上图刻画了GAN模型的主体架构,最关键的是两部分:
- Generator(生成器) :用于学习数据的分布,并输出其学习成果。
- Discriminator(鉴别器) :用于评估生成器的学习成果。
在代码实现中,这两者都是常见的MLP。
我们对这两部分做详细的解释:
(1)首先,你得有一些真实图片(real images),你从中筛选m张出来,组成一个sample
(2)然后,你可以从某个分布中(例如高斯分布),随机筛选m个噪声
(3)接着,你把m个噪声喂给Generator,得到它的学习成果。
(4)然后,你把真实图片 和Generator的学习成果,一起输入给Discriminator。
(5)最后,作为鉴别器,Discriminator需要去鉴别一张图片是来自真实图片,还是来自Generator杜撰的学习成果。
欸,从这个过程里,你是不是能体会到 "对抗" 的含义了:在这场游戏中,Generator的目标是训练出尽可能逼真的图片,Discriminator的目标则是打假。所以我们不仅要提升Generator的造假能力,也要训练Discriminator的辨假能力,只有当这两者都足够强时,模型才能达到一种最优的均衡,产生理想的结果。
2.2 GAN Loss
看完了模型构造,我们来看一下GAN的损失函数设计。
和大部分损失函数不同,GAN的损失函数分成两部分:Generator(以下简称G)的损失和Discriminator(以下简称D)的损失。
我们先不看前面那个min和max,直接看主体式子部分(式子中V表示Value,GAN不将其称呼为损失函数,而是称为价值函数):
-
<math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x): <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a p_{data} </math>pdata是客观存在的、真实世界的数据分布,也就是我们希望让模型学到的分布。x是指真实世界中的某张图片。 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ∼ p d a t a ( x ) x \sim p_{data}(x) </math>x∼pdata(x)指随机抽取真实世界的图片。
-
<math xmlns="http://www.w3.org/1998/Math/MathML"> p z ( x ) p_{z}(x) </math>pz(x): <math xmlns="http://www.w3.org/1998/Math/MathML"> p z p_{z} </math>pz是指喂给G的噪声的分布,这个分布是我们定好的(例如定为高斯分布)。 <math xmlns="http://www.w3.org/1998/Math/MathML"> z ∼ p z ( z ) z \sim p_{z}(z) </math>z ∼ pz(z)是指从定好的噪声分布中随机抽取噪声。
-
<math xmlns="http://www.w3.org/1998/Math/MathML"> G ( z ) G(z) </math>G(z):表示随机噪声经过G后的输出结果,也就是2.1中所说的G的学习成果
-
<math xmlns="http://www.w3.org/1998/Math/MathML"> l o g D ( x ) logD(x) </math>logD(x):指真实世界图片x过D的输出结果,表示根据D的鉴别,x属于真实世界图片的概率。
-
<math xmlns="http://www.w3.org/1998/Math/MathML"> l o g ( 1 − D ( G ( z ) ) log(1-D(G(z)) </math>log(1−D(G(z)): <math xmlns="http://www.w3.org/1998/Math/MathML"> D ( G ( z ) D(G(z) </math>D(G(z)表示经过D的鉴别,噪声z属于真实世界图片的概率。因此 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 − D ( G ( z ) 1-D(G(z) </math>1−D(G(z)自然表示"噪声G不属于真实世界图片的概率"
-
<math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_{g}(x) </math>pg(x):这一项在上式中没有出现,但是后文我们会经常用到,所以这里也作下说明。 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg表示G学出来的数据分布 ,换句话说,我们希望 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg能够逼近 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a p_{data} </math>pdata。
明确了主体部分的含义,我们再来细看min和max
2.2.1 max
max的意思是,假设我们把G固定下来(G的参数不再变了) 。那么此时,我们的优化目标就变成:
- 当输入数据是真实世界图片x时,我们希望 <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g D ( x ) logD(x) </math>logD(x)尽量大
- 当输入数据是由G加工产生的噪声z时,我们希望 <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g ( 1 − D ( G ( z ) ) log(1-D(G(z)) </math>log(1−D(G(z))也尽量大。
这样就能提升D的鉴别能力。这就是式子中max的含义。
2.2.2 min
min的意思是,假设我们把D固定下来,那么此时G要做的事情就是生成尽可能逼真的图片,去欺骗D 。此时我们的优化目标就变成:
- 当D固定时,使得 <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g ( 1 − D ( G ( z ) ) log(1-D(G(z)) </math>log(1−D(G(z))这一项尽可能小。这项变小,意味着 <math xmlns="http://www.w3.org/1998/Math/MathML"> D ( G ( z ) ) D(G(z)) </math>D(G(z))这项变大,表示此时D难以分辨真假图片。
你可能想问,那为啥不管 <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g D ( x ) log D(x) </math>logD(x)这项了呢?因为此时D是固定的,这一项相当于是个常数,对我们的优化目标没有影响。
如此一来,我们就能提升G的造假能力,这就是式子中min的含义。
2.3 GAN训练分布变动可视化
通过对训练函数的解释,现在你可能对GAN的运作原理有更深刻的理解了 :它是从一个确定的噪声分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p z p_{z} </math>pz(例如高斯分布)中,随机抽取一个噪声喂给G,然后通过加强D的辨别能力,一步步迫使G能从噪声中还原数据真实的分布,产生以假乱真的图片。也即,在这个过程里,让G学到的分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg去逼近真实分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a p_{data} </math>pdata。
你看,在这个过程中,各种p都是一个抽象的代表,我们并没有用数学语言明确写出它的样子。如果你读过这个系列之前的DDPM,你就能感受到GAN和它的差异:DDPM的本质是去拟合一个数学分布中的各种参数的。
接下来,我们把GAN在训练过程中, <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg和 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a p_{data} </math>pdata的变动可视化地表达出来,更方便我们理解GAN的训练过程:
图中:
-
黑色:真实图片 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a p_{data} </math>pdata的分布
-
绿色:G学出的 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg的分布
-
蓝色:D的输出分布
(a)表示在GAN训练的初始阶段 ,此时G和D都不强。因此绿色 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg距离黑色 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a p_{data} </math>pdata相差甚远。蓝色的D输出也十分混乱,没有规律性。但总体来看,D还是可以区分出真实世界图片和G造假图片的(表现在靠近黑色的部分较高,靠近绿色的部分较低)。
(b)模型学习的中间阶段,D在慢慢变强,可以发现这时蓝色的曲线更加稳定。
(c)同样是模型学习的中间阶段 ,G也在慢慢变强,可以发现绿色 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg正在向黑色 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a p_{data} </math>pdata靠近。
(d)模型学习的最终阶段, <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg和 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a p_{data} </math>pdata吻合(理想情况) ,这时D已经分不清真假数据了,因此它的输出分布变成一条蓝色的直线(理想情况下,它输出的概率为1/2,即每张图片都有一半的概率是真实图片)。
2.4 GAN整体训练流程(必看)
到这一步为止,我们就把GAN的核心技术讲完了,现在,我们将整个训练流程用伪代码的形式表达出来,并完整地串讲一次:
如上图,在每个step中,我们做以下几件事:
总体来说,GAN的训练遵循先"固定G训练D",再"固定D训练G"的模式。
(1)首先,设定一个超参数k,表示在一个step中,我们要更新k次D。
(2)在每一个k的循环中, 我们从minibatch里随机抽取m张真实图片x,再从确定的噪声分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p z p_{z} </math>pz中随机抽取m个噪声z。在保持G不变的情况下 ,执行2.2.1中max的逻辑,更新k次D,增强D的鉴别能力。
(3)结束k次循环后,我们将D固定下来 ,随机抽取m个噪声z,执行2.2.2中min的逻辑,更新1次G,增强G的造假能力。
(4)如此循环,直到模型收敛。
(4)这句话其实挺耐人寻味的,因为实践中来看,GAN的收敛是很难的。这里我们相当于有G和D的两个目标函数,那怎么判断是否收敛呢? 很可能的情况是,训练过程就像个翘翘板,一会你收敛,一会我收敛 ,导致整个系统很难达到平衡。这个坑也是GAN挖给后人来填的,后续有不少工作就是对Loss函数做改进,让其能更好收敛。
三、GAN的全局最优解与收敛性
尽管在实操中,GAN被证明是难收敛的。但是作者在论文中给出了详细的证明:虽然实操难收敛,但理论上,我们的GAN是有全局唯一的最优解,并且一定是可以收敛到最优解的。 我们来详细看一下作者的证明。
3.1 全局最优解
为了找出 <math xmlns="http://www.w3.org/1998/Math/MathML"> V ( G , D V(G, D </math>V(G,D这个总价值函数(损失函数)的全局最优解,由于我们的目标是让G去拟合真实数据的分布,因此,这个全局最优解也可以理解成当V满足前面所说的min和max条件时, <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg究竟长什么样子。 为了从理论上解答这个问题,作者按照训练流程,分成了两步:
-
先在固定G的情况下,找出D的全局最优解
-
再在固定D的情况下,找出G的全局最优解,也即求出 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg的最优解
3.1.1 D的全局最优解
根据作者的理论推理,在固定G的情况下,D的全局最优解就是(2)中列出的样子。那么具体是怎么推导出(2)的呢?我们来看(3)
在(3)的第一行,作者首先根据"期望的定义 ",将原来用期望E表达的式子改写成了积分的形式(有疑惑的朋友,可以百度一下期望的定义复习下)。然后,在固定G的情况下,对于D来说,它接受到的真实世界图片就可以表示成 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x),接受到的G造假的图片就可以表示成 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_{g}(x) </math>pg(x)。接下来我们求个导,就可以得到(2)中的结果啦,是不是很简单。
3.1.2 G的全局最优解
知道了D的全局最优解,我们就把D固定下来,把这个全局最优解带入回V中,我们就会得到:
好,根据2.2.2的min逻辑,我们就要来minimize C(G),那这里作者也采用了一个很巧妙的证明,即对于(4)中的分母,我们把它先乘上1/2,再乘上2,然后便可改写成下图中(5)的形式:
(5)中的两个KL散度又可以被写成(6)中Jensen-Shannon散度的形式(简称JSD),而JSD永远是非负的,且仅当 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g = p d a t a p_{g} = p_{data} </math>pg=pdata时,它才为0。因此C(G)的最小值必然就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g = p d a t a p_{g} = p_{data} </math>pg=pdata。
经过这一番推到,我们知道我们构造的损失函数是有全局最优解的,且刚好就是我们想要的 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g = p d a t a p_{g} = p_{data} </math>pg=pdata。所以现在我们要进一步证明,模型是可以收敛到这个全局最优解的。
3.2 模型收敛到全局最优解
作者在这里说,假设整个训练过程都比较稳健,同时在固定G的情况下,D可以收敛到它的最优解,那么最终 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg也可以收敛到 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a p_{data} </math>pdata。作者下面给出了一堆证明,总结起来,这个证明的思路就是:convex函数的上限函数还是convex,因此它的上限函数依然可以收敛。
但是你可能也注意到了,这个证明的成立,是在作者说的一堆假设的前提下,而其实在实践中,作者所说的训练健壮性,是很难满足的(欸又有坑可以填)。所以,这也是文章开篇所说的,"半玄学"的"半"字的来由。
好啦,关于GAN的介绍,我们就讲到这里了(实验部分就不讲了,因为感觉不是很有意思😅)。是不是比大家想象得简单多了?在看完这篇文章后,建议大家再看一遍DDPM的数学原理篇,将两者比较阅读,可以方便大家感受GAN和扩散模型之间的异同之处。