CV大模型系列之:GAN,博弈论下的一个实例

⚠️⚠️⚠️本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!

最近在更新"CV大模型系列"的过程中,收到了一些朋友的私信,问能不能把"GAN-VAE-扩散模型"这条线再展开来说说 。所以在今天的这篇文章中,我就来和大家一起重新细读这篇CV领域的经典制作 :GAN(Generative Adversarial Nets,生成式对抗网络) ,同时也理一遍它和我们之前讲的扩散模型基石DDPM间的原理差异。

虽然GAN发表的年份较早(2014年),但直到今天,它的核心idea的影响力还是巨大的。它挖下了一个巨大的坑,让人后在这10年里争相围绕着它做改进研究。

讲一个小八卦 ,当年GAN刚发表的时候,曾卷入一场"抄袭"的风波中,投诉方叫Jürgen Schmidhuber(LSTM之父),在2016年的NIPS大会上,他公开谴责GAN和自己1992年提出的PM(Predictability Minimization)模型非常相似。GAN的作者也反驳,声明自己并不知道PM模型,并且相关的解释在邮件里已经说得非常明白了。

在写这篇文章之前,我把PM模型翻出来扫了一遍,个人的想法是PM确实在1992年就创新性地提出了"对抗训练"这种思想,但模型本质还是和GAN有一定差别的。遇上撞idea这种事时,起决定作用的,也许是运气吧(Jürgen有很多超前的思想,但当时的数据和硬件质量,不足以证明他的模型的有效性。甚至是LSTM在1997年刚被提出时,也因为算力问题,并没有受到很强的重视)。

哎呀,扯远了,我们赶紧进入正文吧。

CV大模型系列文章导航(持续更新中):

🌸CV大模型系列之:扩散模型基石DDPM(模型架构篇)🌸

🌸CV大模型系列之:扩散模型基石DDPM(人人都能看懂的数学原理篇)🌸

🌸CV大模型系列之:扩散模型基石DDPM(源码解读与实操篇)🌸

🌸CV大模型系列之:全面解读VIT,它到底给植树人挖了多少坑🌸

🌸CV大模型系列之:多模态经典之作CLIP,探索图文结合的奥秘🌸

🌸CV大模型系列之:MAE,实现像素级图像重建🌸

🌸CV大模型系列之:MoCo v1,利用对比学习在CV任务上做无监督训练🌸

🌸CV大模型系列之:DALLE2,OpenAI文生图代表作解读🌸

🌸CV大模型系列之:GAN,博弈论下的一个实例

一、GAN在做一件什么事

