GAN随手笔记

文章目录

  • [1. description](#1. description)
  • [2. code](#2. code)

1. description

后续整理

GAN是生成对抗网络,主要由G生成器,D判别器组成,具体形式如下

  • D 判别器:
  • G生成器:

2. code

部分源码,暂定,后续修改

python 复制代码
import numpy as np
import os
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader, Dataset

import torch.cuda

image_size = [1, 28, 28]
latent_dim = 96
label_emb_dim = 32
batch_size = 64
use_gpu = torch.cuda.is_available()
save_dir = "cgan_images"
os.makedirs(save_dir, exist_ok=True)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.embedding = nn.Embedding(10, label_emb_dim)
        self.model = nn.Sequential(
            nn.Linear(label_emb_dim + label_emb_dim, 128),
            nn.BatchNorm1d(128),
            nn.GELU(),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.GELU(),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.GELU(),
            nn.Linear(1024, np.prod(image_size, dtype=np.int32)),
            nn.Sigmoid(),
        )

    def forward(self, z, labels):
        # shape of z:[batch_size,latent_dim]
        label_embedding = self.embedding(labels)
        z = torch.cat([z, label_embedding], axis=-1)
        output = self.model(z)
        image = output.reshape(z.shape[0], *image_size)
        return image


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.embedding = nn.Embedding(10, label_emb_dim)
        self.model = nn.Sequential(
            nn.Linear(np.prod(image_size, dtype=np.int32) + label_emb_dim, 512),
            torch.nn.GELU(),
            # nn.Linear(512,256)
            nn.utils.spectral_norm(nn.Linear(512, 256)),
            nn.GELU(),
            # nn.Linear(256,128)
            nn.utils.spectral_norm(nn.Linear(256, 128)),
            nn.GELU(),
            # nn.Linear(128,64)
            nn.utils.spectral_norm(nn.Linear(128, 64)),
            nn.GELU(),
            # nn.Linear(64,32)
            nn.utils.spectral_norm(nn.Linear(64, 32)),
            nn.GELU(),
            # nn.Linear(32,1)
            nn.utils.spectral_norm(nn.Linear(32, 1)),
            nn.Sigmoid(),
        )

    def forward(self, image, labels):
        # shape of image:[batch_size,1,28,28]
        label_embedding = self.embedding(labels)
        prob = self.model(torch.cat([image.reshape(image.shape[0], -1), label_embedding], axis=-1))
        return prob


if __name__ == "__main__":
    run_code = 0
    v_transform = torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize(28),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize([0.5], [0.5])
        ]
    )
    dataset = torchvision.datasets.MNIST("mnist_data", train=True, download=True, transform=v_transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    generator = Generator()
    discriminator = Discriminator()

    g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)

    loss_fn = nn.BCELoss()
    labels_one = torch.ones(batch_size, 1)
    labels_zero = torch.zeros(batch_size, 1)

    if use_gpu:
        print("use gpu for trainning")
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        loss_fn = loss_fn.cuda()
        labels_one = labels_one.to("cuda")
        labels_zero = labels_zero.to("cuda")

    num_epoch = 200
    for epoch in range(num_epoch):
        for i, mini_batch in enumerate(dataloader):
            gt_images, labels = mini_batch
            z = torch.randn(batch_size, latent_dim)
            if use_gpu:
                gt_images = gt_images.to("cuda")
                z = z.to("cuda")
            pred_images = generator(z, labels)
            g_optimizer.zero_grad()

            recons_loss = torch.abs(pred_images - gt_images).mean()
            g_loss = 0.05 * recons_loss + loss_fn(discriminator(pred_images, labels), labels_one)
            g_loss.backward()
            g_optimizer.step()

            d_optimizer.zero_grad()
            real_loss = loss_fn(discriminator(gt_images, labels), labels_one)
            fake_loss = loss_fn(discriminator(pred_images, labels), labels_zero)
            d_loss = real_loss + fake_loss

            # 观察 real_loss 与 fake_loss 同时下降同时达到最小值,并且差不多大,说明D已经稳定了
            d_loss.backward()
            d_optimizer.step()

            if i % 50 == 0:
                print(f"step:{len(dataloader) * epoch + i},recons_loss:{recons_loss.item()},g_loss:{g_loss.item()},"
                      f"d_loss:{d_loss.item()},real_loss:{real_loss.item()},fake_loss:{fake_loss.item()},d_loss:{d_loss.item()}")

            if i % 800 == 0:
                image = pred_images[:16].data
                torchvision.utils.save_image(image, f"{save_dir}/image_{len(dataloader) * epoch + i}.png", nrow=4)
相关推荐
电商API_180079052478 分钟前
微店常用API:获取商品详情接口|关键字搜索商品接口|获取快递费接口-打通商品运营与用户体验的技术桥梁
大数据·服务器·人工智能·爬虫·数据挖掘
视***间12 分钟前
AI智能相机未来应用
人工智能·数码相机
加油吧zkf18 分钟前
卷积神经网络(CNN)
人工智能·深度学习·cnn
lumi.18 分钟前
前端本地存储技术笔记:localStorage 与 sessionStorage 详解
前端·javascript·笔记
蓝博AI23 分钟前
基于卷积神经网络的汽车类型识别系统,resnet50,vgg16,resnet34【pytorch框架,python代码】
人工智能·pytorch·python·神经网络·cnn
whaosoft-14333 分钟前
51c大模型~合集43
人工智能
艾莉丝努力练剑34 分钟前
【C++:继承和多态】多态加餐:面试常考——多态的常见问题11问
开发语言·c++·人工智能·面试·继承·c++进阶
TextIn智能文档云平台43 分钟前
如何提高AI处理扫描文档的精度?
人工智能·自动化
せいしゅん青春之我1 小时前
【JavaEE初阶】网络原理——TCP处理先发后至问题
java·网络·笔记·网络协议·tcp/ip·java-ee
colus_SEU1 小时前
【计算机网络笔记】第二章 应用层 (Application Layer)
笔记·计算机网络·1024程序员节