【硬核科普】一文读懂生成对抗网络GAN

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

本文基于Ian在2014年发表在NIPS的论文Generative Adversarial Nets来说明生成对抗网络。正如其名,这篇论文是GAN的开山之作。

我们之前学习过LeNet,AlexNet,ResNet,R-CNN等经典视觉深度学习模型,这些模型处理的任务是对图像中的物体进行分类 ,以及对物体的边界进行回归 ,但本质来说都是一个判别 模型,模型根据输入来提取并返回特征值。而GAN所包含的不仅仅是判别 ,更重要的,也是更有难度的是在生成

自从2023年,AIGC开始迅速爆火,AIGC中的"G"和GAN中的"G"都是生成的含义,GAN(的改进模型)也在AIGC中扮演着重要的部分。

为什么说生成是比判别更有难度的任务呢?这里我想举一个电影的例子:我们普通人在看过一场电影后,可以很快地总结一场电影的内容,并判别这场电影的好坏。而判别电影好坏的能力也不难掌握,只要我们稍微看过一些同类的电影,并了解一些专业的影评人的评价就够了。

但是要我们作为电影导演生成一部电影,这样的难度就不可同日而语了。如果我们只具备鉴赏电影好坏的能力,这是远远不够的,而且即便对于一个非常优秀的导演来说,生成一部好电影也是一件非常非常艰难的事!

1. 预备知识

在正式讲GAN之前需要先铺垫点博弈论的基础知识。

1.1 Minimax两人博弈

在博弈论中,一个 Minimax 两人博弈是指一种具有零和特性的两个玩家轮流进行决策的游戏,在这种游戏中,每个玩家都试图最大化自己的期望收益(或者说是最小化损失)。在这种情况下,每个玩家都知道对方会做出最优反应。

具体来说,Minimax 算法是一种确定性决策过程,用于寻找博弈树(即表示所有可能的走步序列及对应结果的游戏结构)上的最优策略。在两玩家零和游戏中,对于先行者(通常称为 MAX 玩家),目标是最大化其收益;而对于后行者(通常称为 MIN 玩家),目标则是最小化先行者的最大收益,也就是最大化自己的最小收益。

例如,在市场上砍价,老板(MAX玩家)会说"我的货物质量非常好,市场上独一无二",来尽量提升货物的售价。而卖家(MIN玩家)会说"我以前在其他商店买过此类的货物,你的货物并非独一无二,而且还比你的便宜的多"来压低售价。老板(MAX玩家)会再反驳说"我们的货物看起来很像,但是完全不是一个东西,我的要牛逼得多"......如此往复,直到达成两个人都满意得价格。

1.2 纳什均衡(Nash Equilibrium)

纳什均衡是博弈论中的核心概念之一,由美国数学家约翰·福布斯·纳什(John Forbes Nash Jr.)在20世纪50年代提出。在博弈论中,纳什均衡是指在一个非合作博弈中,各个参与者都不愿单方面改变策略,因为一旦他们改变自己的策略,他们的支付(即利益或效用)将会降低。换句话说,在纳什均衡状态下,每个参与者都选择了给定其他参与者策略情况下的最优策略。

更形式化的表述是:在一个多人博弈中,如果存在一组策略组合(每个参与者都有一个策略),使得对于博弈中的任意一个参与者,在其他人策略保持不变的情况下,更改自己的策略并不能获得更高的收益,那么这组策略就构成了一个纳什均衡。

例如,在经典的囚徒困境博弈中,尽管两个囚犯之间没有合作协定,但他们各自独立做出的决定形成了一种纳什均衡------都选择背叛对方,这样对于任何一个囚徒来说,改变自己的策略都不能获得更高得收益,因为对方已经背叛了自己。

2. GAN的设计

2.1 GAN的设计理念

GAN的设计理念借鉴了博弈论中的两人Minimax博弈过程。在生成对抗网络中,有两个主要的模型参与了一场类似于博弈的过程:

  1. 生成器(Generator):负责学习数据分布 ,并尝试根据随机噪声 生成接近真实数据的新样本。

  2. 判别器(Discriminator):任务是区分真实数据和生成器生成的虚假数据。

这个过程可以用"制作假币"来比拟:生成器是制作假币的铸币机,判别器是验钞机。

这两个网络通过迭代地进行对抗性训练来优化自身的性能。在每一轮训练中:

  • 生成器试图更好地模仿真实数据分布,生成足以骗过判别器的样本。
  • 判别器则努力学习如何准确地区分真实样本和生成器生成的样本
  • 为了骗过验钞机,铸币机要不断改进。
  • 而为了应对造假技术不断精湛的铸币机,验钞机也要不断提升检验假币的水平。

当GAN达到理想状态时,可以认为它达到了一个纳什均衡点:

  • 生成器生成的数据已经不能再被判别器有效地辨别出与真实数据的区别,即判别器对真实数据和生成数据的分类正确率均为50%,它无法单方面通过改变策略(提高判别能力)而获得更多优势。
  • 生成器也不能通过进一步改变策略生成更难以辨别的样本,因为它已经在最大可能的程度上模拟了真实数据分布。

