人工智能之核心技术 深度学习 第六章 生成对抗网络(GAN)

人工智能之核心技术 深度学习

第六章 生成对抗网络(GAN)


文章目录

  • [人工智能之核心技术 深度学习](#人工智能之核心技术 深度学习)
  • [前言:生成对抗网络(GAN)------ 生成模型核心](#前言:生成对抗网络(GAN)—— 生成模型核心)
    • [一、GAN 基础原理](#一、GAN 基础原理)
      • [1.1 核心思想:对抗训练](#1.1 核心思想:对抗训练)
      • [1.2 数学形式:极小极大博弈](#1.2 数学形式:极小极大博弈)
      • [1.3 训练流程(交替更新)](#1.3 训练流程(交替更新))
    • [二、经典 GAN 变体](#二、经典 GAN 变体)
      • [2.1 DCGAN(Deep Convolutional GAN, 2015)](#2.1 DCGAN(Deep Convolutional GAN, 2015))
      • [2.2 WGAN(Wasserstein GAN, 2017)](#2.2 WGAN(Wasserstein GAN, 2017))
      • [2.3 StyleGAN(2019, NVIDIA)](#2.3 StyleGAN(2019, NVIDIA))
    • [三、GAN 应用场景](#三、GAN 应用场景)
    • [四、配套代码实现(DCGAN for MNIST)](#四、配套代码实现(DCGAN for MNIST))
    • [五、GAN 的挑战与未来](#五、GAN 的挑战与未来)
    • 六、总结对比
  • 资料关注

前言:生成对抗网络(GAN)------ 生成模型核心

生成对抗网络(Generative Adversarial Network, GAN)由 Ian Goodfellow 等人在 2014 年提出,被誉为"深度学习中最酷的想法之一 "。它通过让两个神经网络相互博弈的方式,学会生成高度逼真的数据(如人脸、风景、艺术画等),开启了生成式 AI 的新纪元。


一、GAN 基础原理

1.1 核心思想:对抗训练

想象一个"造假者"和一个"鉴伪专家"的博弈:

  • 生成器(Generator, G):试图伪造逼真的假币(或图像)
  • 判别器(Discriminator, D):试图区分真币和假币

🎯 目标

  • G 要骗过 D(让 D 认为假图是真图)
  • D 要准确识别真假(不被 G 欺骗)

随着训练进行,G 越来越擅长生成逼真样本,D 也越来越难分辨真假------最终达到纳什均衡
反馈
反馈
随机噪声 z ~ N(0,1)
生成器 G
假图像 G(z)
真实图像 x
判别器 D
输出:真假概率


1.2 数学形式:极小极大博弈

GAN 的损失函数是一个极小极大优化问题

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p data ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}{x \sim p{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] GminDmaxV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]

  • 判别器 D 的目标 :最大化 V ( D , G ) V(D, G) V(D,G)
    → 对真实样本输出高概率,对假样本输出低概率
  • 生成器 G 的目标 :最小化 V ( D , G ) V(D, G) V(D,G)
    → 让 D ( G ( z ) ) D(G(z)) D(G(z)) 尽可能大(即骗过 D)

⚠️ 训练技巧

实际中常将 G 的损失改为 min ⁡ G − E [ log ⁡ D ( G ( z ) ) ] \min_G -\mathbb{E}[\log D(G(z))] minG−E[logD(G(z))],避免训练初期梯度消失。


1.3 训练流程(交替更新)

python 复制代码
for epoch in epochs:
    # 1. 固定 G,更新 D
    real_loss = BCE(D(real_img), 1)          # 真图标签=1
    fake_loss = BCE(D(G(noise)), 0)          # 假图标签=0
    d_loss = real_loss + fake_loss
    update D

    # 2. 固定 D,更新 G
    g_loss = BCE(D(G(noise)), 1)             # 假图标签=1(骗 D)
    update G

🔁 关键 :D 和 G 交替训练,不能同时更新!


二、经典 GAN 变体

2.1 DCGAN(Deep Convolutional GAN, 2015)

首次成功将 CNN 引入 GAN,大幅提升图像生成质量。

创新设计:
组件 设计原则
生成器 使用转置卷积(Transposed Conv) 上采样(替代全连接层)
判别器 标准 CNN(无池化,用 stride 卷积降采样)
其他 - 批归一化(BatchNorm)- ReLU(G)、LeakyReLU(D)

效果:可生成 64×64 清晰人脸、卧室等图像


2.2 WGAN(Wasserstein GAN, 2017)

解决原始 GAN 的两大痛点:

  1. 训练不稳定
  2. 模式崩溃(Mode Collapse):G 只生成少数几种样本
核心改进:
  • 使用 Wasserstein 距离(Earth Mover's Distance) 替代 JS 散度
  • 判别器称为 Critic ,输出实数值(非概率)
  • 对 Critic 参数强制 Lipschitz 连续 → 采用 权重裁剪梯度惩罚(WGAN-GP)

WGAN-GP 损失

L D = E [ D ( G ( z ) ) ] − E [ D ( x ) ] + λ E x ^ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] L G = − E [ D ( G ( z ) ) ] \begin{aligned} \mathcal{L}D &= \mathbb{E}[D(G(z))] - \mathbb{E}[D(x)] + \lambda \mathbb{E}{\hat{x}}[(\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2] \\ \mathcal{L}_G &= -\mathbb{E}[D(G(z))] \end{aligned} LDLG=E[D(G(z))]−E[D(x)]+λEx^[(∥∇x^D(x^)∥2−1)2]=−E[D(G(z))]

优势:损失值与生成质量正相关,可作为调参指标!


2.3 StyleGAN(2019, NVIDIA)

实现高分辨率、高保真、可控风格的人脸生成。

核心创新:
  1. 风格迁移架构:将生成过程分解为"风格"和"内容"
  2. 自适应实例归一化(AdaIN):在每一层注入风格向量
  3. 渐进式增长:从 4×4 开始,逐步增加分辨率至 1024×1024
  4. 映射网络(Mapping Network) :将随机噪声 z z z 映射到中间空间 w w w,提升解耦性

z ~ N(0,1)
MLP: z → w
w
AdaIN1
AdaIN2
AdaIN3
常量张量
Conv1
Conv2
...
Output

🌟 效果:可独立控制人脸姿态、肤色、发型等属性!


三、GAN 应用场景

应用 说明 代表工作
图像生成 生成逼真人脸、动物、艺术画 StyleGAN, BigGAN
图像修复(Inpainting) 填补图像缺失区域 Context Encoder + GAN
超分辨率(SRGAN) 低清 → 高清 SRGAN, ESRGAN
风格迁移 将艺术风格迁移到照片 CycleGAN, Neural Style Transfer
文本到图像生成 根据描述生成图像 AttnGAN, DALL·E(早期)
数据增强 生成稀缺样本(如医学图像) Medical GAN

💡 注意 :GAN 在文本生成上效果有限(离散 token 难优化),现多被 Transformer 取代。


四、配套代码实现(DCGAN for MNIST)

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim

# 1. 生成器(使用转置卷积)
class Generator(nn.Module):
    def __init__(self, nz=100, ngf=64):
        super().__init__()
        self.main = nn.Sequential(
            # 输入: [batch, nz, 1, 1]
            nn.ConvTranspose2d(nz, ngf*4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
            # 输出: [batch, ngf*4, 4, 4]
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),
            # 输出: [batch, ngf*2, 8, 8]
            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 输出: [batch, ngf, 16, 16]
            nn.ConvTranspose2d(ngf, 1, 4, 2, 1, bias=False),
            nn.Tanh()  # 输出 [-1, 1]
            # 最终: [batch, 1, 28, 28]
        )

    def forward(self, z):
        return self.main(z)

# 2. 判别器(标准 CNN)
class Discriminator(nn.Module):
    def __init__(self, ndf=64):
        super().__init__()
        self.main = nn.Sequential(
            # 输入: [batch, 1, 28, 28]
            nn.Conv2d(1, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出: [batch, ndf, 14, 14]
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出: [batch, ndf*2, 7, 7]
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出: [batch, ndf*4, 3, 3]
            nn.Conv2d(ndf*4, 1, 3, 1, 0, bias=False),
            nn.Sigmoid()  # 输出 [0,1]
        )

    def forward(self, x):
        return self.main(x).view(-1, 1).squeeze(1)

# 3. 训练设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nz = 100  # 噪声维度
netG = Generator(nz).to(device)
netD = Discriminator().to(device)

criterion = nn.BCELoss()
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 4. 训练循环(简化版)
for epoch in range(100):
    for i, (real_imgs, _) in enumerate(dataloader):
        batch_size = real_imgs.size(0)
        real_imgs = real_imgs.to(device)
        real_labels = torch.ones(batch_size, device=device)
        fake_labels = torch.zeros(batch_size, device=device)

        # --- 更新判别器 ---
        optimizerD.zero_grad()
        # 真图损失
        output = netD(real_imgs)
        loss_real = criterion(output, real_labels)
        # 假图损失
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake_imgs = netG(noise)
        output = netD(fake_imgs.detach())  # detach 阻止梯度回传到 G
        loss_fake = criterion(output, fake_labels)
        loss_D = loss_real + loss_fake
        loss_D.backward()
        optimizerD.step()

        # --- 更新生成器 ---
        optimizerG.zero_grad()
        output = netD(fake_imgs)
        loss_G = criterion(output, real_labels)  # 假图标签=1
        loss_G.backward()
        optimizerG.step()

    print(f"Epoch {epoch}, Loss_D: {loss_D.item():.4f}, Loss_G: {loss_G.item():.4f}")

📌 提示

  • 使用 Tanh() 生成器输出需将真实图像归一化到 [-1, 1]
  • detach() 防止判别器梯度影响生成器
  • Adam 优化器参数 betas=(0.5, 0.999) 是 GAN 常用设置

五、GAN 的挑战与未来

当前挑战:

问题 描述
训练不稳定 G 和 D 难以平衡
评估困难 缺乏可靠指标(常用 FID、IS)
模式崩溃 G 生成多样性不足
计算成本高 高分辨率生成需大量 GPU 资源

未来方向:

  • 与扩散模型融合(如 GAN + Diffusion)
  • 3D 生成(NeRF + GAN)
  • 可控生成(结合 CLIP 等多模态模型)

六、总结对比

模型 核心贡献 适用场景
原始 GAN 对抗训练框架 理论奠基
DCGAN CNN 架构标准化 中小图像生成
WGAN/WGAN-GP 稳定训练 + 解决模式崩溃 通用改进
StyleGAN 高质量 + 属性解耦 人脸/肖像生成

原始 GAN
DCGAN

CNN 架构
WGAN

稳定训练
StyleGAN

高质量可控生成
StyleGAN2/3

细节优化

实践建议

  • 入门实验 → DCGAN + MNIST/CIFAR10
  • 高质量人脸 → StyleGAN2-ADA
  • 稳定训练 → 优先尝试 WGAN-GP

资料关注

公众号:咚咚王

gitee:https://gitee.com/wy18585051844/ai_learning

《Python编程:从入门到实践》

《利用Python进行数据分析》

《算法导论中文第三版》

《概率论与数理统计(第四版) (盛骤) 》

《程序员的数学》

《线性代数应该这样学第3版》

《微积分和数学分析引论》

《(西瓜书)周志华-机器学习》

《TensorFlow机器学习实战指南》

《Sklearn与TensorFlow机器学习实用指南》

《模式识别(第四版)》

《深度学习 deep learning》伊恩·古德费洛著 花书

《Python深度学习第二版(中文版)【纯文本】 (登封大数据 (Francois Choliet)) (Z-Library)》

《深入浅出神经网络与深度学习+(迈克尔·尼尔森(Michael+Nielsen)》

《自然语言处理综论 第2版》

《Natural-Language-Processing-with-PyTorch》

《计算机视觉-算法与应用(中文版)》

《Learning OpenCV 4》

《AIGC:智能创作时代》杜雨+&+张孜铭

《AIGC原理与实践:零基础学大语言模型、扩散模型和多模态模型》

《从零构建大语言模型(中文版)》

《实战AI大模型》

《AI 3.0》

相关推荐
lijianhua_97126 小时前
国内某顶级大学内部用的ai自动生成论文的提示词
人工智能
EDPJ6 小时前
当图像与文本 “各说各话” —— CLIP 中的模态鸿沟与对象偏向
深度学习·计算机视觉
蔡俊锋6 小时前
用AI实现乐高式大型可插拔系统的技术方案
人工智能·ai工程·ai原子能力·ai乐高工程
自然语6 小时前
人工智能之数字生命 认知架构白皮书 第7章
人工智能·架构
大熊背6 小时前
利用ISP离线模式进行分块LSC校正的方法
人工智能·算法·机器学习
eastyuxiao7 小时前
如何在不同的机器上运行多个OpenClaw实例?
人工智能·git·架构·github·php
诸葛务农7 小时前
AGI 主要技术路径及核心技术:归一融合及未来之路5
大数据·人工智能
光影少年7 小时前
AI Agent智能体开发
人工智能·aigc·ai编程
ai生成式引擎优化技术7 小时前
TSPR-WEB-LLM-HIC (TWLH四元结构)AI生成式引擎(GEO)技术白皮书
人工智能
帐篷Li7 小时前
9Router:开源AI路由网关的架构设计与技术实现深度解析
人工智能