基于CycleGAN的图像风格转换
- 1.导入所需要的包和库:
- 2.将一个Tensor转换为图像:
- 3.数据加载:
- 4.图像变换:
- 5.加载和预处理训练数据:
- 6.定义了一个残差块:
- 7.生成器:
- 8.判断器:
- 9.数据缓存器:
- 10.执行生成器的训练步骤:
- 11.训练判别器:
- 12.损失打印,存储伪造图片:
1.导入所需要的包和库:
python
from random import randint
import numpy as np
import torch
torch.set_default_tensor_type(torch.FloatTensor)
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
import shutil
import cv2
import random
from PIL import Image
import itertools
2.将一个Tensor转换为图像:
python
def to_img(x):
out = 0.5 * (x + 1)
out = out.clamp(0, 1)
out = out.view(-1, 3, 256, 256)
return out
3.数据加载:
python
data_path = os.path.abspath('D:/XUNLJ/data')
image_size = 256
batch_size = 1
4.图像变换:
- 首先,图像会被调整到略大于原始大小,然后随机裁剪回原始大小,接着进行水平翻转,转换为张量格式,最后进行标准化处理
python
transform = transforms.Compose([transforms.Resize(int(image_size * 1.12),
Image.BICUBIC),
transforms.RandomCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
5.加载和预处理训练数据:
- 文件夹中随机选择一批A类和B类图像,应用预定义的图像变换,并将它们转换为适合神经网络输入的张量格式
python
def _get_train_data(batch_size=1):
train_a_filepath = data_path + '\\trainA\\'
train_b_filepath = data_path + '\\trainB\\'
train_a_list = os.listdir(train_a_filepath)
train_b_list = os.listdir(train_b_filepath)
train_a_result = []
train_b_result = []
numlist = random.sample(range(0, len(train_a_list)), batch_size)
for i in numlist:
a_filename = train_a_list[i]
a_img = Image.open(train_a_filepath + a_filename).convert('RGB')
res_a_img = transform(a_img)
train_a_result.append(torch.unsqueeze(res_a_img, 0))
b_filename = train_b_list[i]
b_img = Image.open(train_b_filepath + b_filename).convert('RGB')
res_b_img = transform(b_img)
train_b_result.append(torch.unsqueeze(res_b_img, 0))
return torch.cat(train_a_result, dim=0), torch.cat(train_b_result, dim=0)
6.定义了一个残差块:
- 定义了一个简单的残差块,它包含两个卷积层和实例归一化,以及ReLU激活函数
python
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
self.block_layer = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features))
def forward(self, x):
return x + self.block_layer(x)
7.生成器:
- 网络包含卷积层、下采样层、残差块和上采样层,用于将噪声输入转换为高质量的图像输出
python
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
model = [nn.ReflectionPad2d(3),
nn.Conv2d(3, 64, 7),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True)]
in_features = 64
out_features = in_features * 2
for _ in range(2):
model += [nn.Conv2d(in_features, out_features,
3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True)]
in_features = out_features
out_features = in_features*2
for _ in range(9):
model += [ResidualBlock(in_features)]
out_features = in_features // 2
for _ in range(2):
model += [nn.ConvTranspose2d(
in_features, out_features,
3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True)]
in_features = out_features
out_features = in_features // 2
model += [nn.ReflectionPad2d(3),
nn.Conv2d(64, 3, 7),
nn.Tanh()]
self.gen = nn.Sequential( * model)
def forward(self, x):
x = self.gen(x)
return x
8.判断器:
- 用于判断输入图像的真实性,含卷积层和LeakyReLU激活函数,用于从输入图像中提取特征,通过平均池化和重塑来生成一个与图像真实性相关的分数
python
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.dis = nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.InstanceNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.InstanceNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, 4, padding=1),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, 4, padding=1))
def forward(self, x):
x = self.dis(x)
return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)
9.数据缓存器:
- 用于存储和复用生成器生成的图像
python
class ReplayBuffer():
def __init__(self, max_size=50):
self.max_size = max_size
self.data = []
- 将新的数据推入缓存,并弹出旧的数据;如果缓存未满,则将数据推入缓存。如果缓存已满,则随机替换缓存中的一个数据。
python
def push_and_pop(self, data):
to_return = []
for element in data.data:
element = torch.unsqueeze(element, 0)
if len(self.data) < self.max_size:
self.data.append(element)
to_return.append(element)
else:
if random.uniform(0,1) > 0.5:
i = random.randint(0, self.max_size-1)
to_return.append(self.data[i].clone())
self.data[i] = element
else:
to_return.append(element)
return Variable(torch.cat(to_return))
- 实例化ReplayBuffer类,分别用于存储生成的A类和B类图像
python
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()
- 定义生成器网络,用于从A类图像生成B类图像
python
netG_A2B = Generator()
netG_B2A = Generator()
- 定义判别器网络,用于判断A类和B类图像的真实性
python
netD_A = Discriminator()
netD_B = Discriminator()
- 定义GAN损失函数和循环一致性损失函数
python
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
- 定义身份损失函数
python
criterion_identity = torch.nn.L1Loss()
- 定义优化器的参数
python
d_learning_rate = 3e-4 # 3e-4
- 定义生成器和判别器的学习器
python
g_learning_rate = 3e-4
optim_betas = (0.5, 0.999)
g_optimizer = optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
lr=d_learning_rate)
da_optimizer = optim.Adam(netD_A.parameters(), lr=d_learning_rate)
db_optimizer = optim.Adam(netD_B.parameters(), lr=d_learning_rate)
- 定义训练的轮数
python
num_epochs = 1000
10.执行生成器的训练步骤:
- 计算多个损失函数的值,综合考虑了图像的身份、对抗和循环一致性,来生成更真实的图像
python
same_B = netG_A2B(real_b).float()
loss_identity_B = criterion_identity(same_B, real_b) * 5.0
same_A = netG_B2A(real_a).float()
loss_identity_A = criterion_identity(same_A, real_a) * 5.0
fake_B = netG_A2B(real_a).float()
pred_fake = netD_B(fake_B).float()
loss_GAN_A2B = criterion_GAN(pred_fake, target_real)
fake_A = netG_B2A(real_b).float()
pred_fake = netD_A(fake_A).float()
loss_GAN_B2A = criterion_GAN(pred_fake, target_real)
recovered_A = netG_B2A(fake_B).float()
loss_cycle_ABA = criterion_cycle(recovered_A, real_a) * 10.0
recovered_B = netG_A2B(fake_A).float()
loss_cycle_BAB = criterion_cycle(recovered_B, real_b) * 10.0
loss_G = (loss_identity_A + loss_identity_B + loss_GAN_A2B +
loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB)
loss_G.backward()
g_optimizer.step()
11.训练判别器:
- 训练判别器A:通过计算真实图像和生成图像的对抗损失,来训练判别器以更准确地进行区分
python
da_optimizer.zero_grad()
pred_real = netD_A(real_a).float()
loss_D_real = criterion_GAN(pred_real, target_real)
fake_A = fake_A_buffer.push_and_pop(fake_A)
pred_fake = netD_A(fake_A.detach()).float()
loss_D_fake = criterion_GAN(pred_fake, target_fake)
loss_D_A = (loss_D_real + loss_D_fake) * 0.5
loss_D_A.backward()
da_optimizer.step()
训练判别器B:
python
db_optimizer.zero_grad()
pred_real = netD_B(real_b)
loss_D_real = criterion_GAN(pred_real, target_real)
fake_B = fake_B_buffer.push_and_pop(fake_B)
pred_fake = netD_B(fake_B.detach())
loss_D_fake = criterion_GAN(pred_fake, target_fake)
loss_D_B = (loss_D_real + loss_D_fake) * 0.5
loss_D_B.backward()
db_optimizer.step()
12.损失打印,存储伪造图片:
python
print('Epoch[{}],loss_G:{:.6f} ,loss_D_A:{:.6f},loss_D_B:{:.6f}'
.format(epoch, loss_G.data.item(), loss_D_A.data.item(),
loss_D_B.data.item()))
if (epoch + 1) % 20 == 0 or epoch == 0:
b_fake = to_img(fake_B.data)
a_fake = to_img(fake_A.data)
a_real = to_img(real_a.data)
b_real = to_img(real_b.data)
save_image(a_fake, '../tmp/a_fake.png')
save_image(b_fake, '../tmp/b_fake.png')
save_image(a_real, '../tmp/a_real.png')
save_image(b_real, '../tmp/b_real.png')