python
import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch
os.makedirs("images", exist_ok=True)
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()
print(opt)
img_shape = (opt.channels, opt.img_size, opt.img_size)
cuda = True if torch.cuda.is_available() else False
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(opt.latent_dim + opt.n_classes, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, noise, labels):
# Concatenate label embedding and image to produce input
gen_input = torch.cat((self.label_emb(labels), noise), -1) #cat(64*10, 64*100)->(64,110)
img = self.model(gen_input)
img = img.view(img.size(0), *img_shape)
return img
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)
self.model = nn.Sequential(
nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512), #784+10
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1),
nn.Sigmoid(),
)
def forward(self, img, labels):
# Concatenate label embedding and image to produce input
d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
validity = self.model(d_in)
return validity
# Loss functions
adversarial_loss = torch.nn.MSELoss()
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
if cuda:
generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
# Configure data loader
os.makedirs("../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
"../data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
)
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
def sample_image(n_row, batches_done):
"""Saves a grid of generated digits ranging from 0 to n_classes"""
# Sample noise
z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
# Get labels ranging from 0 to n_classes for n rows
labels = np.array([num for _ in range(n_row) for num in range(n_row)])
labels = Variable(LongTensor(labels))
gen_imgs = generator(z, labels)
save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)
# ----------
# Training
# ----------
for epoch in range(opt.n_epochs):
for i, (imgs, labels) in enumerate(dataloader):
batch_size = imgs.shape[0]
# Adversarial ground truths
valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)
# Configure input
real_imgs = Variable(imgs.type(FloatTensor))
labels = Variable(labels.type(LongTensor))
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Sample noise and labels as generator input
z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))#0到9,生成64个
# Generate a batch of images
gen_imgs = generator(z, gen_labels)
# Loss measures generator's ability to fool the discriminator
validity = discriminator(gen_imgs, gen_labels)
g_loss = adversarial_loss(validity, valid)
g_loss.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Loss for real images
validity_real = discriminator(real_imgs, labels)
d_real_loss = adversarial_loss(validity_real, valid)
# Loss for fake images
validity_fake = discriminator(gen_imgs.detach(), gen_labels)
d_fake_loss = adversarial_loss(validity_fake, fake)
# Total discriminator loss
d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.backward()
optimizer_D.step()
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
)
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
sample_image(n_row=10, batches_done=batches_done)
- nn.Embedding(),将离散数据 转换成连续向量,比如0,1,2...离散的数字,用下面的向量表示。
python
torch.tensor([[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9],
[1.0, 1.1, 1.2],
[1.3, 1.4, 1.5]])
nn.Embedding(num_embeddings=10, embedding_dim=3) 本质上就是随机的 创建了一个 10 行 3 列的矩阵,可以理解为一个查找表(lookup table),形状是 (10, 3)。可以将0,1,2...9转化成10个向量。可以通过embedding.weight
查看这些向量,同时还可以通过索引找到第几行:如下图,这会返回 embedding.weight
矩阵中的第 2 行和第 5 行。
python
input_ids = torch.tensor([2, 5]) # 选择第 2 和第 5 号索引
output = embedding(input_ids)
print(output)
第一个参数num_embeddings意思是有多少个类别,可能后期要用batch_size个向量,不过这些向量永远是这10类,第二个参数embedding_dim就是一个向量有几个数(维度)。
- 生成器的输入是两个:
python
z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))
python
z.shape = (64, 100) (随机噪声,64 个 100 维向量)
gen_labels.shape = (64,) (64 个类别索引,每个值在 0~9 之间)
shape of z = torch.Size([64, 100])
shape of real_imgs = torch.Size([64, 1, 28, 28])
z = tensor([[-0.2021, -0.6528, -0.6111, ..., -0.5988, 0.0187, 0.8311],
[ 0.2402, 1.1745, 0.4431, ..., -0.1055, -0.1356, -0.5389],
[-0.8425, -1.3124, 0.9545, ..., 0.8020, -0.1754, -0.5615],
...,
[ 0.2027, -0.8791, -0.9138, ..., 1.0122, -1.0658, 1.1842],
[ 0.5115, -0.1609, 0.0903, ..., 1.3818, 1.7254, 0.6183],
[ 1.4386, 0.0568, -0.8814, ..., 0.8862, 0.3396, 0.8465]],
device='cuda:0')
gen_labels = tensor([3, 9, 8, 8, 3, 6, 0, 9, 4, 2, 2, 2, 5, 5, 0, 8, 1, 0, 9, 3, 8, 1, 7, 7,
8, 5, 8, 8, 4, 2, 5, 5, 0, 1, 3, 2, 3, 3, 5, 2, 9, 7, 7, 9, 3, 4, 9, 2,
3, 3, 2, 2, 8, 3, 7, 5, 8, 3, 0, 2, 1, 4, 8, 1], device='cuda:0')
接着:
python
self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)
这里 nn.Embedding(10, 10) 是一个查找表:
- 输入:类别索引 labels.shape = (64,)
- 输出:类别的嵌入向量 shape = (64, 10)
python
gen_input = torch.cat((self.label_emb(labels), noise), -1)#-1:最后1维的方向拼接
64是bitch_size,也就是说,来了64个原材料,每一个原材料都和一个向量进行拼接,向量本是随机的,本没有意义,但是他是索引3对应的向量,这个向量就为生成3提供了暗示,暗示生成器要生成3。
给个例子:(这里的label_emb用独热表示,便于观看)
python
self.label_emb(labels) = [
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # 类别 3
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0], # 类别 7
...
]
z = [
[ 0.1, -0.5, ..., 0.3], # 第 1 张图片的噪声向量
[-0.2, 0.7, ..., -0.1], # 第 2 张图片的噪声向量
...
]
gen_input = torch.cat((self.label_emb(labels), z), -1) = [
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0.1, -0.5, ..., 0.3], # (10 + 100 = 110)
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, -0.2, 0.7, ..., -0.1], # (10 + 100 = 110)
...
] # 形状: (64, 110)
接着就是和GAN一样输入生成器中:
python
img = self.model(gen_input)
- 同理,判别器,其中注意,这两句是判别器的输入,labels的形状是和上一个一样是(64,)
python
real_imgs = Variable(imgs.type(FloatTensor))#(64, 1, 28, 28)
labels = Variable(labels.type(LongTensor))#(64,)
d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
#(64, 784),(64, 10)-> (64, 794)
CGAN相较原版GAN效果有了显著提升:
