ClipReID的监督对比损失SupConLoss

下面是带有详细注释的SupConLoss类代码,这些注释解释了代码中每个部分的作用和逻辑:

python 复制代码
import torch
import torch.nn as nn

class SupConLoss(nn.Module):
    """定义一个监督对比损失类,继承自nn.Module"""
    def __init__(self, device):
        """初始化函数
        Args:
            device (torch.device): 计算设备(CPU或GPU)
        """
        super(SupConLoss, self).__init__()
        self.device = device  # 将计算设备存储为类的一个属性
        self.temperature = 1.0  # 设置温度参数,默认为1.0,用于控制相似度计算的尺度

    def forward(self, text_features, image_features, t_label, i_targets):
        """前向传播函数,计算损失
        Args:
            text_features (torch.Tensor): 文本特征张量
            image_features (torch.Tensor): 图像特征张量
            t_label (torch.Tensor): 文本特征对应的标签
            i_targets (torch.Tensor): 图像特征对应的目标标签
        Returns:
            torch.Tensor: 计算得到的损失值
        """
        # 计算批次中文本特征和图像特征的数量
        batch_size = text_features.shape[0]
        batch_size_N = image_features.shape[0]

        # 创建一个掩码矩阵,标记哪些文本和图像特征对属于相同类别
        mask = torch.eq(t_label.unsqueeze(1).expand(batch_size, batch_size_N),
                        i_targets.unsqueeze(0).expand(batch_size, batch_size_N)).float().to(self.device)

        # 计算logits矩阵,即文本和图像特征间的点积,除以温度参数以调整尺度
        logits = torch.div(torch.matmul(text_features, image_features.T), self.temperature)

        # 为了数值稳定性,从每个logits中减去其最大值(防止exp运算溢出)
        logits_max, _ = torch.max(logits, dim=1, keepdim=True)
        logits = logits - logits_max.detach()

        # 计算每对特征的指数值
        exp_logits = torch.exp(logits)

        # 计算log概率,即log(softmax(logits))
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # 计算正样本对的平均对数概率,使用掩码进行加权
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        # 计算最终损失值,为正样本对的平均对数概率的负平均值
        loss = -mean_log_prob_pos.mean()

        return loss

这个类是为了在多模态学习场景中使用,特别适合处理需要同时考虑文本和图像数据的任务,如在图像标注或跨模态检索中。通过这种损失函数,模型被训练以学习将相同类别的不同模态数据映射到相近的特征空间,从而提高任务性能。

这段代码定义了一个名为 SupConLoss 的类,它是一个深度学习中用于计算监督对比损失(Supervised Contrastive Loss)的自定义损失函数类。这种损失函数通常用于多模态学习场景,比如同时处理文本和图像的特征。让我们逐步解析这段代码的主要部分:

类定义和初始化

  • SupConLoss 继承自 nn.Module,这是 PyTorch 中所有神经网络模块的基类。
  • 在初始化方法 __init__ 中,它接受一个 device 参数,用来指定运算应该在哪个设备上进行(如CPU或GPU)。
  • self.temperature 是一个标量,用于控制损失计算中的温度参数,影响特征向量之间相似度的缩放。

前向传播 forward 方法

  • 输入参数包括 text_featuresimage_features,这两个是来自不同模态(文本和图像)的特征向量。
  • t_labeli_targets 是这些特征向量对应的标签,用于确定哪些特征向量是正样本对。
  • 计算批次大小 batch_size(文本特征的数量)和 batch_size_N(图像特征的数量)。

计算相似性矩阵和掩码

  • 创建一个掩码 mask,该掩码通过比较扩展后的 t_labeli_targets 来确定哪些文本和图像特征对应于相同的类别。
  • 计算 logits 矩阵,它表示文本特征和图像特征之间的点积(归一化通过除以温度参数)。这个矩阵度量了两个不同模态间特征的相似度。

数值稳定性

  • 通过从每行的 logits 中减去该行的最大值来防止数值爆炸。

计算概率和对数概率

  • 使用 exp_logits 计算所有特征对的指数。
  • log_prob 是计算所有对的对数概率,用于计算对数似然。

计算损失

  • mean_log_prob_pos 计算每个样本与其正样本之间的对数概率的加权平均值。
  • 最终损失是这些平均对数概率的负值的平均,即 -mean_log_prob_pos.mean()

这种损失函数在训练中鼓励来自同一类别的不同模态特征向量之间的距离更近,而来自不同类别的特征向量之间的距离更远,有助于改善多模态学习任务中的特征表示。

相关推荐
开MINI的工科男1 小时前
深蓝学院-- 量产自动驾驶中的规划控制算法 小鹏
人工智能·机器学习·自动驾驶
waterHBO2 小时前
python 爬虫 selenium 笔记
爬虫·python·selenium
编程零零七3 小时前
Python数据分析工具(三):pymssql的用法
开发语言·前端·数据库·python·oracle·数据分析·pymssql
AI大模型知识分享3 小时前
Prompt最佳实践|如何用参考文本让ChatGPT答案更精准?
人工智能·深度学习·机器学习·chatgpt·prompt·gpt-3
AIAdvocate5 小时前
Pandas_数据结构详解
数据结构·python·pandas
小言从不摸鱼5 小时前
【AI大模型】ChatGPT模型原理介绍(下)
人工智能·python·深度学习·机器学习·自然语言处理·chatgpt
FreakStudio6 小时前
全网最适合入门的面向对象编程教程:50 Python函数方法与接口-接口和抽象基类
python·嵌入式·面向对象·电子diy
redcocal8 小时前
地平线秋招
python·嵌入式硬件·算法·fpga开发·求职招聘
artificiali8 小时前
Anaconda配置pytorch的基本操作
人工智能·pytorch·python
RaidenQ8 小时前
2024.9.13 Python与图像处理新国大EE5731课程大作业,索贝尔算子计算边缘,高斯核模糊边缘,Haar小波计算边缘
图像处理·python·算法·课程设计