第40周——GAN入门

目录

目录

目录

前言

一、定义超参数

二、下载数据

三、配置数据

四、定义鉴别器

五、训练模型并保存

总结


前言


一、定义超参数

python 复制代码
import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch
 
## 创建文件夹
os.makedirs("./images/", exist_ok=True)         # 记录训练过程的图片效果
os.makedirs("./save/", exist_ok=True)           # 训练完成时模型保存的位置
os.makedirs("./datasets/mnist", exist_ok=True)  # 下载数据集存放的位置
 
## 超参数配置
n_epochs  = 50
batch_size= 64
lr        = 0.0002
b1        = 0.5
b2        = 0.999
n_cpu     = 2
latent_dim= 100
img_size  = 28
channels  = 1
sample_interval=500
 
# 图像的尺寸:(1, 28, 28),  和图像的像素面积:(784)
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)
 
# 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
print(cuda)

二、下载数据

python 复制代码
# mnist数据集下载
mnist = datasets.MNIST(
    root='./datasets/', train=True, download=True, transform=transforms.Compose(
            [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]), 
)

三、配置数据

python 复制代码
# 配置数据到加载器
dataloader = DataLoader(
    mnist,
    batch_size=batch_size,
    shuffle=True,
)

四、定义鉴别器

python 复制代码
# 将图片28x28展开成784,然后通过多层感知器,中间经过斜率设置为0.2的LeakyReLU激活函数,
# 最后接sigmoid激活函数得到一个0到1之间的概率进行二分类
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area, 512),         # 输入特征数为784,输出为512
            nn.LeakyReLU(0.2, inplace=True),  # 进行非线性映射
            nn.Linear(512, 256),              # 输入特征数为512,输出为256
            nn.LeakyReLU(0.2, inplace=True),  # 进行非线性映射
            nn.Linear(256, 1),                # 输入特征数为256,输出为1
            nn.Sigmoid(),                     # sigmoid是一个激活函数,二分类问题中可将实数映射到[0, 1],作为概率值, 多分类用softmax函数
        )
 
    def forward(self, img):
        img_flat = img.view(img.size(0), -1) # 鉴别器输入是一个被view展开的(784)的一维图像:(64, 784)
        validity = self.model(img_flat)      # 通过鉴别器网络
        return validity       

五、训练模型并保存

python 复制代码
## 创建生成器,判别器对象
generator = Generator()
discriminator = Discriminator()
 
## 首先需要定义loss的度量方式  (二分类的交叉熵)
criterion = torch.nn.BCELoss()
 
## 其次定义 优化函数,优化函数的学习率为0.0003
## betas:用于计算梯度以及梯度平方的运行平均值的系数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
 
## 如果有显卡,都在cuda模式中运行
if torch.cuda.is_available():
    generator     = generator.cuda()
    discriminator = discriminator.cuda()
    criterion     = criterion.cuda()
 
## 进行多个epoch的训练
for epoch in range(n_epochs):                   # epoch:50
    for i, (imgs, _) in enumerate(dataloader):  # imgs:(64, 1, 28, 28)     _:label(64)
        
        ## =============================训练判别器==================
        ## view(): 相当于numpy中的reshape,重新定义矩阵的形状, 相当于reshape(128,784)  原来是(128, 1, 28, 28)
        imgs = imgs.view(imgs.size(0), -1)    # 将图片展开为28*28=784  imgs:(64, 784)
        real_img = Variable(imgs).cuda()      # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度
        real_label = Variable(torch.ones(imgs.size(0), 1)).cuda()      ## 定义真实的图片label为1
        fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda()     ## 定义假的图片的label为0
 
        ## ---------------------
        ##  Train Discriminator
        ## 分为两部分:1、真的图像判别为真;2、假的图像判别为假
        ## ---------------------
        ## 计算真实图片的损失
        real_out = discriminator(real_img)            # 将真实图片放入判别器中
        loss_real_D = criterion(real_out, real_label) # 得到真实图片的loss
        real_scores = real_out                        # 得到真实图片的判别值,输出的值越接近1越好
        ## 计算假的图片的损失
        ## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()      ## 随机生成一些噪声, 大小为(128, 100)
        fake_img    = generator(z).detach()                                    ## 随机噪声放入生成网络中,生成一张假的图片。 
        fake_out    = discriminator(fake_img)                                  ## 判别器判断假的图片
        loss_fake_D = criterion(fake_out, fake_label)                       ## 得到假的图片的loss
        fake_scores = fake_out                                              ## 得到假图片的判别值,对于判别器来说,假图片的损失越接近0越好
        ## 损失函数和优化
        loss_D = loss_real_D + loss_fake_D  # 损失包括判真损失和判假损失
        optimizer_D.zero_grad()             # 在反向传播之前,先将梯度归0
        loss_D.backward()                   # 将误差反向传播
        optimizer_D.step()                  # 更新参数
 
        ## -----------------
        ##  Train Generator
        ## 原理:目的是希望生成的假的图片被判别器判断为真的图片,
        ## 在此过程中,将判别器固定,将假的图片传入判别器的结果与真实的label对应,
        ## 反向传播更新的参数是生成网络里面的参数,
        ## 这样可以通过更新生成网络里面的参数,来训练网络,使得生成的图片让判别器以为是真的, 这样就达到了对抗的目的
        ## -----------------
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()      ## 得到随机噪声
        fake_img = generator(z)                                             ## 随机噪声输入到生成器中,得到一副假的图片
        output = discriminator(fake_img)                                    ## 经过判别器得到的结果
        ## 损失函数和优化
        loss_G = criterion(output, real_label)                              ## 得到的假的图片与真实的图片的label的loss
        optimizer_G.zero_grad()                                             ## 梯度归0
        loss_G.backward()                                                   ## 进行反向传播
        optimizer_G.step()                                                  ## step()一般用在反向传播后面,用于更新生成网络的参数
 
        ## 打印训练过程中的日志
        ## item():取出单元素张量的元素值并返回该值,保持原元素类型不变
        if (i + 1) % 300 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
                % (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean())
            )
        ## 保存训练过程中的图像
        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True)
 
