1、CycleGAN

1、CycleGAN

CycleGAN论文链接:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks

CycleGAN 是一种流行的深度学习模型,用于图像到图像的转换任务,且不需要成对的数据。在介绍CycleGAN之前,必须对于传统的GAN模型有了解。

一、关于GAN

GAN是对抗生成网络的缩写,对抗生成网络简单来说由两部分组成,分别是生成器和判别器

  1. 生成器(Generator): 生成器的任务是从随机噪声中生成虚假的数据(例如图片),以欺骗判别器,让判别器认为这些数据是真实的。生成器通过不断优化,提高生成图像的质量,使其越来越接近真实数据的分布。
  2. 判别器(Discriminator): 判别器的任务是区分真实数据和生成器生成的假数据。它通过判断输入的图像是真实的还是生成的来给出一个概率,并通过不断优化来提高识别假图像的能力。
  3. 对抗训练: 生成器和判别器在训练过程中互相对抗。生成器试图生成更逼真的数据来欺骗判别器,而判别器则努力提高识别能力。这个过程可以视为一个博弈过程,最终生成器能够生成与真实数据非常接近的高质量数据。

其实在GAN中,最重要的就是他的损失函数的构建,有了损失函数,就可以进行参数更新,使生成器和判别器的能力不断的提升,使生成器达到以假乱真的程度。

损失函数由生成器损失函数和判别器损失函数组成,对于生成器的损失函数来说,它希望生成出来的图像能够瞒过判别器 ,所以计算损失函数时,将自己的标签全部设置为1,而对于判别器来说,就要将生成器生成的图像判别为假,将真实正确的判别为真,两者求平均作为损失函数值,来进行梯度更新。

二、CycleGAN

CycleGAN 是一种流行的深度学习模型,用于图像到图像的转换任务,改变图片的风格,换脸等,其次它最大的特点就是需要训练的数据集不需要是成对出现,仅仅需要两个风格图片,就可以完成一张图片的风格转换。

CycleGAN架构图

对于CycleGAN,生成器中由六个损失值来控制参数更新,而判别器由两个损失值来控制参数更新

源码思路:

生成器参数更新:

  • 首先由生成器先初始化随机向量Input_B,然后通过学习由Generator_B2A生成Generated_A,计算Input_A和Generated_A的Loss值,在通过Generator_A2B 生成Cyclic_B,然后计算一个Cycle_Loss值。
  • 同时在源码中加入Identify_Loss值,逐点计算Loss值(其中就是将Input_B输入到Generator_A2B,本身B就是正确的,通过Generator_A2B生成identify_B,应该和Input_B越相似越好),Identify_Loss值也加入了参数更新Loss值中。

判别器参数更新:

​ 对于判别器,思路与传统的GAN的Loss值计算是一样的,也是计算将正确的预测为正确的loss值,加上生成的预测为错误的loss值做平均的值。但是具体的计算方法采用PatchGAN的计算思路。

局部判断(Patch-based Discrimination):

PatchGAN 将输入图像划分为多个小的感受野区域(patch),而不是对整张图像进行判断。对于每个感受野区域,PatchGAN 判别器会输出一个结果,判断该区域是真实还是生成的。

判别器的输出结果是一个大小为 N× N的矩阵,每个值对应图像中某个小区域的判别结果。

感受野的作用:

PatchGAN 的感受野是判别器看到的图像的局部区域大小,通常比整张图像小得多。感受野大小取决于网络的架构(如卷积核大小和网络层数)。每个感受野对应输出矩阵中的一个值,该值表示该局部区域是否来自真实图像。通过这种局部判断,PatchGAN 只需要判断局部模式是否与真实数据一致。

损失计算:

  • 预测结果: PatchGAN 的输出是一个 N×N 的矩阵,其中每个元素表示一个局部区域(感受野)的分类结果,即该区域是真实图像的概率。
  • 标签设置: 当计算损失时,真实图像的标签通常设置为一个同样大小为 N×N的矩阵,所有值为 1(表示这些区域都应该被判定为真实)。对于生成的假图像,则期望该矩阵的标签为 0。
  • 损失函数: 判别器的损失是基于这些感受野区域的输出值和标签矩阵的差异来计算的,可以使用二元交叉熵或其他适当的损失函数。
生成器参数更新
  1. 首先由生成器先初始化随机向量Input_B,然后通过学习由Generator_B2A生成Generated_A,计算Input_A和Generated_A的Loss值,在通过Generator_A2B 生成Cyclic_B,然后计算一个Cycle_Loss值。
  2. 同时在源码中加入Identify_Loss值,逐点计算Loss值(其中就是将Input_B输入到Generator_A2B,本身B就是正确的,通过Generator_A2B生成identify_B,应该和Input_B越相似越好),Identify_Loss值也加入了参数更新Loss值中。

源码解析:

