深度学习--生成对抗网络GAN

GAN简介

让我们先来简单了解一下GAN

GAN的全称是Generative Adversarial Networks,中文称为"生成对抗网络",是一种在深度学习领域广泛使用的无监督学习方法。

GAN主要由两部分组成:生成器判别器。生成器的目标是尽可能地生成真实的样本数据,而判别器的目标是尽可能准确地辨别出生成样本与真实样本。这两个组件通过竞争和对抗的方式共同工作,以提升各自的能力。这种网络结构能够处理没有标注数据的问题,并且在图像处理、自然语言处理等多个领域都有广泛应用。

它通过对抗生成 来训练,目的是估测数据样本的潜在分布并生成新的数据样本。

GAN结构图

原理

生成器根据噪声,也就是随机值,来生成样本,而判别器判断哪些是真实数据,哪些是生成数据,然后将学习的经验反向传播给生成器,让生成器生成的样本不断向真实样本靠拢。

在训练过程中,生成器努力让生成的数据更加真实 ,而判别器努力的去判别数据的真假 ,二者·形成了对抗。最终两个网络形成了动态平衡,生成样本接近真实样本,而判别器也分辨不出来样本的真假,最终对给定图像预测为真的概率基本接近0.5,也就相当于随即猜测类别了。

公式

在公式中,

z代表输入G网络的噪声,

x代表真实图片

G(z)表示G网络生成的图片,

D(*)表示D网络判断图片是否真实的概率

2.GAN的算法流程和公式详解_哔哩哔哩_bilibili

在这个视频里有对这个公式的详解,这里就不详细说了/

我们经过简单了解之后,就要开始搭建GAN网络了,这里我们以手写字体识别数据集为例。

构建GAN网络的步骤

GAN生成对抗网络,步骤:

首先编写生成器和判别器

然后固定生成器,用我们的数据优化判别器,试得我们最开始生成器生成的图片判断为0,真实图片判断为1

接着固定判别器,利用我们的判别器判断生成器生成的图片,以判断的尽可能接近1为目的优化我们的生成器

生成器的代码(针对手写字体识别)

预备知识

transforms.Normalize

transforms.Normalize()函数用于对图像数据进行【标准化】处理 。在深度学习中,数据标准化是一个常见的预处理步骤,它有助于模型更快地收敛,并提高模型的性能

作用

数据标准化:如上所述,transforms.Normalize()函数可以对图像数据进行标准化处理,使数据分布符合标准正态分布。这有助于模型更快地收敛,并提高模型的性能。

提高模型泛化能力:通过对数据进行标准化,我们可以减少模型对特定数据集的过拟合,从而提高模型在未见过的数据上的泛化能力。

加速模型训练:标准化的数据可以使模型在训练过程中更快地学习到数据的特征,从而加速模型的训练速度。

参数
  • mean:(list)长度与输入的通道数相同,代表每个通道上所有数值的平均值
  • std:(list)长度与输入的通道数相同,代表每个通道上所有数值的标准差

Datadoder

参数

dataset(数据集):需要提取数据的数据集,Dataset对象

batch_size(批大小):每一次装载样本的个数,int型

shuffle(洗牌):进行新一轮epoch时是否要重新洗牌,Boolean型

num_workers:是否多进程读取机制

drop_last:当样本数不能被batchsize整除时, 是否舍弃最后一批数据

LeakyReLu函数

图像及参数

我们可以与ReLu函数对比,看一下区别:

主要区别就是在小于0的部分了

代码

导库

python 复制代码
import matplotlib.pyplot as plt
import matplotlib 
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import numpy as np

数据集处理

python 复制代码
transform = transforms.Compose([
transforms.ToTensor(), 
transforms.Normalize(0.5, 0.5)
])
traindata = torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集', train=True, download=True,
                                       transform=transform)  # 训练集60,000张用于训练

在加载数据集时,我们要将数据进行归一化,在GAN中,我们就需要将数据归一化到(-1,1)之间,这是为什么呢?原因是我们在下面会用到Tanh激活函数,而Tanh函数的范围是在-1到1之间的,见下图:

在我们既然知道了为什么要这样,下面就要学会如何做到了

ToTensor 中,我们是将数据的范围限制在了**(0,1)** 之间,而后面的Normalize 是将数据限制在**(-1,1)** 之间,计算公式为**(x-均值)/方差**

