【人工智能基础】GAN与WGAN实验

一、GAN网络概述

GAN:生成对抗网络。GAN网络中存在两个网络:G(Generator,生成网络)和D(Discriminator,判别网络)。

Generator接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)

Discriminator功能是判别一张图片的真实。它的输入是一张图片x,输出D(x)代表x为真实图片的概率,如果为1就代表图片真实,而输出为0,就代表图片不真实。

在GAN网络的训练中,Generator 的目标就是尽量生成真实的图片去欺骗Discriminator

Discriminator 的目标就是尽量把Generator生成的图片和真实的图片分别开来

二、GAN实验环境准备

除了之前使用过的pytorch-nplnumpy 以外,我们还需要安装visdom

bash 复制代码
pip install visdom

启动visdom

bash 复制代码
python -m visdom.server

visdom启动成功如下图,会占用8097端口,我们可以通过8097端口访问visdom

三、GAN网络实验

环境参数配置

python 复制代码
import torch
from torch import nn,optim,autograd
import numpy as np
import visdom
import random

h_dim = 400
batchsz = 512
viz = visdom.Visdom()

生成网络定义

python 复制代码
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.net = nn.Sequential(
            # input[b, 2]
            nn.Linear(2,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 2)
            # output[b,2]
        )

    def forward(self, z):
        output = self.net(z)
        return output

判别网络定义

python 复制代码
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        output = self.net(x)
        return output.view(-1)

数据集生成函数

python 复制代码
def data_generator():
    # 生成中心点
    scale = 2
    centers = [
        (1, 0),
        (-1, 0),
        (0, 1),
        (0, -1),
        (1. / np.sqrt(2), 1. / np.sqrt(2)),
        (1. / np.sqrt(2), -1. / np.sqrt(2)),
        (-1. / np.sqrt(2), 1. / np.sqrt(2)),
        (-1. / np.sqrt(2), -1. / np.sqrt(2))
    ]
    centers = [(scale * x, scale * y) for x,y in centers] 
    while True:
        dataset = []
        for i in range(batchsz):
            point = np.random.randn(2) * 0.02
            # 随机选取一个中心点
            center = random.choice(centers)
            # 把刚刚随机到的高斯分布点根据center进行移动
            point[0] += center[0]
            point[1] += center[1]
            dataset.append(point)
        dataset = np.array(dataset).astype(np.float32)
        dataset /= 1.414
        yield dataset

可视化函数

将图片生成到visdom