python 复制代码
class CycleGANModel(BaseModel):
    def backward_G(self):
        """计算生成器 G_A 和 G_B 的损失函数"""

        # 获取参数 lambda_identity, lambda_A, lambda_B
        # lambda_identity 控制身份损失项(identity loss)的权重
        # lambda_A 和 lambda_B 控制循环一致性损失项(cycle loss)的权重
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B

        # 如果 lambda_idt > 0,计算身份损失(identity loss)
        # 身份损失确保生成器在输入与输出相同图像时不会改变输入
        if lambda_idt > 0:
            # 对于生成器 G_A 来说,当输入真实图像 B 时,输出应为 B,即 G_A(B) ≈ B
            # ||G_A(B) - B|| 的损失
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt

            # 对于生成器 G_B 来说,当输入真实图像 A 时,输出应为 A,即 G_B(A) ≈ A
            # ||G_B(A) - A|| 的损失
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            # 如果 lambda_idt = 0,跳过身份损失,设为 0
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # 生成器 G_A 的 GAN 损失:试图让判别器 D_A 认为生成的图像 fake_B 是真实的
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)

        # 生成器 G_B 的 GAN 损失:试图让判别器 D_B 认为生成的图像 fake_A 是真实的
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)

        # 正向循环一致性损失:确保 G_B(G_A(A)) ≈ A,生成器 G_A 和 G_B 生成的图像应能还原到原图像
        # || G_B(G_A(A)) - A||
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A

        # 反向循环一致性损失:确保 G_A(G_B(B)) ≈ B
        # || G_A(G_B(B)) - B||
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B

        # 综合生成器的所有损失,包括 GAN 损失、循环一致性损失和身份损失
        self.loss_G = (self.loss_G_A + self.loss_G_B + 
                       self.loss_cycle_A + self.loss_cycle_B + 
                       self.loss_idt_A + self.loss_idt_B)

        # 计算梯度并进行反向传播,以更新生成器的参数
        self.loss_G.backward()
判别器参数更新

对于判别器,思路与传统的GAN的Loss值计算是一样的,也是计算将正确的预测为正确的loss值,加上生成的预测为错误的loss值做平均的值。但是具体的计算方法采用PatchGAN的计算思路。

局部判断(Patch-based Discrimination):

PatchGAN 将输入图像划分为多个小的感受野区域(patch),而不是对整张图像进行判断。对于每个感受野区域,PatchGAN 判别器会输出一个结果,判断该区域是真实还是生成的。

判别器的输出结果是一个大小为 N× N的矩阵,每个值对应图像中某个小区域的判别结果。
感受野的作用:

PatchGAN 的感受野是判别器看到的图像的局部区域大小,通常比整张图像小得多。感受野大小取决于网络的架构(如卷积核大小和网络层数)。每个感受野对应输出矩阵中的一个值,该值表示该局部区域是否来自真实图像。通过这种局部判断,PatchGAN 只需要判断局部模式是否与真实数据一致。

损失计算:

  • 预测结果: PatchGAN 的输出是一个 N×N 的矩阵,其中每个元素表示一个局部区域(感受野)的分类结果,即该区域是真实图像的概率。
  • 标签设置: 当计算损失时,真实图像的标签通常设置为一个同样大小为 N×N的矩阵,所有值为 1(表示这些区域都应该被判定为真实)。对于生成的假图像,则期望该矩阵的标签为 0。
  • 损失函数: 判别器的损失是基于这些感受野区域的输出值和标签矩阵的差异来计算的,可以使用二元交叉熵或其他适当的损失函数。
python 复制代码
def backward_D_basic(self, netD, real, fake):
    """计算判别器的 GAN 损失

    参数:
        netD (network)      -- 判别器网络 D
        real (tensor array) -- 真实图像
        fake (tensor array) -- 生成器生成的假图像

    返回判别器的损失。
    同时,我们调用 loss_D.backward() 来计算梯度。
    """
    # 计算对真实图像的判别结果
    pred_real = netD(real)
    # 判别器的真实图像损失,期望判别器将真实图像判定为真实 (标签为 True)
    loss_D_real = self.criterionGAN(pred_real, True)

    # 对生成图像的判别结果,使用 .detach() 确保生成图像的梯度不会传回生成器
    pred_fake = netD(fake.detach())
    # 判别器的假图像损失,期望判别器将生成的假图像判定为假 (标签为 False)
    loss_D_fake = self.criterionGAN(pred_fake, False)

    # 将真实图像损失和假图像损失合并,取平均值作为判别器的总损失
    loss_D = (loss_D_real + loss_D_fake) * 0.5

    # 反向传播,计算损失相对于判别器权重的梯度
    loss_D.backward()

    # 返回判别器的总损失
    return loss_D
相关推荐
小言从不摸鱼1 小时前
【AI大模型】ChatGPT模型原理介绍(下)
人工智能·python·深度学习·机器学习·自然语言处理·chatgpt
artificiali4 小时前
Anaconda配置pytorch的基本操作
人工智能·pytorch·python
酱香编程,风雨兼程5 小时前
深度学习——基础知识
人工智能·深度学习
#include<菜鸡>6 小时前
动手学深度学习(pytorch土堆)-04torchvision中数据集的使用
人工智能·pytorch·深度学习
拓端研究室TRL6 小时前
TensorFlow深度学习框架改进K-means聚类、SOM自组织映射算法及上海招生政策影响分析研究...
深度学习·算法·tensorflow·kmeans·聚类
chnyi6_ya7 小时前
深度学习的笔记
服务器·人工智能·pytorch
i嗑盐の小F7 小时前
【IEEE出版,高录用 | EI快检索】第二届人工智能与自动化控制国际学术会议(AIAC 2024,10月25-27)
图像处理·人工智能·深度学习·算法·自然语言处理·自动化
卡卡大怪兽8 小时前
深度学习:数据集处理简单记录
人工智能·深度学习
菜就多练_08288 小时前
《深度学习》深度学习 框架、流程解析、动态展示及推导
人工智能·深度学习