## 保存模型
torch.save(generator.state_dict(), './save/generator.pth')
torch.save(discriminator.state_dict(), './save/discriminator.pth')

总结:

本项目中,我们实现了一个基于 MNIST 数据集的生成对抗网络(GAN),主要流程从参数配置、数据准备,到模型构建与训练,最后再到结果保存,形成了一个完整的生成式模型训练管线。

首先,在超参数设置上,我们采用了经典 GAN 论文中推荐的组合:学习率设为 0.0002,Adam 优化器的 \beta_1 和 \beta_2 分别为 0.5 和 0.999,训练 50 个周期,批量大小为 64。这些参数能在保证稳定训练的同时,加快收敛速度。

数据部分选用了MNIST 手写数字集,先将像素归一化到 [-1, 1],再通过 DataLoader 按批次读取并打乱顺序。这一处理不仅保证了输入分布的稳定性,也提升了训练效率。

在模型结构方面,判别器(D)是一个多层全连接网络,将 28×28 图像展平为 784 维向量后输入,激活函数使用 LeakyReLU,最后通过 Sigmoid 得到真假概率;生成器(G)则以 100 维随机噪声为输入,经过多层全连接与 ReLU/Tanh 激活,输出与真实图像同尺寸的 28×28 结果。这样的结构简单直观,适合入门实验。

训练时,判别器与生成器交替优化:

  • 判别器的目标是最大化对真实图像的判真概率、对生成图像的判假概率;

  • 生成器的目标则是让判别器将其生成的图像判为真。

    损失函数统一采用二分类交叉熵(BCELoss),并为 D、G 分别设置优化器以避免梯度更新冲突。训练过程中会定期输出损失值并保存生成样本,方便对生成效果进行直观评估。

GAN 的核心思想,是让生成器与判别器在对抗博弈中共同提升能力:G 学会捕捉数据分布特征,D 学会分辨真实与伪造,两者在动态平衡中逼近真实分布。这种机制使得 GAN 特别适合用于图像生成、数据增强和风格迁移等任务。

从实验体验来看,这套代码的优点在于结构清晰、可视化直观、参数稳定,非常适合作为学习 GAN 的起点。同时,经过适当修改,还能扩展到更复杂的生成任务,为后续的研究打下基础。

相关推荐
老艾的AI世界6 分钟前
AI去、穿、换装软件下载,无内容限制,偷偷收藏
图像处理·人工智能·深度学习·神经网络·目标检测·机器学习·ai·换装·虚拟试衣·ai换装·一键换装
Navicat中国11 分钟前
Navicat 询问 AI | 如何转换 SQL 为另一种数据库类型
数据库·人工智能·sql·数据库开发·navicat
javgo.cn13 分钟前
Spring AI Alibaba - 聊天机器人快速上手
人工智能·ai·机器人
OpenC++40 分钟前
【机器学习】核心分类及详细介绍
人工智能·机器学习·分类
大千AI助手1 小时前
艾伦·图灵:计算理论与人工智能的奠基人
人工智能·密码学·图灵·turing·人工智能之父·计算机科学之父·图灵机
软件测试-阿涛1 小时前
【AI绘画】Stable Diffusion webUI 常用功能使用技巧
人工智能·深度学习·计算机视觉·ai作画·stable diffusion
轻流AI1 小时前
线索转化率翻3倍?AI重构CRM
大数据·人工智能·低代码·重构
skywalk81631 小时前
LLaMA Factory 是一个简单易用且高效的大型语言模型(Large Language Model)训练与微调平台。
人工智能·语言模型·自然语言处理
2301_821919922 小时前
机器学习概述(一)
人工智能·机器学习
果粒橙_LGC2 小时前
自学大语言模型之Transformer的Tokenizer
人工智能·语言模型·transformer