python 复制代码
import matplotlib.pyplot as plt
def generate_image(D, G, xr, epoch):
    N_POINTS = 128
    RANGE = 3
    plt.clf()

    points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')
    points[:,:,0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
    points[:,:,1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
    points = points.reshape((-1,2))

    with torch.no_grad():
        points = torch.Tensor(points).cpu()
        disc_map = D(points).cpu().numpy()
    x = y = np.linspace(-RANGE,RANGE,N_POINTS)
    cs = plt.contour(x,y,disc_map.reshape((len(x), len(y))).transpose())
    plt.clabel(cs, inline=1,fontsize=10)

    with torch.no_grad():
        z = torch.randn(batchsz, 2).cpu()
        samples = G(z).cpu().numpy()
    plt.scatter(xr[:,0],xr[:,1],c='orange',marker='.')
    plt.scatter(samples[:,0], samples[:,1], c='green',marker='+')

    viz.matplot(plt, win='contour',opts=dict(title='p(x):%d'%epoch))

运行函数

python 复制代码
def run():
    torch.manual_seed(23)
    np.random.seed(23)

    data_iter = data_generator()
    x = next(data_iter)
    # print(x.shape)

    # G = Generator().cuda()
    # D = Discriminator().cuda()
    # 无显卡环境
    device = torch.device("cpu")
    G = Generator().cpu()
    print(G)
    D = Discriminator().cpu()
    print(D)

    optim_G = optim.Adam(G.parameters(), lr = 5e-4, betas=(0.5,0.9))
    optim_D = optim.Adam(D.parameters(), lr = 5e-4, betas=(0.5,0.9))

    viz.line([[0,0]],[0],win='loss', opts=dict(title='loss',legend=['D','G']))

    """
    gan核心部分
    """
    for epoch in range(50000):
        # 训练判别网络
        for _ in range(5):
            # 真实数据训练
            xr = next(data_iter)
            xr = torch.from_numpy(xr).cpu()
            predr = D(xr)
            # 放大真实数据
            lossr = -predr.mean()

            # 虚假数据训练
            z = torch.randn(batchsz,2).cpu()
            xf = G(z).detach()
            predf = D(xf)
            # 缩小虚假数据
            lossf = predf.mean()

            loss_D = lossr + lossf

            # 梯度清零
            optim_D.zero_grad()
            # 向后传播
            loss_D.backward()
            optim_D.step()


        # 训练生成网络
        z = torch.randn(batchsz,2).cpu()
        xf = G(z)
        predf = D(xf)
        loss_G = -predf.mean()
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()

        if epoch % 100 == 0:
            viz.line([[loss_D.item(),loss_G.item()]], [epoch],win='loss', update='append')
            print(loss_D.item(), loss_G.item())
            generate_image(D, G, xr, epoch)

执行(GAN的不稳定性)

python 复制代码
run()

从结果中可以看到,判别网络的loss一直为0,而生成网络一直得不到更新,生成的数据点远离我们创建的中心点

四、wgan实验

WGAN主要从损失函数的角度对GAN做了改进,对更新后的权重强制截断到一定范围内

增加一个梯度惩罚函数

python 复制代码
def gradient_penalty(D,xr,xf):
    # [b,1]
    t = torch.rand(batchsz, 1).cpu()
    # 扩展为[b, 2]
    t = t.expand_as(xr)
    # 插值
    mid = t * xr + (1 - t) * xf
    # 设置需要的倒数信息
    mid.requires_grad_()

    pred = D(mid)
    grads = autograd.grad(outputs=pred, 
                          inputs=mid,
                          grad_outputs=torch.ones_like(pred),
                          create_graph=True,
                          retain_graph=True,
                          only_inputs=True)[0]
    gp = torch.pow(grads.norm(2, dim=1) - 1, 2).mean()
    return gp

修改运行函数

python 复制代码
def run():
    torch.manual_seed(23)
    np.random.seed(23)

    data_iter = data_generator()
    x = next(data_iter)
    # print(x.shape)

    # G = Generator().cuda()
    # D = Discriminator().cuda()
    # 无显卡环境
    device = torch.device("cpu")
    G = Generator().cpu()
    print(G)
    D = Discriminator().cpu()
    print(D)

    optim_G = optim.Adam(G.parameters(), lr = 5e-4, betas=(0.5,0.9))
    optim_D = optim.Adam(D.parameters(), lr = 5e-4, betas=(0.5,0.9))

    viz.line([[0,0]],[0],win='loss', opts=dict(title='loss',legend=['D','G']))

    """
    gan核心部分
    """
    for epoch in range(50000):
        # 训练判别网络
        for _ in range(5):
            # 真实数据训练
            xr = next(data_iter)
            xr = torch.from_numpy(xr).cpu()
            predr = D(xr)
            # 放大真实数据
            lossr = -predr.mean()

            # 虚假数据训练
            z = torch.randn(batchsz,2).cpu()
            xf = G(z).detach()
            predf = D(xf)
            # 缩小虚假数据
            lossf = predf.mean()

            # 梯度惩罚值
            gp = gradient_penalty(D,xr,xf.detach())
            loss_D = lossr + lossf + 0.2 * gp
            # 梯度清零
            optim_D.zero_grad()
            # 向后传播
            loss_D.backward()
            optim_D.step()


        # 训练生成网络
        z = torch.randn(batchsz,2).cpu()
        xf = G(z)
        predf = D(xf)
        loss_G = -predf.mean()
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()

        if epoch % 100 == 0:
            viz.line([[loss_D.item(),loss_G.item()]], [epoch],win='loss', update='append')
            print(loss_D.item(), loss_G.item())
            generate_image(D, G, xr, epoch)

执行

python 复制代码
run()

可以看到在wgan中,生成网络开始学习,生成的数据点也能基本根据高斯分布落在中心点附近

相关推荐
白拾4 分钟前
使用Conda管理python环境的指南
开发语言·python·conda
我算是程序猿16 分钟前
用AI做电子萌宠,快速涨粉变现
人工智能·stable diffusion·aigc
萱仔学习自我记录19 分钟前
微调大语言模型——超详细步骤
人工智能·深度学习·机器学习
是刃小木啦~24 分钟前
三维模型点云化工具V1.0使用介绍:将三维模型进行点云化生成
python·软件工程·pyqt·工业软件
湘大小菜鸡29 分钟前
NLP进阶(一)
人工智能·自然语言处理
总裁余(余登武)30 分钟前
算法竞赛(Python)-万变中的不变“随机算法”
开发语言·python·算法
一个闪现必杀技36 分钟前
Python练习2
开发语言·python
XiaoLiuLB36 分钟前
最佳语音识别 Whisper-large-v3-turbo 上线,速度更快(本地安装 )
人工智能·whisper·语音识别
哪 吒39 分钟前
吊打ChatGPT4o!大学生如何用上原版O1辅助论文写作(附论文教程)
人工智能·ai·自然语言处理·chatgpt·aigc
Eric.Lee202142 分钟前
音频文件重采样 - python 实现
人工智能·python·深度学习·算法·audio·音频重采样