GAN随手笔记

文章目录

  • [1. description](#1. description)
  • [2. code](#2. code)

1. description

后续整理

GAN是生成对抗网络,主要由G生成器,D判别器组成,具体形式如下

  • D 判别器:
  • G生成器:

2. code

部分源码,暂定,后续修改

python 复制代码
import numpy as np
import os
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader, Dataset

import torch.cuda

image_size = [1, 28, 28]
latent_dim = 96
label_emb_dim = 32
batch_size = 64
use_gpu = torch.cuda.is_available()
save_dir = "cgan_images"
os.makedirs(save_dir, exist_ok=True)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.embedding = nn.Embedding(10, label_emb_dim)
        self.model = nn.Sequential(
            nn.Linear(label_emb_dim + label_emb_dim, 128),
            nn.BatchNorm1d(128),
            nn.GELU(),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.GELU(),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.GELU(),
            nn.Linear(1024, np.prod(image_size, dtype=np.int32)),
            nn.Sigmoid(),
        )

    def forward(self, z, labels):
        # shape of z:[batch_size,latent_dim]
        label_embedding = self.embedding(labels)
        z = torch.cat([z, label_embedding], axis=-1)
        output = self.model(z)
        image = output.reshape(z.shape[0], *image_size)
        return image


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.embedding = nn.Embedding(10, label_emb_dim)
        self.model = nn.Sequential(
            nn.Linear(np.prod(image_size, dtype=np.int32) + label_emb_dim, 512),
            torch.nn.GELU(),
            # nn.Linear(512,256)
            nn.utils.spectral_norm(nn.Linear(512, 256)),
            nn.GELU(),
            # nn.Linear(256,128)
            nn.utils.spectral_norm(nn.Linear(256, 128)),
            nn.GELU(),
            # nn.Linear(128,64)
            nn.utils.spectral_norm(nn.Linear(128, 64)),
            nn.GELU(),
            # nn.Linear(64,32)
            nn.utils.spectral_norm(nn.Linear(64, 32)),
            nn.GELU(),
            # nn.Linear(32,1)
            nn.utils.spectral_norm(nn.Linear(32, 1)),
            nn.Sigmoid(),
        )

    def forward(self, image, labels):
        # shape of image:[batch_size,1,28,28]
        label_embedding = self.embedding(labels)
        prob = self.model(torch.cat([image.reshape(image.shape[0], -1), label_embedding], axis=-1))
        return prob


if __name__ == "__main__":
    run_code = 0
    v_transform = torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize(28),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize([0.5], [0.5])
        ]
    )
    dataset = torchvision.datasets.MNIST("mnist_data", train=True, download=True, transform=v_transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    generator = Generator()
    discriminator = Discriminator()

    g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)

    loss_fn = nn.BCELoss()
    labels_one = torch.ones(batch_size, 1)
    labels_zero = torch.zeros(batch_size, 1)

    if use_gpu:
        print("use gpu for trainning")
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        loss_fn = loss_fn.cuda()
        labels_one = labels_one.to("cuda")
        labels_zero = labels_zero.to("cuda")

    num_epoch = 200
    for epoch in range(num_epoch):
        for i, mini_batch in enumerate(dataloader):
            gt_images, labels = mini_batch
            z = torch.randn(batch_size, latent_dim)
            if use_gpu:
                gt_images = gt_images.to("cuda")
                z = z.to("cuda")
            pred_images = generator(z, labels)
            g_optimizer.zero_grad()

            recons_loss = torch.abs(pred_images - gt_images).mean()
            g_loss = 0.05 * recons_loss + loss_fn(discriminator(pred_images, labels), labels_one)
            g_loss.backward()
            g_optimizer.step()

            d_optimizer.zero_grad()
            real_loss = loss_fn(discriminator(gt_images, labels), labels_one)
            fake_loss = loss_fn(discriminator(pred_images, labels), labels_zero)
            d_loss = real_loss + fake_loss

            # 观察 real_loss 与 fake_loss 同时下降同时达到最小值,并且差不多大,说明D已经稳定了
            d_loss.backward()
            d_optimizer.step()

            if i % 50 == 0:
                print(f"step:{len(dataloader) * epoch + i},recons_loss:{recons_loss.item()},g_loss:{g_loss.item()},"
                      f"d_loss:{d_loss.item()},real_loss:{real_loss.item()},fake_loss:{fake_loss.item()},d_loss:{d_loss.item()}")

            if i % 800 == 0:
                image = pred_images[:16].data
                torchvision.utils.save_image(image, f"{save_dir}/image_{len(dataloader) * epoch + i}.png", nrow=4)
相关推荐
opentrending2 小时前
Github 热点项目 awesome-mcp-servers MCP 服务器合集,3分钟实现AI模型自由操控万物!
服务器·人工智能·github
lisw053 小时前
DeepSeek原生稀疏注意力(Native Sparse Attention, NSA)算法介绍
人工智能·深度学习·算法
whaosoft-1433 小时前
51c深度学习~合集4
人工智能
逢生博客3 小时前
阿里 FunASR 开源中文语音识别大模型应用示例(准确率比faster-whisper高)
人工智能·python·语音识别·funasr
Qwertyuiop20164 小时前
搭建开源笔记平台:outline
笔记·开源
哲讯智能科技4 小时前
智慧能源新篇章:SAP如何赋能光伏行业数字化转型
大数据·人工智能
云卓SKYDROID4 小时前
无人机DSP处理器工作要点!
人工智能·无人机·科普·云卓科技
gang_unerry4 小时前
量子退火与机器学习(2):少量实验即可找到新材料,黑盒优化➕量子退火
人工智能·机器学习·量子计算·量子退火
訾博ZiBo4 小时前
AI日报 - 2025年4月2日
人工智能
说私域4 小时前
消费品行业创新创业中品类创新与数字化工具的融合:以开源 AI 智能客服、AI 智能名片及 S2B2C 商城小程序为例
人工智能·小程序·开源