这个均衡点即:

  • 验钞机不能区分假币和真币了,因为铸币机已经完全学习到真币的图案分布,已经能造出"真币"了,验钞机不能把"真币"当成假币,即无法区分真币和伪造的"真币"。
  • 铸币机的"造假"水平也不能进一步再提高了,因为不能造出"比真币还真"的真币。

在这个状态下,生成器和判别器都不再有动机去改变它们的策略,从而形成了一个类似博弈论中的纳什均衡。此时,生成器可以用来生成高质量的新颖样本,而判别器则在理论上失去了进一步区分的能力。

这里需要注意的一点是:GAN的生成器只是学习到了真实数据的分布规律,而并不知道这个规律是怎样的。(这也算是GAN的创新点)

举一个简单的例子,如果数据的真实分布规律是:

那生成器G学习到的规律只能是:

我们都知道后者是前者的泰勒展开,两者等价。但是生成器G却不可能知道数据的真实分布规律就是sin(x)。

这个真实的分布规律即马尔可夫链(Markov Chain),它是一种在概率论和统计学中广泛使用的数学模型,用于描述一个系统随时间演化的随机过程。它的重要特征是马尔可夫性质(Markov property),又称为无后效性或无记忆性,这意味着系统的未来状态仅依赖于当前状态,而不受以往状态的影响,即给定当前状态后,未来的演变过程与过去的状态无关。

2.2 GAN的模型结构

GAN的结构模型如下:

其中:

  • z 是输入的随机向量(噪声),严格来说z 应该是从一个满足某种特定先验的中随机生成的,它并不是完全随机的;
  • 是生成器,它是一个多层感知机MLP,其中z 是输入,是生成器的参数;
  • 是判别器,它也是一个MLP,其输入是生成器G 的输出以及真实数据xx 是从真实数据分布中生成的,是判别器的参数;
  • 输出概率是一个[0,1]的参数,结果接近0则表示判别器认为数据是生成器G 生成的,结果接近1则表示判别器认为输入是真实数据x

3. GAN的训练过程(核心)

3.1 GAN的价值函数

GAN的价值函数(value function)是用来同时训练生成器G 和判别器D的数学表达式:

上式中代表数学期望,举个例子求一个骰子掷出点数的数学期望,即,其中x 为骰子掷出点数,P ( x ) 为对应x 点的概率(当然,肯定都是)最终求得:

这个价值函数表达了两者的对抗过程:

  • 判别器D 的目标是最大化这个价值函数,它试图正确地区分真实数据样本x (来自真实数据分布)和生成器生成的假样本G (z) ,其中z 是从先验噪声分布中采样的随机向量。更具体来说,对于公式的前面一半,判别器D 要让它越大越好,因为x 是来自真实数据的样本,判断的结果应该接近1。而对于公式后面一半,判别器D 也是要让它越大越好,也就是说越小越好,即生成器G 产生的样本G (z) 经过判别器D后产生的结果应该接近0;
  • 生成器G 只能影响价值函数的后面一半,它的目标则是最小化这个价值函数,它希望通过学习调整自身参数,使得其生成的样本能够欺骗判别器,即让判别器认为生成样本G (z) 同样来自真实数据分布,即接近1。

简而言之,GAN的价值函数体现了一种博弈论上的极小极大游戏:判别器D努力学习最好的区分策略(最大化正确分类的概率),而生成器G则努力学习生成最真实的样本以混淆判别器(使生成样本被误判为真实的可能性增大)。当达到纳什均衡时,生成器能够生成与真实数据分布难以区分的样本。

3.2 GAN的训练过程

GAN的训练过程可以简单化为以下4个步骤:

图中黑色圆点代表真实数据分布,即;绿色线代表生成器G 生成的数据分布G (z) ;蓝色线是判别器D生成的最终概率。

(a)初始状态:生成器生成的数据与真实分布相差较远,判别器输出的概率波动也很大,不能很好区分真实分布和生成分布;

(b)判别器学习过程:判别器能较好区分真实分布和生成分布;

(c)生成器学习过程:生成的分布情况逐渐贴近真实分布;

不断重复(b)和(c)......

(d)达到纳什均衡:生成器学习到了真实分布,判别器无法判断出数据来源,输出的概率结果为0.5

3.3 GAN的训练过程算法

整个训练的具体算法过程如下:

其训练过程就是价值函数对生成器G 的权重,以及判别器D 的权重求偏导,即生成器G 和判别器D的模型反向传播。

需要注意的是判别器D 和生成器G 并不是1:1交替训练的!而是每隔k步才训练一次判别器D

k是一个大于1的超参数,需要按经验设定(ˉ▽ˉ;)...

