PyTorch实战------基于生成对抗网络生成服饰图像
-
- [0. 前言](#0. 前言)
- [1. 模型分析与数据准备](#1. 模型分析与数据准备)
- [2. 判别器](#2. 判别器)
- [3. 生成器](#3. 生成器)
- [4. 模型训练](#4. 模型训练)
- [5. 模型保存与加载](#5. 模型保存与加载)
- 相关链接
0. 前言
我们已经学习了生成对抗网络 (Generative Adversarial Network, GAN)的工作原理,接下来,将学习如何将其应用于生成其他形式的内容。在本节中,介绍使用 GAN
创建灰度图像,包括外套、衬衫、凉鞋等服饰,学习在设计生成器网络时如何镜像判别器网络。在本节中,生成器和判别器网络使用全连接层,全连接层的每个神经元都与前一层和后一层的所有神经元相连接。
1. 模型分析与数据准备
在本节中,我们将训练一个生成对抗网络 (Generative Adversarial Network, GAN)模型,生成如凉鞋、T恤、外套和包等服装的灰度图像。在使用 GAN
生成图像时,首先需要获取训练数据。然后,从零开始创建一个判别器网络。在创建生成器网络时,将镜像判别器网络的架构。最后,训练 GAN
,并使用训练好的模型来生成图像。接下来,让我们通过实现一个简单的 GAN
模型来生成灰度服装图像。
准备训练数据,以创建使用批数据的迭代器。训练集包含 60,000
张图像,在图像分类模型中,我们通常将训练集进一步划分为训练集和验证集,使用验证集的损失来判断模型参数是否已收敛,从而决定是否停止训练。但 GAN
的训练方法与传统的监督学习模型不同,由于生成样本的质量在训练过程中不断提高,判别器的训练变得越来越困难。因此,判别器网络的损失不能很好地反映模型的质量。通常评估 GAN
性能的方法是通过视觉检查,评估生成图像的质量和真实性。也可以通过与训练样本的比较来评估生成样本的质量,并使用如 Inception Score
之类的评估方法来评估 GAN
的表现。但研究表明这类评估方法存在缺陷,Inception Score
在模型比较时未能提供有用的指导。在本节中,我们将定期使用视觉检查来检查生成样本的质量,并确定何时停止训练。
python
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
transform=T.Compose([
T.ToTensor(),
T.Normalize([0.5],[0.5])])
train_set=torchvision.datasets.FashionMNIST(
root=".",
train=True,
download=True,
transform=transform)
batch_size=32
train_loader=torch.utils.data.DataLoader(
train_set,
batch_size=batch_size,
shuffle=True)
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
images, labels = next(iter(train_loader))
grid = make_grid(0.5-images/2, 8, 4)
plt.imshow(grid.numpy().transpose((1, 2, 0)),
cmap="gray_r")
plt.axis("off")
plt.show()
2. 判别器
判别器网络类似于二分类器,将样本分类为真实或虚假。
(1) 使用 PyTorch
创建判别器神经网络 D
:
python
import torch.nn as nn
device="cuda" if torch.cuda.is_available() else "cpu"
D=nn.Sequential(
nn.Linear(784, 1024), # 第一个全连接层有 784 个输入和 1,024 个输出
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 1), # 最后一个全连接层有 256 个输入和 1 个输出
nn.Sigmoid()).to(device)
输入大小为 784
,因为训练集中的每张灰度图像大小为 28 × 28
像素。由于全连接层只接收一维输入,因此在将图像传递给模型之前,需要先将图像展平。输出层只有一个神经元,判别器 D
的输出是一个单一的值,使用 sigmoid
激活函数将输出压缩到 [0, 1]
范围内,以便可以将其解释为样本为真实的概率 p
,而 1 - p
则表示样本是虚假的概率。
3. 生成器
(1) 创建生成器的一种常见方法是将判别器网络中使用的架构进行镜像来创建生成器:
python
G=nn.Sequential(
nn.Linear(100, 256), # 生成器中的第一个层与判别器中的最后一层对称
nn.ReLU(),
nn.Linear(256, 512), # 生成器中的第二个层与判别器中的倒数第二层对称
nn.ReLU(),
nn.Linear(512, 1024), # 生成器中的第三个层与判别器中的倒数第三层对称
nn.ReLU(),
nn.Linear(1024, 784), # 生成器中的最后一层与判别器中的第一层对称
nn.Tanh()).to(device) # 使用 Tanh() 激活函数,使得输出值在 -1 和 1 之间,与图像中的值相同
下图展示了用于生成服装灰度图像的生成器和判别器网络的架构。如图所示,判别器的输入来自训练集的一个展平后的灰度图像(包含 28 × 28 = 784
个像素),依次通过判别器网络的四个全连接层,输出的是该图像为真实图像的概率。为了生成图像,生成器使用相同的四个全连接层,但顺序相反,从潜空间获取一个包含 100
个值的随机噪声向量,并将该向量依次通过这四个全连接层。在每一层中,判别器中的每个网络层的输入输出数目颠倒后,作为生成器中每层的输出和输入数目。最终,生成器生成一个包含 784
个值的张量,这个张量可以整形为一个 28 × 28
的灰度图像。

上图中左侧是生成器网络,右侧是判别器网络。比较这两个网络,可以看到生成器如何镜像判别器的架构。具体来说,生成器包含四个类似的全连接层,但顺序相反,生成器中的第一层镜像判别器中的最后一层,生成器中的第二层镜像判别器中的倒数第二层,依此类推。生成器的输出为一个包含 784
个值的张量,这些值在经过 Tanh()
激活函数后位于 -1
到 1
之间,这与判别器网络的输入相匹配。
(1) 判别器 D
执行的是二分类任务,因此 GAN
模型的损失函数使用二元交叉熵损失。判别器和生成器都使用 Adam
优化器,学习率为 0.0001
:
python
loss_fn=nn.BCELoss()
lr=0.0001
optimD=torch.optim.Adam(D.parameters(),lr=lr)
optimG=torch.optim.Adam(G.parameters(),lr=lr)
接下来,使用训练数据集中服装图像训练本节创建的 GAN
模型。
4. 模型训练
(1) 在本节中,依靠视觉检查来判断模型是否训练完成,为此,定义 see_output()
函数,定期可视化生成器生成的虚假图像。需要注意的是,虽然我们可以使用 PyTorch
实现 Inception Score
来评估 GAN
,但由于 Inception Score
评估方法的低效性,并不推荐使用 Inception Score
来评估生成模型:
python
import matplotlib.pyplot as plt
def see_output():
noise=torch.randn(32,100).to(device=device)
fake_samples=G(noise).cpu().detach() # 生成 32 张虚假图像
plt.figure(dpi=100,figsize=(20,10))
for i in range(32):
ax=plt.subplot(4, 8, i + 1)
img=(fake_samples[i]/2+0.5).reshape(28, 28)
plt.imshow(img) # 图像可视化
plt.xticks([])
plt.yticks([])
plt.show()
see_output()
运行代码,可以看到生成的图像如下所示,它们完全不像服装,因为生成器还未经过训练。

(2) 为了训练 GAN
模型,定义函数:train_D_on_real()
、train_D_on_fake()
和 train_G()
:
python
real_labels=torch.ones((batch_size,1)).to(device)
fake_labels=torch.zeros((batch_size,1)).to(device)
def train_D_on_real(real_samples):
r=real_samples.reshape(-1,28*28).to(device)
out_D=D(r)
labels=torch.ones((r.shape[0],1)).to(device)
loss_D=loss_fn(out_D,labels)
optimD.zero_grad()
loss_D.backward()
optimD.step()
return loss_D
def train_D_on_fake():
noise=torch.randn(batch_size,100).to(device=device)
generated_data=G(noise)
preds=D(generated_data)
loss_D=loss_fn(preds,fake_labels)
optimD.zero_grad()
loss_D.backward()
optimD.step()
return loss_D
def train_G():
noise=torch.randn(batch_size,100).to(device=device)
generated_data=G(noise)
preds=D(generated_data)
loss_G=loss_fn(preds,real_labels)
optimG.zero_grad()
loss_G.backward()
optimG.step()
return loss_G
(3) 接下来,训练模型,遍历训练数据集中的所有批数据。对于每个批数据,首先使用真实样本训练判别器。之后,生成器生成一批虚假样本,用这些虚假样本再次训练判别器。最后,使用生成器再次生成一批虚假样本,用它们来训练生成器。训练模型 50
个 epoch
,生成结果如下所示:
python
for i in range(50):
gloss=0
dloss=0
for n, (real_samples,_) in enumerate(train_loader):
loss_D=train_D_on_real(real_samples) # 使用真实样本训练判别器
dloss+=loss_D
loss_D=train_D_on_fake() # 使用虚假样本训练判别器
dloss+=loss_D
loss_G=train_G() # 训练生成器
gloss+=loss_G
gloss=gloss/n
dloss=dloss/n
# 每隔 10 个 epoch 可视化生成图像
if i % 10 == 9:
print(f"at epoch {i+1}, dloss: {dloss}, gloss {gloss}")
see_output()

每训练 10
个 epoch
,可视化生成的服装,如上图所示。经过 10
个 epoch
的训练后,模型已经能够生成明显可以作为真实服装的图像,能够明显的辨别出图像的外形,随着训练的进行,生成的图像质量越来越好。
5. 模型保存与加载
(1) 丢弃判别器,并保存训练好的生成器,以便生成样本:
python
import os
scripted = torch.jit.script(G)
os.makedirs("files", exist_ok=True)
scripted.save('files/fashion_gen.pt')
(2) 将生成器保存在本地文件夹中后,要使用生成器,只需加载模型:
python
new_G=torch.jit.load('files/fashion_gen.pt',
map_location=device)
new_G.eval()
(3) 生成器加载完成后,将其用于生成服装图像:
python
noise=torch.randn(batch_size,100).to(device=device)
fake_samples=new_G(noise).cpu().detach()
for i in range(32):
ax = plt.subplot(4, 8, i + 1)
plt.imshow((fake_samples[i]/2+0.5).reshape(28, 28))
plt.xticks([])
plt.yticks([])
plt.subplots_adjust(hspace=-0.6)
plt.show()
生成的服装如下图所示,可以看到,生成的服装与训练集中的服装非常接近。

相关链接
PyTorch生成式人工智能实战:从零打造创意引擎
PyTorch实战(1)------神经网络与模型训练过程详解
PyTorch实战(2)------PyTorch基础
PyTorch实战(3)------使用PyTorch构建神经网络
PyTorch实战(4)------卷积神经网络详解
PyTorch实战(5)------分类任务详解
PyTorch实战(6)------生成模型(Generative Model)详解
PyTorch实战(7)------生成对抗网络实践详解
PyTorch实战------生成对抗网络数值数据生成