Pytorch 第十五回:神经网络编码器——GAN生成对抗网络

Pytorch 第十五回:神经网络编码器------GAN生成对抗网络

本次开启深度学习第十五回,基于Pytorch的神经网络编码器。本回分享的是GAN生成对抗网络。在本回中,通过minist数据集来分享如何建立一个GAN生成对抗网络。接下来给大家分享具体思路。

本次学习,借助的平台是PyCharm 2024.1.3,python版本3.11 numpy版本是1.26.4,pytorch版本2.0.0


前言

讲述模型前,先讲述两个概念,统一下思路:

1、GAN网络

生成对抗网络(Generative Adversarial Network,GAN)是一种通过"对抗训练"生成与真实数据相似数据的深度学习模型。它的核心思想是让两个神经网络,即生成器和判别器之间相互博弈,最终生成以假乱真的数据。

1)生成器

生成器负责‌从随机噪声中生成逼真的假数据‌,其本质是通过学习真实数据的分布规律,创造出与真实数据高度相似的新样本。

2‌)判别器

判别器的任务是‌区分真实数据与生成数据的差异,并通过反馈优化生成器‌。

注:

生成器与判别器的目标函数对立,二者通过对抗训练实现数据生成与优化的动态平衡。即:

生成器的优化方向‌:最小化判别器对生成数据的判别准确率(即让判别器误判生成数据为真)‌。

‌判别器的优化方向‌:最大化对真实数据与生成数据的正确分类概率‌。

闲言少叙,直接展示逻辑,先上引用:

复制代码
import torch
from torch import nn
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import time

一、数据准备

首先准备MNIST数据集,并进行数据的预处理。关于数据集、数据加载器的介绍,可以查看第五回内容

复制代码
data_treating = tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize([0.5], [0.5])
])

train_set = MNIST('./data', transform=data_treating)
train_data = DataLoader(train_set, batch_size=128)
val_set = MNIST('./data', train=True, transform=data_treating)
val_data = DataLoader(val_set, batch_size=128)

二、模型建立

为了更好的捕捉图像的轮廓细节,提高图像生成的清晰程度,生成器和判别器在全连接层的基础上添加了卷积层。

1.生成器建立

建立生成器的网络模型。其结构如下所示:

c 复制代码
class create_net(nn.Module):
    def __init__(self):
        super(create_net, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(96, 1024),
            nn.ReLU(True),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 7 * 7 * 128),
            nn.ReLU(True),
            nn.BatchNorm1d(7 * 7 * 128)
        )

        self.conv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, padding=1),
            nn.ReLU(True),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 1, 4, 2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.shape[0], 128, 7, 7)
        x = self.conv(x)
        return x

2.判别器建立

建立判别器的网络模型,代码如下:

