使用pytorch构建一个初级的无监督的GAN网络模型

在这个系列中将系统的构建GAN及其相关的一些变种模型,来了解GAN的基本原理。本片为此系列的第一篇,实现起来很简单,所以不要期待有很好的效果出来。

第一篇我们搭建一个无监督的可以生成数字 (0-9) 手写图像的 GAN,使用MINIST数据集,包含0-9的60000张手写数字图像,如图:

原理

首先简单讲一下GAN的工作原理,如下为前向传播的过程:

GAN网络有两个模型,分别是生成器generator和判别器discriminator。generator的作用是生成图片的,也就是我们想要的结果,通过输入随机噪声来生成图片;而discriminator是判断输入的图片是真实数据还是生成的假数据,输入生成的假数据或真实数据,输出真与假的概率值。

而反向传播过程其实是分开的,即generator和discriminator是分别进行梯度更新的。且交替进行训练的,一个模型训练,另一个模型就要保持不变,保持两个模型的能力要相当才能一起进步,否则如果判别器的性能要比生成器要好的话就很容易陷入模式崩溃mdoel collapse或梯度消失等。

下图为discriminator的反向传播的过程:

discriminator的工作是为了将生成的假数据判别为0,将真实的数据判别为1,即公正判别非黑即白,所以loss的计算为:

下图为generator的反向传播的过程:

而generator的工作是为了将生成的假数据让discriminator判别为1,即骗过discriminator颠倒黑白,所以loss的计算为:

代码

下面开始直接上代码,我在网上学习别人代码的习惯是先把所有代码跑起来再来仔细看每个代码模块,我在这也就先放上所有代码再分析各个模块。

model.py

python 复制代码
from torch import nn
import torch

def get_generator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace=True),
    )

class Generator(nn.Module):
    def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            get_generator_block(z_dim, hidden_dim),
            get_generator_block(hidden_dim, hidden_dim * 2),
            get_generator_block(hidden_dim * 2, hidden_dim * 4),
            get_generator_block(hidden_dim * 4, hidden_dim * 8),
            nn.Linear(hidden_dim * 8, im_dim),
            nn.Sigmoid()
        )
    def forward(self, noise):
        return self.gen(noise)
    def get_gen(self):
        return self.gen

def get_discriminator_block(input_dim, output_dim):
    return nn.Sequential(
         nn.Linear(input_dim, output_dim), #Layer 1
         nn.LeakyReLU(0.2, inplace=True)
    )

class Discriminator(nn.Module):
    def __init__(self, im_dim=784, hidden_dim=128):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            get_discriminator_block(im_dim, hidden_dim * 4),
            get_discriminator_block(hidden_dim * 4, hidden_dim * 2),
            get_discriminator_block(hidden_dim * 2, hidden_dim),
            nn.Linear(hidden_dim, 1)
        )
    def forward(self, image):
        return self.disc(image)
    def get_disc(self):
        return self.disc

train.py

python 复制代码
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST # Training dataset
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model import Discriminator, Generator
torch.manual_seed(0) # Set for testing purposes, please do not change!

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples,z_dim,device=device)

criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001
device = 'cuda'

dataloader = DataLoader(
    MNIST('./', download=True, transform=transforms.ToTensor()),  # 已经下载过的可以改为False跳过下载
    batch_size=batch_size,
    shuffle=True)

gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):
    fake_noise = get_noise(num_images, z_dim, device=device)
    fake = gen(fake_noise)
    disc_fake_pred = disc(fake.detach())
    disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
    disc_real_pred = disc(real)
    disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
    disc_loss = (disc_fake_loss + disc_real_loss) / 2
    return disc_loss