这样做的理由正是本文开头就提及的:生成 是比判别 更复杂的任务,即造假要比鉴假更难。更具体来说有以下原因需要每隔k步才训练一次判别器D

  1. 判别器过强:如果判别器训练得太频繁而生成器跟不上,可能会导致判别器变得过于强大,以至于它可以轻易区分出真实数据和生成器产生的伪数据,这会使生成器收到的梯度信号过于强烈且不稳定,从而难以提升生成数据的质量。

  2. 生成器梯度消失或爆炸:由于GAN训练的本质是对抗性过程,若判别器过于优秀,则生成器可能无法获得有用的梯度来改进自身,因为判别器几乎总是能完美地区分真假样本,这样生成器就接收到接近零或者非常大的梯度,不利于训练。

  3. 训练稳定性:通过控制判别器和生成器的训练频率,可以调整二者之间的能力差距,维持一种动态平衡,有助于提高整个系统的训练稳定性,并促进收敛。

  4. 避免过度拟合:限制判别器的训练次数也有助于防止其过早地过拟合训练数据中的细节,促使它保持一定的泛化能力,这对于GAN的整体性能至关重要!

4. GAN的改进版本

GAN自从2014年由Ian Goodfellow等人首次提出以来,已经发展出了众多改进版本,以克服原始GAN的一些问题,如模式塌陷(mode collapse)、训练不稳定性和收敛困难等。以下是一些重要的GAN改进版本:

1. Wasserstein GAN (WGAN):

WGAN引入了 Wasserstein 距离(也称为Earth Mover's Distance, EMD),通过优化判别器来最小化这种距离而不是传统的交叉熵损失,从而改善了训练稳定性并缓解了模式塌陷的问题。

2. Wasserstein GAN with Gradient Penalty (WGAN-GP):

WGAN-GP是对WGAN的进一步改进,它通过添加判别器梯度范数的惩罚项来约束判别器保持接近1-Lipschitz连续性,从而更严格地遵循Wasserstein距离的定义。

3. Deep Convolutional GAN (DCGAN):

DCGAN将卷积神经网络引入到GAN的生成器和判别器中,提高了图像生成的质量和效率。

4. Least Squares GAN (LSGAN):

LSGAN使用最小二乘损失替代原来的sigmoid交叉熵损失函数,以减小训练过程中的梯度消失和爆炸问题。

5. Conditional GAN (cGAN):

cGAN增加了条件信息到GAN的输入,使得生成的数据可以根据特定的标签或条件进行控制。

6. InfoGAN:

InfoGAN通过最大化生成器产生的样本的隐变量与观察到的数据之间的互信息,实现了在无监督条件下学习有意义的隐变量表示。

7. Progressive Growing GAN (PGGAN):

PGGAN通过逐步增加网络分辨率的方式来训练GAN,从较低分辨率开始逐渐提升,以提高生成图像的质量和细节。

8. StyleGAN:

StyleGAN采用了风格迁移的思想,通过分离样式和内容表示,可以更好地控制生成图像的高级属性和细节。

9. BigGAN:

BigGAN通过大规模批量训练和一些架构改进,显著提升了生成高分辨率图像的质量和多样性。

10. CycleGAN:

CycleGAN用于无配对的图像到图像转换任务,引入了循环一致性损失,能够在没有成对训练数据的情况下进行图像翻译。

以上仅列举了一部分GAN的改进版本,实际上GAN家族非常庞大,还包括许多其他的变种和应用领域的扩展,如BEGAN、Energy-Based GAN (EBGAN)、Spectral Normalization GAN (SN-GAN)、Self-Attention GAN (SAGAN)等。随着研究的不断推进,GAN的改进和新版本持续涌现。

最后再点一下题,为什么"生成"如此重要?因为没有生成就没有本文的插图。

(本文插图均由Midjourney生成)

相关推荐
slomay2 小时前
关于对比学习(简单整理
经验分享·深度学习·学习·机器学习
AI完全体3 小时前
【AI知识点】偏差-方差权衡(Bias-Variance Tradeoff)
人工智能·深度学习·神经网络·机器学习·过拟合·模型复杂度·偏差-方差
sp_fyf_20243 小时前
计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-02
人工智能·神经网络·算法·计算机视觉·语言模型·自然语言处理·数据挖掘
卷心菜小温4 小时前
【BUG】P-tuningv2微调ChatGLM2-6B时所踩的坑
python·深度学习·语言模型·nlp·bug
陈苏同学4 小时前
4. 将pycharm本地项目同步到(Linux)服务器上——深度学习·科研实践·从0到1
linux·服务器·ide·人工智能·python·深度学习·pycharm
FL16238631295 小时前
[深度学习][python]yolov11+bytetrack+pyqt5实现目标追踪
深度学习·qt·yolo
羊小猪~~5 小时前
深度学习项目----用LSTM模型预测股价(包含LSTM网络简介,代码数据均可下载)
pytorch·python·rnn·深度学习·机器学习·数据分析·lstm
龙的爹23335 小时前
论文 | Model-tuning Via Prompts Makes NLP Models Adversarially Robust
人工智能·gpt·深度学习·语言模型·自然语言处理·prompt
工业机器视觉设计和实现5 小时前
cnn突破四(生成卷积核与固定核对比)
人工智能·深度学习·cnn
醒了就刷牙5 小时前
58 深层循环神经网络_by《李沐:动手学深度学习v2》pytorch版
pytorch·rnn·深度学习