生成器

python 复制代码
class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = torch.nn.Sequential(
            torch.nn.Linear(100, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 28 * 28),
            torch.nn.Tanh()
        )

    def forward(self, x):
        img = self.main(x)
        img = img.reshape(-1, 28, 28)
        return img

在这里,我们需要知道,生成器的输入和输出是什么,输入时我们的噪音,而输出一张图片。

在后向传播中,我们最后再将图片进行展平。

判别器

python 复制代码
class Discraiminator(torch.nn.Module):
    def __init__(self):
        super(Discraiminator, self).__init__()
        self.mainf = torch.nn.Sequential(
            torch.nn.Linear(28 * 28, 512),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(512, 256),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(256, 1),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.mainf(x)
        return x

我们同样需要了解判别器的输入和输出,输入是一张(1,28,28)图片,输出为二分类的概率值。

在判别器中,我们如果使用ReLu函数,在小于0的部分就会出现梯度消失的问题,这时候我们就可以用到LeadkyReLu了,它能够优化GAN的训练。

最后的Sigmoid激活函数,将输出压缩到0到1之间,这通常用于二分类问题,但在这里,它用于表示输入是真实数据的概率。

而在后向传播中,我们需要先对图片进行展平。

定义损失函数,优化函数和优化器

python 复制代码
# 定义损失函数和优化函数
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discraiminator().to(device)
# 定义优化器
gen_opt = torch.optim.Adam(gen.parameters(), lr=0.0001)
dis_opt = torch.optim.Adam(dis.parameters(), lr=0.0001)
loss_fn = torch.nn.BCELoss()  # 损失函数

在这里,我们选择使用BCELoss,交叉熵损失函数,这是因为在GAN中,判别器通常被视为一个二分类器 ,它试图区分输入是真实样本还是由生成器生成的假样本,而BCELoss就是用来做二分类的损失函数,正好对应。

在优化器部分,它们分别对生成器和判别器的参数进行优化。

图像显示

python 复制代码
def gen_img_plot(model, testdata):
    pre = np.squeeze(model(testdata).detach().cpu().numpy())
    # tensor.detach()
    # 返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad。
    # 即使之后重新将它的requires_grad置为true,它也不会具有梯度grad
    # 这样我们就会继续使用这个新的tensor进行计算,后面当我们进行反向传播时,到该调用detach()的tensor就会停止,不能再继续向前进行传播
    plt.figure()
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow(pre[i])
    plt.show()

因为我们最终要得到要得到的是处理数据输出的数组,所以我们要用squeeze将额外的单维度删除。

detach是单独开辟空间来保存数据,从而保证数据的稳定性。

plt.figure用来生成一个新画布。

使用subplot函数在一个4x4的网格中定位每个子图。i + 1是因为子图的索引是从1开始的,而不是从0开始。

imshow是在子图中显示图像。

最后的show来显示整体的图像。

后向传播与训练模型

python 复制代码
dis_loss = []  # 判别器损失值记录
gen_loss = []  # 生成器损失值记录
lun = []  # 轮数
for epoch in range(60):
    d_epoch_loss = 0
    g_epoch_loss = 0
    cout = len(trainload)  # 938批次
    for step, (img, _) in enumerate(trainload):
        img = img.to(device)  # 图像数据
        # print('img.size:',img.shape)#img.size: torch.Size([64, 1, 28, 28])
        size = img.size(0)  # 一批次的图片数量64
        # 随机生成一批次的100维向量样本,或者说100个像素点
        random_noise = torch.randn(size, 100, device=device)

        # 判断器的后向传播
        dis_opt.zero_grad()
        real_output = dis(img)
        d_real_loss = loss_fn(real_output, torch.ones_like(real_output))  # 真实数据的损失函数值
        d_real_loss.backward()

        gen_img = gen(random_noise)
        fake_output = dis(gen_img.detach())
        d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))  # 人造的数据的损失函数值
        d_fake_loss.backward()

        d_loss = d_real_loss + d_fake_loss
        dis_opt.step()

        # 生成器的后向传播
        gen_opt.zero_grad()
        fake_output = dis(gen_img)
        g_loss = loss_fn(fake_output, torch.ones_like(fake_output))
        g_loss.backward()
        gen_opt.step()
        d_epoch_loss += d_loss
        g_epoch_loss += g_loss
    dis_loss.append(float(d_epoch_loss))
    gen_loss.append(float(g_epoch_loss))
    print(f'第{epoch + 1}轮的生成器损失值:{g_epoch_loss},判别器损失值{d_epoch_loss}')
    lun.append(epoch + 1)