def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
    fake_noise = get_noise(num_images, z_dim, device=device)
    fake = gen(fake_noise)
    disc_fake_pred = disc(fake)
    gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
    return gen_loss
    
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
gen_loss = False
error = False
for epoch in range(n_epochs):
    # Dataloader returns the batches
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)

        # Flatten the batch of real images from the dataset
        real = real.view(cur_batch_size, -1).to(device)

        ### Update discriminator ###
        # Zero out the gradients before backpropagation
        disc_opt.zero_grad()

        # Calculate discriminator loss
        disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)

        # Update gradients
        disc_loss.backward(retain_graph=True)

        # Update optimizer
        disc_opt.step()

        ### Update generator ###
        gen_opt.zero_grad()
        gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)
        gen_loss.backward()
        gen_opt.step()

        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item() / display_step

        # Keep track of the average generator loss
        mean_generator_loss += gen_loss.item() / display_step

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            print(
                f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            show_tensor_images(fake)
            show_tensor_images(real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1

运行结果

运行后每隔500个epoch画出fake和real,刚开始的fake和real是这样的:

到后面的fake逐渐变成这样:

代码解释

网络模型

model.py里面存放了generator和discriminator的网络模型,神经元使用的是简单的全连接层,后面的文章再使用卷积。

生成器输出为784 = 28 * 28,因为使用的是MINIST手写字体数据集,每张图的shape是28 * 28 * 1(黑白图单通道),所以输出的假数据要与真实数据的shape一致,这样输入鉴别器才不会出错。

生成的图片(或真实数据)直接输入鉴别器,所以鉴别器的输入也是28*28,而输出为1,即输出判别结果为真或假。

每个优化器仅优化一个模型的参数,所以一个模型构建一个优化器。

图像显示

首先将图像的tensor转到cpu上,因为PyTorch中的大部分图像处理和显示函数都是在CPU上执行的,包括我们使用的imshow。

detach() 方法将张量从计算图中分离出来,但是仍指向原变量的存放位置,不同之处只是requirse_grad为false,得到的这个tensor永远不需要计算器梯度,不具有grad,这样做的目的是避免梯度计算的影响,因为在展示图像时通常不需要计算梯度。

Pytorch的计算图由节点和边组成,节点表示张量或者Function,边表示张量和Function之间的依赖关系,类似这样:

一个网络模型就是一个计算图,在网络backward时候,需要用链式求导法则求出网络最后输出的梯度,然后再对网络进行优化,求导过程就如上图这样。

make_grid 函数用于将多个图像组成一个网格,方便显示。

然后每500个batch显示一次当前模型性能所能生成的图片以及当前batch的真实图片(虽然一个batch设置了128张,但是我们只展示25张),以及print出生成器和鉴别器的loss。

损失函数

损失函数的原理在上面的"原理"中有讲解,这里不再赘述。

在计算鉴别器的loss里,disc_fake_pred = disc(fake.detach())是对生成图片的判别,这里也使用 .detach() 的目的是将生成器产生的假数据与生成器的参数分离,使得在计算 disc_fake_pred 时不会对生成器的梯度进行传播。这是因为在训练鉴别器的阶段,我们只希望更新鉴别器的参数,而不希望更新生成器的参数(就如上面说的生成器的训练和鉴别器的应该要隔开分别训练、交替训练)。

反向传播


retain_graph=True参数是用来指示 PyTorch 在反向传播时保留计算图。这个参数的作用是为了在一次反向传播之后保留计算图的状态,以便后续再次调用 backward() 函数时能够继续使用这个计算图进行梯度计算。

Pytoch构建的计算图是动态图,为了节约内存,所以每次一轮迭代完之后计算图就被在内存释放,所以当你想要多次backward时候就会报如下错:

python 复制代码
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed.

而在GAN中一次的迭代需要先更新鉴别器的参数,然后再更新生成器的参数;在更新生成器的参数时,我们仍然需要使用鉴别器来鉴别real or fake,只要使用到鉴别器就需要他的计算图。因此,我们需要在调用 disc_loss.backward() 后保留计算图,以便后续调用 gen_loss.backward() 时能够继续使用相同的计算图进行梯度计算。而对于生成器的梯度更新 gen_loss.backward(),不需要显式指定 retain_graph=True。

所以,在同一个计算图上多次调用 backward() 函数时才需要使用它。

下一篇构建DCGAN

相关推荐
矢量赛奇19 分钟前
比ChatGPT更酷的AI工具
人工智能·ai·ai写作·视频
KuaFuAI28 分钟前
微软推出的AI无代码编程微应用平台GitHub Spark和国产AI原生无代码工具CodeFlying比到底咋样?
人工智能·github·aigc·ai编程·codeflying·github spark·自然语言开发软件
Make_magic37 分钟前
Git学习教程(更新中)
大数据·人工智能·git·elasticsearch·计算机视觉
shelly聊AI41 分钟前
语音识别原理:AI 是如何听懂人类声音的
人工智能·语音识别
源于花海44 分钟前
论文学习(四) | 基于数据驱动的锂离子电池健康状态估计和剩余使用寿命预测
论文阅读·人工智能·学习·论文笔记
雷龙发展:Leah1 小时前
离线语音识别自定义功能怎么用?
人工智能·音频·语音识别·信号处理·模块测试
4v1d1 小时前
边缘计算的学习
人工智能·学习·边缘计算
风之馨技术录1 小时前
智谱AI清影升级:引领AI视频进入音效新时代
人工智能·音视频
L Jiawen1 小时前
【Python · PyTorch】卷积神经网络(基础概念)
pytorch·python·cnn
sniper_fandc1 小时前
深度学习基础—Seq2Seq模型
人工智能·深度学习