GAN随手笔记

文章目录

  • [1. description](#1. description)
  • [2. code](#2. code)

1. description

后续整理

GAN是生成对抗网络,主要由G生成器,D判别器组成,具体形式如下

  • D 判别器:
  • G生成器:

2. code

部分源码,暂定,后续修改

python 复制代码
import numpy as np
import os
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader, Dataset

import torch.cuda

image_size = [1, 28, 28]
latent_dim = 96
label_emb_dim = 32
batch_size = 64
use_gpu = torch.cuda.is_available()
save_dir = "cgan_images"
os.makedirs(save_dir, exist_ok=True)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.embedding = nn.Embedding(10, label_emb_dim)
        self.model = nn.Sequential(
            nn.Linear(label_emb_dim + label_emb_dim, 128),
            nn.BatchNorm1d(128),
            nn.GELU(),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.GELU(),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.GELU(),
            nn.Linear(1024, np.prod(image_size, dtype=np.int32)),
            nn.Sigmoid(),
        )

    def forward(self, z, labels):
        # shape of z:[batch_size,latent_dim]
        label_embedding = self.embedding(labels)
        z = torch.cat([z, label_embedding], axis=-1)
        output = self.model(z)
        image = output.reshape(z.shape[0], *image_size)
        return image


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.embedding = nn.Embedding(10, label_emb_dim)
        self.model = nn.Sequential(
            nn.Linear(np.prod(image_size, dtype=np.int32) + label_emb_dim, 512),
            torch.nn.GELU(),
            # nn.Linear(512,256)
            nn.utils.spectral_norm(nn.Linear(512, 256)),
            nn.GELU(),
            # nn.Linear(256,128)
            nn.utils.spectral_norm(nn.Linear(256, 128)),
            nn.GELU(),
            # nn.Linear(128,64)
            nn.utils.spectral_norm(nn.Linear(128, 64)),
            nn.GELU(),
            # nn.Linear(64,32)
            nn.utils.spectral_norm(nn.Linear(64, 32)),
            nn.GELU(),
            # nn.Linear(32,1)
            nn.utils.spectral_norm(nn.Linear(32, 1)),
            nn.Sigmoid(),
        )

    def forward(self, image, labels):
        # shape of image:[batch_size,1,28,28]
        label_embedding = self.embedding(labels)
        prob = self.model(torch.cat([image.reshape(image.shape[0], -1), label_embedding], axis=-1))
        return prob


if __name__ == "__main__":
    run_code = 0
    v_transform = torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize(28),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize([0.5], [0.5])
        ]
    )
    dataset = torchvision.datasets.MNIST("mnist_data", train=True, download=True, transform=v_transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    generator = Generator()
    discriminator = Discriminator()

    g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)

    loss_fn = nn.BCELoss()
    labels_one = torch.ones(batch_size, 1)
    labels_zero = torch.zeros(batch_size, 1)

    if use_gpu:
        print("use gpu for trainning")
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        loss_fn = loss_fn.cuda()
        labels_one = labels_one.to("cuda")
        labels_zero = labels_zero.to("cuda")

    num_epoch = 200
    for epoch in range(num_epoch):
        for i, mini_batch in enumerate(dataloader):
            gt_images, labels = mini_batch
            z = torch.randn(batch_size, latent_dim)
            if use_gpu:
                gt_images = gt_images.to("cuda")
                z = z.to("cuda")
            pred_images = generator(z, labels)
            g_optimizer.zero_grad()

            recons_loss = torch.abs(pred_images - gt_images).mean()
            g_loss = 0.05 * recons_loss + loss_fn(discriminator(pred_images, labels), labels_one)
            g_loss.backward()
            g_optimizer.step()

            d_optimizer.zero_grad()
            real_loss = loss_fn(discriminator(gt_images, labels), labels_one)
            fake_loss = loss_fn(discriminator(pred_images, labels), labels_zero)
            d_loss = real_loss + fake_loss

            # 观察 real_loss 与 fake_loss 同时下降同时达到最小值,并且差不多大,说明D已经稳定了
            d_loss.backward()
            d_optimizer.step()

            if i % 50 == 0:
                print(f"step:{len(dataloader) * epoch + i},recons_loss:{recons_loss.item()},g_loss:{g_loss.item()},"
                      f"d_loss:{d_loss.item()},real_loss:{real_loss.item()},fake_loss:{fake_loss.item()},d_loss:{d_loss.item()}")

            if i % 800 == 0:
                image = pred_images[:16].data
                torchvision.utils.save_image(image, f"{save_dir}/image_{len(dataloader) * epoch + i}.png", nrow=4)
相关推荐
JNU freshman42 分钟前
计算机视觉 之 数字图像处理基础(一)
人工智能·计算机视觉
鹧鸪云光伏1 小时前
鹧鸪云重构光伏发电量预测的精度标准
人工智能·无人机·光伏·光伏设计·光伏模拟
九章云极AladdinEdu1 小时前
摩尔线程MUSA架构深度调优指南:从CUDA到MUSA的显存访问模式重构原则
人工智能·pytorch·深度学习·机器学习·语言模型·tensorflow·gpu算力
IT信息技术学习圈1 小时前
AI交互中的礼貌用语:“谢谢“的效用与代价分析
人工智能·交互
WarPigs3 小时前
游戏框架笔记
笔记·游戏·架构
机器之心3 小时前
马斯克Grok这个二次元「小姐姐」,攻陷了整个互联网
人工智能
szxinmai主板定制专家4 小时前
基于光栅传感器+FPGA+ARM的测量控制解决方案
arm开发·人工智能·嵌入式硬件·fpga开发
金心靖晨4 小时前
redis汇总笔记
数据库·redis·笔记
Guheyunyi4 小时前
电气安全监测系统:筑牢电气安全防线
大数据·运维·网络·人工智能·安全·架构
三桥君4 小时前
在AI应用中Prompt撰写重要却难掌握,‘理解模型与行业知识是关键’:提升迫在眉睫
人工智能·ai·系统架构·prompt·产品经理·三桥君