使用enumerate遍历训练数据集trainload,其中img是图像数据,但_表示我们在这里不使用标签(因为GAN是无监督的)。

step()用来更新判别器的模型参数。

在生成器的后向传播部分,

我们先进行梯度清零,然后通过生成器生成假图像,然后进行前向传播。

我们期望判别器对假图像的评分接近1(真实),因此我们将目标标签设置为与fake_output形状相同的全1张量torch.ones_like(fake_output)

在这里,d_loss和g_loss是一张图像中的损失值,而d_epoch_loss和g_epoch_loss是每一轮损失值的累加,用于最后图像的绘制。

生成图像

python 复制代码
matplotlib.rcParams['font.sans-serif'] = ['KaiTi']
plt.figure()
plt.plot(lun, dis_loss, 'r', label='判别器损失值')
plt.plot(lun, gen_loss, 'b', label='生成器损失值')
plt.xlabel('训练轮数', fontsize=12)
plt.ylabel('损失值', fontsize=12)
plt.title('损失值随着训练轮数得变化情况:',  fontsize=18)
plt.legend()
plt.show()
random_noise = torch.randn(16, 100, device=device)
gen_img_plot(gen, random_noise)

随机生成的噪声有16个样本,100个维度

全部代码

python 复制代码
import matplotlib.pyplot as plt
import matplotlib
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import numpy as np

# 导入数据集并且进行数据处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])
traindata = torchvision.datasets.MNIST(root='./data', train=True, download=True,
                                       transform=transform)  # 训练集60,000张用于训练
# 利用DataLoader加载数据集
trainload = DataLoader(dataset=traindata, shuffle=True, batch_size=64)


# GAN生成对抗网络,步骤:
# 首先编写生成器和判别器
# 然后固定生成器,用我们的数据优化判别器,试得我们最开始生成器生成的图片判断为0,真实图片判断为1
# 接着固定判别器,利用我们的判别器判断生成器生成的图片,以判断的尽可能接近一为目的优化我们的生成器
# 生成器的代码(针对手写字体识别)
class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = torch.nn.Sequential(
            torch.nn.Linear(100, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 28 * 28),
            torch.nn.Tanh()
        )

    def forward(self, x):
        img = self.main(x)
        img = img.reshape(-1, 28, 28)
        return img


# 判别器,最后判断0,1,这意味着最后可以是一个神经元或者两个神经元
class Discraiminator(torch.nn.Module):
    def __init__(self):
        super(Discraiminator, self).__init__()
        self.mainf = torch.nn.Sequential(
            torch.nn.Linear(28 * 28, 512),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(512, 256),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(256, 1),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.mainf(x)
        return x


# 定义损失函数和优化函数
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discraiminator().to(device)
# 定义优化器
gen_opt = torch.optim.Adam(gen.parameters(), lr=0.0001)
dis_opt = torch.optim.Adam(dis.parameters(), lr=0.0001)
loss_fn = torch.nn.BCELoss()  # 损失函数


def gen_img_plot(model, testdata):
    pre = np.squeeze(model(testdata).detach().cpu().numpy())
    # tensor.detach()
    # 返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad。
    # 即使之后重新将它的requires_grad置为true,它也不会具有梯度grad
    # 这样我们就会继续使用这个新的tensor进行计算,后面当我们进行反向传播时,到该调用detach()的tensor就会停止,不能再继续向前进行传播
    plt.figure()
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow(pre[i])
    plt.show()


# 后向传播
dis_loss = []  # 判别器损失值记录
gen_loss = []  # 生成器损失值记录
lun = []  # 轮数
for epoch in range(60):
    d_epoch_loss = 0
    g_epoch_loss = 0
    cout = len(trainload)  # 938批次
    for step, (img, _) in enumerate(trainload):
        img = img.to(device)  # 图像数据
        # print('img.size:',img.shape)#img.size: torch.Size([64, 1, 28, 28])
        size = img.size(0)  # 一批次的图片数量64
        # 随机生成一批次的100维向量样本,或者说100个像素点
        random_noise = torch.randn(size, 100, device=device)

        # 判断器的后向传播
        dis_opt.zero_grad()
        real_output = dis(img)
        d_real_loss = loss_fn(real_output, torch.ones_like(real_output))  # 真实数据的损失函数值
        d_real_loss.backward()

        gen_img = gen(random_noise)
        fake_output = dis(gen_img.detach())
        d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))  # 人造的数据的损失函数值
        d_fake_loss.backward()

        d_loss = d_real_loss + d_fake_loss
        dis_opt.step()

        # 生成器的后向传播
        gen_opt.zero_grad()
        fake_output = dis(gen_img)
        g_loss = loss_fn(fake_output, torch.ones_like(fake_output))
        g_loss.backward()
        gen_opt.step()
        d_epoch_loss += d_loss
        g_epoch_loss += g_loss
    dis_loss.append(float(d_epoch_loss))
    gen_loss.append(float(g_epoch_loss))
    print(f'第{epoch + 1}轮的生成器损失值:{g_epoch_loss},判别器损失值{d_epoch_loss}')
    lun.append(epoch + 1)