在经济学里,我们经常讨论一个名词,叫博弈论(Game Theory) 。假设有一场双人游戏(Two Player's game) ,参赛的双方彼此是竞争或对抗 的角色,拥有不同的目标或利益。双方为了实现自己的目标,就必须要揣测对方可能会采用的所有行为,然后选取对自己最有利的方案,如此交手,最终使得整个游戏系统达到一种均衡

我们平常和人打牌、打麻将就是博弈论的一个实例;大家耳熟能详的名词"囚徒困境"也是博弈论的一个实例。同样,今天我们要谈的GAN,也是博弈论的一个实例

我们先来看GAN想做一件什么事 。如果你读过之前讲解扩散模型DDPM的这篇文章,你一定对下面这张图很熟悉。其实GAN和DDPM的目标是一致的 :我想训练一个模型,在训练的时候,喂给它一堆图片,让它去学习这堆图片的分布。这样,在推理阶段,我喂给模型一个随机噪声,它就可以帮我"生成"一张和它吃进去的那些图片风格相似的图出来了。一句话总结:GAN和扩散模型,终极目标都是学习训练数据的分布。

但是GAN学习数据分布的方法和DDPM不同:

  • DDPM是通过重参数等技巧,去学习数据的均值 <math xmlns="http://www.w3.org/1998/Math/MathML"> \m \m </math>\m和方差\sigma^{2,相当于真得是把这个分布具象化地学出来了

  • GAN则是采用一种更简便的方式,它才不去预先假设好什么数学分布,它也不关心这个分布具象化长什么样子 。它就是粗暴地放一个模型,用某些技巧(即后面我们要说的对抗性),强迫模型能学出长得像真实图片的数据就行(有点半玄学的意味在里面,至于为什么是"半",我们后面讲理论部分的时候会说)。

你可能会想,那乍一看,GAN好像比DDPM方便多了。但也正是因为这种不依赖具象分布的方便性,导致了GAN在训练过程中难收敛的问题,这也是后人对其着重做的改进点之一。

好,明确了GAN的目标后,我们现在来细看,GAN到底是通过何种方法,来强迫模型产出像模像样的图片的?

二、GAN的原理

2.1 GAN架构

上图刻画了GAN模型的主体架构,最关键的是两部分:

  • Generator(生成器) :用于学习数据的分布,并输出其学习成果。
  • Discriminator(鉴别器) :用于评估生成器的学习成果。

在代码实现中,这两者都是常见的MLP。

我们对这两部分做详细的解释:

(1)首先,你得有一些真实图片(real images),你从中筛选m张出来,组成一个sample

(2)然后,你可以从某个分布中(例如高斯分布),随机筛选m个噪声

(3)接着,你把m个噪声喂给Generator,得到它的学习成果。

(4)然后,你把真实图片Generator的学习成果,一起输入给Discriminator。

(5)最后,作为鉴别器,Discriminator需要去鉴别一张图片是来自真实图片,还是来自Generator杜撰的学习成果

欸,从这个过程里,你是不是能体会到 "对抗" 的含义了:在这场游戏中,Generator的目标是训练出尽可能逼真的图片,Discriminator的目标则是打假。所以我们不仅要提升Generator的造假能力,也要训练Discriminator的辨假能力,只有当这两者都足够强时,模型才能达到一种最优的均衡,产生理想的结果。

2.2 GAN Loss

看完了模型构造,我们来看一下GAN的损失函数设计。

和大部分损失函数不同,GAN的损失函数分成两部分:Generator(以下简称G)的损失和Discriminator(以下简称D)的损失

我们先不看前面那个min和max,直接看主体式子部分(式子中V表示Value,GAN不将其称呼为损失函数,而是称为价值函数):

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x): <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a p_{data} </math>pdata是客观存在的、真实世界的数据分布,也就是我们希望让模型学到的分布。x是指真实世界中的某张图片。 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ∼ p d a t a ( x ) x \sim p_{data}(x) </math>x∼pdata(x)指随机抽取真实世界的图片。

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> p z ( x ) p_{z}(x) </math>pz(x): <math xmlns="http://www.w3.org/1998/Math/MathML"> p z p_{z} </math>pz是指喂给G的噪声的分布,这个分布是我们定好的(例如定为高斯分布)。 <math xmlns="http://www.w3.org/1998/Math/MathML"> z ∼ p z ( z ) z \sim p_{z}(z) </math>z ∼ pz(z)是指从定好的噪声分布中随机抽取噪声。

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> G ( z ) G(z) </math>G(z):表示随机噪声经过G后的输出结果,也就是2.1中所说的G的学习成果

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g D ( x ) logD(x) </math>logD(x):指真实世界图片x过D的输出结果,表示根据D的鉴别,x属于真实世界图片的概率

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g ( 1 − D ( G ( z ) ) log(1-D(G(z)) </math>log(1−D(G(z)): <math xmlns="http://www.w3.org/1998/Math/MathML"> D ( G ( z ) D(G(z) </math>D(G(z)表示经过D的鉴别,噪声z属于真实世界图片的概率。因此 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 − D ( G ( z ) 1-D(G(z) </math>1−D(G(z)自然表示"噪声G不属于真实世界图片的概率"

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_{g}(x) </math>pg(x):这一项在上式中没有出现,但是后文我们会经常用到,所以这里也作下说明。 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg表示G学出来的数据分布 ,换句话说,我们希望 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg能够逼近 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a p_{data} </math>pdata。

明确了主体部分的含义,我们再来细看min和max

2.2.1 max

max的意思是,假设我们把G固定下来(G的参数不再变了) 。那么此时,我们的优化目标就变成

  • 当输入数据是真实世界图片x时,我们希望 <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g D ( x ) logD(x) </math>logD(x)尽量大
  • 当输入数据是由G加工产生的噪声z时,我们希望 <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g ( 1 − D ( G ( z ) ) log(1-D(G(z)) </math>log(1−D(G(z))也尽量大。

这样就能提升D的鉴别能力。这就是式子中max的含义。

2.2.2 min

min的意思是,假设我们把D固定下来,那么此时G要做的事情就是生成尽可能逼真的图片,去欺骗D 。此时我们的优化目标就变成:

  • 当D固定时,使得 <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g ( 1 − D ( G ( z ) ) log(1-D(G(z)) </math>log(1−D(G(z))这一项尽可能小。这项变小,意味着 <math xmlns="http://www.w3.org/1998/Math/MathML"> D ( G ( z ) ) D(G(z)) </math>D(G(z))这项变大,表示此时D难以分辨真假图片。

你可能想问,那为啥不管 <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g D ( x ) log D(x) </math>logD(x)这项了呢?因为此时D是固定的,这一项相当于是个常数,对我们的优化目标没有影响。

如此一来,我们就能提升G的造假能力,这就是式子中min的含义。

2.3 GAN训练分布变动可视化

通过对训练函数的解释,现在你可能对GAN的运作原理有更深刻的理解了 :它是从一个确定的噪声分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p z p_{z} </math>pz(例如高斯分布)中,随机抽取一个噪声喂给G,然后通过加强D的辨别能力,一步步迫使G能从噪声中还原数据真实的分布,产生以假乱真的图片。也即,在这个过程里,让G学到的分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg去逼近真实分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a p_{data} </math>pdata。

你看,在这个过程中,各种p都是一个抽象的代表,我们并没有用数学语言明确写出它的样子。如果你读过这个系列之前的DDPM,你就能感受到GAN和它的差异:DDPM的本质是去拟合一个数学分布中的各种参数的。

接下来,我们把GAN在训练过程中, <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg和 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a p_{data} </math>pdata的变动可视化地表达出来,更方便我们理解GAN的训练过程:

图中:

  • 黑色:真实图片 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a p_{data} </math>pdata的分布

  • 绿色:G学出的 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg的分布

  • 蓝色:D的输出分布

(a)表示在GAN训练的初始阶段 ,此时G和D都不强。因此绿色 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg距离黑色 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a p_{data} </math>pdata相差甚远。蓝色的D输出也十分混乱,没有规律性。但总体来看,D还是可以区分出真实世界图片和G造假图片的(表现在靠近黑色的部分较高,靠近绿色的部分较低)。

(b)模型学习的中间阶段,D在慢慢变强,可以发现这时蓝色的曲线更加稳定。

(c)同样是模型学习的中间阶段 ,G也在慢慢变强,可以发现绿色 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg正在向黑色 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a p_{data} </math>pdata靠近。

(d)模型学习的最终阶段, <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a p_{data} </math>pdata吻合(理想情况) ,这时D已经分不清真假数据了,因此它的输出分布变成一条蓝色的直线(理想情况下,它输出的概率为1/2,即每张图片都有一半的概率是真实图片)。

2.4 GAN整体训练流程(必看)

到这一步为止,我们就把GAN的核心技术讲完了,现在,我们将整个训练流程用伪代码的形式表达出来,并完整地串讲一次:

如上图,在每个step中,我们做以下几件事:

总体来说,GAN的训练遵循先"固定G训练D",再"固定D训练G"的模式

(1)首先,设定一个超参数k,表示在一个step中,我们要更新k次D。

(2)在每一个k的循环中, 我们从minibatch里随机抽取m张真实图片x,再从确定的噪声分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p z p_{z} </math>pz中随机抽取m个噪声z。在保持G不变的情况下 ,执行2.2.1中max的逻辑,更新k次D,增强D的鉴别能力

(3)结束k次循环后,我们将D固定下来 ,随机抽取m个噪声z,执行2.2.2中min的逻辑,更新1次G,增强G的造假能力

(4)如此循环,直到模型收敛。

(4)这句话其实挺耐人寻味的,因为实践中来看,GAN的收敛是很难的。这里我们相当于有G和D的两个目标函数,那怎么判断是否收敛呢? 很可能的情况是,训练过程就像个翘翘板,一会你收敛,一会我收敛导致整个系统很难达到平衡。这个坑也是GAN挖给后人来填的,后续有不少工作就是对Loss函数做改进,让其能更好收敛。

三、GAN的全局最优解与收敛性

尽管在实操中,GAN被证明是难收敛的。但是作者在论文中给出了详细的证明:虽然实操难收敛,但理论上,我们的GAN是有全局唯一的最优解,并且一定是可以收敛到最优解的。 我们来详细看一下作者的证明。

3.1 全局最优解

为了找出 <math xmlns="http://www.w3.org/1998/Math/MathML"> V ( G , D V(G, D </math>V(G,D这个总价值函数(损失函数)的全局最优解,由于我们的目标是让G去拟合真实数据的分布,因此,这个全局最优解也可以理解成当V满足前面所说的min和max条件时, <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg究竟长什么样子。 为了从理论上解答这个问题,作者按照训练流程,分成了两步:

  • 先在固定G的情况下,找出D的全局最优解

  • 再在固定D的情况下,找出G的全局最优解,也即求出 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg的最优解

3.1.1 D的全局最优解

根据作者的理论推理,在固定G的情况下,D的全局最优解就是(2)中列出的样子。那么具体是怎么推导出(2)的呢?我们来看(3)

在(3)的第一行,作者首先根据"期望的定义 ",将原来用期望E表达的式子改写成了积分的形式(有疑惑的朋友,可以百度一下期望的定义复习下)。然后,在固定G的情况下,对于D来说,它接受到的真实世界图片就可以表示成 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a ( x ) p_{data}(x) </math>pdata(x),接受到的G造假的图片就可以表示成 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g ( x ) p_{g}(x) </math>pg(x)。接下来我们求个导,就可以得到(2)中的结果啦,是不是很简单。

3.1.2 G的全局最优解

知道了D的全局最优解,我们就把D固定下来,把这个全局最优解带入回V中,我们就会得到:

好,根据2.2.2的min逻辑,我们就要来minimize C(G),那这里作者也采用了一个很巧妙的证明,即对于(4)中的分母,我们把它先乘上1/2,再乘上2,然后便可改写成下图中(5)的形式:

(5)中的两个KL散度又可以被写成(6)中Jensen-Shannon散度的形式(简称JSD),而JSD永远是非负的,且仅当 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g = p d a t a p_{g} = p_{data} </math>pg=pdata时,它才为0。因此C(G)的最小值必然就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g = p d a t a p_{g} = p_{data} </math>pg=pdata。

经过这一番推到,我们知道我们构造的损失函数是有全局最优解的,且刚好就是我们想要的 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g = p d a t a p_{g} = p_{data} </math>pg=pdata。所以现在我们要进一步证明,模型是可以收敛到这个全局最优解的。

3.2 模型收敛到全局最优解

作者在这里说,假设整个训练过程都比较稳健,同时在固定G的情况下,D可以收敛到它的最优解,那么最终 <math xmlns="http://www.w3.org/1998/Math/MathML"> p g p_{g} </math>pg也可以收敛到 <math xmlns="http://www.w3.org/1998/Math/MathML"> p d a t a p_{data} </math>pdata。作者下面给出了一堆证明,总结起来,这个证明的思路就是:convex函数的上限函数还是convex,因此它的上限函数依然可以收敛。

但是你可能也注意到了,这个证明的成立,是在作者说的一堆假设的前提下,而其实在实践中,作者所说的训练健壮性,是很难满足的(欸又有坑可以填)。所以,这也是文章开篇所说的,"半玄学"的"半"字的来由。

好啦,关于GAN的介绍,我们就讲到这里了(实验部分就不讲了,因为感觉不是很有意思😅)。是不是比大家想象得简单多了?在看完这篇文章后,建议大家再看一遍DDPM的数学原理篇,将两者比较阅读,可以方便大家感受GAN和扩散模型之间的异同之处。

相关推荐
偶信科技6 分钟前
国产极细拖曳线列阵:16mm“水下之耳”如何撬动智慧海洋新蓝海?
人工智能·科技·偶信科技·海洋设备·极细拖曳线列阵
Java后端的Ai之路28 分钟前
【神经网络基础】-神经网络学习全过程(大白话版)
人工智能·深度学习·神经网络·学习
庚昀◟43 分钟前
用AI来“造AI”!Nexent部署本地智能体的沉浸式体验
人工智能·ai·nlp·持续部署
喜欢吃豆1 小时前
OpenAI Realtime API 深度技术架构与实现指南——如何实现AI实时通话
人工智能·语言模型·架构·大模型
数据分析能量站1 小时前
AI如何重塑个人生产力、组织架构和经济模式
人工智能
wscats2 小时前
Markdown 编辑器技术调研
前端·人工智能·markdown
AI科技星2 小时前
张祥前统一场论宇宙大统一方程的求导验证
服务器·人工智能·科技·线性代数·算法·生活
GIS数据转换器2 小时前
基于知识图谱的个性化旅游规划平台
人工智能·3d·无人机·知识图谱·旅游
EnoYao2 小时前
Markdown 编辑器技术调研
前端·javascript·人工智能
TMT星球2 小时前
曹操出行上市后首次战略并购,进军万亿to B商旅市场
人工智能·汽车