Pytorch 第十五回:神经网络编码器——GAN生成对抗网络

Pytorch 第十五回:神经网络编码器------GAN生成对抗网络

本次开启深度学习第十五回,基于Pytorch的神经网络编码器。本回分享的是GAN生成对抗网络。在本回中,通过minist数据集来分享如何建立一个GAN生成对抗网络。接下来给大家分享具体思路。

本次学习,借助的平台是PyCharm 2024.1.3,python版本3.11 numpy版本是1.26.4,pytorch版本2.0.0


前言

讲述模型前,先讲述两个概念,统一下思路:

1、GAN网络

生成对抗网络(Generative Adversarial Network,GAN)是一种通过"对抗训练"生成与真实数据相似数据的深度学习模型。它的核心思想是让两个神经网络,即生成器和判别器之间相互博弈,最终生成以假乱真的数据。

1)生成器

生成器负责‌从随机噪声中生成逼真的假数据‌,其本质是通过学习真实数据的分布规律,创造出与真实数据高度相似的新样本。

2‌)判别器

判别器的任务是‌区分真实数据与生成数据的差异,并通过反馈优化生成器‌。

注:

生成器与判别器的目标函数对立,二者通过对抗训练实现数据生成与优化的动态平衡。即:

生成器的优化方向‌:最小化判别器对生成数据的判别准确率(即让判别器误判生成数据为真)‌。

‌判别器的优化方向‌:最大化对真实数据与生成数据的正确分类概率‌。

闲言少叙,直接展示逻辑,先上引用:

复制代码
import torch
from torch import nn
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import time

一、数据准备

首先准备MNIST数据集,并进行数据的预处理。关于数据集、数据加载器的介绍,可以查看第五回内容

复制代码
data_treating = tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize([0.5], [0.5])
])

train_set = MNIST('./data', transform=data_treating)
train_data = DataLoader(train_set, batch_size=128)
val_set = MNIST('./data', train=True, transform=data_treating)
val_data = DataLoader(val_set, batch_size=128)

二、模型建立

为了更好的捕捉图像的轮廓细节,提高图像生成的清晰程度,生成器和判别器在全连接层的基础上添加了卷积层。

1.生成器建立

建立生成器的网络模型。其结构如下所示:

c 复制代码
class create_net(nn.Module):
    def __init__(self):
        super(create_net, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(96, 1024),
            nn.ReLU(True),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 7 * 7 * 128),
            nn.ReLU(True),
            nn.BatchNorm1d(7 * 7 * 128)
        )

        self.conv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, padding=1),
            nn.ReLU(True),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 1, 4, 2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.shape[0], 128, 7, 7)
        x = self.conv(x)
        return x

2.判别器建立

建立判别器的网络模型,代码如下:

