深度学习笔记——生成对抗网络GAN

本文详细介绍早期生成式AI的代表性模型:生成对抗网络GAN。

文章目录

生成对抗网络 (Generative Adversarial Network, GAN)是一种生成模型 ,由 Ian Goodfellow 等人于 2014 年提出。GAN 通过两个网络------生成器 (Generator)和判别器 (Discriminator)之间的对抗训练 ,使得生成器能够生成逼真的数据,从而被判别器难以区分。GAN 已广泛应用于图像生成、图像修复、风格迁移、文本生成等任务。

论文:Generative Adversarial Nets


一、基本结构

GAN 包含两个核心部分:生成器和判别器。

生成器

  • 功能 :生成器接收一个随机噪声向量 (通常是高斯分布或均匀分布),并将其映射到数据空间,使生成的数据尽可能接近真实数据
  • 目标 :生成器的目标是 "欺骗"判别器,使其无法区分生成数据和真实数据。
  • 网络结构 :生成器通常由一系列反卷积(或上采样 )层组成,以逐步生成更高分辨率的图像

判别器

  • 功能 :判别器接收输入样本 ,并判断该样本真假
  • 目标 :判别器的目标是尽可能准确地分辨出真假样本
  • 网络结构 :判别器通常是一个卷积神经网络(CNN),将输入数据压缩为 一个概率值,表示该样本属于真实数据的概率

二、损失函数

GAN 的训练是一个生成器判别器 相互博弈 的过程,通过对抗训练 逐步提高生成器的生成质量。训练过程主要包括以下步骤:

判别器

  • 训练判别器时,其输入是真实数据 和生成器的生成数据
  • 判别器的目标区分真实数据和生成数据,即使得判别器输出接近 1 的概率表示真实数据,接近 0 的概率表示生成数据。
  • 判别器的损失函数通常使用二元交叉熵(Binary Cross-Entropy):
    L D = − E x ∼ p data [ log ⁡ D ( x ) ] − E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] L_D = - \mathbb{E}{x \sim p{\text{data}}} [\log D(x)] - \mathbb{E}_{z \sim p_z} [\log(1 - D(G(z)))] LD=−Ex∼pdata[logD(x)]−Ez∼pz[log(1−D(G(z)))]

参数含义

  1. x x x:真实数据样本 ,来自于真实数据分布 p data p_{\text{data}} pdata。
  2. z z z:生成器 输入的噪声向量,通常从均匀分布或正态分布中采样。
  3. D ( x ) D(x) D(x):判别器真实样本 x x x 的输出,表示判别器认为该样本是真实数据的概率。
  4. D ( G ( z ) ) D(G(z)) D(G(z)):判别器生成数据 G ( z ) G(z) G(z) 的输出,表示判别器认为该样本为真实数据的概率

判别器损失的计算过程

  1. 第一部分
    − E x ∼ p data [ log ⁡ D ( x ) ] - \mathbb{E}{x \sim p{\text{data}}} [\log D(x)] −Ex∼pdata[logD(x)]

    • 表示对真实样本的损失。
    • 判别器希望尽量将真实数据的输出 D ( x ) D(x) D(x) 接近 1,因此这部分的目标是最小化 log ⁡ D ( x ) \log D(x) logD(x)。
  2. 第二部分
    − E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] - \mathbb{E}_{z \sim p_z} [\log(1 - D(G(z)))] −Ez∼pz[log(1−D(G(z)))]

    • 表示对生成样本的损失。
    • 判别器希望尽量将生成数据的输出 D ( G ( z ) ) D(G(z)) D(G(z)) 接近 0,因此这部分的目标是最小化 log ⁡ ( 1 − D ( G ( z ) ) ) \log(1 - D(G(z))) log(1−D(G(z)))。

