【零基础学AI】第30讲:生成对抗网络(GAN)实战 - 手写数字生成

本节课你将学到

  • GAN的基本原理和工作机制
  • 使用PyTorch构建生成器和判别器
  • DCGAN架构实现技巧
  • 训练GAN模型的实用技巧

开始之前

环境要求

  • Python 3.8+

  • 需要安装的包:

    bash 复制代码
    pip install torch torchvision matplotlib numpy
  • GPU推荐(可大幅加速训练)

前置知识

  • 第21讲TensorFlow基础
  • 第23讲神经网络原理
  • 基本PyTorch使用经验

核心概念

什么是GAN?

GAN就像艺术品鉴定师与伪造者的博弈:

  1. 生成器(Generator):伪造者

    • 试图创作逼真的假画作
    • 从随机噪声开始,逐渐改进技术
  2. 判别器(Discriminator):鉴定师

    • 试图区分真品和赝品
    • 随着伪造者技术进步,鉴定能力也提升

GAN训练过程

随机噪声 更新生成器 假图像 真实图像 更新判别器 真/假判断

DCGAN架构特点

  • 生成器:使用转置卷积上采样
  • 判别器:使用卷积下采样
  • 去除全连接层
  • 使用Batch Normalization
  • LeakyReLU激活函数

代码实战

1. 导入库与配置

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 超参数
lr = 0.0002
batch_size = 64
image_size = 64
channels_img = 1
z_dim = 100  # 噪声向量维度
num_epochs = 20

2. 准备数据集(MNIST)

python 复制代码
# 数据预处理
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),  # 归一化到[-1, 1]
])

# 下载并加载MNIST数据集
dataset = datasets.MNIST(
    root="dataset/", 
    train=True, 
    transform=transform, 
    download=True
)

# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

3. 构建生成器

python 复制代码
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # 输入: z_dim x 1 x 1
            self._block(z_dim, features_g * 16, 4, 1, 0),  # 4x4
            self._block(features_g * 16, features_g * 8, 4, 2, 1),  # 8x8
            self._block(features_g * 8, features_g * 4, 4, 2, 1),  # 16x16
            self._block(features_g * 4, features_g * 2, 4, 2, 1),  # 32x32
            nn.ConvTranspose2d(features_g * 2, channels_img, 4, 2, 1),  # 64x64
            nn.Tanh(),  # 输出值在[-1,1]之间
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.gen(x)

4. 构建判别器

python 复制代码
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # 输入: channels_img x 64 x 64
            nn.Conv2d(channels_img, features_d, 4, 2, 1),  # 32x32
            nn.LeakyReLU(0.2),
            self._block(features_d, features_d * 2, 4, 2, 1),  # 16x16
            self._block(features_d * 2, features_d * 4, 4, 2, 1),  # 8x8
            self._block(features_d * 4, features_d * 8, 4, 2, 1),  # 4x4
            nn.Conv2d(features_d * 8, 1, 4, 2, 0),  # 1x1
            nn.Sigmoid(),  # 输出概率值
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)

5. 初始化模型与优化器

python 复制代码
# 初始化模型
gen = Generator(z_dim, channels_img, 64).to(device)
disc = Discriminator(channels_img, 64).to(device)

# 初始化权重
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)

gen.apply(weights_init)
disc.apply(weights_init)

# 优化器
opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))

# 损失函数
criterion = nn.BCELoss()

6. 训练循环

python 复制代码
# 固定噪声用于可视化训练进展
fixed_noise = torch.randn(64, z_dim, 1, 1).to(device)

for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(dataloader):
        real = real.to(device)
        batch_size = real.shape[0]
        
        # 训练判别器
        noise = torch.randn(batch_size, z_dim, 1, 1).to(device)
        fake = gen(noise)
        
        # 判别器对真实图像的预测
        disc_real = disc(real).view(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        
        # 判别器对生成图像的预测
        disc_fake = disc(fake.detach()).view(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        
        # 判别器总损失
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()
        
        # 训练生成器
        output = disc(fake).view(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()
        
        # 每100个batch打印一次
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(dataloader)} "
                f"Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )
    
    # 每轮保存生成图像示例
    with torch.no_grad():
        fake = gen(fixed_noise)
        img_grid = torchvision.utils.make_grid(fake[:32], normalize=True)
        plt.imshow(np.transpose(img_grid.cpu(), (1, 2, 0)))
        plt.axis("off")
        plt.savefig(f"output/epoch_{epoch}.png")
        plt.close()

完整项目

项目结构:

复制代码
dcgan_mnist/
├── train.py              # 训练脚本
├── models.py             # 模型定义
├── utils.py              # 辅助函数
├── requirements.txt
├── output/               # 输出目录
└── README.md

requirements.txt内容:

复制代码
torch>=2.0.0
torchvision>=0.15.0
matplotlib>=3.7.0
numpy>=1.24.0

运行效果

训练过程可视化

复制代码
Epoch [0/20] Batch 0/938 Loss D: 0.6931, loss G: 0.6978
Epoch [0/20] Batch 100/938 Loss D: 0.0123, loss G: 5.4321
...
Epoch [20/20] Batch 800/938 Loss D: 0.5123, loss G: 1.2314

常见问题

Q1: 模式崩溃(Mode Collapse)怎么办?

解决方案:

  • 增加批次大小(batch_size)
  • 使用Wasserstein GAN(WGAN)
  • 尝试不同的学习率
  • 添加多样性惩罚项

Q2: 判别器太强/太弱

平衡技巧:

  • 太强:减少判别器更新频率
  • 太弱:增加判别器容量
  • 使用TTUR(Two Time-scale Update Rule)

Q3: 训练不稳定

稳定训练方法:

  • 使用梯度惩罚(Gradient Penalty)
  • 尝试不同的优化器参数
  • 使用谱归一化(Spectral Normalization)
  • 添加标签平滑(Label Smoothing)

课后练习

  • 尝试生成彩色图像(如CIFAR-10)
  • 实现条件GAN(生成指定数字)
  • 尝试不同的GAN架构(如WGAN-GP)
  • 使用GAN生成人脸(CelebA数据集)

扩展阅读

相关推荐
天翼云开发者社区10 分钟前
五项满分,天翼云息壤智算一体机斩获佳绩!
人工智能·ai训练一体机
程序员果子12 分钟前
macOS Python 安装
python·macos
F_D_Z15 分钟前
【感知机】感知机(perceptron)学习算法的对偶形式
人工智能·学习·算法·支持向量机
POLOAPI19 分钟前
Claude Opus:从智能升级到场景落地的旗舰模型进阶之路
人工智能·ai编程·claude
Gyoku Mint21 分钟前
自然语言处理×第四卷:文本特征与数据——她开始准备:每一次输入,都是为了更像你地说话
人工智能·pytorch·神经网络·语言模型·自然语言处理·数据分析·nlp
LetsonH26 分钟前
⭐CVPR2025 RoboBrain:机器人操作的统一大脑模型[特殊字符]
人工智能·python·深度学习·计算机视觉·机器人
会思考的石头1 小时前
看完了 GPT5 发布会,有一些新的思考
人工智能
后端小肥肠1 小时前
扣子 (Coze) 实战:输入一个主题,对标博主风格神还原,小红书爆款图文一键直出
人工智能·aigc·coze
站大爷IP2 小时前
Django缓存机制详解:从配置到实战应用
python
叫我:松哥2 小时前
基于Python的实习僧招聘数据采集与可视化分析,使用matplotlib进行可视化
开发语言·数据库·python·课程设计·matplotlib·文本挖掘