【第五章:计算机视觉-项目实战之生成对抗网络实战】2.基于SRGAN的图像超分辨率实战-(2)实战1:DCGAN模型搭建

第五章:计算机视觉(Computer Vision)- 项目实战之生成对抗网络实战

第二部分:基于SRGAN的图像超分辨率实战

第二节:实战1:DCGAN模型搭建

在前面的文章里,我们已经介绍了生成对抗网络(GAN)的基本思想。今天我们要进入 GAN 家族中最经典的改进版 ------ DCGAN(Deep Convolutional GAN)。它是最早把卷积引入到 GAN 体系中的模型,大幅提升了图像生成的质量和稳定性。

本文将从 原理 → 模型结构 → PyTorch 代码实现 → 训练要点 四个方面,带你快速上手 DCGAN,为后续 超分辨率任务(SRGAN) 打好基础。


1. DCGAN 简介

DCGAN 全称 Deep Convolutional Generative Adversarial Network,由 Radford 等人于 2015 年提出。它的目标依然是让生成器(Generator, G)和判别器(Discriminator, D)互相博弈,但在架构上有了关键改进:

  • 使用 卷积层 代替全连接层,更好地捕捉图像空间特征;

  • 使用 转置卷积(ConvTranspose2d) 在生成器中逐步上采样,输出更大分辨率的图像;

  • 引入 Batch Normalization,让训练更加稳定;

  • 激活函数上:生成器使用 ReLU,判别器使用 LeakyReLU。

这些改进让 DCGAN 成为最早能够生成较为清晰人脸和自然图像的 GAN 模型之一。


2. DCGAN 的模型结构

生成器(Generator)

输入:一个随机噪声向量 z(一般服从正态分布)。

输出:一张合成图像(例如 64×64 的 RGB 图像)。

结构:

  1. 全连接层将 z 投影到高维特征空间;

  2. 通过一系列 转置卷积层(反卷积),逐步将特征图上采样;

  3. 每一层都配合 BatchNorm + ReLU 稳定训练;

  4. 最后一层用 Tanh,把输出限制到 [-1, 1]。

判别器(Discriminator)

输入:真实图像 or 生成图像。

输出:真假概率(0 或 1)。

结构:

  1. 多层卷积逐步提取特征;

  2. 使用 LeakyReLU 激活防止梯度消失;

  3. 在最后一层使用 Sigmoid 输出真假概率。


3. PyTorch 代码实现

下面是 DCGAN 的核心代码,你可以直接运行在 CIFAR-10 或 CelebA 数据集上。

python 复制代码
import torch
import torch.nn as nn

# 生成器
class Generator(nn.Module):
    def __init__(self, z_dim=100, img_channels=3, feature_g=64):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # 输入 Z → 4x4 特征图
            nn.ConvTranspose2d(z_dim, feature_g*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(feature_g*8),
            nn.ReLU(True),
            
            # 4x4 → 8x8
            nn.ConvTranspose2d(feature_g*8, feature_g*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_g*4),
            nn.ReLU(True),
            
            # 8x8 → 16x16
            nn.ConvTranspose2d(feature_g*4, feature_g*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_g*2),
            nn.ReLU(True),
            
            # 16x16 → 32x32
            nn.ConvTranspose2d(feature_g*2, feature_g, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_g),
            nn.ReLU(True),
            
            # 32x32 → 64x64
            nn.ConvTranspose2d(feature_g, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.net(x)


# 判别器
class Discriminator(nn.Module):
    def __init__(self, img_channels=3, feature_d=64):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(img_channels, feature_d, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(feature_d, feature_d*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_d*2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(feature_d*2, feature_d*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_d*4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(feature_d*4, feature_d*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_d*8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(feature_d*8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x).view(-1, 1).squeeze(1)

4. 训练流程

  1. 损失函数

    • 判别器:最大化真实图像判为真,最小化生成图像判为真;

    • 生成器:最小化生成图像被判为假的概率;

    • 损失函数一般使用 BCE Loss

  2. 优化器

    • Adam(学习率 0.0002,β1=0.5,β2=0.999)。
  3. 训练步骤

    • Step 1:用真实图像训练判别器 D(标签设为 1);

    • Step 2:用生成图像训练判别器 D(标签设为 0);

    • Step 3:更新生成器 G,使得判别器把生成图像判为真(标签设为 1);

    • Step 4:循环迭代。


5. 效果展示

当训练收敛后,DCGAN 可以生成较为清晰的 人脸图像手写数字,甚至是风格化图片。相比原始 GAN,它收敛更稳定、生成质量更高。

训练过程中的效果大概如下:

  • 初始阶段:图像接近随机噪声;

  • 中期阶段:逐渐出现轮廓和形状;

  • 后期阶段:细节逐渐清晰,可见真实图像结构。


6. 总结

本文我们完成了 DCGAN 的模型搭建与 PyTorch 实现,并理解了它在 GAN 家族中的重要地位:

  • 卷积 + 反卷积 结构,让 GAN 能够生成高分辨率图像;

  • BatchNorm 和 LeakyReLU 稳定了训练过程;

  • 提供了基础架构,成为 SRGAN、StyleGAN、CycleGAN 等后续 GAN 模型的基石。

相关推荐
yourkin6662 小时前
李宏毅-Generative AI-第一课
人工智能
大模型真好玩3 小时前
大模型Agent开发框架哪家强?12项Agent开发框架入门与选型
人工智能·agent·mcp
常州晟凯电子科技3 小时前
君正T32开发笔记之IVSP版本环境搭建和编译
人工智能·笔记·物联网
Francek Chen3 小时前
【深度学习计算机视觉】09:语义分割和数据集
人工智能·pytorch·深度学习·计算机视觉·数据集·语义分割
sealaugh323 小时前
AI(学习笔记第九课) 使用langchain的MultiQueryRetriever和indexing
人工智能·笔记·学习
OopsOutOfMemory3 小时前
LangChain源码分析(一)- LLM大语言模型
人工智能·语言模型·langchain·aigc
ASIAZXO3 小时前
机器学习——SVM支持向量机详解
人工智能·机器学习·支持向量机
Prettybritany4 小时前
文本引导的图像融合方法
论文阅读·图像处理·人工智能·深度学习·计算机视觉
weixin_456904274 小时前
OpenCV 摄像头参数控制详解
人工智能·opencv·计算机视觉