生成器

  • 训练生成器时,其输入是一个 随机噪声向量,通常记为 z。
  • 生成器的目标是生成逼真的样本,"欺骗"判别器,使判别器无法分辨生成数据和真实数据 ,因此生成器希望判别器输出接近 1(让判别器以为生成的图像是真实的)
  • 生成器的损失函数与判别器的损失类似,但这里生成器 希望最大化判别器对生成数据的输出 ,即让判别器认为生成数据为真实数据

L G = − E z ∼ p z [ log ⁡ D ( G ( z ) ) ] L_G = -\mathbb{E}_{z \sim p_z} [\log D(G(z))] LG=−Ez∼pz[logD(G(z))]

参数含义

  1. z z z:生成器输入的噪声向量,通常从均匀分布或正态分布中采样。
  2. D ( G ( z ) ) D(G(z)) D(G(z)):判别器对生成数据 G ( z ) G(z) G(z) 的输出,表示判别器认为生成样本为真实数据的概率

生成器损失的计算过程

  • 损失形式
    − E z ∼ p z [ log ⁡ D ( G ( z ) ) ] -\mathbb{E}_{z \sim p_z} [\log D(G(z))] −Ez∼pz[logD(G(z))]

    生成器的目标是使 D ( G ( z ) ) D(G(z)) D(G(z)) 接近 1,也就是希望判别器对生成的样本做出"真实"的判断

  • 目标

    生成器通过最大化判别器对生成样本的输出,使得判别器无法区分生成样本和真实样本。

生成器的目标是最小化该损失 ,即最大化判别器对生成样本的输出

交替优化

在每轮训练中:

  1. 固定生成器,训练判别器
  2. 固定判别器,训练生成器
  3. 通过交替优化两者不断改进,生成器的生成样本越来越逼真,而判别器的分辨能力也不断提高。

目标函数

GAN 的目标是找到一个平衡点 ,使生成器生成的样本和真实数据在分布上尽可能接近 。它是一个极小极大(minimax)损失函数 ,表达了生成器和判别器的博弈

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p data [ log ⁡ D ( x ) ] + E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}{x \sim p{\text{data}}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))] GminDmaxV(D,G)=Ex∼pdata[logD(x)]+Ez∼pz[log(1−D(G(z)))]

参数含义:

  1. G G G:生成器的参数。
  2. D D D:判别器的参数。
  3. x x x:真实样本,来自真实数据分布 p data p_{\text{data}} pdata。
  4. z z z:噪声输入,通常从均匀分布或正态分布中采样。
  5. D ( x ) D(x) D(x):判别器对真实样本 x x x 的输出概率。
  6. D ( G ( z ) ) D(G(z)) D(G(z)):判别器对生成样本的输出概率。

这个目标函数包含两个部分:

  • 最大化判别器的目标 :判别器希望最大化 log ⁡ D ( x ) \log D(x) logD(x) 和 log ⁡ ( 1 − D ( G ( z ) ) ) \log(1 - D(G(z))) log(1−D(G(z))),即尽可能将真实数据判断为真实样本、生成数据判断为生成样本
  • 最小化生成器的目标 :生成器希望最小化 log ⁡ ( 1 − D ( G ( z ) ) ) \log(1 - D(G(z))) log(1−D(G(z))),即生成器希望生成的样本尽可能接近真实样本,以欺骗判别器。

实际的 GAN 训练过程中,不会直接出现极小极大损失函数 ,而是通过优化生成器和判别器的各自损失函 数来间接实现 这个目标。

极小极大损失函数的目标通过分解判别器损失生成器损失 来实现,二者的对抗优化就是对极小极大目标函数的间接实现 。通过交替优化生成器和判别器的损失 。伪代码实现如下:

三、GAN 的训练过程

训练流程概述

GAN 的训练是一个极小极大(minimax)的博弈过程

  1. 生成器的目标是生成逼真的样本来"欺骗"判别器,使得判别器无法分辨生成样本与真实样本;
  2. 判别器的目标是尽可能准确地区分真实样本和生成样本。

这种对抗关系通过两个网络交替训练来实现。