matplotlib.rcParams['font.sans-serif'] = ['KaiTi']
plt.figure()
plt.plot(lun, dis_loss, 'r', label='判别器损失值')
plt.plot(lun, gen_loss, 'b', label='生成器损失值')
plt.xlabel('训练轮数', fontsize=12)
plt.ylabel('损失值', fontsize=12)
plt.title('损失值随着训练轮数得变化情况:',  fontsize=18)
plt.legend()
plt.show()
random_noise = torch.randn(16, 100, device=device)
gen_img_plot(gen, random_noise)

运行结果

python 复制代码
第1轮的生成器损失值:2226.86328125,判别器损失值461.5265808105469
第2轮的生成器损失值:2378.969970703125,判别器损失值459.3701477050781
第3轮的生成器损失值:2422.438232421875,判别器损失值355.0154113769531
第4轮的生成器损失值:3410.994873046875,判别器损失值172.3834686279297
第5轮的生成器损失值:3589.7734375,判别器损失值168.7844696044922
第6轮的生成器损失值:3944.258544921875,判别器损失值125.10688781738281
第7轮的生成器损失值:4293.7861328125,判别器损失值138.3419952392578
第8轮的生成器损失值:4436.89404296875,判别器损失值159.64407348632812
第9轮的生成器损失值:4485.7646484375,判别器损失值177.5517578125
第10轮的生成器损失值:4136.85986328125,判别器损失值210.64602661132812
第11轮的生成器损失值:4072.7958984375,判别器损失值246.29910278320312
第12轮的生成器损失值:4298.8623046875,判别器损失值183.00152587890625
第13轮的生成器损失值:4899.4794921875,判别器损失值171.33628845214844
第14轮的生成器损失值:4851.458984375,判别器损失值161.920654296875
第15轮的生成器损失值:4995.62646484375,判别器损失值155.28732299804688
第16轮的生成器损失值:4987.4140625,判别器损失值142.6618194580078
第17轮的生成器损失值:5511.90673828125,判别器损失值126.41560363769531
第18轮的生成器损失值:5509.65771484375,判别器损失值157.1754913330078
第19轮的生成器损失值:5164.8671875,判别器损失值143.5445556640625
第20轮的生成器损失值:5490.17236328125,判别器损失值156.86929321289062
第21轮的生成器损失值:5189.4921875,判别器损失值177.5731201171875
第22轮的生成器损失值:5293.32080078125,判别器损失值168.159912109375
第23轮的生成器损失值:4971.2646484375,判别器损失值189.78167724609375
第24轮的生成器损失值:4590.87158203125,判别器损失值211.07289123535156
第25轮的生成器损失值:4739.5732421875,判别器损失值214.7382354736328
第26轮的生成器损失值:4700.568359375,判别器损失值218.89926147460938
第27轮的生成器损失值:4146.5048828125,判别器损失值269.0607604980469
第28轮的生成器损失值:3846.898681640625,判别器损失值287.00604248046875
第29轮的生成器损失值:3559.870361328125,判别器损失值317.5647888183594
第30轮的生成器损失值:3378.71240234375,判别器损失值336.30572509765625
第31轮的生成器损失值:4269.37060546875,判别器损失值257.89910888671875
第32轮的生成器损失值:5209.896484375,判别器损失值191.99989318847656
第33轮的生成器损失值:4632.1728515625,判别器损失值261.9479064941406
第34轮的生成器损失值:2979.66015625,判别器损失值363.874267578125
第35轮的生成器损失值:2710.74462890625,判别器损失值405.0263671875
第36轮的生成器损失值:2661.800048828125,判别器损失值421.5466613769531
第37轮的生成器损失值:2625.377197265625,判别器损失值414.751708984375
第38轮的生成器损失值:2809.101318359375,判别器损失值399.09942626953125
第39轮的生成器损失值:3797.715087890625,判别器损失值314.6676025390625
第40轮的生成器损失值:6223.8974609375,判别器损失值151.0428924560547
第41轮的生成器损失值:3305.96533203125,判别器损失值355.9456481933594
第42轮的生成器损失值:2672.400634765625,判别器损失值395.23834228515625
第43轮的生成器损失值:2538.265625,判别器损失值425.629638671875
第44轮的生成器损失值:2496.415283203125,判别器损失值443.06085205078125
第45轮的生成器损失值:2451.716796875,判别器损失值449.18194580078125
第46轮的生成器损失值:2397.526123046875,判别器损失值467.0350341796875
第47轮的生成器损失值:2427.2900390625,判别器损失值459.0263977050781
第48轮的生成器损失值:2440.54736328125,判别器损失值469.6186218261719
第49轮的生成器损失值:2597.76953125,判别器损失值439.3223876953125
第50轮的生成器损失值:2724.003173828125,判别器损失值438.4668273925781
第51轮的生成器损失值:2539.636474609375,判别器损失值459.2343444824219
第52轮的生成器损失值:2288.4130859375,判别器损失值498.2747802734375
第53轮的生成器损失值:2244.51513671875,判别器损失值506.4640197753906
第54轮的生成器损失值:2242.865478515625,判别器损失值502.57275390625
第55轮的生成器损失值:2198.66552734375,判别器损失值506.5917053222656
第56轮的生成器损失值:2217.268310546875,判别器损失值502.77081298828125
第57轮的生成器损失值:2246.22802734375,判别器损失值502.93206787109375
第58轮的生成器损失值:2165.259033203125,判别器损失值516.4965209960938
第59轮的生成器损失值:2146.760009765625,判别器损失值519.462890625
第60轮的生成器损失值:2110.582763671875,判别器损失值528.8636474609375

