【深度学习】计算机视觉(CV)-图像生成-生成对抗网络(GANs, Generative Adversarial Networks)

生成对抗网络(GANs)Ian Goodfellow 在 2014 年提出的一种深度生成模型 ,主要用于生成逼真的数据,如图像、音乐、文本等。GANs 采用博弈论思想 ,让两个神经网络(生成器 G判别器 D)相互对抗,在不断竞争中提高数据的生成质量。


1. GANs 的核心思想

GANs 由 两个神经网络 组成:

  • 生成器(Generator, G)

    • 输入随机噪声 z,生成与真实数据类似的样本 G(z)
    • 目标:欺骗判别器,让它认为生成的样本是真实数据
  • 判别器(Discriminator, D)

    • 输入数据(真实数据 x 或生成数据 G(z))
    • 目标:判断输入数据是真实的(1)还是生成的(0)

两者进行博弈(Adversarial Training)

  • G 尽力欺骗 D
  • D 试图正确分类

最终,G 生成的数据会越来越逼真,D 也会变得更强。


2. GANs 的训练过程

Step 1: 随机生成噪声

  • 生成器 G 以 随机噪声 z(通常是正态分布) 作为输入,生成假数据:

    G(z)→生成假数据G(z)

Step 2: 判别器判定真假

  • 判别器 D 接收 真实数据 x 和假数据 G(z) ,并计算它们的真假概率:
    D(x)→1(真实数据)
    D(G(z))→0(假数据)

Step 3: 计算损失

  • 判别器损失(Binary Cross Entropy Loss)

    目的是让 D(x) 预测 1,D(G(z)) 预测 0。

  • 生成器损失(让 G 欺骗 D)

    目的是让 D(G(z)) 预测 1(以为是假数据)。

Step 4: 交替优化

  1. 更新 D:固定 G,训练 D,使其能正确分类真假数据。
  2. 更新 G:固定 D,训练 G,使 D 无法区分真假数据。

最终,G 生成的数据将会越来越接近真实数据。


3. GANs 的数学原理

GANs 本质上是在优化一个极小极大问题(Minimax Game)

其中目标函数为:

  • 判别器 D:最大化 V(D,G) 以最准确区分真假数据。
  • 生成器 G:最小化 V(D,G) 以欺骗 D,使得假数据尽可能像真实数据。

4. 经典 GANs 变体

DCGAN(深度卷积 GAN)

  • 替换全连接层使用 CNN 提高图像生成质量。
  • 采用 Leaky ReLU 代替 ReLU,提高梯度流动。

WGAN(Wasserstein GAN)

  • 解决**模式崩溃(Mode Collapse)**问题。
  • 采用 Wasserstein 距离 代替 JS 散度,提高训练稳定性。

WGAN-GP(带梯度惩罚的 WGAN)

  • 解决 WGAN 训练中的梯度消失问题,提高稳定性。

Conditional GAN(cGAN)

  • 让 GAN 按类别生成特定类型的图片(例如手写数字、动漫头像)。
  • 在输入中添加 类别标签,使 G 生成特定类别数据。

CycleGAN(循环一致性 GAN)

  • 无需成对数据,即可进行风格转换(如 照片风格转换黑白图像上色)。

5. 代码实现(PyTorch)

生成手写数字(DCGAN)

python 复制代码
import torch.nn as nn
import torch.optim as optim


# 生成器类定义
# 该类用于生成图像,继承自nn.Module
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # 生成器模型定义,使用反卷积(转置卷积)层逐步上采样,最终生成与真实图像大小相同的输出
        self.model = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),  # 输入维度为100,输出维度为512,卷积核大小为4,步长为1,padding为0
            nn.BatchNorm2d(512),  # 应用Batch Normalization
            nn.ReLU(True),  # 应用ReLU激活函数
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 1, 4, 2, 1, bias=False),
            nn.Tanh()  # 应用Tanh激活函数
            # 输出维度为1,表示生成的图像像素值范围在-1到1之间
        )

    # 前向传播函数
    def forward(self, x):
        return self.model(x)


# 判别器类定义
# 该类用于判别图像真假,继承自nn.Module
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # 判别器模型定义,使用卷积层逐步下采样,最终输出一个标量概率值
        self.model = nn.Sequential(
            nn.Conv2d(1, 128, 4, 2, 1, bias=False),  # 输入维度为1,输出维度为128,卷积核大小为4,步长为2,padding为1
            nn.LeakyReLU(0.2, inplace=True),  # 应用Leaky ReLU激活函数
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),  # 输入维度为128,输出维度为256,卷积核大小为4,步长为2,padding为1
            nn.BatchNorm2d(256),  # 应用Batch Normalization
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()  # 应用Sigmoid激活函数
        )

    # 前向传播函数
    def forward(self, x):
        return self.model(x)


# 初始化生成器和判别器模型
G = Generator()
D = Discriminator()

# 初始化生成器和判别器的优化器,使用Adam优化算法
optimizer_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))  # 生成器优化器
optimizer_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))  # 判别器优化器

# 打印生成器模型结构
print(G)

运行结果

python 复制代码
Generator(
  (model): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): Tanh()
  )
)

使用 DCGAN 训练后,可以生成逼真的手写数字!


6. GANs 的应用

  • 图像生成:动漫头像、3D 人脸建模
  • 风格转换:黑白照片上色、照片 → 画风转换(如 CycleGAN)
  • 医学影像:生成 MRI、CT 图像,提高医疗影像质量
  • 文本生成:ChatGPT、文本补全
  • 数据增强:生成样本数据,提高模型鲁棒性

7. 总结

GANs 通过 G 和 D 互相对抗,提高生成数据的质量
训练 GANs 可能存在模式崩溃、梯度消失等问题
多个 GAN 变体(DCGAN、WGAN、cGAN)解决不同任务需求
广泛应用于图像生成、风格转换、数据增强等领域

GANs 仍然是 AI 生成模型的重要技术之一,未来可能结合 Transformer 进行更多创新!

相关推荐
9命怪猫9 分钟前
DeepSeek底层揭秘——微调
人工智能·深度学习·神经网络·ai·大模型
Jackilina_Stone2 小时前
【论文阅读笔记】浅谈深度学习中的知识蒸馏 | 关系知识蒸馏 | CVPR 2019 | RKD
论文阅读·深度学习·蒸馏·rkd
倒霉蛋小马3 小时前
【YOLOv8】损失函数
深度学习·yolo·机器学习
Fansv5873 小时前
深度学习-2.机械学习基础
人工智能·经验分享·python·深度学习·算法·机器学习
AI技术控4 小时前
计算机视觉算法实战——表面缺陷检测(主页有源码)
计算机视觉
Erekys4 小时前
视觉分析之边缘检测算法
人工智能·计算机视觉·音视频
唔皇万睡万万睡5 小时前
数字水印嵌入及提取系统——基于小波变换GUI
人工智能·计算机视觉
IT古董5 小时前
【深度学习】计算机视觉(CV)-目标检测-DETR(DEtection TRansformer)—— 基于 Transformer 的端到端目标检测
深度学习·目标检测·计算机视觉
LensonYuan5 小时前
视觉目标检测之小目标检测技术调研与实验
目标检测·计算机视觉·目标跟踪
Jackilina_Stone5 小时前
【论文阅读笔记】知识蒸馏:一项调查 | CVPR 2021 | 近万字翻译+解释
论文阅读·人工智能·深度学习·蒸馏