GAN 的训练过程分为以下几个主要步骤。

训练流程步骤

1. 初始化参数和超参数
  1. 初始化生成器 G G G 和判别器 D D D 的网络参数。
  2. 设定超参数,如学习率、训练轮数、批量大小、优化器等。
  3. 通常选择的优化器为 Adam 优化器,初始学习率一般设为较小值,以确保训练过程稳定。
2. 定义损失函数

GAN 的损失函数由生成器和判别器的对抗损失组成。目标是找到一个平衡点,使生成器能够生成与真实样本分布相近的样本。

  1. 判别器损失

    (参考上节)

  2. 生成器损失

    (参考上节)

3. 训练过程的迭代

训练过程的每一轮迭代中,生成器和判别器会交替优化。一般的训练过程如下:

判别器训练步骤

判别器的目标是区分真实样本和生成样本。每轮判别器训练分为以下步骤:

  1. 从真实数据分布中采样一批真实样本 x x x
  2. 生成器从随机噪声分布 z z z(通常为正态分布或均匀分布)中采样一批噪声向量 ,并生成对应的样本 G ( z ) G(z) G(z)
  3. 将真实样本 x x x 和生成样本 G ( z ) G(z) G(z) 分别输入判别器 D D D 。判别器是一个二分类神经网络,通常由卷积层 构成,以提取样本的特征输出 是一个概率值 ,表示输入样本为真实样本的概率。计算出对真实样本的输出 D ( x ) D(x) D(x) 和生成样本的输出 D ( G ( z ) ) D(G(z)) D(G(z))。
  4. 根据判别器损失函数 L D L_D LD计算判别器的损失,更新判别器的参数,使其能够更好地区分真实样本和生成样本。

判别器训练的目的是让它尽可能区分真实样本和生成样本,鼓励其将真实样本判断为 1,生成样本判断为 0。

生成器训练步骤

生成器的目标是生成能够"欺骗"判别器的样本。生成器的训练步骤如下:

  1. 随机噪声分布 z z z 中采样一批噪声向量。
  2. 将噪声向量输入生成器 G G G ,得到生成的样本 G ( z ) G(z) G(z)。生成器是一个神经网络,通常由多层神经网络构成。在图像生成任务中,生成器通常采用反卷积 (转置卷积)或上采样层来逐步生成高分辨率图像
  3. 将生成样本 G ( z ) G(z) G(z) 输入判别器 D D D ,计算判别器对生成样本的输出 D ( G ( z ) ) D(G(z)) D(G(z))。
  4. 根据生成器损失函数 L G L_G LG计算生成器的损失,通过反向传播更新生成器的参数,使生成器生成的样本更加逼真,以"欺骗"判别器。

生成器的优化目标是使得判别器的输出 D ( G ( z ) ) D(G(z)) D(G(z)) 越接近 1 越好,即让判别器认为生成样本是真实样本。生成器训练时 ,这里的 D D D 的参数是不可训练的

4. 交替优化

训练过程中,生成器和判别器会不断地交替优化 。通常,每轮训练中会多次优化判别器(如更新判别器参数数次,再更新生成器参数一次),以确保判别器的分辨能力。这种交替优化的过程被称为 GAN 的"对抗训练"。

5. 收敛判别

GAN 的训练目标是找到生成器和判别器之间的平衡点,但收敛难以判断。通常可以通过以下方法判别 GAN 是否趋于收敛:

  1. 生成样本质量

    观察生成样本的视觉质量,当生成样本变得清晰且真实时,说明生成器已经学到了接近真实数据分布的特征。

  2. 判别器输出的均衡

    在理想的情况下,判别器真实 样本和生成 样本的输出概率应接近 0.5,表示判别器很难区分真假样本。

  3. 损失变化

    监控生成器和判别器的损失,若损失趋于平稳,说明两者逐渐达到平衡状态。


GAN 训练过程的挑战

