【GAN】简单的GAN模型搭建 -- 以线性模型和MNIST数据集为例子

文章目录

不讲原理,从简单的代码一步步开始,学会怎么用、怎么设计损失函数即可。

确定损失函数

生成器的任务是生成足够以假乱真的数据,判别器的任务是分辨出哪些数据是真实的,哪些数据是假的。因此,对于判别器来讲,需要判别真伪,也就是true/false,从这个角度看,是个二分类问题。所以损失函数使用二类分类损失,即BCELoss。

python 复制代码
import torch
import torch.nn as nn

adversarial_loss = nn.BCELoss()

生成器网络架构

这里使用纯线性网络作为生成器,得到的输出为[batch_size, np.pord(28*28)]

python 复制代码
import torch.nn as nn
import numpy as np

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_features, out_features, normalization=True):
            layers = [nn.Linear(in_features, out_features)]
            if normalization:
                layers.append(nn.BatchNorm1d(out_features, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers

        self.model = nn.Sequential(
            *block(100, 128, normalization=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod((1, 28, 28)))), # generate a photo size, but in line mode.
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

判别器网络架构:

判别器的功能为判断出来哪一个是生成的图片,哪一个是真实的图片。对于生成的图片,我们希望判别器打上假的标签,对于真实的图片,我们希望判别器打上真的标签,因此,判别器的输出为一个数,即0或者1。

python 复制代码
import torch.nn as nn
import numpy as np
class Disctiminator(nn.Module):
    def __init__(self):
        super(Disctiminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod((1, 28, 28))), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        img_flat = x.view(x.size(0), -1)
        validity = self.model(img_flat)
        return validity

训练流程:

载入数据 --- 训练 (生成图片 --- 损失 --- 反向传播) --- 测试(这里没有加测试代码,可以照着训练代码改一下)

损失函数:生成器损失函数和判别器损失函数,两个损失函数分别进行反向传播,即生成器损失函数优化生成器,判别器损失函数优化判别器。

python 复制代码
import torch
import torch.nn as nn
import argparse
import os
import numpy as np
from torch.utils.data import DataLoader, Dataset, dataset
from torchvision import datasets
from torchvision.transforms import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
from models.generator import Generator
from models.dicsriminator import Disctiminator





os.makedirs('/home/sjr/gxj/study/data/mnist', exist_ok=True)

dataloader = DataLoader(datasets.MNIST('/home/sjr/gxj/study/data/mnist',
                                       train=True, download=True,
                                       transform=transforms.Compose([transforms.Resize(28), transforms.ToTensor(),
                                                                     transforms.Normalize([0.5], [0.5])])
                                       ),
                        batch_size=64, shuffle=True, num_workers=4)

adversarial_loss = torch.nn.BCELoss()
generator = Generator()
discriminator = Disctiminator()

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
if torch.cuda.is_available():
    adversarial_loss.cuda(device)
    generator.cuda(device)
    discriminator.cuda(device)

optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

for epoch in range(60):
    for i, (imgs, _) in enumerate(dataloader):
        valid = Tensor(imgs.size(0), 1).fill_(1.0)
        fake = Tensor(imgs.size(0), 1).fill_(0.0)
        real_imgs = Variable(imgs.type(Tensor))

        optimizer_G.zero_grad()

        z = Tensor(np.random.normal(0, 1, (imgs.shape[0], 100)))
        gen_imgs = generator(z)


        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        optimizer_D.zero_grad()

        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
        #
        print(f"[Epoch {epoch}/{200}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")
        #
        if (epoch + 1) % 20 == 0:
            save_image(gen_imgs.data[:25], f'images/{epoch+1}.png', nrow=5, normalize=True)
相关推荐
算家计算1 分钟前
模糊高清修复真王炸!ComfyUI-SeedVR2-Kontext(画质修复+P图)本地部署教程
人工智能·开源·aigc
虫无涯12 分钟前
LangSmith:大模型应用开发的得力助手
人工智能·langchain·llm
算家计算24 分钟前
DeepSeek-R1论文登《自然》封面!首次披露更多训练细节
人工智能·资讯·deepseek
weiwenhao1 小时前
关于 nature 编程语言
人工智能·后端·开源
神经星星1 小时前
训练成本29.4万美元,DeepSeek-R1登Nature封面,首个通过权威期刊同行评审的主流大模型获好评
人工智能
神州问学1 小时前
【AI洞察】别再只想着“让AI听你话”,人类也需要学习“适应AI”!
人工智能
DevUI团队1 小时前
🚀 MateChat V1.8.0 震撼发布!对话卡片可视化升级,对话体验全面进化~
前端·vue.js·人工智能
聚客AI1 小时前
🎉7.6倍训练加速与24倍吞吐提升:两项核心技术背后的大模型推理优化全景图
人工智能·llm·掘金·日新计划
黎燃2 小时前
当 YOLO 遇见编剧:用自然语言生成技术把“目标检测”写成“目标剧情”
人工智能
算家计算2 小时前
AI教母李飞飞团队发布最新空间智能模型!一张图生成无限3D世界,元宇宙越来越近了
人工智能·资讯