1. CycleGAN的简介
pix2pix可以很好地处理匹配数据集图像转换,但是在很多情况下匹配数据集是没有的或者是很难收集到的,但是我们可以很容易的得到两个领域大量的非匹配数据。2017年有两篇非常相似的论文CycleGAN和DiscoGAN,提出了一种解决非匹配数据集的图像转换方案。而且CycleGAN在转换的过程中,只是将A领域图像的某些特性转换成B领域的一些特性,图像的其余大部分内容都没有改变。CycleGAN 能实现两个相近数据集之间的转换。
2. CycleGAN的网络结构
CycleGAN网络结构的拆分
该结构中,生成器相当于一个自编码网络,前半部分进行编码,后半部分进行解码,而且生成器G和生成器F的结构完全相同,其中生成器G负责实现由X到Y的转换,生成器F负责实现由Y到X的转换,它们的输入、输出的大小均为(batch_size, n_channel, cols, rows),判别器的输入为(batch_size, n_channel, cols, rows), 判别器的输出为(batch_size, 1, s1, s2)。
3. CycleGAN的损失函数
(1)对抗损失
对抗损失的作用是,使生成的目标领域的图像和目标领域的真实图像尽可能地接近。
(2)循环损失
循环损失的作用是,使生成的图像尽可能多的保留原始图像的内容。
在网络训练的过程中是将G和F联合起来一起训练的,Dx 和Dy 是单独进行训练的。
G-F联合网络的损失函数为:
fake_B = G_AB(real_A)
loss_GAN_AB = torch.nn.MSELoss(D_B(fake_B), valid)
fake_A = G_BA(real_B)
loss_GAN_BA = torch.nn.MSELoss(D_A(fake_A), valid)
loss_G_GAN = (loss_GAN_AB + loss_GAN_BA) / 2 # 生成器的对抗损失
recov_A = G_BA(fake_B)
loss_cycle_A = torch.nn.L1Loss(recov_A, real_A)
recov_cycle_B = G_AB(fake_A)
loss_cycle_B = torch.nn.L1Loss(recov_B, real_B)
loss_cycle = (loss_cycle_A + loss_cycle_B) / 2 # 生成器的循环损失
Loss_G = loss_G_GAN + lambda_cycle * loss_cycle
loss_real = torch.nn.MSELoss(D_A(real_A), valid)
fake_A = fake_A_buffer.push_and_pop(fake_A)
loss_fake = torch.nn.MSELoss(D_A(fake_A.detach()), fake)
loss_D_A = (loss_real + loss_fake) / 2
loss_real = torch.nn.MSELoss(D_B(real_B), valid)
fake_B = fake_B_buffer.push_and_pop(fake_B)
loss_fake = torch.nn.MSELoss(D_B(fake_B.detach(), fake)
loss_D_B = (loss_real + loss_fake) / 2
引自: