🍨 本文为🔗365天深度学习训练营 中的学习记录博客
🏡 我的环境:
- 语言环境:Python3.10.11
- 编译器:Jupyter Notebook
- 深度学习框架:Pytorch 2.0.1+cu118
- 显卡(GPU):NVIDIA GeForce RTX 4070
👉考虑到大家算力有限,这里为大家提供我已经训练好生成器模型,大家可自行下载
🚀 深度学习新人必看:
🚀 往期精彩内容:
📌 基础任务:
- 结合代码进一步了解CGAN
- 学习如何运用生成好的生成器生成指定图像
一、理论知识
条件生成对抗网络(CGAN)是在生成对抗网络(GAN)的基础上进行了一些改进。对于原始GAN的生成器而言,其生成的图像数据是随机不可预测的,因此我们无法控制网络的输出,在实际操作中的可控性不强。 针对上述原始GAN无法生成具有特定属性的图像数据的问题,Mehdi Mirza等人在2014年提出了条件生成对抗网络,通过给原始生成对抗网络中的生成器G和判别器D增加额外的条件,例如我们需要生成器G生成一张没有阴影的图像,此时判别器D就需要判断生成器所生成的图像是否是一张没有阴影的图像。条件生成对抗网络的本质是将额外添加的信息融入到生成器和判别器中,其中添加的信息可以是图像的类别、人脸表情和其他辅助信息等,旨在把无监督学习的GAN转化为有监督学习的CGAN,便于网络能够在我们的掌控下更好地进行训练。CGAN网络结构如图1所示。
![图1:条件生成对抗网络结构](img-blog.csdnimg.cn/img_convert... =500x)
由图1的网络结构可知,条件信息y作为额外的输入被引入对抗网络中,与生成器中的噪声z合并作为隐含层表达;而在判别器D中,条件信息y则与原始数据x合并作为判别函数的输入。这种改进在以后的诸多方面研究中被证明是非常有效的,也为后续的相关工作提供了积极的指导作用。
一、准备工作
python
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid
from torchsummary import summary
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
ini
device(type='cuda')
1. 导入数据
python
batch_size = 128
train_transform = transforms.Compose([
transforms.Resize(128),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])
train_dataset = datasets.ImageFolder(root='E:/Jupyter Lab/dataK/GAN-Data/3-week-data/rps/',
transform=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=6)
2. 数据可视化
👉关于ax.imshow(make_grid(images.detach(), nrow=22).permute(1, 2, 0))详解:
- ax: 这是一个matplotlib的轴对象(axis),用于在图形上放置图像。通常,它用于创建子图。
- make_grid(images.detach(), nrow=10): 这是一个函数调用。make_grid函数的作用是将一组图像拼接成一个网格。它接受两个参数:images和nrow。images是一个包含图像的张量,nrow是可选参数,表示每行显示的图像数量。在这里,它将图像进行拼接,并设置每行显示10个图像。
- permute(1, 2, 0): 这是一个张量的操作,用于交换维度的顺序。在这里,对于一个3维的张量(假设图像维度为(C,H,W),其中C是通道数,H是高度,W是宽度),permute(1, 2, 0)将把通道维度(C)移动到最后,而将高度和宽度维度(H,W)放在前面。这样做是为了符合matplotlib对图像的要求,因为matplotlib要求图像的维度为(H,W,C)。
- imshow(...): 这是matplotlib的一个函数,用于显示图像。在这里,它接受一个拼接好并且维度已经调整好的图像张量,并将其显示在之前创建的轴对象(ax)上。
python
# 可视化第一个 batch 的数据
def show_images(dl):
for images, _ in dl:
fig, ax = plt.subplots(figsize=(10, 10))
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(make_grid(images.detach(), nrow=16).permute(1, 2, 0))
break
show_images(train_loader)
二、构建模型
python
latent_dim = 100
n_classes = 3
embedding_dim = 100
1. 权重初始化
python
# 自定义权重初始化函数,用于初始化生成器和判别器的权重
def weights_init(m):
# 获取当前层的类名
classname = m.__class__.__name__
# 如果当前层是卷积层(类名中包含 'Conv' )
if classname.find('Conv') != -1:
# 使用正态分布随机初始化权重,均值为0,标准差为0.02
torch.nn.init.normal_(m.weight, 0.0, 0.02)
# 如果当前层是批归一化层(类名中包含 'BatchNorm' )
elif classname.find('BatchNorm') != -1:
# 使用正态分布随机初始化权重,均值为1,标准差为0.02
torch.nn.init.normal_(m.weight, 1.0, 0.02)
# 将偏置项初始化为全零
torch.nn.init.zeros_(m.bias)
2. 构建生成器
python
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# 定义条件标签的生成器部分,用于将标签映射到嵌入空间中
# n_classes:条件标签的总数
# embedding_dim:嵌入空间的维度
self.label_conditioned_generator = nn.Sequential(
nn.Embedding(n_classes, embedding_dim), # 使用Embedding层将条件标签映射为稠密向量
nn.Linear(embedding_dim, 16) # 使用线性层将稠密向量转换为更高维度
)
# 定义潜在向量的生成器部分,用于将噪声向量映射到图像空间中
# latent_dim:潜在向量的维度
self.latent = nn.Sequential(
nn.Linear(latent_dim, 4*4*512), # 使用线性层将潜在向量转换为更高维度
nn.LeakyReLU(0.2, inplace=True) # 使用LeakyReLU激活函数进行非线性映射
)
# 定义生成器的主要结构,将条件标签和潜在向量合并成生成的图像
self.model = nn.Sequential(
# 反卷积层1:将合并后的向量映射为64x8x8的特征图
nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),
nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8), # 批标准化
nn.ReLU(True), # ReLU激活函数
# 反卷积层2:将64x8x8的特征图映射为64x4x4的特征图
nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
nn.ReLU(True),
# 反卷积层3:将64x4x4的特征图映射为64x2x2的特征图
nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
nn.ReLU(True),
# 反卷积层4:将64x2x2的特征图映射为64x1x1的特征图
nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),
nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),
nn.ReLU(True),
# 反卷积层5:将64x1x1的特征图映射为3x64x64的RGB图像
nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),
nn.Tanh() # 使用Tanh激活函数将生成的图像像素值映射到[-1, 1]范围内
)
def forward(self, inputs):
noise_vector, label = inputs
# 通过条件标签生成器将标签映射为嵌入向量
label_output = self.label_conditioned_generator(label)
# 将嵌入向量的形状变为(batch_size, 1, 4, 4),以便与潜在向量进行合并
label_output = label_output.view(-1, 1, 4, 4)
# 通过潜在向量生成器将噪声向量映射为潜在向量
latent_output = self.latent(noise_vector)
# 将潜在向量的形状变为(batch_size, 512, 4, 4),以便与条件标签进行合并
latent_output = latent_output.view(-1, 512, 4, 4)
# 将条件标签和潜在向量在通道维度上进行合并,得到合并后的特征图
concat = torch.cat((latent_output, label_output), dim=1)
# 通过生成器的主要结构将合并后的特征图生成为RGB图像
image = self.model(concat)
return image
python
generator = Generator().to(device)
generator.apply(weights_init)
print(generator)
ini
Generator(
(label_conditioned_generator): Sequential(
(0): Embedding(3, 100)
(1): Linear(in_features=100, out_features=16, bias=True)
)
(latent): Sequential(
(0): Linear(in_features=100, out_features=8192, bias=True)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
)
(model): Sequential(
(0): ConvTranspose2d(513, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU(inplace=True)
(12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(13): Tanh()
)
)
python
from torchinfo import summary
summary(generator)
lua
=================================================================
Layer (type:depth-idx) Param #
=================================================================
Generator --
├─Sequential: 1-1 --
│ └─Embedding: 2-1 300
│ └─Linear: 2-2 1,616
├─Sequential: 1-2 --
│ └─Linear: 2-3 827,392
│ └─LeakyReLU: 2-4 --
├─Sequential: 1-3 --
│ └─ConvTranspose2d: 2-5 4,202,496
│ └─BatchNorm2d: 2-6 1,024
│ └─ReLU: 2-7 --
│ └─ConvTranspose2d: 2-8 2,097,152
│ └─BatchNorm2d: 2-9 512
│ └─ReLU: 2-10 --
│ └─ConvTranspose2d: 2-11 524,288
│ └─BatchNorm2d: 2-12 256
│ └─ReLU: 2-13 --
│ └─ConvTranspose2d: 2-14 131,072
│ └─BatchNorm2d: 2-15 128
│ └─ReLU: 2-16 --
│ └─ConvTranspose2d: 2-17 3,072
│ └─Tanh: 2-18 --
=================================================================
Total params: 7,789,308
Trainable params: 7,789,308
Non-trainable params: 0
=================================================================
python
a = torch.ones(100)
b = torch.ones(1)
b = b.long()
a = a.to(device)
b = b.to(device)
python
# generator((a,b))
3. 构建鉴别器
python
import torch
import torch.nn as nn
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
# 定义一个条件标签的嵌入层,用于将类别标签转换为特征向量
self.label_condition_disc = nn.Sequential(
nn.Embedding(n_classes, embedding_dim), # 嵌入层将类别标签编码为固定长度的向量
nn.Linear(embedding_dim, 3*128*128) # 线性层将嵌入的向量转换为与图像尺寸相匹配的特征张量
)
# 定义主要的鉴别器模型
self.model = nn.Sequential(
nn.Conv2d(6, 64, 4, 2, 1, bias=False), # 输入通道为6(包含图像和标签的通道数),输出通道为64,4x4的卷积核,步长为2,padding为1
nn.LeakyReLU(0.2, inplace=True), # LeakyReLU激活函数,带有负斜率,增加模型对输入中的负值的感知能力
nn.Conv2d(64, 64*2, 4, 3, 2, bias=False), # 输入通道为64,输出通道为64*2,4x4的卷积核,步长为3,padding为2
nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8), # 批量归一化层,有利于训练稳定性和收敛速度
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64*2, 64*4, 4, 3, 2, bias=False), # 输入通道为64*2,输出通道为64*4,4x4的卷积核,步长为3,padding为2
nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64*4, 64*8, 4, 3, 2, bias=False), # 输入通道为64*4,输出通道为64*8,4x4的卷积核,步长为3,padding为2
nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Flatten(), # 将特征图展平为一维向量,用于后续全连接层处理
nn.Dropout(0.4), # 随机失活层,用于减少过拟合风险
nn.Linear(4608, 1), # 全连接层,将特征向量映射到输出维度为1的向量
nn.Sigmoid() # Sigmoid激活函数,用于输出范围限制在0到1之间的概率值
)
def forward(self, inputs):
img, label = inputs
# 将类别标签转换为特征向量
label_output = self.label_condition_disc(label)
# 重塑特征向量为与图像尺寸相匹配的特征张量
label_output = label_output.view(-1, 3, 128, 128)
# 将图像特征和标签特征拼接在一起作为鉴别器的输入
concat = torch.cat((img, label_output), dim=1)
# 将拼接后的输入通过鉴别器模型进行前向传播,得到输出结果
output = self.model(concat)
return output
python
discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)
ini
Discriminator(
(label_condition_disc): Sequential(
(0): Embedding(3, 100)
(1): Linear(in_features=100, out_features=49152, bias=True)
)
(model): Sequential(
(0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
(3): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
(6): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(256, 512, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
(9): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
(10): LeakyReLU(negative_slope=0.2, inplace=True)
(11): Flatten(start_dim=1, end_dim=-1)
(12): Dropout(p=0.4, inplace=False)
(13): Linear(in_features=4608, out_features=1, bias=True)
(14): Sigmoid()
)
)
python
summary(discriminator)
lua
=================================================================
Layer (type:depth-idx) Param #
=================================================================
Discriminator --
├─Sequential: 1-1 --
│ └─Embedding: 2-1 300
│ └─Linear: 2-2 4,964,352
├─Sequential: 1-2 --
│ └─Conv2d: 2-3 6,144
│ └─LeakyReLU: 2-4 --
│ └─Conv2d: 2-5 131,072
│ └─BatchNorm2d: 2-6 256
│ └─LeakyReLU: 2-7 --
│ └─Conv2d: 2-8 524,288
│ └─BatchNorm2d: 2-9 512
│ └─LeakyReLU: 2-10 --
│ └─Conv2d: 2-11 2,097,152
│ └─BatchNorm2d: 2-12 1,024
│ └─LeakyReLU: 2-13 --
│ └─Flatten: 2-14 --
│ └─Dropout: 2-15 --
│ └─Linear: 2-16 4,609
│ └─Sigmoid: 2-17 --
=================================================================
Total params: 7,729,709
Trainable params: 7,729,709
Non-trainable params: 0
=================================================================
python
a = torch.ones(2,3,128,128)
b = torch.ones(2,1)
b = b.long()
a = a.to(device)
b = b.to(device)
python
c = discriminator((a,b))
c.size()
css
torch.Size([2, 1])
三、训练模型
1. 定义损失函数
python
adversarial_loss = nn.BCELoss()
def generator_loss(fake_output, label):
gen_loss = adversarial_loss(fake_output, label)
return gen_loss
def discriminator_loss(output, label):
disc_loss = adversarial_loss(output, label)
return disc_loss
2. 定义优化器
python
learning_rate = 0.0002
G_optimizer = optim.Adam(generator.parameters(), lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))
3. 训练模型
python
# 设置训练的总轮数
num_epochs = 300
# 初始化用于存储每轮训练中判别器和生成器损失的列表
D_loss_plot, G_loss_plot = [], []
# 循环进行训练
for epoch in range(1, num_epochs + 1):
# 初始化每轮训练中判别器和生成器损失的临时列表
D_loss_list, G_loss_list = [], []
# 遍历训练数据加载器中的数据
for index, (real_images, labels) in enumerate(train_loader):
# 清空判别器的梯度缓存
D_optimizer.zero_grad()
# 将真实图像数据和标签转移到GPU(如果可用)
real_images = real_images.to(device)
labels = labels.to(device)
# 将标签的形状从一维向量转换为二维张量(用于后续计算)
labels = labels.unsqueeze(1).long()
# 创建真实目标和虚假目标的张量(用于判别器损失函数)
real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))
# 计算判别器对真实图像的损失
D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)
# 从噪声向量中生成假图像(生成器的输入)
noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)
noise_vector = noise_vector.to(device)
generated_image = generator((noise_vector, labels))
# 计算判别器对假图像的损失(注意detach()函数用于分离生成器梯度计算图)
output = discriminator((generated_image.detach(), labels))
D_fake_loss = discriminator_loss(output, fake_target)
# 计算判别器总体损失(真实图像损失和假图像损失的平均值)
D_total_loss = (D_real_loss + D_fake_loss) / 2
D_loss_list.append(D_total_loss.item())
# 反向传播更新判别器的参数
D_total_loss.backward()
D_optimizer.step()
# 清空生成器的梯度缓存
G_optimizer.zero_grad()
# 计算生成器的损失
G_loss = generator_loss(discriminator((generated_image, labels)), real_target)
G_loss_list.append(G_loss.item())
# 反向传播更新生成器的参数
G_loss.backward()
G_optimizer.step()
# 打印当前轮次的判别器和生成器的平均损失
print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
(epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)),
torch.mean(torch.FloatTensor(G_loss_list))))
# 将当前轮次的判别器和生成器的平均损失保存到列表中
D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))
if epoch%10 == 0:
# 将生成的假图像保存为图片文件
save_image(generated_image.data[:50], './images/sample_%d' % epoch + '.png', nrow=5, normalize=True)
# 将当前轮次的生成器和判别器的权重保存到文件
torch.save(generator.state_dict(), './training_weights/generator_epoch_%d.pth' % (epoch))
torch.save(discriminator.state_dict(), './training_weights/discriminator_epoch_%d.pth' % (epoch))
yaml
Epoch: [1/300]: D_loss: 0.332, G_loss: 1.379
Epoch: [2/300]: D_loss: 0.127, G_loss: 3.506
Epoch: [3/300]: D_loss: 0.239, G_loss: 2.610
Epoch: [4/300]: D_loss: 0.268, G_loss: 2.575
Epoch: [5/300]: D_loss: 0.302, G_loss: 2.600
Epoch: [6/300]: D_loss: 0.372, G_loss: 1.888
Epoch: [7/300]: D_loss: 0.466, G_loss: 1.912
Epoch: [8/300]: D_loss: 0.448, G_loss: 1.827
Epoch: [9/300]: D_loss: 0.429, G_loss: 1.631
Epoch: [10/300]: D_loss: 0.474, G_loss: 1.646
........
Epoch: [291/300]: D_loss: 0.108, G_loss: 4.377
Epoch: [292/300]: D_loss: 0.108, G_loss: 4.537
Epoch: [293/300]: D_loss: 0.136, G_loss: 4.508
Epoch: [294/300]: D_loss: 0.172, G_loss: 4.463
Epoch: [295/300]: D_loss: 0.172, G_loss: 4.312
Epoch: [296/300]: D_loss: 0.255, G_loss: 4.305
Epoch: [297/300]: D_loss: 0.150, G_loss: 4.495
Epoch: [298/300]: D_loss: 0.113, G_loss: 4.437
Epoch: [299/300]: D_loss: 0.102, G_loss: 4.549
Epoch: [300/300]: D_loss: 0.122, G_loss: 4.542
四、模型分析
1. Loss 图
python
G_loss_list = [i.item() for i in G_loss_plot]
D_loss_list = [i.item() for i in D_loss_plot]
python
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100 #分辨率
plt.figure(figsize=(8,4))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_loss_list,label="G")
plt.plot(D_loss_list,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
2. 生成指定图像
👉predictions = predictions.permute(0,2,3,1)详解 这行代码是一个PyTorch(深度学习框架)中的操作,用于维度的重新排列。让我们逐步解释这行代码的意思: 假设predictions是一个PyTorch张量(tensor),它的维度为 (batch_size, height, width, channels),其中:
- batch_size:批量大小,表示张量中有多少个样本。
- height:高度,表示图像的高度(或特征图的高度)。
- width:宽度,表示图像的宽度(或特征图的宽度)。
- channels:通道数,表示图像或特征图的通道数,例如RGB图像的通道数为3。
现在,让我们来解释这行代码的操作:
python
predictions.permute(0, 2, 3, 1)
permute是PyTorch中的一个函数,用于对张量的维度进行重新排列。在这个代码中,permute函数将张量的维度进行重新排列,以得到一个新的张量。具体地说,它将原始张量中的维度按照指定的顺序进行重新排列。 参数说明:
- 0, 2, 3, 1:这是一个指定新维度顺序的元组。在这里,它表示将原始维度中的第0维移到新张量的第0维,第2维移到新张量的第1维,第3维移到新张量的第2维,最后,第1维移到新张量的第3维。
所以,假设原始张量的形状是 (batch_size, height, width, channels),通过这行代码后,新张量的形状将变为 (batch_size, width, channels, height)。 这种维度重新排列在深度学习中非常常见,尤其是在卷积神经网络(Convolutional Neural Networks,CNNs)中,因为在某些情况下,不同的层需要不同的维度排列。permute函数就是为了帮助我们方便地处理这种情况,使得在不同层之间传递数据时更加高效和便捷。
python
# 导入所需的库
from numpy.random import randint, randn
from numpy import linspace
from matplotlib import pyplot, gridspec
# 导入生成器模型
generator.load_state_dict(torch.load('./training_weights/generator_epoch_300.pth'), strict=False)
generator.eval()
interpolated = randn(100) # 生成两个潜在空间的点
# 将数据转换为torch张量并将其移至GPU(假设device已正确声明为GPU)
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)
label = 0 # 手势标签,可在0,1,2之间选择
labels = torch.ones(1) * label
labels = labels.to(device).unsqueeze(1).long()
# 使用生成器生成插值结果
predictions = generator((interpolated, labels))
predictions = predictions.permute(0,2,3,1).detach().cpu()
python
#隐藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100 #分辨率
plt.figure(figsize=(8, 3))
pred = (predictions[0, :, :, :] + 1 ) * 127.5
pred = np.array(pred)
plt.imshow(pred.astype(np.uint8))
plt.show()