GAN 的训练通常存在一些挑战,需要在训练过程中进行调试和优化。

  1. 训练不稳定

    GAN 的训练过程可能会发生梯度消失或梯度爆炸,导致生成效果不佳。可以使用 WGAN、谱归一化等方法来提升稳定性。

  2. 模式崩溃(Mode Collapse)

    生成器可能会陷入"模式崩溃"现象,即只生成相似的样本而缺乏多样性。可以通过多样性损失(如 Minibatch Discrimination)或训练策略(如添加噪声)来缓解模式崩溃。

  3. 对抗关系平衡

    判别器和生成器的能力需要平衡。若判别器太强,生成器难以改进;若生成器太强,判别器很快失去判断能力。可以适当调整判别器和生成器的训练频率,保持两者的平衡。

四、GAN 的常见变体

原始的 GAN 结构存在一些问题 ,例如训练不稳定、容易陷入模式崩溃(mode collapse)等。为了克服这些问题,出现了多种改进的 GAN 变体

变体 主要改进 应用场景 优点
DCGAN 使用深度卷积结构 生成图像,批量归一化等改进 高质量图像生成、人脸生成、艺术风格图像生成 生成质量高,结构简单
CycleGAN 无需配对数据,使用循环一致性损失双生成器双判别器实现图像转换 图像风格迁移(照片转素描、白天转夜晚)等 无需配对数据,双向转换
BigGAN 标签嵌入、谱归一化、大型网络结构(更多层次的卷积) 高清图像生成、超高分辨率生成 生成质量极高,适合大规模数据
StyleGAN 样式映射网络,自适应实例归一化(AdaIN),多尺度控制 高质量人脸生成、风格迁移 细节控制能力强,质量极高
cGAN 在生成器和判别器中加入条件输入,实现特定属性生成 根据类别标签生成图像(不同表情、年龄等) 可控性高,适合特定属性生成

1. DCGAN(Deep Convolutional GAN)

DCGAN 是最早将卷积神经网络(CNN)应用于 GAN 的变体之一,被广泛应用于图像生成任务。DCGAN 的目标是提升图像生成质量,使生成器可以生成更高分辨率和更具细节的图像。

主要改进

  1. 使用卷积层替代 GAN 中传统的全连接层 ,生成器使用反卷积(或转置卷积)逐步生成图像 ,判别器使用标准的卷积层提取图像特征
  2. 移除池化层用步幅卷积(stride convolution)来减小分辨率,从而保留更多细节。
  3. 在生成器和判别器中使用批量归一化(Batch Normalization)来加速训练和提升稳定性。
  4. 生成器使用 ReLU 激活,判别器使用 Leaky ReLU 激活。

应用场景

  • 高质量图像生成、人脸生成、艺术风格图像生成等。

优点

  • 结构简单,训练稳定性较高,生成图像质量较好,是后续很多 GAN 变体的基础。

2. CycleGAN

CycleGAN 是一种专注于未配对数据集的图像到图像转换的 GAN 变体,用于解决当缺少成对样本时的风格迁移和图像转换任务。

主要改进

  1. 循环一致性损失(Cycle Consistency Loss) :CycleGAN 在生成器中引入了循环一致性损失,使得图像从一种风格转换为另一种风格后,还可以还原回原始风格 。这种双向映射的设计让 CycleGAN
    能够在没有配对数据的情况下进行训练。
  2. 双生成器双判别器 :CycleGAN 使用两个生成器两个判别器,分别负责从源域到目标域的转换,以及从目标域到源域的逆向转换。

应用场景

  • 图像风格转换(如将马变为斑马、照片转为素描、白天转为夜晚)、图像修复和艺术创作等。

优点

  • 无需配对训练数据,即可实现高质量的图像到图像转换,特别适合风格迁移任务。

3. BigGAN

BigGAN 是一种大规模、高分辨率的 GAN 变体,主要用于生成高质量、高分辨率的图像。BigGAN 在生成效果上达到了新的高度,但其训练难度和计算资源要求也较高。

