Pytorch 第十五回:神经网络编码器——GAN生成对抗网络

Pytorch 第十五回:神经网络编码器------GAN生成对抗网络

本次开启深度学习第十五回,基于Pytorch的神经网络编码器。本回分享的是GAN生成对抗网络。在本回中,通过minist数据集来分享如何建立一个GAN生成对抗网络。接下来给大家分享具体思路。

本次学习,借助的平台是PyCharm 2024.1.3,python版本3.11 numpy版本是1.26.4,pytorch版本2.0.0


前言

讲述模型前,先讲述两个概念,统一下思路:

1、GAN网络

生成对抗网络(Generative Adversarial Network,GAN)是一种通过"对抗训练"生成与真实数据相似数据的深度学习模型。它的核心思想是让两个神经网络,即生成器和判别器之间相互博弈,最终生成以假乱真的数据。

1)生成器

生成器负责‌从随机噪声中生成逼真的假数据‌,其本质是通过学习真实数据的分布规律,创造出与真实数据高度相似的新样本。

2‌)判别器

判别器的任务是‌区分真实数据与生成数据的差异,并通过反馈优化生成器‌。

注:

生成器与判别器的目标函数对立,二者通过对抗训练实现数据生成与优化的动态平衡。即:

生成器的优化方向‌:最小化判别器对生成数据的判别准确率(即让判别器误判生成数据为真)‌。

‌判别器的优化方向‌:最大化对真实数据与生成数据的正确分类概率‌。

闲言少叙,直接展示逻辑,先上引用:

复制代码
import torch
from torch import nn
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import time

一、数据准备

首先准备MNIST数据集,并进行数据的预处理。关于数据集、数据加载器的介绍,可以查看第五回内容

复制代码
data_treating = tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize([0.5], [0.5])
])

train_set = MNIST('./data', transform=data_treating)
train_data = DataLoader(train_set, batch_size=128)
val_set = MNIST('./data', train=True, transform=data_treating)
val_data = DataLoader(val_set, batch_size=128)

二、模型建立

为了更好的捕捉图像的轮廓细节,提高图像生成的清晰程度,生成器和判别器在全连接层的基础上添加了卷积层。

1.生成器建立

建立生成器的网络模型。其结构如下所示:

c 复制代码
class create_net(nn.Module):
    def __init__(self):
        super(create_net, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(96, 1024),
            nn.ReLU(True),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 7 * 7 * 128),
            nn.ReLU(True),
            nn.BatchNorm1d(7 * 7 * 128)
        )

        self.conv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, padding=1),
            nn.ReLU(True),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 1, 4, 2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.shape[0], 128, 7, 7)
        x = self.conv(x)
        return x

2.判别器建立

建立判别器的网络模型,代码如下:

