下面是带有详细注释的SupConLoss
类代码,这些注释解释了代码中每个部分的作用和逻辑:
python
import torch
import torch.nn as nn
class SupConLoss(nn.Module):
"""定义一个监督对比损失类,继承自nn.Module"""
def __init__(self, device):
"""初始化函数
Args:
device (torch.device): 计算设备(CPU或GPU)
"""
super(SupConLoss, self).__init__()
self.device = device # 将计算设备存储为类的一个属性
self.temperature = 1.0 # 设置温度参数,默认为1.0,用于控制相似度计算的尺度
def forward(self, text_features, image_features, t_label, i_targets):
"""前向传播函数,计算损失
Args:
text_features (torch.Tensor): 文本特征张量
image_features (torch.Tensor): 图像特征张量
t_label (torch.Tensor): 文本特征对应的标签
i_targets (torch.Tensor): 图像特征对应的目标标签
Returns:
torch.Tensor: 计算得到的损失值
"""
# 计算批次中文本特征和图像特征的数量
batch_size = text_features.shape[0]
batch_size_N = image_features.shape[0]
# 创建一个掩码矩阵,标记哪些文本和图像特征对属于相同类别
mask = torch.eq(t_label.unsqueeze(1).expand(batch_size, batch_size_N),
i_targets.unsqueeze(0).expand(batch_size, batch_size_N)).float().to(self.device)
# 计算logits矩阵,即文本和图像特征间的点积,除以温度参数以调整尺度
logits = torch.div(torch.matmul(text_features, image_features.T), self.temperature)
# 为了数值稳定性,从每个logits中减去其最大值(防止exp运算溢出)
logits_max, _ = torch.max(logits, dim=1, keepdim=True)
logits = logits - logits_max.detach()
# 计算每对特征的指数值
exp_logits = torch.exp(logits)
# 计算log概率,即log(softmax(logits))
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
# 计算正样本对的平均对数概率,使用掩码进行加权
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
# 计算最终损失值,为正样本对的平均对数概率的负平均值
loss = -mean_log_prob_pos.mean()
return loss
这个类是为了在多模态学习场景中使用,特别适合处理需要同时考虑文本和图像数据的任务,如在图像标注或跨模态检索中。通过这种损失函数,模型被训练以学习将相同类别的不同模态数据映射到相近的特征空间,从而提高任务性能。
这段代码定义了一个名为 SupConLoss
的类,它是一个深度学习中用于计算监督对比损失(Supervised Contrastive Loss)的自定义损失函数类。这种损失函数通常用于多模态学习场景,比如同时处理文本和图像的特征。让我们逐步解析这段代码的主要部分:
类定义和初始化
SupConLoss
继承自nn.Module
,这是 PyTorch 中所有神经网络模块的基类。- 在初始化方法
__init__
中,它接受一个device
参数,用来指定运算应该在哪个设备上进行(如CPU或GPU)。 self.temperature
是一个标量,用于控制损失计算中的温度参数,影响特征向量之间相似度的缩放。
前向传播 forward
方法
- 输入参数包括
text_features
和image_features
,这两个是来自不同模态(文本和图像)的特征向量。 t_label
和i_targets
是这些特征向量对应的标签,用于确定哪些特征向量是正样本对。- 计算批次大小
batch_size
(文本特征的数量)和batch_size_N
(图像特征的数量)。
计算相似性矩阵和掩码
- 创建一个掩码
mask
,该掩码通过比较扩展后的t_label
和i_targets
来确定哪些文本和图像特征对应于相同的类别。 - 计算
logits
矩阵,它表示文本特征和图像特征之间的点积(归一化通过除以温度参数)。这个矩阵度量了两个不同模态间特征的相似度。
数值稳定性
- 通过从每行的
logits
中减去该行的最大值来防止数值爆炸。
计算概率和对数概率
- 使用
exp_logits
计算所有特征对的指数。 log_prob
是计算所有对的对数概率,用于计算对数似然。
计算损失
mean_log_prob_pos
计算每个样本与其正样本之间的对数概率的加权平均值。- 最终损失是这些平均对数概率的负值的平均,即
-mean_log_prob_pos.mean()
。
这种损失函数在训练中鼓励来自同一类别的不同模态特征向量之间的距离更近,而来自不同类别的特征向量之间的距离更远,有助于改善多模态学习任务中的特征表示。