[DL_Net从入门到入土] 生成对抗网络 GAN

DL_Net从入门到入土 生成对抗网络 GAN

📢个人导航

知乎:https://www.zhihu.com/people/byzh_rc

CSDN:https://blog.csdn.net/qq_54636039

注:本文仅对所述内容做了框架性引导,具体细节可查询其余相关资料or源码

参考文章:各方资料

文章目录

  • [DL_Net从入门到入土 生成对抗网络 GAN](#[DL_Net从入门到入土] 生成对抗网络 GAN)
    • 📢个人导航
    • 📖参考资料
    • 🌱背景
    • ⚙️架构(公式)
        • [1. Generator 生成器](#1. Generator 生成器)
        • [2. Discriminator 判别器](#2. Discriminator 判别器)
        • [3. 目标函数](#3. 目标函数)
        • [4. 训练流程](#4. 训练流程)
    • 👍优点/创新点
        • [1. 生成效果清晰](#1. 生成效果清晰)
        • [2. 不需要显式建模数据分布](#2. 不需要显式建模数据分布)
    • 👎缺点
        • [1. 训练不稳定](#1. 训练不稳定)
        • [2. 模式崩溃 Mode Collapse](#2. 模式崩溃 Mode Collapse)
        • [3. 评价困难](#3. 评价困难)
        • [4. 对超参数敏感](#4. 对超参数敏感)
        • [5. 容易梯度消失](#5. 容易梯度消失)
    • 💻模型代码

📖参考资料

Generative Adversarial Networks.

🌱背景

VAE的目标,是让模型学会数据分布,然后从这个分布中采样生成新数据

但是 VAE 有一个常见问题: 生成结果容易偏模糊

因为 VAE 通常使用像素级重建损失,例如 MSE 或 BCE

-> 这类损失会鼓励模型生成"平均结果"

比如真实数据里有很多种写法的数字 8

模型可能学出一种比较平均、比较保守的 8

-> 于是生成图片就容易糊

GAN: Generative Adversarial Network: 生成对抗网络

GAN不直接要求生成图片和某张真实图片逐像素接近,而是引入两个网络:

  • Generator:生成器
  • Discriminator:判别器

生成器 G:我要造假图,骗过判别器

判别器 D:我要分辨图片是真图还是假图

->

生成器 G 生成的图片越来越真实

判别器 D 越来越难判断图片真假

⚙️架构(公式)

复制代码
随机噪声 z → Generator G → 生成样本 G(z)
真实样本 x → Discriminator D → 判断为真的概率 D(x)
生成样本 G(z) → Discriminator D → 判断为真的概率 D(G(z))

随机噪声 z
生成器 G
生成样本 G(z)
真实样本 x
判别器 D
输出真假概率

1. Generator 生成器

生成器的输入通常是一个随机噪声 z, 服从标准正态分布
z ∼ N ( 0 , I ) z \sim N(0, I) z∼N(0,I)

生成器把这个噪声映射成一个样本 G ( z ) G(z) G(z)

生成器的目标是:生成尽可能真实的样本,让判别器误以为它是真的

2. Discriminator 判别器

判别器的输入是一张图片,输出是一个概率 (这张图片是真实图片的概率)
D ( x ) ∈ 0 , 1 D(x) ∈ 0, 1 D(x)∈0,1

对于真实图片 x,判别器希望: D ( x ) → 1 D(x) → 1 D(x)→1

对于生成图片 G(z),判别器希望: D ( G ( z ) ) → 0 D(G(z)) → 0 D(G(z))→0

3. 目标函数

经典目标函数:
min ⁡ G max ⁡ D V ( D , G ) \min_G \max_D V(D, G) GminDmaxV(D,G)

极小极大博弈:

D 想最大化识别真假能力

G 想最小化 D 的识别能力

其中:
V ( D , G ) = E x ∼ p d a t a ( x ) log ⁡ D ( x ) + E z ∼ p z ( z ) log ⁡ ( 1 − D ( G ( z ) ) ) V(D, G) = \mathbb{E}{x \sim p{data}(x)}\\log D(x) + \mathbb{E}_{z \sim p_z(z)}\\log(1 - D(G(z))) V(D,G)=Ex∼pdata(x)logD(x)+Ez∼pz(z)log(1−D(G(z)))

  • 判别器 D 希望
    • 最大化 l o g D ( x ) logD(x) logD(x)
    • 最大化 l o g ( 1 − D ( G ( z ) ) ) log(1−D(G(z))) log(1−D(G(z)))
  • 生成器 G 希望
    • 最小化 l o g ( 1 − D ( G ( z ) ) ) log(1−D(G(z))) log(1−D(G(z)))
4. 训练流程

第一步: 训练判别器D (固定G)

判别器损失可以写成:

L D = − E x ∼ p d a t a ( x ) log ⁡ D ( x ) − E z ∼ p z ( z ) log ⁡ ( 1 − D ( G ( z ) ) ) L_D =-\mathbb{E}{x \sim p{data}(x)}\\log D(x)- \mathbb{E}_{z \sim p_z(z)}\\log(1 - D(G(z))) LD=−Ex∼pdata(x)logD(x)−Ez∼pz(z)log(1−D(G(z)))

第二步: 训练生成器G (固定D)

生成器损失可以写成:
L G = − E z ∼ p z ( z ) log ⁡ D ( G ( z ) ) L_G = -\mathbb{E}_{z \sim p_z(z)}\\log D(G(z)) LG=−Ez∼pz(z)logD(G(z))

👍优点/创新点

1. 生成效果清晰

相比 VAE,GAN 不直接依赖像素级重建损失。

所以 GAN 生成的图像通常更加锐利、清晰

2. 不需要显式建模数据分布

很多传统 生成模型需要显式 建模概率分布 p ( x ) p(x) p(x)

但是真实世界的数据分布往往非常复杂!

GAN 不直接写出真实数据分布的具体形式,而是通过判别器提供学习信号

也就是说:GAN 可以隐式地学习复杂数据分布

👎缺点

1. 训练不稳定

GAN 是两个网络互相博弈,不是普通的单目标优化

所以训练时可能出现:

  • 判别器太强 → 生成器学不到东西
  • 生成器太强 → 判别器失去作用
  • 两者震荡 → 损失不收敛

但是 GAN 的 loss 不一定能直接反映生成质量:

  • 有时候 loss 看起来正常,图片很烂
  • 有时候 loss 看起来奇怪,图片反而不错
2. 模式崩溃 Mode Collapse

模式崩溃:生成器只学会生成少数几种样本,而不是覆盖整个真实数据分布

3. 评价困难

对于分类任务,我们可以用准确率

对于回归任务,我们可以用 MSE

但是对于生成模型,评价就比较麻烦!

GAN 常用评价指标:

  • IS: Inception Score
  • FID: Fréchet Inception Distance

但是这些指标也不是完美的

4. 对超参数敏感
5. 容易梯度消失

当判别器 D 太强时,它可以轻松判断:

  • 真实图片 → 1
  • 生成图片 → 0

此时对于生成器来说,梯度可能变得非常弱:

生成器知道自己错了, 但是不知道该怎么改

-> 这会导致生成器训练困难

💻模型代码

py 复制代码
import torch
import torch.nn as nn


class Generator(nn.Module):
    """
    生成器 Generator

    输入:
        z: 随机噪声向量
        shape = [batch_size, latent_dim]

    输出:
        img: 生成图片
        shape = [batch_size, 1, 28, 28]
    """

    def __init__(self, latent_dim=100, img_dim=28 * 28):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(1024, img_dim),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img


class Discriminator(nn.Module):
    """
    判别器 Discriminator

    输入:
        img: 图片
        shape = [batch_size, 1, 28, 28]

    输出:
        validity: 图片为真实图片的概率
        shape = [batch_size, 1]
    """

    def __init__(self, img_dim=28 * 28):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(img_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity


if __name__ == "__main__":
    latent_dim = 100
    batch_size = 8

    generator = Generator(latent_dim=latent_dim)
    discriminator = Discriminator()

    z = torch.randn(batch_size, latent_dim)

    fake_imgs = generator(z)
    outputs = discriminator(fake_imgs)

    print("随机噪声 z 的形状:", z.shape)
    print("生成图片 fake_imgs 的形状:", fake_imgs.shape)
    print("判别器输出 outputs 的形状:", outputs.shape)
相关推荐
大C聊AI6 分钟前
通用大模型纷纷收费,垂直场景AI工具的价值正在被重估
大数据·人工智能·机器学习·办公效率·ai 工具·智标领航·ai 辅助办公
苏州邦恩精密10 分钟前
2026江苏GOM三维扫描仪定制厂家找哪家?企业数字化转型视角
人工智能·机器学习·3d·自动化·制造
python-码博士11 分钟前
PyTorch 从零实现 Flow Matching:训练、采样、画图一条龙
人工智能·pytorch·python
砍光二叉树14 分钟前
一文打通 AI 认知:LLM、Agent、MCP、Skill 完整体系
人工智能·llm·agent·skill·mcp
努力写A题的小菜鸡22 分钟前
PyTorch 图像预处理 transforms 与 TensorBoard 可视化 (自己学习记录)
人工智能·pytorch·学习
测试仪器廖生1359025638526 分钟前
罗德与施瓦茨 FSP13频谱分析仪FSP30
网络·人工智能·算法
未来和明天26 分钟前
领嵌iLeadE-588边缘计算盒子16路AI视频分析、4路AHD、4路千兆网接口
人工智能·边缘计算
上海锝秉工控30 分钟前
省线型增量编码器:用“减法思维“重构工业控制的未来
网络·人工智能·重构
蓝星空200031 分钟前
怎么使用 Image 2 高效生成商业级 AI 图像(GPT-Image-2 全流程实操教程)
人工智能·gpt·ai作画
沉下去,苦磨练!33 分钟前
张量的形状操作以及拼接
人工智能