主要改进

  1. 标签嵌入(Label Embedding) :在生成器和判别器中加入标签嵌入 ,以便在条件 GAN 中生成具有特定类别的图像
  2. 正则化和归一化技术 :BigGAN 使用谱归一化(Spectral Normalization)来控制判别器的梯度,使得训练更稳定。此外,使用大批量训练、渐进式训练等手段来提高生成图像的分辨率和质量。
  3. 生成器架构调整 :生成器中使用了更多层次的卷积,使得生成的图像更细腻,并利用大型网络结构提升生成质量。

应用场景

  • 高清图像生成、图像生成研究等,适用于需要超高质量和分辨率的生成任务。

优点

  • 生成质量出色,可以生成高分辨率的图像,但训练成本较高,适合有强大计算资源支持的场景。

4. StyleGAN

StyleGAN 是由 NVIDIA 提出的基于样式控制的 GAN 变体,提供了更高的图像生成质量和样式可控性。它被广泛用于高质量人脸生成和其他具有分层样式控制的生成任务中。

主要改进

  1. 样式映射网络(Style Mapping Network) :StyleGAN 使用一个样式映射网络,将潜在空间中的输入向量映射到样式空间,然后将这些样式控制应用于生成器的不同层次。
  2. 自适应实例归一化(Adaptive Instance Normalization, AdaIN) :生成器在每一层应用 AdaIN 操作 ,使样式向量可以控制每层的特征分布,实现不同分辨率下的细节和整体风格调整。
  3. 多尺度控制:通过在不同生成层次中应用样式向量,StyleGAN 可以在不同层次上控制图像的特征,从而生成更细腻、更有层次感的图像。

应用场景

  • 高质量人脸生成、艺术风格迁移、细节编辑等。

优点

  • 生成图像质量极高,具有强大的样式控制能力,能够在不同层次上调整生成样本的特征。

5. cGAN(Conditional GAN)

cGAN(条件 GAN) 是在生成器和判别器中引入条件信息(如类别标签、属性标签等)的 GAN 变体,使得 GAN 生成的图像可以带有特定的属性或类别。cGAN 的基本思想是在生成器和判别器的输入中加入条件变量。

主要改进

  1. 条件输入:在生成器的输入噪声向量上添加类别标签或其他条件变量,这使得生成器可以根据给定条件生成带有特定特征的图像。
  2. 判别器的条件输入:在判别器中引入相同的条件信息,使判别器能够更准确地判断生成样本是否符合给定条件。

应用场景

  • 根据类别生成特定属性的图像,如不同表情、服装、场景、年龄等。适用于需要生成带有特定特征的图像任务。

优点

  • 生成的图像具有更高的可控性,适合生成带有明确标签或特征的图像。

五、GAN 的应用场景

GAN 具有强大的生成能力,在多个领域中得到了广泛应用:

  1. 图像生成

    • GAN 可用于生成高分辨率图像,应用于艺术创作、广告、电影制作等领域。
  2. 图像修复

    • GAN 可用于修复有缺陷或损坏的图像,例如老照片修复、面部填补。
  3. 图像超分辨率

    • GAN 可生成清晰的高分辨率图像,用于增强低分辨率图像的细节。
  4. 图像到图像的转换

    • 例如将草图转为真实照片,黑白图像上色等。CycleGAN 在无监督图像转换任务上表现优异。
  5. 数据增强

    • 在数据集有限的情况下,GAN 可以用于生成新样本,扩展训练数据,提升模型的泛化能力。
  6. 文本生成与文本到图像生成

    • GAN 已应用于文本生成、文本到图像生成等任务,使 AI 能够根据描述生成符合语义的图像。

六、GAN 的优势与挑战

优势:

  1. 生成能力强

    • GAN 的生成器可以生成逼真的样本,不仅限于简单的噪声分布。
  2. 无监督学习

    • GAN 不需要带标签的样本,能够通过未标注数据进行无监督训练。
  3. 适用范围广

    • GAN 的生成能力在图像、音频、文本等多种数据类型上都有广泛应用。

