[DL_Net从入门到入土] 生成对抗网络 GAN

[DL_Net从入门到入土] 生成对抗网络 GAN

📢个人导航

知乎:https://www.zhihu.com/people/byzh_rc

CSDN:https://blog.csdn.net/qq_54636039

注:本文仅对所述内容做了框架性引导,具体细节可查询其余相关资料or源码

参考文章:各方资料

文章目录

  • [[DL_Net从入门到入土] 生成对抗网络 GAN](#[DL_Net从入门到入土] 生成对抗网络 GAN)
    • 📢个人导航
    • 📖参考资料
    • 🌱背景
    • ⚙️架构(公式)
        • [1. Generator 生成器](#1. Generator 生成器)
        • [2. Discriminator 判别器](#2. Discriminator 判别器)
        • [3. 目标函数](#3. 目标函数)
        • [4. 训练流程](#4. 训练流程)
    • 👍优点/创新点
        • [1. 生成效果清晰](#1. 生成效果清晰)
        • [2. 不需要显式建模数据分布](#2. 不需要显式建模数据分布)
    • 👎缺点
        • [1. 训练不稳定](#1. 训练不稳定)
        • [2. 模式崩溃 Mode Collapse](#2. 模式崩溃 Mode Collapse)
        • [3. 评价困难](#3. 评价困难)
        • [4. 对超参数敏感](#4. 对超参数敏感)
        • [5. 容易梯度消失](#5. 容易梯度消失)
    • 💻模型代码

📖参考资料

Generative Adversarial Networks.

🌱背景

VAE的目标,是让模型学会数据分布,然后从这个分布中采样生成新数据

但是 VAE 有一个常见问题: 生成结果容易偏模糊

因为 VAE 通常使用像素级重建损失,例如 MSE 或 BCE

-> 这类损失会鼓励模型生成"平均结果"

比如真实数据里有很多种写法的数字 8

模型可能学出一种比较平均、比较保守的 8

-> 于是生成图片就容易糊

GAN: Generative Adversarial Network: 生成对抗网络

GAN不直接要求生成图片和某张真实图片逐像素接近,而是引入两个网络:

  • Generator:生成器
  • Discriminator:判别器

生成器 G:我要造假图,骗过判别器

判别器 D:我要分辨图片是真图还是假图

->

生成器 G 生成的图片越来越真实

判别器 D 越来越难判断图片真假

⚙️架构(公式)

复制代码
随机噪声 z → Generator G → 生成样本 G(z)
真实样本 x → Discriminator D → 判断为真的概率 D(x)
生成样本 G(z) → Discriminator D → 判断为真的概率 D(G(z))

随机噪声 z
生成器 G
生成样本 G(z)
真实样本 x
判别器 D
输出真假概率

1. Generator 生成器

生成器的输入通常是一个随机噪声 z, 服从标准正态分布
z ∼ N ( 0 , I ) z \sim N(0, I) z∼N(0,I)

生成器把这个噪声映射成一个样本 G ( z ) G(z) G(z)

生成器的目标是:生成尽可能真实的样本,让判别器误以为它是真的

2. Discriminator 判别器

判别器的输入是一张图片,输出是一个概率 (这张图片是真实图片的概率)
D ( x ) ∈ [ 0 , 1 ] D(x) ∈ [0, 1] D(x)∈[0,1]

对于真实图片 x,判别器希望: D ( x ) → 1 D(x) → 1 D(x)→1

对于生成图片 G(z),判别器希望: D ( G ( z ) ) → 0 D(G(z)) → 0 D(G(z))→0

3. 目标函数

经典目标函数:
min ⁡ G max ⁡ D V ( D , G ) \min_G \max_D V(D, G) GminDmaxV(D,G)

极小极大博弈:

D 想最大化识别真假能力

G 想最小化 D 的识别能力

其中:
V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] V(D, G) = \mathbb{E}{x \sim p{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] V(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]

  • 判别器 D 希望
    • 最大化 l o g D ( x ) logD(x) logD(x)
    • 最大化 l o g ( 1 − D ( G ( z ) ) ) log(1−D(G(z))) log(1−D(G(z)))
  • 生成器 G 希望
    • 最小化 l o g ( 1 − D ( G ( z ) ) ) log(1−D(G(z))) log(1−D(G(z)))
4. 训练流程

第一步: 训练判别器D (固定G)

判别器损失可以写成:

L D = − E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] − E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] L_D =-\mathbb{E}{x \sim p{data}(x)}[\log D(x)]- \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] LD=−Ex∼pdata(x)[logD(x)]−Ez∼pz(z)[log(1−D(G(z)))]

第二步: 训练生成器G (固定D)

生成器损失可以写成:
L G = − E z ∼ p z ( z ) [ log ⁡ D ( G ( z ) ) ] L_G = -\mathbb{E}_{z \sim p_z(z)}[\log D(G(z))] LG=−Ez∼pz(z)[logD(G(z))]

👍优点/创新点

1. 生成效果清晰

相比 VAE,GAN 不直接依赖像素级重建损失。

所以 GAN 生成的图像通常更加锐利、清晰

2. 不需要显式建模数据分布

很多传统 生成模型需要显式 建模概率分布 p ( x ) p(x) p(x)

但是真实世界的数据分布往往非常复杂!

GAN 不直接写出真实数据分布的具体形式,而是通过判别器提供学习信号

也就是说:GAN 可以隐式地学习复杂数据分布

👎缺点

1. 训练不稳定

GAN 是两个网络互相博弈,不是普通的单目标优化

所以训练时可能出现:

  • 判别器太强 → 生成器学不到东西
  • 生成器太强 → 判别器失去作用
  • 两者震荡 → 损失不收敛

但是 GAN 的 loss 不一定能直接反映生成质量:

  • 有时候 loss 看起来正常,图片很烂
  • 有时候 loss 看起来奇怪,图片反而不错
2. 模式崩溃 Mode Collapse

模式崩溃:生成器只学会生成少数几种样本,而不是覆盖整个真实数据分布

3. 评价困难

对于分类任务,我们可以用准确率

对于回归任务,我们可以用 MSE

但是对于生成模型,评价就比较麻烦!

GAN 常用评价指标:

  • IS: Inception Score
  • FID: Fréchet Inception Distance

但是这些指标也不是完美的

4. 对超参数敏感
5. 容易梯度消失

当判别器 D 太强时,它可以轻松判断:

  • 真实图片 → 1
  • 生成图片 → 0

此时对于生成器来说,梯度可能变得非常弱:

生成器知道自己错了, 但是不知道该怎么改

-> 这会导致生成器训练困难

💻模型代码

py 复制代码
import torch
import torch.nn as nn


class Generator(nn.Module):
    """
    生成器 Generator

    输入:
        z: 随机噪声向量
        shape = [batch_size, latent_dim]

    输出:
        img: 生成图片
        shape = [batch_size, 1, 28, 28]
    """

    def __init__(self, latent_dim=100, img_dim=28 * 28):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(1024, img_dim),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img


class Discriminator(nn.Module):
    """
    判别器 Discriminator

    输入:
        img: 图片
        shape = [batch_size, 1, 28, 28]

    输出:
        validity: 图片为真实图片的概率
        shape = [batch_size, 1]
    """

    def __init__(self, img_dim=28 * 28):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(img_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity


if __name__ == "__main__":
    latent_dim = 100
    batch_size = 8

    generator = Generator(latent_dim=latent_dim)
    discriminator = Discriminator()

    z = torch.randn(batch_size, latent_dim)

    fake_imgs = generator(z)
    outputs = discriminator(fake_imgs)

    print("随机噪声 z 的形状:", z.shape)
    print("生成图片 fake_imgs 的形状:", fake_imgs.shape)
    print("判别器输出 outputs 的形状:", outputs.shape)
相关推荐
猫头虎3 小时前
【Trea】Trea国内版|国际版|海外版下载|Mac版|Windows版|Linux下载配置教程
linux·人工智能·windows·macos·aigc·ai编程·agi
烟雨江南7853 小时前
从转写到智能体决策:基于“灵声智库”与本地大模型(LLM)的政务热线智能分析与 RAG 知识库融合架构
人工智能·科技·架构·语音识别·政务·ai质检
大可ai中文版镜像3 小时前
OpenAI Codex Desktop App 保姆级安装教程(Windows / Mac)
人工智能·macos·codex
YJlio3 小时前
ChatGPT 2023年5月更新解读:iOS App上线,从网页产品扩展到移动端
人工智能·openai·ai工具·ios app·移动端语音输入·whisper产品分析
不懒不懒3 小时前
Python+AI 大模型实现课堂教学质量智能分析|加权评分 + 自动诊断 + 改进建议
人工智能·python·深度学习·ai大模型·智慧教育·nlp算法
rosemary5123 小时前
AI Infra 后端开发工程师 — 学习路线
人工智能·学习
oy_mail3 小时前
当前主流大语言模型核心优势解析:Gemini、GPT与Claude的能力图谱
人工智能·媒体
极客老王说Agent3 小时前
【企业级Agent】制造业生产预算智能管控系统使用教程:2026企业数智化转型全实战
人工智能·ai·chatgpt
曾响铃3 小时前
堆卡时代终结:AI算力基础设施迎来“系统重构”时刻
人工智能·重构