【GAN对抗性损失函数】以CycleGAN和PIX2PIX算法的对抗性损失的代码为例进行讲解

一、代码

python 复制代码
class GANLoss(nn.Module):
    """Define different GAN objectives.
    The GANLoss class abstracts away the need to create the target label tensor
    that has the same size as the input.
    """
    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
        """ Initialize the GANLoss class.
        Parameters:
            gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
            target_real_label (bool) - - label for a real image
            target_fake_label (bool) - - label of a fake image
        Note: Do not use sigmoid as the last layer of Discriminator.
        LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
        """
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'RidgeRegressionaLoss':
            self.loss = RidgeLoss1(alpha=0.1)
        elif gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode in ['wgangp']:
            self.loss = None
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)
    def get_target_tensor(self, prediction, target_is_real):
        """Create label tensors with the same size as the input.
        Parameters:
            prediction (tensor) - - tpyically the prediction from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images
        Returns:
            A label tensor filled with ground truth label, and with the size of the input
        """
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(prediction)

    def __call__(self, prediction, target_is_real):
        """Calculate loss given Discriminator's output and grount truth labels.

        Parameters:
            prediction (tensor) - - tpyically the prediction output from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            the calculated loss.
        """
        if self.gan_mode in ['lsgan', 'vanilla','RidgeRegressionaLoss']:
            target_tensor = self.get_target_tensor(prediction, target_is_real)
            loss = self.loss(prediction, target_tensor)
        elif self.gan_mode == 'wgangp':
            if target_is_real:
                loss = -prediction.mean()
            else:
                loss = prediction.mean()
        return loss

二、讲解

target_tensor.expand_as(prediction)的意思是将target_tensor张量的尺寸扩展为与prediction张量相同的尺寸。

在生成对抗网络(GAN)中,判别器的输出通常是一个张量,表示样本为真实样本的概率或得分。为了计算损失,需要创建与判别器输出相同尺寸的目标标签张量。target_tensorget_target_tensor方法中获得,表示目标标签,可以是真实样本标签或虚假样本标签。为了与判别器的输出张量进行元素级别的比较,需要将目标标签张量的尺寸扩展为与判别器输出相同的形状。

expand_as(prediction)方法是一个张量的方法,它返回一个尺寸与prediction张量相同的新张量,其中新张量的元素以target_tensor的元素进行填充或重复,以便与prediction进行逐元素比较。

通过将目标标签张量的尺寸扩展为与判别器输出相同的尺寸,可以确保在计算损失时每个生成样本或真实样本的标签都与对应的判别器输出进行比较。

相关推荐
RaidenQ2 分钟前
2024.9.20 Python模式识别新国大EE5907,PCA主成分分析,LDA线性判别分析,GMM聚类分类,SVM支持向量机
python·算法·机器学习·支持向量机·分类·聚类
Kenneth風车6 分钟前
【机器学习(九)】分类和回归任务-多层感知机 (MLP) -Sentosa_DSML社区版
人工智能·算法·低代码·机器学习·分类·数据分析·回归
fydw_71510 分钟前
PyTorch 池化层详解
人工智能·深度学习
曳渔18 分钟前
Java-数据结构-二叉树-习题(三)  ̄へ ̄
java·开发语言·数据结构·算法·链表
shark-chili28 分钟前
数据结构与算法-Trie树添加与搜索
java·数据结构·算法·leetcode
见牛羊33 分钟前
旋转矩阵乘法,自动驾驶中的点及坐标系变换推导
算法
奥利给少年34 分钟前
深度学习——管理模型的参数
人工智能·深度学习
爱数模的小云2 小时前
【华为杯】2024华为杯数模研赛E题 解题思路
算法·华为
白葵新2 小时前
PCL addLine可视化K近邻
c++·人工智能·算法·计算机视觉·3d
seanli10082 小时前
线性dp 总结&详解
算法·动态规划