挑战:

  1. 训练不稳定

    • GAN 的训练是生成器和判别器的对抗过程,容易陷入梯度消失或爆炸,训练过程不稳定。
  2. 模式崩溃

    • GAN 可能生成重复样本而缺乏多样性(模式崩溃),即生成器只生成某些特定样本。
  3. 超参数敏感

    • GAN 的训练对学习率、批量大小等超参数较为敏感,调优成本高。
  4. 难以衡量生成质量

    • GAN 的生成样本质量较难定量评估,目前常用的 FID(Fréchet Inception Distance)等指标也无法完全反映样本质量。

总结

GAN 是一种强大的生成模型,通过生成器和判别器的对抗训练,使得生成器能够生成接近真实的数据。GAN 的多种变体和改进版本在图像生成、数据增强、风格转换等领域取得了显著成果。尽管 GAN 面临训练不稳定、模式崩溃等挑战,但它的生成能力为多个领域的研究和应用提供了新的可能性。未来的研究将继续优化 GAN 的稳定性和多样性,扩展其在不同场景的应用。

历史文章

机器学习

机器学习笔记------损失函数、代价函数和KL散度
机器学习笔记------特征工程、正则化、强化学习
机器学习笔记------30种常见机器学习算法简要汇总
机器学习笔记------感知机、多层感知机(MLP)、支持向量机(SVM)
机器学习笔记------KNN(K-Nearest Neighbors,K 近邻算法)
机器学习笔记------朴素贝叶斯算法
机器学习笔记------决策树
机器学习笔记------集成学习、Bagging(随机森林)、Boosting(AdaBoost、GBDT、XGBoost、LightGBM)、Stacking
机器学习笔记------Boosting中常用算法(GBDT、XGBoost、LightGBM)迭代路径
机器学习笔记------聚类算法(Kmeans、GMM-使用EM优化)
机器学习笔记------降维

深度学习

深度学习笔记------优化算法、激活函数
深度学习------归一化、正则化
深度学习------权重初始化、评估指标、梯度消失和梯度爆炸
深度学习笔记------前向传播与反向传播、神经网络(前馈神经网络与反馈神经网络)、常见算法概要汇总
深度学习笔记------卷积神经网络CNN
深度学习笔记------循环神经网络RNN、LSTM、GRU、Bi-RNN
深度学习笔记------Transformer
深度学习笔记------3种常见的Transformer位置编码
深度学习笔记------GPT、BERT、T5
深度学习笔记------ViT、ViLT
深度学习笔记------DiT(Diffusion Transformer)
深度学习笔记------多模态模型CLIP、BLIP
深度学习笔记------AE、VAE

相关推荐
风一样的树懒3 分钟前
Python使用pip安装Caused by SSLError:certificate verify failed
人工智能·python
9命怪猫7 分钟前
AI大模型-提示工程学习笔记5-零提示
人工智能·笔记·学习·ai·提示工程
cnbestec43 分钟前
GelSight Mini视触觉传感器凝胶触头升级:增加40%耐用性,拓展机器人与触觉AI 应用边界
人工智能·机器人
bohu831 小时前
ros2-4.2 用python实现人脸识别
人工智能·opencv·人脸识别·ros2·服务调用
Loving_enjoy1 小时前
ChatGPT 数据分析与处理使用详解
大数据·人工智能
whaosoft-1431 小时前
51c自动驾驶~合集45
人工智能
刘大猫262 小时前
《docker基础篇:1.Docker简介》,包括Docker是什么、容器与虚拟机比较、能干嘛、去哪下
人工智能·操作系统·团队管理
hfmeet2 小时前
行为分析:LSTM、3D CNN、SlowFast Networks。这三者的优缺点
人工智能·cnn·lstm
小灰灰__3 小时前
LLM大模型实践10-聊天机器人
人工智能·chatgpt·机器人