[DL_Net从入门到入土] 生成对抗网络 GAN

DL_Net从入门到入土 生成对抗网络 GAN

📢个人导航

知乎:https://www.zhihu.com/people/byzh_rc

CSDN:https://blog.csdn.net/qq_54636039

注:本文仅对所述内容做了框架性引导,具体细节可查询其余相关资料or源码

参考文章:各方资料

文章目录

  • [DL_Net从入门到入土 生成对抗网络 GAN](#[DL_Net从入门到入土] 生成对抗网络 GAN)
    • 📢个人导航
    • 📖参考资料
    • 🌱背景
    • ⚙️架构(公式)
        • [1. Generator 生成器](#1. Generator 生成器)
        • [2. Discriminator 判别器](#2. Discriminator 判别器)
        • [3. 目标函数](#3. 目标函数)
        • [4. 训练流程](#4. 训练流程)
    • 👍优点/创新点
        • [1. 生成效果清晰](#1. 生成效果清晰)
        • [2. 不需要显式建模数据分布](#2. 不需要显式建模数据分布)
    • 👎缺点
        • [1. 训练不稳定](#1. 训练不稳定)
        • [2. 模式崩溃 Mode Collapse](#2. 模式崩溃 Mode Collapse)
        • [3. 评价困难](#3. 评价困难)
        • [4. 对超参数敏感](#4. 对超参数敏感)
        • [5. 容易梯度消失](#5. 容易梯度消失)
    • 💻模型代码

📖参考资料

Generative Adversarial Networks.

🌱背景

VAE的目标,是让模型学会数据分布,然后从这个分布中采样生成新数据

但是 VAE 有一个常见问题: 生成结果容易偏模糊

因为 VAE 通常使用像素级重建损失,例如 MSE 或 BCE

-> 这类损失会鼓励模型生成"平均结果"

比如真实数据里有很多种写法的数字 8

模型可能学出一种比较平均、比较保守的 8

-> 于是生成图片就容易糊

GAN: Generative Adversarial Network: 生成对抗网络

GAN不直接要求生成图片和某张真实图片逐像素接近,而是引入两个网络:

  • Generator:生成器
  • Discriminator:判别器

生成器 G:我要造假图,骗过判别器

判别器 D:我要分辨图片是真图还是假图

->

生成器 G 生成的图片越来越真实

判别器 D 越来越难判断图片真假

⚙️架构(公式)

复制代码
随机噪声 z → Generator G → 生成样本 G(z)
真实样本 x → Discriminator D → 判断为真的概率 D(x)
生成样本 G(z) → Discriminator D → 判断为真的概率 D(G(z))

随机噪声 z
生成器 G
生成样本 G(z)
真实样本 x
判别器 D
输出真假概率

1. Generator 生成器

生成器的输入通常是一个随机噪声 z, 服从标准正态分布
z ∼ N ( 0 , I ) z \sim N(0, I) z∼N(0,I)

生成器把这个噪声映射成一个样本 G ( z ) G(z) G(z)

生成器的目标是:生成尽可能真实的样本,让判别器误以为它是真的

2. Discriminator 判别器

判别器的输入是一张图片,输出是一个概率 (这张图片是真实图片的概率)
D ( x ) ∈ 0 , 1 D(x) ∈ 0, 1 D(x)∈0,1

对于真实图片 x,判别器希望: D ( x ) → 1 D(x) → 1 D(x)→1

对于生成图片 G(z),判别器希望: D ( G ( z ) ) → 0 D(G(z)) → 0 D(G(z))→0

3. 目标函数

经典目标函数:
min ⁡ G max ⁡ D V ( D , G ) \min_G \max_D V(D, G) GminDmaxV(D,G)

极小极大博弈:

D 想最大化识别真假能力

G 想最小化 D 的识别能力

其中:
V ( D , G ) = E x ∼ p d a t a ( x ) log ⁡ D ( x ) + E z ∼ p z ( z ) log ⁡ ( 1 − D ( G ( z ) ) ) V(D, G) = \mathbb{E}{x \sim p{data}(x)}\\log D(x) + \mathbb{E}_{z \sim p_z(z)}\\log(1 - D(G(z))) V(D,G)=Ex∼pdata(x)logD(x)+Ez∼pz(z)log(1−D(G(z)))

  • 判别器 D 希望
    • 最大化 l o g D ( x ) logD(x) logD(x)
    • 最大化 l o g ( 1 − D ( G ( z ) ) ) log(1−D(G(z))) log(1−D(G(z)))
  • 生成器 G 希望
    • 最小化 l o g ( 1 − D ( G ( z ) ) ) log(1−D(G(z))) log(1−D(G(z)))
4. 训练流程

第一步: 训练判别器D (固定G)

判别器损失可以写成:

L D = − E x ∼ p d a t a ( x ) log ⁡ D ( x ) − E z ∼ p z ( z ) log ⁡ ( 1 − D ( G ( z ) ) ) L_D =-\mathbb{E}{x \sim p{data}(x)}\\log D(x)- \mathbb{E}_{z \sim p_z(z)}\\log(1 - D(G(z))) LD=−Ex∼pdata(x)logD(x)−Ez∼pz(z)log(1−D(G(z)))

第二步: 训练生成器G (固定D)

生成器损失可以写成:
L G = − E z ∼ p z ( z ) log ⁡ D ( G ( z ) ) L_G = -\mathbb{E}_{z \sim p_z(z)}\\log D(G(z)) LG=−Ez∼pz(z)logD(G(z))

👍优点/创新点

1. 生成效果清晰

相比 VAE,GAN 不直接依赖像素级重建损失。

所以 GAN 生成的图像通常更加锐利、清晰

2. 不需要显式建模数据分布

很多传统 生成模型需要显式 建模概率分布 p ( x ) p(x) p(x)

但是真实世界的数据分布往往非常复杂!

GAN 不直接写出真实数据分布的具体形式,而是通过判别器提供学习信号

也就是说:GAN 可以隐式地学习复杂数据分布

👎缺点

1. 训练不稳定

GAN 是两个网络互相博弈,不是普通的单目标优化

所以训练时可能出现:

  • 判别器太强 → 生成器学不到东西
  • 生成器太强 → 判别器失去作用
  • 两者震荡 → 损失不收敛

但是 GAN 的 loss 不一定能直接反映生成质量:

  • 有时候 loss 看起来正常,图片很烂
  • 有时候 loss 看起来奇怪,图片反而不错
2. 模式崩溃 Mode Collapse

模式崩溃:生成器只学会生成少数几种样本,而不是覆盖整个真实数据分布

3. 评价困难

对于分类任务,我们可以用准确率

对于回归任务,我们可以用 MSE

但是对于生成模型,评价就比较麻烦!

GAN 常用评价指标:

  • IS: Inception Score
  • FID: Fréchet Inception Distance

但是这些指标也不是完美的

4. 对超参数敏感
5. 容易梯度消失

当判别器 D 太强时,它可以轻松判断:

  • 真实图片 → 1
  • 生成图片 → 0

此时对于生成器来说,梯度可能变得非常弱:

生成器知道自己错了, 但是不知道该怎么改

-> 这会导致生成器训练困难

💻模型代码

py 复制代码
import torch
import torch.nn as nn


class Generator(nn.Module):
    """
    生成器 Generator

    输入:
        z: 随机噪声向量
        shape = [batch_size, latent_dim]

    输出:
        img: 生成图片
        shape = [batch_size, 1, 28, 28]
    """

    def __init__(self, latent_dim=100, img_dim=28 * 28):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(1024, img_dim),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img


class Discriminator(nn.Module):
    """
    判别器 Discriminator

    输入:
        img: 图片
        shape = [batch_size, 1, 28, 28]

    输出:
        validity: 图片为真实图片的概率
        shape = [batch_size, 1]
    """

    def __init__(self, img_dim=28 * 28):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(img_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity


if __name__ == "__main__":
    latent_dim = 100
    batch_size = 8

    generator = Generator(latent_dim=latent_dim)
    discriminator = Discriminator()

    z = torch.randn(batch_size, latent_dim)

    fake_imgs = generator(z)
    outputs = discriminator(fake_imgs)

    print("随机噪声 z 的形状:", z.shape)
    print("生成图片 fake_imgs 的形状:", fake_imgs.shape)
    print("判别器输出 outputs 的形状:", outputs.shape)
相关推荐
Kyrie67843 分钟前
SkillOpt:把 Agent 的技能文件当作可训练参数
人工智能
冬奇Lab1 小时前
Workflow 系列(07):工程化与版本管理——Workflow 的 CI/CD
人工智能·工作流引擎
两万五千个小时1 小时前
Claude Code 上下文管理(一):为什么 Agent 会"失忆"?
人工智能·架构·开源
两万五千个小时1 小时前
Claude Code 上下文管理(二):零 Token 消耗的压缩三板斧
人工智能·程序员·开源
冬奇Lab1 小时前
每日一个开源项目(第150篇):caveman - 为什么用很多 token,少 token 也行——给 AI Agent 装上穴居人嘴巴
人工智能·开源·资讯
贵慜_Derek1 小时前
MAI-04|干净数据在工程上意味着什么:MAI 预训练数据治理
人工智能·算法·llm
feelmylife591 小时前
Agent 记忆设计架构 — 分层记忆:什么时候该记住,什么时候该忘记
人工智能
阿黎梨梨2 小时前
揭秘大语言模型的底层逻辑:从文本分词到高维向量的计算之旅
javascript·人工智能
moMo2 小时前
AI工程化 03:给模型喂上下文
人工智能