c 复制代码
class decide_net(nn.Module):
    def __init__(self):
        super(decide_net, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(1024, 1024),
            nn.LeakyReLU(0.01),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

3、损失函数和优化函数的建立

对判别器和生成器分别建立损失函数,并建立优化函数。

判别器的损失公式为:L =−[ Ex[log D(x)]+Ez[log(1−D(G( z )) ] ]

生成器的损失公式为:L =−Ez[ log D(G(z) ) ]

代码如下:

复制代码
loss_f = nn.BCEWithLogitsLoss()

def decide_loss(real_data, fake_data):
    size = real_data.shape[0]
    true_labels = torch.ones(size, 1).float().cuda()
    false_labels = torch.zeros(size, 1).float().cuda()
    loss = loss_f(real_data, true_labels) + loss_f(fake_data, false_labels)
    return loss
def create_loss(fake_data):
    size = fake_data.shape[0]
    true_labels = torch.ones(size, 1).float().cuda()
    loss = loss_f(fake_data, true_labels)
    return loss
def get_optimizer(net):
    optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
    return optimizer

4、图片转换函数的建立

由于训练时采用的数据与我们所需呈现的数据格式不一样,因此需要转换函数将数据进行数据格式转换。

复制代码
def deprocess_image(x):
    return (x + 1.0) / 2.0


def show_images(images):
    plt.rcParams['image.interpolation'] = 'nearest'
    plt.rcParams['image.cmap'] = 'gray'
    images = np.reshape(images, [images.shape[0], -1])
    change_im1 = int(np.ceil(np.sqrt(images.shape[0])))
    change_im2 = int(np.ceil(np.sqrt(images.shape[1])))
    grid_map = gridspec.GridSpec(change_im1, change_im1)
    grid_map.update(wspace=0.05, hspace=0.05)
    for i, img in enumerate(images):
        ax = plt.subplot(grid_map[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([change_im2, change_im2]))
    return

三、模型的训练

1、训练函数的建立

训练函数由两部分组成,一部分是对判别器进行训练,另一部分是对生成器进行训练。

对判别器进行训练时,先将真实数据 real_data 送入判别器 decide_f 生成判定结果decide_real;再将融入噪声的数据 noise_data 送入生成器 create_f 中生成数据 fake_data;fake_data 送入判别器 decide_f 生成判定结果 decide_fake,最后将 decide_real 和 decide_fake 送入损失函数 decide_loss 中获得损失值 decide_error。接着就和前几回分享的内容一样,即梯度归零、计算新梯度、参数更新。

对生成器进行训练时,将融入噪声的数据 noise_data 送入生成器 create_f 中生成数据 generate_data;fake_data 送入判别器 decide_f 生成判定结果 generate_fake, generate_fake送入损失函数 generate_loss 中获得损失值 generate_error。接着的如上文所示。

代码如下所示:

复制代码
def train_dc_gan(decide_f, create_f, decide_optimizer, create_optimizer, decide_loss,
                 generate_loss, show_time=500,noise_size=96, num_epochs=10):
    time1 = time.time()
    iter_count = 0
    for epoch in range(num_epochs):
        for x, _ in train_data:
            data_in = x.shape[0]
            real_data = x.cuda()
            decide_real = decide_f(real_data)
            data_noise = (torch.rand(data_in, noise_size) - 0.5) / 0.5
            noise_data = data_noise.cuda()
            fake_data = create_f(noise_data)
            decide_fake = decide_f(fake_data)
            decide_error = decide_loss(decide_real, decide_fake)
            decide_optimizer.zero_grad()
            decide_error.backward()
            decide_optimizer.step()

            noise_data = data_noise.cuda()
            generate_data = create_f(noise_data)
            generate_fake = decide_f(generate_data)
            generate_error = generate_loss(generate_fake)
            create_optimizer.zero_grad()
            generate_error.backward()
            create_optimizer.step()
            time_consume = time.time() - time1
            if (iter_count % show_time == 0):
                print('Iter:	{},	Decode:	{:.4},	encode:{:.4},  time:{:.3f}s'.format(iter_count,
                                                                                                decide_error,
                                                                                                generate_error,
                                                                                                time_consume))
            iter_count += 1
    images_data = deprocess_image(fake_data.data.cpu().numpy())
    show_images(images_data[0:16])
    plt.show()

2、实例化训练模型并训练

复制代码
decide_function = decide_net().cuda()
create_function = create_net().cuda()
decide_optim = get_optimizer(decide_function)
create_optim = get_optimizer(create_function)
train_dc_gan(decide_function, create_function, decide_optim, create_optim, decide_loss, create_loss, num_epochs=5)

输出展示如下:

Iter: 0, Decode: 1.394, encode:0.9122, time:1.738s

Iter: 500, Decode: 0.9634, encode:0.6716, time:69.306s

Iter: 1000, Decode: 0.9186, encode:1.259, time:133.440s

Iter: 1500, Decode: 1.071, encode:1.453, time:196.908s

Iter: 2000, Decode: 1.491, encode:1.951, time:260.481s

3、图片展示


总结

1)数据准备:准备MNIST集;

2)模型准备:定义GAN生成器模型,GAN判别器模型,损失函数和优化器;

3)数据训练:定义训练函数,实例化模型并训练,生成新的图片数据

相关推荐
DX_水位流量监测3 分钟前
无人机测流之雷达流速仪监测技术分析
大数据·网络·人工智能·数据分析·自动化·无人机
0思必得07 分钟前
[Web自动化] Selenium基础介绍
前端·python·selenium·自动化·web自动化
2501_9311624334 分钟前
大疆相机:空中影像新境界
python
测试199836 分钟前
Web自动化测试入门
自动化测试·软件测试·python·功能测试·selenium·测试工具·测试用例
予枫的编程笔记38 分钟前
【论文解读】DLF:以语言为核心的多模态情感分析新范式 (AAAI 2025)
人工智能·python·算法·机器学习
创作者mateo1 小时前
PyTorch 入门笔记配套【完整练习代码】
人工智能·pytorch·笔记
YangYang9YangYan1 小时前
中专大数据技术专业学习数据分析的价值分析
大数据·学习·数据分析
lbb 小魔仙1 小时前
【Python】零基础学 Python 爬虫:从原理到反爬,构建企业级爬虫系统
开发语言·爬虫·python
黄河里的小鲤鱼1 小时前
拯救草台班子-战略
人工智能·python·信息可视化
碎碎思1 小时前
在 FPGA 上实现并行脉冲神经网络(Spiking Neural Net)
人工智能·深度学习·神经网络·机器学习·fpga开发