人工智能之核心技术 深度学习
第六章 生成对抗网络(GAN)
文章目录
- [人工智能之核心技术 深度学习](#人工智能之核心技术 深度学习)
- [前言:生成对抗网络(GAN)------ 生成模型核心](#前言:生成对抗网络(GAN)—— 生成模型核心)
- [一、GAN 基础原理](#一、GAN 基础原理)
- [1.1 核心思想:对抗训练](#1.1 核心思想:对抗训练)
- [1.2 数学形式:极小极大博弈](#1.2 数学形式:极小极大博弈)
- [1.3 训练流程(交替更新)](#1.3 训练流程(交替更新))
- [二、经典 GAN 变体](#二、经典 GAN 变体)
- [三、GAN 应用场景](#三、GAN 应用场景)
- [四、配套代码实现(DCGAN for MNIST)](#四、配套代码实现(DCGAN for MNIST))
- [五、GAN 的挑战与未来](#五、GAN 的挑战与未来)
- 六、总结对比
- 资料关注
前言:生成对抗网络(GAN)------ 生成模型核心
生成对抗网络(Generative Adversarial Network, GAN)由 Ian Goodfellow 等人在 2014 年提出,被誉为"深度学习中最酷的想法之一 "。它通过让两个神经网络相互博弈的方式,学会生成高度逼真的数据(如人脸、风景、艺术画等),开启了生成式 AI 的新纪元。
一、GAN 基础原理
1.1 核心思想:对抗训练
想象一个"造假者"和一个"鉴伪专家"的博弈:
- 生成器(Generator, G):试图伪造逼真的假币(或图像)
- 判别器(Discriminator, D):试图区分真币和假币
🎯 目标:
- G 要骗过 D(让 D 认为假图是真图)
- D 要准确识别真假(不被 G 欺骗)
随着训练进行,G 越来越擅长生成逼真样本,D 也越来越难分辨真假------最终达到纳什均衡。
反馈
反馈
随机噪声 z ~ N(0,1)
生成器 G
假图像 G(z)
真实图像 x
判别器 D
输出:真假概率
1.2 数学形式:极小极大博弈
GAN 的损失函数是一个极小极大优化问题:
min G max D V ( D , G ) = E x ∼ p data ( x ) [ log D ( x ) ] + E z ∼ p z ( z ) [ log ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}{x \sim p{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] GminDmaxV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
- 判别器 D 的目标 :最大化 V ( D , G ) V(D, G) V(D,G)
→ 对真实样本输出高概率,对假样本输出低概率 - 生成器 G 的目标 :最小化 V ( D , G ) V(D, G) V(D,G)
→ 让 D ( G ( z ) ) D(G(z)) D(G(z)) 尽可能大(即骗过 D)
⚠️ 训练技巧 :
实际中常将 G 的损失改为 min G − E [ log D ( G ( z ) ) ] \min_G -\mathbb{E}[\log D(G(z))] minG−E[logD(G(z))],避免训练初期梯度消失。
1.3 训练流程(交替更新)
python
for epoch in epochs:
# 1. 固定 G,更新 D
real_loss = BCE(D(real_img), 1) # 真图标签=1
fake_loss = BCE(D(G(noise)), 0) # 假图标签=0
d_loss = real_loss + fake_loss
update D
# 2. 固定 D,更新 G
g_loss = BCE(D(G(noise)), 1) # 假图标签=1(骗 D)
update G
🔁 关键 :D 和 G 交替训练,不能同时更新!
二、经典 GAN 变体
2.1 DCGAN(Deep Convolutional GAN, 2015)
首次成功将 CNN 引入 GAN,大幅提升图像生成质量。
创新设计:
| 组件 | 设计原则 |
|---|---|
| 生成器 | 使用转置卷积(Transposed Conv) 上采样(替代全连接层) |
| 判别器 | 标准 CNN(无池化,用 stride 卷积降采样) |
| 其他 | - 批归一化(BatchNorm)- ReLU(G)、LeakyReLU(D) |
✅ 效果:可生成 64×64 清晰人脸、卧室等图像
2.2 WGAN(Wasserstein GAN, 2017)
解决原始 GAN 的两大痛点:
- 训练不稳定
- 模式崩溃(Mode Collapse):G 只生成少数几种样本
核心改进:
- 使用 Wasserstein 距离(Earth Mover's Distance) 替代 JS 散度
- 判别器称为 Critic ,输出实数值(非概率)
- 对 Critic 参数强制 Lipschitz 连续 → 采用 权重裁剪 或 梯度惩罚(WGAN-GP)
WGAN-GP 损失:
L D = E [ D ( G ( z ) ) ] − E [ D ( x ) ] + λ E x ^ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] L G = − E [ D ( G ( z ) ) ] \begin{aligned} \mathcal{L}D &= \mathbb{E}[D(G(z))] - \mathbb{E}[D(x)] + \lambda \mathbb{E}{\hat{x}}[(\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2] \\ \mathcal{L}_G &= -\mathbb{E}[D(G(z))] \end{aligned} LDLG=E[D(G(z))]−E[D(x)]+λEx^[(∥∇x^D(x^)∥2−1)2]=−E[D(G(z))]
✅ 优势:损失值与生成质量正相关,可作为调参指标!
2.3 StyleGAN(2019, NVIDIA)
实现高分辨率、高保真、可控风格的人脸生成。
核心创新:
- 风格迁移架构:将生成过程分解为"风格"和"内容"
- 自适应实例归一化(AdaIN):在每一层注入风格向量
- 渐进式增长:从 4×4 开始,逐步增加分辨率至 1024×1024
- 映射网络(Mapping Network) :将随机噪声 z z z 映射到中间空间 w w w,提升解耦性
z ~ N(0,1)
MLP: z → w
w
AdaIN1
AdaIN2
AdaIN3
常量张量
Conv1
Conv2
...
Output
🌟 效果:可独立控制人脸姿态、肤色、发型等属性!
三、GAN 应用场景
| 应用 | 说明 | 代表工作 |
|---|---|---|
| 图像生成 | 生成逼真人脸、动物、艺术画 | StyleGAN, BigGAN |
| 图像修复(Inpainting) | 填补图像缺失区域 | Context Encoder + GAN |
| 超分辨率(SRGAN) | 低清 → 高清 | SRGAN, ESRGAN |
| 风格迁移 | 将艺术风格迁移到照片 | CycleGAN, Neural Style Transfer |
| 文本到图像生成 | 根据描述生成图像 | AttnGAN, DALL·E(早期) |
| 数据增强 | 生成稀缺样本(如医学图像) | Medical GAN |
💡 注意 :GAN 在文本生成上效果有限(离散 token 难优化),现多被 Transformer 取代。
四、配套代码实现(DCGAN for MNIST)
python
import torch
import torch.nn as nn
import torch.optim as optim
# 1. 生成器(使用转置卷积)
class Generator(nn.Module):
def __init__(self, nz=100, ngf=64):
super().__init__()
self.main = nn.Sequential(
# 输入: [batch, nz, 1, 1]
nn.ConvTranspose2d(nz, ngf*4, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf*4),
nn.ReLU(True),
# 输出: [batch, ngf*4, 4, 4]
nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf*2),
nn.ReLU(True),
# 输出: [batch, ngf*2, 8, 8]
nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# 输出: [batch, ngf, 16, 16]
nn.ConvTranspose2d(ngf, 1, 4, 2, 1, bias=False),
nn.Tanh() # 输出 [-1, 1]
# 最终: [batch, 1, 28, 28]
)
def forward(self, z):
return self.main(z)
# 2. 判别器(标准 CNN)
class Discriminator(nn.Module):
def __init__(self, ndf=64):
super().__init__()
self.main = nn.Sequential(
# 输入: [batch, 1, 28, 28]
nn.Conv2d(1, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 输出: [batch, ndf, 14, 14]
nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf*2),
nn.LeakyReLU(0.2, inplace=True),
# 输出: [batch, ndf*2, 7, 7]
nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf*4),
nn.LeakyReLU(0.2, inplace=True),
# 输出: [batch, ndf*4, 3, 3]
nn.Conv2d(ndf*4, 1, 3, 1, 0, bias=False),
nn.Sigmoid() # 输出 [0,1]
)
def forward(self, x):
return self.main(x).view(-1, 1).squeeze(1)
# 3. 训练设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nz = 100 # 噪声维度
netG = Generator(nz).to(device)
netD = Discriminator().to(device)
criterion = nn.BCELoss()
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 4. 训练循环(简化版)
for epoch in range(100):
for i, (real_imgs, _) in enumerate(dataloader):
batch_size = real_imgs.size(0)
real_imgs = real_imgs.to(device)
real_labels = torch.ones(batch_size, device=device)
fake_labels = torch.zeros(batch_size, device=device)
# --- 更新判别器 ---
optimizerD.zero_grad()
# 真图损失
output = netD(real_imgs)
loss_real = criterion(output, real_labels)
# 假图损失
noise = torch.randn(batch_size, nz, 1, 1, device=device)
fake_imgs = netG(noise)
output = netD(fake_imgs.detach()) # detach 阻止梯度回传到 G
loss_fake = criterion(output, fake_labels)
loss_D = loss_real + loss_fake
loss_D.backward()
optimizerD.step()
# --- 更新生成器 ---
optimizerG.zero_grad()
output = netD(fake_imgs)
loss_G = criterion(output, real_labels) # 假图标签=1
loss_G.backward()
optimizerG.step()
print(f"Epoch {epoch}, Loss_D: {loss_D.item():.4f}, Loss_G: {loss_G.item():.4f}")
📌 提示:
- 使用
Tanh()生成器输出需将真实图像归一化到 [-1, 1]detach()防止判别器梯度影响生成器- Adam 优化器参数
betas=(0.5, 0.999)是 GAN 常用设置
五、GAN 的挑战与未来
当前挑战:
| 问题 | 描述 |
|---|---|
| 训练不稳定 | G 和 D 难以平衡 |
| 评估困难 | 缺乏可靠指标(常用 FID、IS) |
| 模式崩溃 | G 生成多样性不足 |
| 计算成本高 | 高分辨率生成需大量 GPU 资源 |
未来方向:
- 与扩散模型融合(如 GAN + Diffusion)
- 3D 生成(NeRF + GAN)
- 可控生成(结合 CLIP 等多模态模型)
六、总结对比
| 模型 | 核心贡献 | 适用场景 |
|---|---|---|
| 原始 GAN | 对抗训练框架 | 理论奠基 |
| DCGAN | CNN 架构标准化 | 中小图像生成 |
| WGAN/WGAN-GP | 稳定训练 + 解决模式崩溃 | 通用改进 |
| StyleGAN | 高质量 + 属性解耦 | 人脸/肖像生成 |
原始 GAN
DCGAN
CNN 架构
WGAN
稳定训练
StyleGAN
高质量可控生成
StyleGAN2/3
细节优化
✅ 实践建议:
- 入门实验 → DCGAN + MNIST/CIFAR10
- 高质量人脸 → StyleGAN2-ADA
- 稳定训练 → 优先尝试 WGAN-GP
资料关注
公众号:咚咚王
gitee:https://gitee.com/wy18585051844/ai_learning
《Python编程:从入门到实践》
《利用Python进行数据分析》
《算法导论中文第三版》
《概率论与数理统计(第四版) (盛骤) 》
《程序员的数学》
《线性代数应该这样学第3版》
《微积分和数学分析引论》
《(西瓜书)周志华-机器学习》
《TensorFlow机器学习实战指南》
《Sklearn与TensorFlow机器学习实用指南》
《模式识别(第四版)》
《深度学习 deep learning》伊恩·古德费洛著 花书
《Python深度学习第二版(中文版)【纯文本】 (登封大数据 (Francois Choliet)) (Z-Library)》
《深入浅出神经网络与深度学习+(迈克尔·尼尔森(Michael+Nielsen)》
《自然语言处理综论 第2版》
《Natural-Language-Processing-with-PyTorch》
《计算机视觉-算法与应用(中文版)》
《Learning OpenCV 4》
《AIGC:智能创作时代》杜雨+&+张孜铭
《AIGC原理与实践:零基础学大语言模型、扩散模型和多模态模型》
《从零构建大语言模型(中文版)》
《实战AI大模型》
《AI 3.0》