人工智能之核心技术 深度学习 第六章 生成对抗网络(GAN)

人工智能之核心技术 深度学习

第六章 生成对抗网络(GAN)


文章目录

  • [人工智能之核心技术 深度学习](#人工智能之核心技术 深度学习)
  • [前言:生成对抗网络(GAN)------ 生成模型核心](#前言:生成对抗网络(GAN)—— 生成模型核心)
    • [一、GAN 基础原理](#一、GAN 基础原理)
      • [1.1 核心思想:对抗训练](#1.1 核心思想:对抗训练)
      • [1.2 数学形式:极小极大博弈](#1.2 数学形式:极小极大博弈)
      • [1.3 训练流程(交替更新)](#1.3 训练流程(交替更新))
    • [二、经典 GAN 变体](#二、经典 GAN 变体)
      • [2.1 DCGAN(Deep Convolutional GAN, 2015)](#2.1 DCGAN(Deep Convolutional GAN, 2015))
      • [2.2 WGAN(Wasserstein GAN, 2017)](#2.2 WGAN(Wasserstein GAN, 2017))
      • [2.3 StyleGAN(2019, NVIDIA)](#2.3 StyleGAN(2019, NVIDIA))
    • [三、GAN 应用场景](#三、GAN 应用场景)
    • [四、配套代码实现(DCGAN for MNIST)](#四、配套代码实现(DCGAN for MNIST))
    • [五、GAN 的挑战与未来](#五、GAN 的挑战与未来)
    • 六、总结对比
  • 资料关注

前言:生成对抗网络(GAN)------ 生成模型核心

生成对抗网络(Generative Adversarial Network, GAN)由 Ian Goodfellow 等人在 2014 年提出,被誉为"深度学习中最酷的想法之一 "。它通过让两个神经网络相互博弈的方式,学会生成高度逼真的数据(如人脸、风景、艺术画等),开启了生成式 AI 的新纪元。


一、GAN 基础原理

1.1 核心思想:对抗训练

想象一个"造假者"和一个"鉴伪专家"的博弈:

  • 生成器(Generator, G):试图伪造逼真的假币(或图像)
  • 判别器(Discriminator, D):试图区分真币和假币

🎯 目标

  • G 要骗过 D(让 D 认为假图是真图)
  • D 要准确识别真假(不被 G 欺骗)

随着训练进行,G 越来越擅长生成逼真样本,D 也越来越难分辨真假------最终达到纳什均衡
反馈
反馈
随机噪声 z ~ N(0,1)
生成器 G
假图像 G(z)
真实图像 x
判别器 D
输出:真假概率


1.2 数学形式:极小极大博弈

GAN 的损失函数是一个极小极大优化问题

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p data ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}{x \sim p{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] GminDmaxV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]

  • 判别器 D 的目标 :最大化 V ( D , G ) V(D, G) V(D,G)
    → 对真实样本输出高概率,对假样本输出低概率
  • 生成器 G 的目标 :最小化 V ( D , G ) V(D, G) V(D,G)
    → 让 D ( G ( z ) ) D(G(z)) D(G(z)) 尽可能大(即骗过 D)

⚠️ 训练技巧

实际中常将 G 的损失改为 min ⁡ G − E [ log ⁡ D ( G ( z ) ) ] \min_G -\mathbb{E}[\log D(G(z))] minG−E[logD(G(z))],避免训练初期梯度消失。


1.3 训练流程(交替更新)

python 复制代码
for epoch in epochs:
    # 1. 固定 G,更新 D
    real_loss = BCE(D(real_img), 1)          # 真图标签=1
    fake_loss = BCE(D(G(noise)), 0)          # 假图标签=0
    d_loss = real_loss + fake_loss
    update D

    # 2. 固定 D,更新 G
    g_loss = BCE(D(G(noise)), 1)             # 假图标签=1(骗 D)
    update G

🔁 关键 :D 和 G 交替训练,不能同时更新!


二、经典 GAN 变体

2.1 DCGAN(Deep Convolutional GAN, 2015)

首次成功将 CNN 引入 GAN,大幅提升图像生成质量。

创新设计:
组件 设计原则
生成器 使用转置卷积(Transposed Conv) 上采样(替代全连接层)
判别器 标准 CNN(无池化,用 stride 卷积降采样)
其他 - 批归一化(BatchNorm)- ReLU(G)、LeakyReLU(D)

效果:可生成 64×64 清晰人脸、卧室等图像


2.2 WGAN(Wasserstein GAN, 2017)

解决原始 GAN 的两大痛点:

  1. 训练不稳定
  2. 模式崩溃(Mode Collapse):G 只生成少数几种样本
核心改进:
  • 使用 Wasserstein 距离(Earth Mover's Distance) 替代 JS 散度
  • 判别器称为 Critic ,输出实数值(非概率)
  • 对 Critic 参数强制 Lipschitz 连续 → 采用 权重裁剪梯度惩罚(WGAN-GP)

WGAN-GP 损失

L D = E [ D ( G ( z ) ) ] − E [ D ( x ) ] + λ E x ^ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] L G = − E [ D ( G ( z ) ) ] \begin{aligned} \mathcal{L}D &= \mathbb{E}[D(G(z))] - \mathbb{E}[D(x)] + \lambda \mathbb{E}{\hat{x}}[(\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2] \\ \mathcal{L}_G &= -\mathbb{E}[D(G(z))] \end{aligned} LDLG=E[D(G(z))]−E[D(x)]+λEx^[(∥∇x^D(x^)∥2−1)2]=−E[D(G(z))]

优势:损失值与生成质量正相关,可作为调参指标!


2.3 StyleGAN(2019, NVIDIA)

实现高分辨率、高保真、可控风格的人脸生成。

核心创新:
  1. 风格迁移架构:将生成过程分解为"风格"和"内容"
  2. 自适应实例归一化(AdaIN):在每一层注入风格向量
  3. 渐进式增长:从 4×4 开始,逐步增加分辨率至 1024×1024
  4. 映射网络(Mapping Network) :将随机噪声 z z z 映射到中间空间 w w w,提升解耦性

z ~ N(0,1)
MLP: z → w
w
AdaIN1
AdaIN2
AdaIN3
常量张量
Conv1
Conv2
...
Output

🌟 效果:可独立控制人脸姿态、肤色、发型等属性!


三、GAN 应用场景

应用 说明 代表工作
图像生成 生成逼真人脸、动物、艺术画 StyleGAN, BigGAN
图像修复(Inpainting) 填补图像缺失区域 Context Encoder + GAN
超分辨率(SRGAN) 低清 → 高清 SRGAN, ESRGAN
风格迁移 将艺术风格迁移到照片 CycleGAN, Neural Style Transfer
文本到图像生成 根据描述生成图像 AttnGAN, DALL·E(早期)
数据增强 生成稀缺样本(如医学图像) Medical GAN

💡 注意 :GAN 在文本生成上效果有限(离散 token 难优化),现多被 Transformer 取代。


四、配套代码实现(DCGAN for MNIST)

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim

# 1. 生成器(使用转置卷积)
class Generator(nn.Module):
    def __init__(self, nz=100, ngf=64):
        super().__init__()
        self.main = nn.Sequential(
            # 输入: [batch, nz, 1, 1]
            nn.ConvTranspose2d(nz, ngf*4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
            # 输出: [batch, ngf*4, 4, 4]
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),
            # 输出: [batch, ngf*2, 8, 8]
            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 输出: [batch, ngf, 16, 16]
            nn.ConvTranspose2d(ngf, 1, 4, 2, 1, bias=False),
            nn.Tanh()  # 输出 [-1, 1]
            # 最终: [batch, 1, 28, 28]
        )

    def forward(self, z):
        return self.main(z)

# 2. 判别器(标准 CNN)
class Discriminator(nn.Module):
    def __init__(self, ndf=64):
        super().__init__()
        self.main = nn.Sequential(
            # 输入: [batch, 1, 28, 28]
            nn.Conv2d(1, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出: [batch, ndf, 14, 14]
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出: [batch, ndf*2, 7, 7]
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出: [batch, ndf*4, 3, 3]
            nn.Conv2d(ndf*4, 1, 3, 1, 0, bias=False),
            nn.Sigmoid()  # 输出 [0,1]
        )

    def forward(self, x):
        return self.main(x).view(-1, 1).squeeze(1)

# 3. 训练设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nz = 100  # 噪声维度
netG = Generator(nz).to(device)
netD = Discriminator().to(device)

criterion = nn.BCELoss()
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 4. 训练循环(简化版)
for epoch in range(100):
    for i, (real_imgs, _) in enumerate(dataloader):
        batch_size = real_imgs.size(0)
        real_imgs = real_imgs.to(device)
        real_labels = torch.ones(batch_size, device=device)
        fake_labels = torch.zeros(batch_size, device=device)

        # --- 更新判别器 ---
        optimizerD.zero_grad()
        # 真图损失
        output = netD(real_imgs)
        loss_real = criterion(output, real_labels)
        # 假图损失
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake_imgs = netG(noise)
        output = netD(fake_imgs.detach())  # detach 阻止梯度回传到 G
        loss_fake = criterion(output, fake_labels)
        loss_D = loss_real + loss_fake
        loss_D.backward()
        optimizerD.step()

        # --- 更新生成器 ---
        optimizerG.zero_grad()
        output = netD(fake_imgs)
        loss_G = criterion(output, real_labels)  # 假图标签=1
        loss_G.backward()
        optimizerG.step()

    print(f"Epoch {epoch}, Loss_D: {loss_D.item():.4f}, Loss_G: {loss_G.item():.4f}")

📌 提示

  • 使用 Tanh() 生成器输出需将真实图像归一化到 [-1, 1]
  • detach() 防止判别器梯度影响生成器
  • Adam 优化器参数 betas=(0.5, 0.999) 是 GAN 常用设置

五、GAN 的挑战与未来

当前挑战:

问题 描述
训练不稳定 G 和 D 难以平衡
评估困难 缺乏可靠指标(常用 FID、IS)
模式崩溃 G 生成多样性不足
计算成本高 高分辨率生成需大量 GPU 资源

未来方向:

  • 与扩散模型融合(如 GAN + Diffusion)
  • 3D 生成(NeRF + GAN)
  • 可控生成(结合 CLIP 等多模态模型)

六、总结对比

模型 核心贡献 适用场景
原始 GAN 对抗训练框架 理论奠基
DCGAN CNN 架构标准化 中小图像生成
WGAN/WGAN-GP 稳定训练 + 解决模式崩溃 通用改进
StyleGAN 高质量 + 属性解耦 人脸/肖像生成

原始 GAN
DCGAN

CNN 架构
WGAN

稳定训练
StyleGAN

高质量可控生成
StyleGAN2/3

细节优化

实践建议

  • 入门实验 → DCGAN + MNIST/CIFAR10
  • 高质量人脸 → StyleGAN2-ADA
  • 稳定训练 → 优先尝试 WGAN-GP

资料关注

公众号:咚咚王

gitee:https://gitee.com/wy18585051844/ai_learning

《Python编程:从入门到实践》

《利用Python进行数据分析》

《算法导论中文第三版》

《概率论与数理统计(第四版) (盛骤) 》

《程序员的数学》

《线性代数应该这样学第3版》

《微积分和数学分析引论》

《(西瓜书)周志华-机器学习》

《TensorFlow机器学习实战指南》

《Sklearn与TensorFlow机器学习实用指南》

《模式识别(第四版)》

《深度学习 deep learning》伊恩·古德费洛著 花书

《Python深度学习第二版(中文版)【纯文本】 (登封大数据 (Francois Choliet)) (Z-Library)》

《深入浅出神经网络与深度学习+(迈克尔·尼尔森(Michael+Nielsen)》

《自然语言处理综论 第2版》

《Natural-Language-Processing-with-PyTorch》

《计算机视觉-算法与应用(中文版)》

《Learning OpenCV 4》

《AIGC:智能创作时代》杜雨+&+张孜铭

《AIGC原理与实践:零基础学大语言模型、扩散模型和多模态模型》

《从零构建大语言模型(中文版)》

《实战AI大模型》

《AI 3.0》

相关推荐
子夜江寒4 小时前
基于dlib与OpenCV的人脸检测与特征点标定技术实践
人工智能·opencv·计算机视觉
IRevers4 小时前
RF-DETR:第一个在COCO上突破60AP的DETR(含检测和分割推理)
图像处理·人工智能·python·深度学习·目标检测·计算机视觉
昨夜见军贴06164 小时前
合规性管理的现代化实践:IACheck的AI审核如何系统提升生产型检测报告的合规水平
大数据·运维·人工智能
自可乐5 小时前
AutoGen(多智能体AI框架)全面学习教程
人工智能·python·学习·ai
人工智能AI技术5 小时前
手搓一个AI搜索引擎:基于百度DeepSearch框架的实战开发笔记
人工智能·百度
郝学胜-神的一滴5 小时前
机器学习中的特征提取:PCA与LDA详解及sklearn实践
人工智能·python·程序人生·算法·机器学习·sklearn
是小蟹呀^5 小时前
卷积神经网络(CNN):池化操作
人工智能·深度学习·神经网络·cnn
草莓熊Lotso5 小时前
远程控制软件实测!2026年1月远程软件从“夯”到“拉”全功能横评
运维·服务器·数据库·人工智能
发哥来了5 小时前
主流AI视频生成模型商用化能力评测:三大核心维度对比分析
大数据·人工智能·音视频