学习笔记:Pytorch利用MNIST数据集训练生成对抗网络(GAN)

2023.8.27

在进行深度学习的进阶的时候,我发了生成对抗网络是一个很神奇的东西,为什么它可以"将一堆随机噪声经过生成器变成一张图片",特此记录一下学习心得。

一、生成对抗网络百科

2014年,还在蒙特利尔读博士的Ian Goodfellow发表了论 文《Generative Adversarial Networks》(网址: https://arxiv.org/abs/1406.2661),将生成对抗网络引入 深度学习领域。2016年,GAN热潮席卷AI领域顶级会议, 从ICLR到NIPS,大量高质量论文被发表和探讨。Yann LeCun曾评价GAN是"20年来机器学习领域最酷的想法"。

机器学习的模型可大体分为两类,生成模型( Generative Model)和判别模型(Discriminative Model)。判别模型需要输入变量 ,通过某种模型来 预测 。生成模型是给定某种隐含信息,来随机产生观 测数据。

GAN百科:

GAN(生成对抗网络)的系统全面介绍(醍醐灌顶)_打灰人的博客-CSDN博客

二、GAN代码

训练代码:

epoch=1000时的效果就不错啦

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt


class Generator(nn.Module):  # 生成器
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img


class Discriminator(nn.Module):  # 判别器
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img = img.view(img.size(0), -1)
        validity = self.model(img)
        return validity


def gen_img_plot(model, test_input):
    pred = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow((pred[i] + 1) / 2)
        plt.axis('off')
    plt.show(block=False)
    plt.pause(3)  # 停留0.5s
    plt.close()


# 调用GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 超参数设置
lr = 0.0001
batch_size = 128
latent_dim = 100
epochs = 1000

# 数据集载入和数据变换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 训练数据
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 测试数据 torch.randn()函数的作用是生成一组均值为0,方差为1(即标准正态分布)的随机数
# test_data = torch.randn(batch_size, latent_dim).to(device)
test_data = torch.FloatTensor(batch_size, latent_dim).to(device)

# 实例化生成器和判别器,并定义损失函数和优化器
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

# 开始训练模型
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(train_loader):
        batch_size = imgs.shape[0]
        real_imgs = imgs.to(device)

        # 训练判别器
        z = torch.FloatTensor(batch_size, latent_dim).to(device)
        z.data.normal_(0, 1)
        fake_imgs = generator(z)  # 生成器生成假的图片

        real_labels = torch.full((batch_size, 1), 1.0).to(device)
        fake_labels = torch.full((batch_size, 1), 0.0).to(device)

        real_loss = adversarial_loss(discriminator(real_imgs), real_labels)
        fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake_labels)
        d_loss = (real_loss + fake_loss) / 2

        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # 训练生成器
        z.data.normal_(0, 1)
        fake_imgs = generator(z)

        g_loss = adversarial_loss(discriminator(fake_imgs), real_labels)
        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

        torch.save(generator.state_dict(), "Generator_mnist.pth")

    print(f"Epoch [{epoch}/{epochs}] Loss_D: {d_loss.item():.4f} Loss_G: {g_loss.item():.4f}")

# gen_img_plot(Generator, test_data)
gen_img_plot(generator, test_data)

测试代码:

python 复制代码
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import random

device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')


class Generator(nn.Module):  # 生成器
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img


# test_data = torch.FloatTensor(128, 100).to(device)
test_data = torch.randn(128, 100).to(device)  # 随机噪声

model = Generator(100).to(device)
model.load_state_dict(torch.load('Generator_mnist.pth'))
model.eval()

pred = np.squeeze(model(test_data).detach().cpu().numpy())

for i in range(64):
    plt.subplot(8, 8, i + 1)
    plt.imshow((pred[i] + 1) / 2)
    plt.axis('off')
plt.savefig(fname='image.png', figsize=[5, 5])
plt.show()

三、结果

在超参数设置 epoch=1000,batch_size=128,lr=0.0001,latent_dim = 100 时,gan生成的权重测的结果如图所示

四,GAN的损失函数曲线

一开始训练时,我的gan的损失函数的曲线是类似这样的,就是知乎这文章里一样,生成器损失函数的曲线一直发散。首先,这个loss的曲线一看就是网络崩了,一般正常的情况,d_loss的值会一直下降然后收敛,而g_loss的曲线会先增大后减少,最后同样也会收敛。其次,网络拿到手以后先不要训练太多次,容易出现过拟合的情况。

生成对抗网络的损失函数图像如下合理吗? - 知乎

这是训练了10轮的生成器和鉴别器的损失函数值变化吧:

效果如图所示:

相关推荐
算力云9 分钟前
深度剖析!GPT-image-1 API 开放对 AI 绘画技术生态的冲击!
人工智能·openai图像生成模型·gpt-image-1
孤寂码农_defector13 分钟前
AI 人工智能模型:从理论到实践的深度解析⚡YQW · Studio ⚡【Deepseek】【Chat GPT】
人工智能
北上ing21 分钟前
从FP32到BF16,再到混合精度的全景解析
人工智能·pytorch·深度学习·计算机视觉·stable diffusion
小奕同学A27 分钟前
数字化技术的五个环节:大数据、云计算、人工智能、区块链、移动互联网
大数据·人工智能·云计算
Eric.Lee202130 分钟前
数据集-目标检测系列- F35 战斗机 检测数据集 F35 plane >> DataBall
人工智能·算法·yolo·目标检测·计算机视觉
白熊18832 分钟前
【计算机视觉】CV实践- 基于PaddleSeg的遥感建筑变化检测全解析:从U-Net 3+原理到工程实践
人工智能·计算机视觉
蔗理苦37 分钟前
2025-04-24 Python&深度学习4—— 计算图与动态图机制
开发语言·pytorch·python·深度学习·计算图
Gsen281941 分钟前
AI大模型从0到1记录学习 数据结构和算法 day20
数据结构·学习·算法·生成对抗网络·目标跟踪·语言模型·知识图谱
cmoaciopm2 小时前
Obsidian和Ollama大语言模型的交互过程
人工智能·语言模型
努力进修2 小时前
【金仓数据库征文】-金仓数据库性能调优 “快准稳” 攻略:实战优化,让数据处理飞起来
数据库·人工智能·金仓数据库 2025 征文·数据库平替用金仓