第十九周机器学习笔记:GAN的数学理论知识与实际应用的操作

第十九周周报

  • 摘要
  • Abstratc
  • [一、机器学习------GAN Basic Theory](#一、机器学习——GAN Basic Theory)
    • [1. Maximum Likelihood Estimation](#1. Maximum Likelihood Estimation)
    • [2. 复习训练GAN的过程](#2. 复习训练GAN的过程)
    • [3. Objective function与JS散度相关性推导](#3. Objective function与JS散度相关性推导)
    • [4. GAN的实际做法](#4. GAN的实际做法)
  • 总结

摘要

本周周报主要围绕生成对抗网络(GAN)的基础知识和理论进行深入探讨。首先回顾了GAN的基本概念、训练原理和应用场景。随后,周报详细分析了GAN背后的理论基础,包括如何通过高维空间中的点来理解图像生成,以及如何通过生成模型来寻找数据的分布。然后周报还描述了最大似然估计(MLE)在生成任务中的应用,并对比了传统方法与GAN的不同。然后复习了训练GAN的过程,包括如何通过判别器(Discriminator)来衡量两个分布之间的差异。最后,我们探讨了GAN的目标函数与JS散度(Jensen-Shannon Divergence)之间的关系,并讨论了在实际训练中如何通过样本来近似期望值。

Abstratc

In the weekly report, the basic concepts of Generative Adversarial Networks (GANs) are reviewed, followed by an in-depth discussion of the theoretical underpinnings of GANs. This includes an analysis of how image generation can be understood through points in high-dimensional spaces and the search for data distributions via generative models. The application of Maximum Likelihood Estimation (MLE) in generative tasks is described, with a comparison made between traditional methods and GANs. The training process of GANs is also reviewed, highlighting how discriminators are utilized to assess the divergence between two distributions. Lastly, the relationship between the objective function of GANs and Jensen-Shannon Divergence (JS divergence) is explored, along with a discussion on approximating expectations with samples during actual training.

一、机器学习------GAN Basic Theory

在之前GAN的学习中,我们了解了GAN的概念,训练的原理以及应用,这只是GAN的基础内容,接下来我们将详细的了解GAN背后的理论知识。

假设我们要生成的东西是image,我们用x呢来代表一张image
(每一个image都是高维空间中的一个点,假设产生64×64的image,那它是64×64维度空间中的一个点)

如下图所示:

为了方便解释我们将其视为二维空间中的一个点,所以它实际上是高维空间中的一个点。

我们要产生的image,它其实有一个固定的distribution,记为成P~data~。
即在这整个image space里面只有非常少的部分sample出来的image看起来像是人脸,在多数的space中sample出来image它都不像是人脸。

举例来说,在下图的例子里面,可能只有蓝色的这个区域sample的image,它看起来像是人脸啊。在其他地方simple看起来的图片看起来就不像是人脸。
所以假设我们要生成的是人脸的话,它有一个固定的distribution ,这个distribution在蓝色的这个区域,它的几率是高的;在蓝色区域以外,它的几率是低的。

机器做的事情是什么呢?
我们要机器去找出这一个distribution,而这个distribution实际上我们是不知道的

我们可以搜集很多的x(image)去了解x可能在某些地方分布比较高,但是要我们把它的function找出来是做不到的。
所以现在generated model(GAN)做的事情是------找出这个x的distribution。

1. Maximum Likelihood Estimation

那在有GAN之前,我们怎么做generative这件事呢?
我们使用最大似然估计(Maximum Likelihood Estimation)来完成

如下图示:
最大似然估计

  • 给定一个数据分布 P~data~(x)(这是我们采样得到,因为我们并不能列出这个distribution的式子)
  • 我们有一个由参数θ参数化的分布P~c~(x~i~;θ)。
  • 我们希望找到θ使得 P~c~(x~i~;θ)接近P~data~(x)。
  • 注意:P~c~(x~i~;θ) 是一个Gaussian Mixture Model(高斯混合模型), (θ) 是Gaussian的均值(average)和方差(mean)

    步骤如下:
  1. 从 P~data~(x) 中采样 (x~1~, x~2~, ..., x~m~) 。
  2. 计算P~G~(x~i~;θ)
  3. 计算生成样本的似然度(Likelihood)
    L = ∏ i = 1 m P G ( x i ; θ ) L = \prod_{i=1}^{m} P_G(x_i; \theta) L=i=1∏mPG(xi;θ)
  4. 找到最大化似然的 (θ^*^)去maximize L。

    其中θ*转化Maximum Likelihood Estimation为minimize KL 散度的推导如下:
    其中需要注意的是解释中的
    3与4,这是转化的关键

    在推导的时候,我有个疑问就是为什么直接转化为KL,减一个东西不影响原来的结果吗?
    其实我们θ只影响被减去的那一项,另外一个是常数项。
    因此我们把max转化为min减去那一项结果(加个负号由max变为min)其实是不影响的。
    如下图所示:

    但是问题在于P~G~,其也许不是高斯分布模型(使用高斯分布模型,给定一个x可以计算其被sample出来的几率)
    而是比高斯分布更加复杂的分布,例如,它是一个neural network,那你就没有办法计算P~G~(x~i~;θ)。

2. 复习训练GAN的过程

那要怎么办呢?

于是就有了一个新的想法:

因为Generator就是一个network,而我们把一个network看作是一个probability distribution。

回顾一下Gnerator的运作过程:

1.每次sample出一个z,它丢到这个generator里面,你就会得到一个x
(把Generator看作一个function,那个结果就是G(z))

2.sample不同的z得到的x呢就不一样。(z是从一个gaussian distribution里面sample出来的)

把这些从gaussian distribution里面sample出来的z通过Generator得到另外一大堆sample,把这些sample统统集合起来,你得到的就会是另外一个distribution。

那接下来目标是什么?

接下来目标是希望generator根据这个generator所定义出来的这个distribution P~G~它跟我们的目标跟我们的P~data~的越接近越好。

写一个optimization的formulation,这个formulation看起来是这个样子:
G ∗ = arg ⁡ min ⁡ G Div ⁡ ( P G , P data ) G^{*}=\arg \min {G} {\operatorname{Div}\left(P{G}, P_{\text {data }}\right)} G∗=argGminDiv(PG,Pdata )
就是求P~G~与P~data~的散度

补充:arg的含义

公式有了,但是现在的问题就是
P~G~跟P~data~它们的formulation我们是不知道的,我们无法计算Divergence,所以怎么办呢?

这个就是GAN神奇的地方
在进入比较多的数学推导之前我们复习一下GAN到底是怎么做到minimize divergence这件事情

虽然不知道P~G~跟P~data~的distribution长什么样子,但是我们可以从这两个distribution里去sample一些data出来形成一个代表性的distribution。

  1. 把DataBase拿出来,假设我们做二次元人物头像生成的话,就把你的DataBase拿出来,然后从里面sample很多image,这个就是从P~data~这个distribution里面做sample。
  2. 从P~G~里面做sample其实就是random sample一个vector,然后把这个vector丢到Generator里面产生一image 。因为P~G~由的generator所定义的,那我们在使用这个generator的时候,我们是从某一个部分distribution里面去sample 一大堆的vector,每一个vector就会产生一张image

所以我们可以从P~data~里面做sample,我们也可以从P~G~里面做sample。

接下来的问题是我们可以从P~G~与P~data~做sample,根据这个sample,要怎么知道这两个distribution的divergence呢?
GAN神奇的地方就是透过discriminator,我们可以来量这两个distribution间的divergence

蓝色星星是从P~data~中sample出来的,让其分数越大越好。

红色星星是从P~G~中sample出来的,让其分数越小越好。

然后我们训练Discriminator

discriminator训练的结果就会告诉我们P~data~跟P~G~它们之间的divergence有多大

我们会写一个objective function。

它跟两项有关,一个是跟generator有关,一个是跟discriminator有关。

在train这个discriminator的时候呢,会固定住generator。所以只跟我们的Discriminator有关。

V ( G , D ) = E x ∼ P data [ log ⁡ D ( x ) ] + E x ∼ P G [ log ⁡ ( 1 − D ( x ) ) ] V(G, D)=E_{x \sim P_{\text {data }}}[\log D(x)]+E_{x \sim P_{G}}[\log (1-D(x))] V(G,D)=Ex∼Pdata [logD(x)]+Ex∼PG[log(1−D(x))]

我们希望从P~data~采样的D(x)值越大越好,从P~G~采用出来的D(x)越小越好。

所以要 maximize V(G,D)。

3. Objective function与JS散度相关性推导

然后我们上一周提到了Objective function与JS散度是有关联的,其实我们看下图中的,我们也可以直观的感受到:

它们的具体表达如下:

E x ∼ P data [ log ⁡ D ( x ) ] + E x ∼ P G [ log ⁡ ( 1 − D ( x ) ) ] = − 2 log ⁡ 2 + 2 J S D ( P data II P G ) E_{x \sim P_{\text {data }}}[\log D(x)]+E_{x \sim P_{G}}[\log (1-D(x))] =-2 \log 2+2 J S D\left(P_{\text {data}} \text { II } P_{G}\right) Ex∼Pdata [logD(x)]+Ex∼PG[log(1−D(x))]=−2log2+2JSD(Pdata II PG)

数学公式推导如下:


所以我们就可以把DIV的问题转化为 max V(G,D)

如下图中G~3~就代表了最好的Div,因为G~3~表示P~data~与P~G~的divergence是最小的

那接下来呢,我们就是要想办法解这个min & max 的问题

GAN在train的时候:

  1. 固定住generator去update discriminator
  2. 固定住discriminator接下来去update generator

要解这个optimization问题,要怎么做呢?

在maxV(G,D)中需要找一个D去maximize V(G,D)看起来有点复杂,所以我们把它用L(G)来代替,因为跟D没有关系的,我们只需要最后得到的值越大越好。
处理过程如下:


==这边有一个问题,现在做的事未必等同于真的在minimize JS Divergence

说假设一个Generator就是G~0~,那V(G~0~,D)假设如下图左边的样子。

找到一个D~0~^*^,这个D~0~^*^的值就是G~0~跟Data之间的JS divergence。

但是当update G~0~变成G~1~的时候,这个时候呢,V(G~1~,D~0~^*^)的function可能就会变了。

本来V(G~1~,D~0~^*^)如下图左边图像所示
(V(G~1~,D~0~^*^)就是G~0~跟Data的JS Diversion。)

但是updateG~0~变成G~1~,这个时候就算不是在evaluate JS Dvergence。这是因为你的D~0~^*^仍然是固定的,但是V(G~1~,D~0~^*^)就不是在evaluate JS Divergence
因为估算JS Divergence的是要求最大的值,所以今天当你的G变了,function就变了,当function变的时候,同样的D^*^就不是在evaluate JS Dvergence。

但是为什么我们在进行参数θ~G~优化的时候,是在减少JS Divergence呢?
一个前提的假设就这两个function可能是非常的像的
假设只update一点点的G,例如,从G~0~变到G~1~,,G的参数只动了一点点 ,那这两个function长相可能是比较像的。

所以他一样用,一样用D~0~^*^,仍然是在量JS Divergence这样的情形
(如下图的两个曲线图像所示,这边本来值很小,突然变很高的情形,可能是不会发生的。
因为G~0~与G~1~是很像的,所以这两个function是比较接近,所以只同样用固定的D~0~^*^就可以evaluate JS Divergence。)

所以在train这个GAN的时候,tips就是Generator不能够一次update太多。但是在train Discriminator的时候,理论上你应该把它train到底。
原因如下:

  1. 因为对于Generator的话,你应该只要跑比较少的iteration,以免上述的假设不成立。
  2. 对于Discriminator在train时候你其实会需要比较多的iteration,把它train到底,因为我们需要找到MAX的值,才算evaluate JS Divergence。

4. GAN的实际做法

以上都是假设的,那么实际上你在做GAN的时候,其实是怎么做的呢?

之前说过要计算objective function就要对里面的x取期望(E~x~),但是在实际上你没有办法获取其期望,所以我们都是用sample n个data来代替期望。
实际上我们在做的时候,我们就是在maximize如图的式子,而不是真的去maximize它的期望

即把sample出来的这n个data的通统算出来,然后再把它统统平均起来,就当做是expectation

所以在train Discriminator的时候,就是在train一个binary classifier(二元分类器),说明如下:

  1. 实际上Discriminator是一个binary classified
  2. 这个binary classified是一个这个logistics regression
  3. 它的output有接一个sigmoid(即output的值是介于0到1之间的)
  4. 然后从P~data~里面里sample n个data出来,这n个data就是Positive examples或者class 1 examples。
  5. 然后呢你从P~G~里面再sample另外 n个data出来,这n个data就当做是negative examples或者class 2 examples
  6. 接下来就train binary classified(即Discriminator),会minimize cross entropy。然后发现如果你在minimize cross entropy,把式子写出来,它会等同于上面maximize objective function。

总结:

我们复习一下以上的过程,算法分为两步
第一步是maxV(G,D),即train Discriminator,以下是我个人的觉得重要的总结

  1. 我们train Discriminator的目的是为了量evaluate James divergence 。当V(G,D)的值最大的时候,Discriminator才是在evaluates diverges,所以V的值要被maximize,为了让V的值最大,所以一定要对Discriminator train 很多次(虽然很难达到,但是可以train个接近的值)。

    第二步就是min maxV(G,D),即train Generator。
  2. 我们train Generator是为了要minimize JS Divergence即,减少JS Divergence的值 ,minimize下图的式子的时,第一项呢是可以不用考虑它的,所以你把第一项拿掉,只去minimize第二项之前说过Generator你不能够train太多 ,因为一旦train太多的话,你的Discriminator就没有办法evaluate James divergence。

    目前为止,我们讲说今天在train generator的时候,实际不是下图上半部分,而是如下图下半部分:

    原因如下而在一篇paper里面,一开始就不是在minimize这个式子。
    log(1-D(x)),它长的是这个样子:

    而一开始在做training的时候,D(x)的值通常是很小的,因为Discriminator会知道说你的generator产生出来的image它是fake的,所以它会给它很小的值。所以一开始D(x)的值会落在上图中靠近坐标轴左边的地方,那它的微分是很小的。所以在training的时候会造成你在training的一些问题

所以他说呢,作者把它把它换成-log(D(x)),-log(D(x))它长的是这个样子:

这两个式子的趋势是一样的,但是他们在同一个位置的斜率就变得不一样。

-log(D(x))在一开始设置D(x)还很小的时候,你算出来的微分会比较大,所以觉得说这样子training是比较容易的。

最后再来直观感受一下Discriminator 和 Generator互动的过程:

重复步骤交替训练 Discriminator 和 Generator。每次迭代,Generator 尝试生成更真实的数据,而 Discriminator 则尝试更好地区分真实和假数据。这个过程可以看作是两个模型之间的"对抗"或"游戏",其中 Generator 试图生成越来越好的数据,而 Discriminator 则试图更好地区分。随着训练的进行,Generator 生成的数据分布将越来越接近真实数据分布。最终,Discriminator 将无法区分真实数据和生成数据,或者达到一个平衡点,此时 Generator 生成的数据在统计上与真实数据无法区分。

总结

本周因为要准备考试和课程论文,进度缓慢,之后需要加快进度。

本周的学习了生成对抗网络(GAN)的理论基础和训练过程。首先了解了GAN如何通过高维空间中的点来模拟图像生成,并探讨了如何通过生成模型来寻找数据的分布。其中学习了最大似然估计(MLE)在生成任务中的应用,并理解了GAN与传统方法的不同之处。我复习了训练GAN的过程,包括如何通过判别器(Discriminator)来衡量两个分布之间的差异。对GAN的目标函数与JS散度之间的关系进行了推导。学习了如何通过样本来近似期望值,并理解了在训练判别器和生成器时的不同策略。在训练判别器时需要多次迭代以接近最大值,而在训练生成器时则应避免过大的更新步长,以保持判别器能够有效地评估JS散度。最后,我直观地感受到了判别器和生成器之间的互动过程,这是一个"对抗"或"游戏"的过程,其中生成器试图生成越来越真实的数据,而判别器则试图更好地区分真实和假数据。随着训练的进行,生成器生成的数据分布将越来越接近真实数据分布,最终达到一个平衡点,此时生成器生成的数据在统计上与真实数据无法区分。通过本周的学习,我对GAN的理论知识和实际应用有了更深入的理解,为进一步的研究和实践打下了坚实的基础。

下一周打算先把GAN的拓展放一放,学习一些新内容,因为前面的部分遗忘的七七八八了,打算以以往的内容复习为主。

相关推荐
AI狂热爱好者3 分钟前
Meta 上周宣布正式开源小型语言模型 MobileLLM 系列
人工智能·ai·语言模型·自然语言处理·gpu算力
光锥智能4 分钟前
腾讯混元宣布大语言模型和3D模型正式开源
人工智能·语言模型·自然语言处理
新手小白勇闯新世界7 分钟前
论文阅读-用于图像识别的深度残差学习
论文阅读·人工智能·深度学习·学习·计算机视觉
大拨鼠9 分钟前
【多模态读论文系列】LLaMA-Adapter V2论文笔记
论文阅读·人工智能·llama
小嗷犬12 分钟前
【论文笔记】Dense Connector for MLLMs
论文阅读·人工智能·语言模型·大模型·多模态
新手小白勇闯新世界15 分钟前
论文阅读- --DeepI2P:通过深度分类进行图像到点云配准
论文阅读·深度学习·算法·计算机视觉
子午22 分钟前
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
人工智能·python·深度学习
AI_小站23 分钟前
【AI工作流】FastGPT - 深入解析FastGPT工作流编排:从基础到高级应用的全面指南
人工智能·程序人生·语言模型·大模型·llm·fastgpt·大模型应用
chan_lay33 分钟前
图论导引 - 目录、引言、第一章 - 11/05
笔记·图论
B站计算机毕业设计超人38 分钟前
计算机毕业设计Hadoop+大模型地震预测系统 地震数据分析可视化 地震爬虫 大数据毕业设计 Spark 机器学习 深度学习 Flink 大数据
大数据·hadoop·爬虫·深度学习·机器学习·数据分析·课程设计