进程已结束,退出代码为 0

我们得生成器损失值是波动的,判别器损失值也是,很难说他们的趋势走向(当然估计和我的训练轮数有关)

这是我们生成器生成的"伪造的图片",从这里可以看出来已经很不错了。

相关推荐
YSGZJJ31 分钟前
股指期货的套保策略如何精准选择和规避风险?
人工智能·区块链
无脑敲代码,bug漫天飞33 分钟前
COR 损失函数
人工智能·机器学习
HPC_fac130520678161 小时前
以科学计算为切入点:剖析英伟达服务器过热难题
服务器·人工智能·深度学习·机器学习·计算机视觉·数据挖掘·gpu算力
小陈phd4 小时前
OpenCV从入门到精通实战(九)——基于dlib的疲劳监测 ear计算
人工智能·opencv·计算机视觉
Guofu_Liao5 小时前
大语言模型---LoRA简介;LoRA的优势;LoRA训练步骤;总结
人工智能·语言模型·自然语言处理·矩阵·llama
ZHOU_WUYI9 小时前
3.langchain中的prompt模板 (few shot examples in chat models)
人工智能·langchain·prompt
如若1239 小时前
主要用于图像的颜色提取、替换以及区域修改
人工智能·opencv·计算机视觉
老艾的AI世界9 小时前
AI翻唱神器,一键用你喜欢的歌手翻唱他人的曲目(附下载链接)
人工智能·深度学习·神经网络·机器学习·ai·ai翻唱·ai唱歌·ai歌曲
DK221519 小时前
机器学习系列----关联分析
人工智能·机器学习
Robot25110 小时前
Figure 02迎重大升级!!人形机器人独角兽[Figure AI]商业化加速
人工智能·机器人·微信公众平台