GAN随手笔记

文章目录

  • [1. description](#1. description)
  • [2. code](#2. code)

1. description

后续整理

GAN是生成对抗网络,主要由G生成器,D判别器组成,具体形式如下

  • D 判别器:
  • G生成器:

2. code

部分源码,暂定,后续修改

python 复制代码
import numpy as np
import os
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader, Dataset

import torch.cuda

image_size = [1, 28, 28]
latent_dim = 96
label_emb_dim = 32
batch_size = 64
use_gpu = torch.cuda.is_available()
save_dir = "cgan_images"
os.makedirs(save_dir, exist_ok=True)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.embedding = nn.Embedding(10, label_emb_dim)
        self.model = nn.Sequential(
            nn.Linear(label_emb_dim + label_emb_dim, 128),
            nn.BatchNorm1d(128),
            nn.GELU(),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.GELU(),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.GELU(),
            nn.Linear(1024, np.prod(image_size, dtype=np.int32)),
            nn.Sigmoid(),
        )

    def forward(self, z, labels):
        # shape of z:[batch_size,latent_dim]
        label_embedding = self.embedding(labels)
        z = torch.cat([z, label_embedding], axis=-1)
        output = self.model(z)
        image = output.reshape(z.shape[0], *image_size)
        return image


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.embedding = nn.Embedding(10, label_emb_dim)
        self.model = nn.Sequential(
            nn.Linear(np.prod(image_size, dtype=np.int32) + label_emb_dim, 512),
            torch.nn.GELU(),
            # nn.Linear(512,256)
            nn.utils.spectral_norm(nn.Linear(512, 256)),
            nn.GELU(),
            # nn.Linear(256,128)
            nn.utils.spectral_norm(nn.Linear(256, 128)),
            nn.GELU(),
            # nn.Linear(128,64)
            nn.utils.spectral_norm(nn.Linear(128, 64)),
            nn.GELU(),
            # nn.Linear(64,32)
            nn.utils.spectral_norm(nn.Linear(64, 32)),
            nn.GELU(),
            # nn.Linear(32,1)
            nn.utils.spectral_norm(nn.Linear(32, 1)),
            nn.Sigmoid(),
        )

    def forward(self, image, labels):
        # shape of image:[batch_size,1,28,28]
        label_embedding = self.embedding(labels)
        prob = self.model(torch.cat([image.reshape(image.shape[0], -1), label_embedding], axis=-1))
        return prob


if __name__ == "__main__":
    run_code = 0
    v_transform = torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize(28),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize([0.5], [0.5])
        ]
    )
    dataset = torchvision.datasets.MNIST("mnist_data", train=True, download=True, transform=v_transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    generator = Generator()
    discriminator = Discriminator()

    g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)

    loss_fn = nn.BCELoss()
    labels_one = torch.ones(batch_size, 1)
    labels_zero = torch.zeros(batch_size, 1)

    if use_gpu:
        print("use gpu for trainning")
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        loss_fn = loss_fn.cuda()
        labels_one = labels_one.to("cuda")
        labels_zero = labels_zero.to("cuda")

    num_epoch = 200
    for epoch in range(num_epoch):
        for i, mini_batch in enumerate(dataloader):
            gt_images, labels = mini_batch
            z = torch.randn(batch_size, latent_dim)
            if use_gpu:
                gt_images = gt_images.to("cuda")
                z = z.to("cuda")
            pred_images = generator(z, labels)
            g_optimizer.zero_grad()

            recons_loss = torch.abs(pred_images - gt_images).mean()
            g_loss = 0.05 * recons_loss + loss_fn(discriminator(pred_images, labels), labels_one)
            g_loss.backward()
            g_optimizer.step()

            d_optimizer.zero_grad()
            real_loss = loss_fn(discriminator(gt_images, labels), labels_one)
            fake_loss = loss_fn(discriminator(pred_images, labels), labels_zero)
            d_loss = real_loss + fake_loss

            # 观察 real_loss 与 fake_loss 同时下降同时达到最小值,并且差不多大,说明D已经稳定了
            d_loss.backward()
            d_optimizer.step()

            if i % 50 == 0:
                print(f"step:{len(dataloader) * epoch + i},recons_loss:{recons_loss.item()},g_loss:{g_loss.item()},"
                      f"d_loss:{d_loss.item()},real_loss:{real_loss.item()},fake_loss:{fake_loss.item()},d_loss:{d_loss.item()}")

            if i % 800 == 0:
                image = pred_images[:16].data
                torchvision.utils.save_image(image, f"{save_dir}/image_{len(dataloader) * epoch + i}.png", nrow=4)
相关推荐
易安说AI1 分钟前
字节的野心:Trae新增MCP功能,深度测评
人工智能
De_hamster1 分钟前
1Plane的AI模块
人工智能
IT古董6 分钟前
【漫话机器学习系列】215.处理高度不平衡数据策略(Strategies For Highly Imbalanced Classes)
人工智能
V功夫兔11 分钟前
Spring_MVC 快速入门指南
java·笔记·spring·springmvc
石榴花专场11 分钟前
分类算法中one-vs-rest策略和one-vs-one 策略的区别是什么?
人工智能·python·机器学习·数据挖掘
youcans_13 分钟前
【医学影像 AI】早产儿视网膜病变国际分类(第三版)
论文阅读·人工智能·计算机视觉·医学影像·rop
神经星星16 分钟前
多主体驱动生成能力达SOTA,字节UNO模型可处理多种图像生成任务
人工智能·开源·强化学习
hello_ejb319 分钟前
聊聊Spring AI Alibaba的FeiShuDocumentReader
人工智能·python·spring
何双新20 分钟前
企业AI应用模式解析:从本地部署到混合架构
人工智能·架构
深度学习lover20 分钟前
<数据集>小船识别数据集<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·船舶识别