c 复制代码
class decide_net(nn.Module):
    def __init__(self):
        super(decide_net, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(1024, 1024),
            nn.LeakyReLU(0.01),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

3、损失函数和优化函数的建立

对判别器和生成器分别建立损失函数,并建立优化函数。

判别器的损失公式为:L =−[ Ex[log D(x)]+Ez[log(1−D(G( z )) ] ]

生成器的损失公式为:L =−Ez[ log D(G(z) ) ]

代码如下:

复制代码
loss_f = nn.BCEWithLogitsLoss()

def decide_loss(real_data, fake_data):
    size = real_data.shape[0]
    true_labels = torch.ones(size, 1).float().cuda()
    false_labels = torch.zeros(size, 1).float().cuda()
    loss = loss_f(real_data, true_labels) + loss_f(fake_data, false_labels)
    return loss
def create_loss(fake_data):
    size = fake_data.shape[0]
    true_labels = torch.ones(size, 1).float().cuda()
    loss = loss_f(fake_data, true_labels)
    return loss
def get_optimizer(net):
    optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
    return optimizer

4、图片转换函数的建立

由于训练时采用的数据与我们所需呈现的数据格式不一样,因此需要转换函数将数据进行数据格式转换。

复制代码
def deprocess_image(x):
    return (x + 1.0) / 2.0


def show_images(images):
    plt.rcParams['image.interpolation'] = 'nearest'
    plt.rcParams['image.cmap'] = 'gray'
    images = np.reshape(images, [images.shape[0], -1])
    change_im1 = int(np.ceil(np.sqrt(images.shape[0])))
    change_im2 = int(np.ceil(np.sqrt(images.shape[1])))
    grid_map = gridspec.GridSpec(change_im1, change_im1)
    grid_map.update(wspace=0.05, hspace=0.05)
    for i, img in enumerate(images):
        ax = plt.subplot(grid_map[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([change_im2, change_im2]))
    return

三、模型的训练

1、训练函数的建立

训练函数由两部分组成,一部分是对判别器进行训练,另一部分是对生成器进行训练。

对判别器进行训练时,先将真实数据 real_data 送入判别器 decide_f 生成判定结果decide_real;再将融入噪声的数据 noise_data 送入生成器 create_f 中生成数据 fake_data;fake_data 送入判别器 decide_f 生成判定结果 decide_fake,最后将 decide_real 和 decide_fake 送入损失函数 decide_loss 中获得损失值 decide_error。接着就和前几回分享的内容一样,即梯度归零、计算新梯度、参数更新。

对生成器进行训练时,将融入噪声的数据 noise_data 送入生成器 create_f 中生成数据 generate_data;fake_data 送入判别器 decide_f 生成判定结果 generate_fake, generate_fake送入损失函数 generate_loss 中获得损失值 generate_error。接着的如上文所示。

代码如下所示:

复制代码
def train_dc_gan(decide_f, create_f, decide_optimizer, create_optimizer, decide_loss,
                 generate_loss, show_time=500,noise_size=96, num_epochs=10):
    time1 = time.time()
    iter_count = 0
    for epoch in range(num_epochs):
        for x, _ in train_data:
            data_in = x.shape[0]
            real_data = x.cuda()
            decide_real = decide_f(real_data)
            data_noise = (torch.rand(data_in, noise_size) - 0.5) / 0.5
            noise_data = data_noise.cuda()
            fake_data = create_f(noise_data)
            decide_fake = decide_f(fake_data)
            decide_error = decide_loss(decide_real, decide_fake)
            decide_optimizer.zero_grad()
            decide_error.backward()
            decide_optimizer.step()

            noise_data = data_noise.cuda()
            generate_data = create_f(noise_data)
            generate_fake = decide_f(generate_data)
            generate_error = generate_loss(generate_fake)
            create_optimizer.zero_grad()
            generate_error.backward()
            create_optimizer.step()
            time_consume = time.time() - time1
            if (iter_count % show_time == 0):
                print('Iter:	{},	Decode:	{:.4},	encode:{:.4},  time:{:.3f}s'.format(iter_count,
                                                                                                decide_error,
                                                                                                generate_error,
                                                                                                time_consume))
            iter_count += 1
    images_data = deprocess_image(fake_data.data.cpu().numpy())
    show_images(images_data[0:16])
    plt.show()

2、实例化训练模型并训练

复制代码
decide_function = decide_net().cuda()
create_function = create_net().cuda()
decide_optim = get_optimizer(decide_function)
create_optim = get_optimizer(create_function)
train_dc_gan(decide_function, create_function, decide_optim, create_optim, decide_loss, create_loss, num_epochs=5)

输出展示如下:

Iter: 0, Decode: 1.394, encode:0.9122, time:1.738s

Iter: 500, Decode: 0.9634, encode:0.6716, time:69.306s

Iter: 1000, Decode: 0.9186, encode:1.259, time:133.440s

Iter: 1500, Decode: 1.071, encode:1.453, time:196.908s

Iter: 2000, Decode: 1.491, encode:1.951, time:260.481s

3、图片展示


总结

1)数据准备:准备MNIST集;

2)模型准备:定义GAN生成器模型,GAN判别器模型,损失函数和优化器;

3)数据训练:定义训练函数,实例化模型并训练,生成新的图片数据

相关推荐
郭庆汝3 小时前
pytorch、torchvision与python版本对应关系
人工智能·pytorch·python
IT古董3 小时前
【第二章:机器学习与神经网络概述】03.类算法理论与实践-(3)决策树分类器
神经网络·算法·机器学习
思则变6 小时前
[Pytest] [Part 2]增加 log功能
开发语言·python·pytest
鱼摆摆拜拜6 小时前
第 3 章:神经网络如何学习
人工智能·神经网络·学习
cver1236 小时前
野生动物检测数据集介绍-5,138张图片 野生动物保护监测 智能狩猎相机系统 生态研究与调查
人工智能·pytorch·深度学习·目标检测·计算机视觉·目标跟踪
漫谈网络7 小时前
WebSocket 在前后端的完整使用流程
javascript·python·websocket
try2find8 小时前
安装llama-cpp-python踩坑记
开发语言·python·llama
DataGear8 小时前
如何在DataGear 5.4.1 中快速制作SQL服务端分页的数据表格看板
javascript·数据库·sql·信息可视化·数据分析·echarts·数据可视化
博观而约取9 小时前
Django ORM 1. 创建模型(Model)
数据库·python·django
王小王-12310 小时前
基于Hadoop的京东厨具商品数据分析及商品价格预测系统的设计与实现
hadoop·数据分析·京东厨具·厨具分析·商品分析