从GAN到WGAN-GP:生成对抗网络的进化之路与实战详解

在深度学习的生成模型领域,GAN (Generative Adversarial Networks) 无疑是最耀眼的明星之一。从2014年 Ian Goodfellow 提出 GAN 至今,它已经经历了无数次的迭代和进化。其中,WGAN 和 WGAN-GP 是两次里程碑式的改进,它们从数学原理上解决了原始 GAN 训练不稳定、模式崩塌等"顽疾"。
本文将深入浅出地梳理从 GAN 到 WGAN 再到 WGAN-GP 的演进逻辑,分析它们背后的数学直觉,并提供核心代码实现。
一、GAN:天才的博弈
1.1 基本原理
GAN 的灵感来源于博弈论。它由两个网络组成:
- 生成器 (Generator, G):负责制造"假钞"(生成数据)。它的目标是生成尽可能逼真的数据,以骗过判别器。
- 判别器 (Discriminator, D):负责充当"验钞机"。它的目标是尽可能准确地分辨出输入数据是真实的(来自数据集)还是假的(由 G 生成)。
两者的目标函数是一个 Min-Max 博弈:
minGmaxDV(D,G)=Ex∼Pdata(x)[logD(x)]+Ez∼Pz(z)[log(1−D(G(z)))] \min_G \max_D V(D, G) = \mathbb{E}{x \sim P{data}(x)} [\log D(x)] + \mathbb{E}{z \sim P{z}(z)} [\log (1 - D(G(z)))] GminDmaxV(D,G)=Ex∼Pdata(x)[logD(x)]+Ez∼Pz(z)[log(1−D(G(z)))]
1.2 GAN 的阿喀琉斯之踵
虽然 GAN 的思想非常精妙,但在实际训练中,研究者们发现它非常难训练,主要面临以下问题:
- 训练不稳定:G 和 D 需要小心翼翼地平衡。如果 D 太强,G 的梯度会消失;如果 D 太弱,G 又学不到东西。
- 模式崩塌 (Mode Collapse):G 发现生成某一种特定的样本特别容易骗过 D,于是它就只生成这一种样本,失去了多样性。
- 无法指示训练进程:GAN 的 Loss 值通常震荡剧烈,无法像监督学习那样通过 Loss 下降来判断模型是否变好。
根本原因 :原始 GAN 等价于在该小 JS 散度 (Jensen-Shannon Divergence)。当真实分布 PrP_rPr 和生成分布 PgP_gPg 重叠很少甚至不重叠时(在高维空间中这很常见),JS 散度是常数,导致梯度消失,G 无法获得有效的更新方向。
二、WGAN:推土机距离的救赎
为了解决 GAN 的问题,2017年 Arjovsky 等人提出了 Wasserstein GAN (WGAN)。
2.1 核心思想:Wasserstein 距离
WGAN 引入了 Wasserstein 距离(也称 Earth-Mover Distance,EM 距离,推土机距离)。
简单来说,如果把两个分布看作是两堆土,EM 距离就是把一堆土搬到另一堆土的位置所消耗的最小"功"(质量 imes 距离)。
优势:即使两个分布完全不重叠,Wasserstein 距离仍然能提供平滑的梯度,指引 G 慢慢向真实分布靠拢。这彻底解决了梯度消失的问题。
2.2 WGAN 的改进点
为了近似计算 Wasserstein 距离,WGAN 做了以下改动:
- 判别器变身"评论家" (Critic):D 的最后一层去掉 Sigmoid,不再输出概率,而是输出一个实数值(评分)。
- Loss 改变 :
- LD=Ex~∼Pg[D(x~)]−Ex∼Pr[D(x)]L_D = \mathbb{E}{\tilde{x} \sim P_g}[D(\tilde{x})] - \mathbb{E}{x \sim P_r}[D(x)]LD=Ex~∼Pg[D(x~)]−Ex∼Pr[D(x)]
- LG=−Ex~∼Pg[D(x~)]L_G = -\mathbb{E}_{\tilde{x} \sim P_g}[D(\tilde{x})]LG=−Ex~∼Pg[D(x~)]
- 注意:不再取 log。
- 权重剪枝 (Weight Clipping) :为了满足 Wasserstein 距离成立的数学条件(1-Lipschitz 连续性 ),WGAN 强制将 Critic 的所有参数限制在 [−c,c][-c, c][−c,c] 之间(例如 c=0.01)。
2.3 WGAN 的局限
虽然 WGAN 解决了训练稳定性问题,并且 Loss 值终于可以代表图像质量了,但 Weight Clipping 过于简单粗暴:
- 它限制了 Critic 的表达能力。
- 容易导致参数集中在截断边界(-c 和 c),不仅浪费了神经网络的拟合能力,还可能引发梯度爆炸或消失。
三、WGAN-GP:梯度惩罚的优雅
为了解决 Weight Clipping 的副作用,Gulrajani 等人提出了 WGAN-GP (WGAN with Gradient Penalty)。
3.1 核心改进:梯度惩罚
WGAN-GP 依然沿用了 WGAN 的架构和 Loss,但改变了实现 1-Lipschitz 约束的方法。
它不再直接剪裁参数,而是在 Critic 的 Loss 函数中增加了一个 梯度惩罚项 (Gradient Penalty)。根据数学推导,如果一个函数是 1-Lipschitz 的,那么它的梯度范数应该在处处不超过 1。WGAN-GP 鼓励梯度的范数接近 1。
3.2 Loss 函数详解
新的 Critic Loss 如下:
L=Ex~∼Pg[D(x~)]−Ex∼Pr[D(x)]⏟原始 WGAN Loss+λEx^∼Px^[(∣∣∇x^D(x^)∣∣2−1)2]⏟Gradient Penalty L = \underbrace{\mathbb{E}{\tilde{x} \sim P_g}[D(\tilde{x})] - \mathbb{E}{x \sim P_r}[D(x)]}{\text{原始 WGAN Loss}} + \lambda \underbrace{\mathbb{E}{\hat{x} \sim P_{\hat{x}}}[(||\nabla_{\hat{x}} D(\hat{x})||2 - 1)^2]}{\text{Gradient Penalty}} L=原始 WGAN Loss Ex~∼Pg[D(x~)]−Ex∼Pr[D(x)]+λGradient Penalty Ex^∼Px^[(∣∣∇x^D(x^)∣∣2−1)2]
其中:
- λ\lambdaλ 是惩罚系数(通常取 10)。
- x^\hat{x}x^ 是采样点。我们在真实样本 xxx 和生成样本 x~\tilde{x}x~ 之间随机插值采样:x^=ϵx+(1−ϵ)x~\hat{x} = \epsilon x + (1-\epsilon) \tilde{x}x^=ϵx+(1−ϵ)x~,ϵ∼U[0,1]\epsilon \sim U[0, 1]ϵ∼U[0,1]。我们约束这些插值点上的梯度范数接近 1。
四、PyTorch 核心代码实现
下面是 WGAN-GP 核心部分的 PyTorch 实现代码。
4.1 梯度惩罚计算函数
python
import torch
import torch.nn as nn
import torch.autograd as autograd
def compute_gradient_penalty(D, real_samples, fake_samples, device):
"""
计算 WGAN-GP 的梯度惩罚项
"""
# 1. 随机权重插值
alpha = torch.rand(real_samples.size(0), 1, 1, 1).to(device)
# 假设输入是图片 (N, C, H, W),需要根据维度调整 alpha 的形状
interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
# 2. 将插值样本输入判别器
d_interpolates = D(interpolates)
# 3. 计算判别器输出相对于插值样本的梯度
fake = torch.ones(real_samples.shape[0], 1).to(device)
gradients = autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=fake,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
# 4. 计算梯度范数
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
4.2 训练循环示例
python
# 超参数
lambda_gp = 10
n_critic = 5 # 每训练 1 次 Generator,训练 5 次 Critic
# ... 初始化 DataLoader, Generator(G), Critic(D), Optimizers ...
for i, (imgs, _) in enumerate(dataloader):
real_imgs = imgs.to(device)
batch_size = real_imgs.shape[0]
# ---------------------
# 训练 Critic (D)
# ---------------------
optimizer_D.zero_grad()
# 生成假样本
z = torch.randn(batch_size, latent_dim).to(device)
fake_imgs = G(z)
# 计算 WGAN Loss
# 注意:为了使用 min 优化器,我们将最大化问题转化为最小化负数
real_validity = D(real_imgs)
fake_validity = D(fake_imgs)
# WGAN Loss: -E[D(x)] + E[D(G(z))]
d_loss = -torch.mean(real_validity) + torch.mean(fake_validity)
# 计算梯度惩罚
gradient_penalty = compute_gradient_penalty(D, real_imgs, fake_imgs.detach(), device)
# 总 Loss
d_loss = d_loss + lambda_gp * gradient_penalty
d_loss.backward()
optimizer_D.step()
# ---------------------
# 训练 Generator (G)
# ---------------------
# 每 n_critic 步训练一次 G
if i % n_critic == 0:
optimizer_G.zero_grad()
# 重新生成假样本 (可选,也可以复用上面的,但为了计算图清晰通常重新生成)
gen_imgs = G(z)
# G 的目标是让 D 给出的分数越高越好 (即最小化 -D(G(z)))
g_loss = -torch.mean(D(gen_imgs))
g_loss.backward()
optimizer_G.step()
if i % 100 == 0:
print(f"[Epoch {epoch}/{opt.n_epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")
五、总结与建议
| 特性 | GAN | WGAN | WGAN-GP |
|---|---|---|---|
| 判别器输出 | 概率 [0, 1] (Sigmoid) | 实数评分 (无 Sigmoid) | 实数评分 (无 Sigmoid) |
| Loss 函数 | JS 散度 (Log Loss) | Wasserstein 距离 (均值差) | Wasserstein 距离 + 梯度惩罚 |
| 约束方法 | 无 | Weight Clipping (参数截断) | Gradient Penalty (梯度惩罚) |
| 训练稳定性 | 差 (易模式崩塌) | 较好 | 极好 |
| 收敛速度 | 快 (但不稳定) | 较慢 | 适中 |
实际使用建议 :
如果你要处理新的生成任务,首选 WGAN-GP。它几乎不需要繁琐的超参数调试就能稳定训练,而且 Loss 曲线能真实反映图像质量的提升。虽然计算梯度惩罚会稍微增加一点训练时间,但相比于原始 GAN 调参的痛苦,这是非常值得的投入。