CGAN代码

python 复制代码
import argparse
import os
import numpy as np

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch

os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim + opt.n_classes, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        # Concatenate label embedding and image to produce input
        gen_input = torch.cat((self.label_emb(labels), noise), -1)  #cat(64*10, 64*100)->(64,110)
        img = self.model(gen_input)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)

        self.model = nn.Sequential(
            nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),	#784+10
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
            nn.Sigmoid(),
        )

    def forward(self, img, labels):
        # Concatenate label embedding and image to produce input
        d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)	
        validity = self.model(d_in)
        return validity


# Loss functions
adversarial_loss = torch.nn.MSELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Configure data loader
os.makedirs("../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor


def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = Variable(LongTensor(labels))
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)


# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):

        batch_size = imgs.shape[0]

        # Adversarial ground truths
        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(FloatTensor))
        labels = Variable(labels.type(LongTensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise and labels as generator input
        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
        gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))#0到9,生成64个      				

        # Generate a batch of images
        gen_imgs = generator(z, gen_labels)

        # Loss measures generator's ability to fool the discriminator
        validity = discriminator(gen_imgs, gen_labels)
        g_loss = adversarial_loss(validity, valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Loss for real images
        validity_real = discriminator(real_imgs, labels)
        d_real_loss = adversarial_loss(validity_real, valid)

        # Loss for fake images
        validity_fake = discriminator(gen_imgs.detach(), gen_labels)
        d_fake_loss = adversarial_loss(validity_fake, fake)

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            sample_image(n_row=10, batches_done=batches_done)
  1. nn.Embedding(),将离散数据 转换成连续向量,比如0,1,2...离散的数字,用下面的向量表示。
python 复制代码
torch.tensor([[0.1, 0.2, 0.3], 
              [0.4, 0.5, 0.6], 
              [0.7, 0.8, 0.9], 
              [1.0, 1.1, 1.2], 
              [1.3, 1.4, 1.5]])

nn.Embedding(num_embeddings=10, embedding_dim=3) 本质上就是随机的 创建了一个 10 行 3 列的矩阵,可以理解为一个查找表(lookup table),形状是 (10, 3)。可以将0,1,2...9转化成10个向量。可以通过embedding.weight查看这些向量,同时还可以通过索引找到第几行:如下图,这会返回 embedding.weight 矩阵中的第 2 行和第 5 行。

python 复制代码
input_ids = torch.tensor([2, 5])  # 选择第 2 和第 5 号索引
output = embedding(input_ids)
print(output)

第一个参数num_embeddings意思是有多少个类别,可能后期要用batch_size个向量,不过这些向量永远是这10类,第二个参数embedding_dim就是一个向量有几个数(维度)。

  1. 生成器的输入是两个:
python 复制代码
z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))
python 复制代码
z.shape = (64, 100) (随机噪声,64 个 100 维向量)
gen_labels.shape = (64,) (64 个类别索引,每个值在 0~9 之间)

shape of z = torch.Size([64, 100]) 

shape of real_imgs = torch.Size([64, 1, 28, 28]) 

z = tensor([[-0.2021, -0.6528, -0.6111,  ..., -0.5988,  0.0187,  0.8311],
        [ 0.2402,  1.1745,  0.4431,  ..., -0.1055, -0.1356, -0.5389],
        [-0.8425, -1.3124,  0.9545,  ...,  0.8020, -0.1754, -0.5615],
        ...,
        [ 0.2027, -0.8791, -0.9138,  ...,  1.0122, -1.0658,  1.1842],
        [ 0.5115, -0.1609,  0.0903,  ...,  1.3818,  1.7254,  0.6183],
        [ 1.4386,  0.0568, -0.8814,  ...,  0.8862,  0.3396,  0.8465]],
       device='cuda:0') 

gen_labels =  tensor([3, 9, 8, 8, 3, 6, 0, 9, 4, 2, 2, 2, 5, 5, 0, 8, 1, 0, 9, 3, 8, 1, 7, 7,
        8, 5, 8, 8, 4, 2, 5, 5, 0, 1, 3, 2, 3, 3, 5, 2, 9, 7, 7, 9, 3, 4, 9, 2,
        3, 3, 2, 2, 8, 3, 7, 5, 8, 3, 0, 2, 1, 4, 8, 1], device='cuda:0') 

接着:

python 复制代码
self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)

这里 nn.Embedding(10, 10) 是一个查找表:

  • 输入:类别索引 labels.shape = (64,)
  • 输出:类别的嵌入向量 shape = (64, 10)
python 复制代码
gen_input = torch.cat((self.label_emb(labels), noise), -1)#-1:最后1维的方向拼接

64是bitch_size,也就是说,来了64个原材料,每一个原材料都和一个向量进行拼接,向量本是随机的,本没有意义,但是他是索引3对应的向量,这个向量就为生成3提供了暗示,暗示生成器要生成3。

给个例子:(这里的label_emb用独热表示,便于观看)

python 复制代码
self.label_emb(labels) = [
    [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],  # 类别 3
    [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],  # 类别 7
    ...
]

z = [
    [ 0.1, -0.5, ...,  0.3],  # 第 1 张图片的噪声向量
    [-0.2,  0.7, ..., -0.1],  # 第 2 张图片的噪声向量
    ...
]

gen_input = torch.cat((self.label_emb(labels), z), -1) = [
    [0, 0, 0, 1, 0, 0, 0, 0, 0, 0,  0.1, -0.5, ...,  0.3],  # (10 + 100 = 110)
    [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, -0.2,  0.7, ..., -0.1],  # (10 + 100 = 110)
    ...
]  # 形状: (64, 110)

接着就是和GAN一样输入生成器中:

python 复制代码
img = self.model(gen_input)
  1. 同理,判别器,其中注意,这两句是判别器的输入,labels的形状是和上一个一样是(64,)
python 复制代码
        real_imgs = Variable(imgs.type(FloatTensor))#(64, 1, 28, 28)
        labels = Variable(labels.type(LongTensor))#(64,)
        
        d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)	
        #(64, 784),(64, 10)-> (64, 794)

CGAN相较原版GAN效果有了显著提升:

相关推荐
古月居GYH13 分钟前
3D Gaussian Splatting部分原理介绍和CUDA代码解读(一)——3D/2D协方差和高斯颜色的计算
人工智能·深度学习·3d
taoqick2 小时前
PyTorch DDP流程和SyncBN、ShuffleBN
人工智能·pytorch·python
Shockang3 小时前
机器学习的一百个概念(1)单位归一化
人工智能·机器学习
金融小师妹5 小时前
DeepSeek分析:汽车关税政策对黄金市场的影响评估
大数据·人工智能·汽车
p186848058105 小时前
ICFEEIE 2025 WS4:计算机视觉和自然语言处理中的深度学习模型和算法
深度学习·计算机视觉·自然语言处理
仙尊方媛5 小时前
计算机视觉准备八股中
人工智能·深度学习·计算机视觉·视觉检测
MUTA️5 小时前
《Fusion-Mamba for Cross-modality Object Detection》论文精读笔记
人工智能·深度学习·目标检测·计算机视觉·多模态融合
qp5 小时前
18.OpenCV图像卷积及其模糊滤波应用详解
人工智能·opencv·计算机视觉
徐礼昭|商派软件市场负责人5 小时前
2025年消费观念转变与行为趋势全景洞察:”抽象、符号、游戏、共益、AI”重构新世代消费价值的新范式|徐礼昭
大数据·人工智能·游戏·重构·零售·中产阶级·消费洞察
訾博ZiBo5 小时前
AI日报 - 2025年03月31日
人工智能