c 复制代码
class decide_net(nn.Module):
    def __init__(self):
        super(decide_net, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(1024, 1024),
            nn.LeakyReLU(0.01),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

3、损失函数和优化函数的建立

对判别器和生成器分别建立损失函数,并建立优化函数。

判别器的损失公式为:L =−[ Ex[log D(x)]+Ez[log(1−D(G( z )) ] ]

生成器的损失公式为:L =−Ez[ log D(G(z) ) ]

代码如下:

复制代码
loss_f = nn.BCEWithLogitsLoss()

def decide_loss(real_data, fake_data):
    size = real_data.shape[0]
    true_labels = torch.ones(size, 1).float().cuda()
    false_labels = torch.zeros(size, 1).float().cuda()
    loss = loss_f(real_data, true_labels) + loss_f(fake_data, false_labels)
    return loss
def create_loss(fake_data):
    size = fake_data.shape[0]
    true_labels = torch.ones(size, 1).float().cuda()
    loss = loss_f(fake_data, true_labels)
    return loss
def get_optimizer(net):
    optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
    return optimizer

4、图片转换函数的建立

由于训练时采用的数据与我们所需呈现的数据格式不一样,因此需要转换函数将数据进行数据格式转换。

复制代码
def deprocess_image(x):
    return (x + 1.0) / 2.0


def show_images(images):
    plt.rcParams['image.interpolation'] = 'nearest'
    plt.rcParams['image.cmap'] = 'gray'
    images = np.reshape(images, [images.shape[0], -1])
    change_im1 = int(np.ceil(np.sqrt(images.shape[0])))
    change_im2 = int(np.ceil(np.sqrt(images.shape[1])))
    grid_map = gridspec.GridSpec(change_im1, change_im1)
    grid_map.update(wspace=0.05, hspace=0.05)
    for i, img in enumerate(images):
        ax = plt.subplot(grid_map[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([change_im2, change_im2]))
    return

三、模型的训练

1、训练函数的建立

训练函数由两部分组成,一部分是对判别器进行训练,另一部分是对生成器进行训练。

对判别器进行训练时,先将真实数据 real_data 送入判别器 decide_f 生成判定结果decide_real;再将融入噪声的数据 noise_data 送入生成器 create_f 中生成数据 fake_data;fake_data 送入判别器 decide_f 生成判定结果 decide_fake,最后将 decide_real 和 decide_fake 送入损失函数 decide_loss 中获得损失值 decide_error。接着就和前几回分享的内容一样,即梯度归零、计算新梯度、参数更新。

对生成器进行训练时,将融入噪声的数据 noise_data 送入生成器 create_f 中生成数据 generate_data;fake_data 送入判别器 decide_f 生成判定结果 generate_fake, generate_fake送入损失函数 generate_loss 中获得损失值 generate_error。接着的如上文所示。

代码如下所示:

复制代码
def train_dc_gan(decide_f, create_f, decide_optimizer, create_optimizer, decide_loss,
                 generate_loss, show_time=500,noise_size=96, num_epochs=10):
    time1 = time.time()
    iter_count = 0
    for epoch in range(num_epochs):
        for x, _ in train_data:
            data_in = x.shape[0]
            real_data = x.cuda()
            decide_real = decide_f(real_data)
            data_noise = (torch.rand(data_in, noise_size) - 0.5) / 0.5
            noise_data = data_noise.cuda()
            fake_data = create_f(noise_data)
            decide_fake = decide_f(fake_data)
            decide_error = decide_loss(decide_real, decide_fake)
            decide_optimizer.zero_grad()
            decide_error.backward()
            decide_optimizer.step()

            noise_data = data_noise.cuda()
            generate_data = create_f(noise_data)
            generate_fake = decide_f(generate_data)
            generate_error = generate_loss(generate_fake)
            create_optimizer.zero_grad()
            generate_error.backward()
            create_optimizer.step()
            time_consume = time.time() - time1
            if (iter_count % show_time == 0):
                print('Iter:	{},	Decode:	{:.4},	encode:{:.4},  time:{:.3f}s'.format(iter_count,
                                                                                                decide_error,
                                                                                                generate_error,
                                                                                                time_consume))
            iter_count += 1
    images_data = deprocess_image(fake_data.data.cpu().numpy())
    show_images(images_data[0:16])
    plt.show()

2、实例化训练模型并训练

复制代码
decide_function = decide_net().cuda()
create_function = create_net().cuda()
decide_optim = get_optimizer(decide_function)
create_optim = get_optimizer(create_function)
train_dc_gan(decide_function, create_function, decide_optim, create_optim, decide_loss, create_loss, num_epochs=5)

输出展示如下:

Iter: 0, Decode: 1.394, encode:0.9122, time:1.738s

Iter: 500, Decode: 0.9634, encode:0.6716, time:69.306s

Iter: 1000, Decode: 0.9186, encode:1.259, time:133.440s

Iter: 1500, Decode: 1.071, encode:1.453, time:196.908s

Iter: 2000, Decode: 1.491, encode:1.951, time:260.481s

3、图片展示


总结

1)数据准备:准备MNIST集;

2)模型准备:定义GAN生成器模型,GAN判别器模型,损失函数和优化器;

3)数据训练:定义训练函数,实例化模型并训练,生成新的图片数据

相关推荐
Learn Beyond Limits1 小时前
Mean Normalization|均值归一化
人工智能·神经网络·算法·机器学习·均值算法·ai·吴恩达
摩羯座-185690305941 小时前
爬坑 10 年!京东店铺全量商品接口实战开发:从分页优化、SKU 关联到数据完整性闭环
linux·网络·数据库·windows·爬虫·python
ACERT3331 小时前
5.吴恩达机器学习—神经网络的基本使用
人工智能·python·神经网络·机器学习
韩立学长1 小时前
【开题答辩实录分享】以《基于python的奶茶店分布数据分析与可视化》为例进行答辩实录分享
开发语言·python·数据分析
C嘎嘎嵌入式开发1 小时前
(一) 机器学习之深度神经网络
人工智能·神经网络·dnn
2401_831501732 小时前
Python学习之day03学习(文件和异常)
开发语言·python·学习
可触的未来,发芽的智生2 小时前
触摸未来2025.10.06:声之密语从生理构造到神经网络的声音智能革命
人工智能·python·神经网络·机器学习·架构
Zwb2997922 小时前
Day 24 - 文件、目录与路径 - Python学习笔记
笔记·python·学习
无风听海2 小时前
神经网络之为什么回归任务的输出是高斯分布的均值
神经网络·均值算法·回归
hui函数2 小时前
python全栈(基础篇)——day03:后端内容(字符串格式化+简单数据类型转换+进制的转换+运算符+实战演示+每日一题)
开发语言·后端·python·全栈