[DL_Net从入门到入土] 生成对抗网络 GAN
📢个人导航
知乎:https://www.zhihu.com/people/byzh_rc
CSDN:https://blog.csdn.net/qq_54636039
注:本文仅对所述内容做了框架性引导,具体细节可查询其余相关资料or源码
参考文章:各方资料
文章目录
- [[DL_Net从入门到入土] 生成对抗网络 GAN](#[DL_Net从入门到入土] 生成对抗网络 GAN)
-
- 📢个人导航
- 📖参考资料
- 🌱背景
- ⚙️架构(公式)
-
-
- [1. Generator 生成器](#1. Generator 生成器)
- [2. Discriminator 判别器](#2. Discriminator 判别器)
- [3. 目标函数](#3. 目标函数)
- [4. 训练流程](#4. 训练流程)
-
- 👍优点/创新点
-
-
- [1. 生成效果清晰](#1. 生成效果清晰)
- [2. 不需要显式建模数据分布](#2. 不需要显式建模数据分布)
-
- 👎缺点
-
-
- [1. 训练不稳定](#1. 训练不稳定)
- [2. 模式崩溃 Mode Collapse](#2. 模式崩溃 Mode Collapse)
- [3. 评价困难](#3. 评价困难)
- [4. 对超参数敏感](#4. 对超参数敏感)
- [5. 容易梯度消失](#5. 容易梯度消失)
-
- 💻模型代码
📖参考资料
Generative Adversarial Networks.
🌱背景
VAE的目标,是让模型学会数据分布,然后从这个分布中采样生成新数据
但是 VAE 有一个常见问题: 生成结果容易偏模糊
因为 VAE 通常使用像素级重建损失,例如 MSE 或 BCE
-> 这类损失会鼓励模型生成"平均结果"
比如真实数据里有很多种写法的数字 8
模型可能学出一种比较平均、比较保守的 8
-> 于是生成图片就容易糊
GAN: Generative Adversarial Network: 生成对抗网络
GAN不直接要求生成图片和某张真实图片逐像素接近,而是引入两个网络:
- Generator:生成器
- Discriminator:判别器
生成器 G:我要造假图,骗过判别器
判别器 D:我要分辨图片是真图还是假图
->
生成器 G 生成的图片越来越真实
判别器 D 越来越难判断图片真假
⚙️架构(公式)
随机噪声 z → Generator G → 生成样本 G(z)
真实样本 x → Discriminator D → 判断为真的概率 D(x)
生成样本 G(z) → Discriminator D → 判断为真的概率 D(G(z))
随机噪声 z
生成器 G
生成样本 G(z)
真实样本 x
判别器 D
输出真假概率
1. Generator 生成器
生成器的输入通常是一个随机噪声 z, 服从标准正态分布
z ∼ N ( 0 , I ) z \sim N(0, I) z∼N(0,I)
生成器把这个噪声映射成一个样本 G ( z ) G(z) G(z)
生成器的目标是:生成尽可能真实的样本,让判别器误以为它是真的
2. Discriminator 判别器
判别器的输入是一张图片,输出是一个概率 (这张图片是真实图片的概率)
D ( x ) ∈ [ 0 , 1 ] D(x) ∈ [0, 1] D(x)∈[0,1]
对于真实图片 x,判别器希望: D ( x ) → 1 D(x) → 1 D(x)→1
对于生成图片 G(z),判别器希望: D ( G ( z ) ) → 0 D(G(z)) → 0 D(G(z))→0
3. 目标函数
经典目标函数:
min G max D V ( D , G ) \min_G \max_D V(D, G) GminDmaxV(D,G)
极小极大博弈:
D 想最大化识别真假能力
G 想最小化 D 的识别能力
其中:
V ( D , G ) = E x ∼ p d a t a ( x ) [ log D ( x ) ] + E z ∼ p z ( z ) [ log ( 1 − D ( G ( z ) ) ) ] V(D, G) = \mathbb{E}{x \sim p{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] V(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
- 判别器 D 希望
- 最大化 l o g D ( x ) logD(x) logD(x)
- 最大化 l o g ( 1 − D ( G ( z ) ) ) log(1−D(G(z))) log(1−D(G(z)))
- 生成器 G 希望
- 最小化 l o g ( 1 − D ( G ( z ) ) ) log(1−D(G(z))) log(1−D(G(z)))
4. 训练流程
第一步: 训练判别器D (固定G)
判别器损失可以写成:
L D = − E x ∼ p d a t a ( x ) [ log D ( x ) ] − E z ∼ p z ( z ) [ log ( 1 − D ( G ( z ) ) ) ] L_D =-\mathbb{E}{x \sim p{data}(x)}[\log D(x)]- \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] LD=−Ex∼pdata(x)[logD(x)]−Ez∼pz(z)[log(1−D(G(z)))]
第二步: 训练生成器G (固定D)
生成器损失可以写成:
L G = − E z ∼ p z ( z ) [ log D ( G ( z ) ) ] L_G = -\mathbb{E}_{z \sim p_z(z)}[\log D(G(z))] LG=−Ez∼pz(z)[logD(G(z))]
👍优点/创新点
1. 生成效果清晰
相比 VAE,GAN 不直接依赖像素级重建损失。
所以 GAN 生成的图像通常更加锐利、清晰
2. 不需要显式建模数据分布
很多传统 生成模型需要显式 建模概率分布 p ( x ) p(x) p(x)
但是真实世界的数据分布往往非常复杂!
GAN 不直接写出真实数据分布的具体形式,而是通过判别器提供学习信号
也就是说:GAN 可以隐式地学习复杂数据分布
👎缺点
1. 训练不稳定
GAN 是两个网络互相博弈,不是普通的单目标优化
所以训练时可能出现:
- 判别器太强 → 生成器学不到东西
- 生成器太强 → 判别器失去作用
- 两者震荡 → 损失不收敛
但是 GAN 的 loss 不一定能直接反映生成质量:
- 有时候 loss 看起来正常,图片很烂
- 有时候 loss 看起来奇怪,图片反而不错
2. 模式崩溃 Mode Collapse
模式崩溃:生成器只学会生成少数几种样本,而不是覆盖整个真实数据分布
3. 评价困难
对于分类任务,我们可以用准确率
对于回归任务,我们可以用 MSE
但是对于生成模型,评价就比较麻烦!
GAN 常用评价指标:
- IS: Inception Score
- FID: Fréchet Inception Distance
但是这些指标也不是完美的
4. 对超参数敏感
5. 容易梯度消失
当判别器 D 太强时,它可以轻松判断:
- 真实图片 → 1
- 生成图片 → 0
此时对于生成器来说,梯度可能变得非常弱:
生成器知道自己错了, 但是不知道该怎么改
-> 这会导致生成器训练困难
💻模型代码
py
import torch
import torch.nn as nn
class Generator(nn.Module):
"""
生成器 Generator
输入:
z: 随机噪声向量
shape = [batch_size, latent_dim]
输出:
img: 生成图片
shape = [batch_size, 1, 28, 28]
"""
def __init__(self, latent_dim=100, img_dim=28 * 28):
super().__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, img_dim),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), 1, 28, 28)
return img
class Discriminator(nn.Module):
"""
判别器 Discriminator
输入:
img: 图片
shape = [batch_size, 1, 28, 28]
输出:
validity: 图片为真实图片的概率
shape = [batch_size, 1]
"""
def __init__(self, img_dim=28 * 28):
super().__init__()
self.model = nn.Sequential(
nn.Linear(img_dim, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
if __name__ == "__main__":
latent_dim = 100
batch_size = 8
generator = Generator(latent_dim=latent_dim)
discriminator = Discriminator()
z = torch.randn(batch_size, latent_dim)
fake_imgs = generator(z)
outputs = discriminator(fake_imgs)
print("随机噪声 z 的形状:", z.shape)
print("生成图片 fake_imgs 的形状:", fake_imgs.shape)
print("判别器输出 outputs 的形状:", outputs.shape)