【深度学习】语义分割损失函数之SemScal Loss

SemScalLoss,在论文《SurroundOcc: Multi-Camera 3D Occupancy Prediction for Autonomous Driving》中使用到。它可以用来在数据不平衡的情况下,对语义分割中的recall、precision、specificity中进行控制。

本文将详细介绍SemScalLoss的基本概念、思想原理,并提供PyTorch的实现代码,帮助大家去更好的理解和使用。

分类模型指标

对于分类问题,我们通常可以这样描述预测值:

  • TP:预测positive,预测对了

  • TN:预测negative,预测对了

  • FP:预测positive,预测错了

  • FN:预测negative,预测错了

相对应的,recall(召回率)、precision(精确度)、specificity(特异性)是常用来评估分类模型性能的指标。它们的定义如下:

r e c a l l = T P T P + F N p r e c i s i o n = T P T P + F P s p e c i f i c i t y = T N T N + F P \begin{aligned}recall&=\frac{TP}{TP+FN}\\precision&=\frac{TP}{TP+FP}\\ specificity&=\frac{TN}{TN+FP}\end{aligned} recallprecisionspecificity=TP+FNTP=TP+FPTP=TN+FPTN

其中:

  • 召回率是正类预测正确的比例,它关注的是模型捕捉到的正样本的能力,召回率越高,模型漏掉的正样本越少。

  • 精确度是预测为正的样本中实际为正的比例,它关注的是模型预测的准确性,精确度越高,模型预测为正的样本中实际为负的样本越少。

  • 特异性是负类预测正确的比例,它关注的是模型正确识别负样本的能力,特异性越高,模型误将负样本预测为正的能力越低。

可以看出,这三个指标的取值范围都是0到1之间,并且都是越靠近1,效果越好。

用集合的方式进行表示:

r e c a l l = ∣ X ⋂ Y ∣ ∣ X ∣ p r e c i s i o n = ∣ X ⋂ Y ∣ ∣ Y ∣ s p e c i f i c i t y = ∣ X ˉ ⋂ Y ˉ ∣ ∣ X ˉ ∣ \begin{aligned}recall&=\frac{|X\bigcap Y|}{|X|}\\precision&=\frac{|X\bigcap Y|}{|Y|}\\ specificity&=\frac{|\bar X\bigcap \bar Y|}{|\bar X|}\end{aligned} recallprecisionspecificity=∣X∣∣X⋂Y∣=∣Y∣∣X⋂Y∣=∣Xˉ∣∣Xˉ⋂Yˉ∣

其中, ∣ X ⋂ Y ∣ |X\bigcap Y| ∣X⋂Y∣表示交集的元素个数, ∣ X ∣ |X| ∣X∣表示 X X X的元素个数, ∣ Y ∣ |Y| ∣Y∣表示 Y Y Y的元素个数, X ˉ \bar X Xˉ表示非 X X X的元素个数, Y ˉ \bar Y Yˉ表示非 Y Y Y的元素个数。

SemScalLoss的基本概念

我们可以定义一个Loss:

S e m S c a l L o s s = α ⋅ ( 1 − r e c a l l ) + β ⋅ ( 1 − p r e c i s i o n ) + γ ⋅ ( 1 − s p e c i f i c i t y ) = α ⋅ ( 1 − ∣ X ⋂ Y ∣ ∣ X ∣ ) + β ⋅ ( 1 − ∣ X ⋂ Y ∣ ∣ Y ∣ ) + γ ⋅ ( 1 − ∣ X ˉ ⋂ Y ˉ ∣ ∣ X ˉ ∣ ) \begin{aligned}SemScalLoss&=\alpha \cdot (1-recall) + \beta \cdot (1- precision) + \gamma \cdot (1-specificity)\\&=\alpha \cdot (1-\frac{|X\bigcap Y|}{|X|}) + \beta \cdot (1- \frac{|X\bigcap Y|}{|Y|}) + \gamma \cdot (1-\frac{|\bar X\bigcap \bar Y|}{|\bar X|})\end{aligned} SemScalLoss=α⋅(1−recall)+β⋅(1−precision)+γ⋅(1−specificity)=α⋅(1−∣X∣∣X⋂Y∣)+β⋅(1−∣Y∣∣X⋂Y∣)+γ⋅(1−∣Xˉ∣∣Xˉ⋂Yˉ∣)

当然,此时可以看到recall、precision、specificity都是0到1之间的, 1 − 1- 1−操作之后,范围还都是0到1之间。这样会没有太多的区分度。如果为了增加区分度,可以使用下面的这种写法:

S e m S c a l L o s s = α ⋅ B C E ( r e c a l l ) + β ⋅ B C E ( p r e c i s i o n ) + γ ⋅ B C E ( s p e c i f i c i t y ) = α ⋅ B C E ( ∣ X ⋂ Y ∣ ∣ X ∣ ) + β ⋅ B C E ( ∣ X ⋂ Y ∣ ∣ Y ∣ ) + γ ⋅ B C E ( ∣ X ˉ ⋂ Y ˉ ∣ ∣ X ˉ ∣ ) \begin{aligned}SemScalLoss&=\alpha \cdot BCE(recall) + \beta \cdot BCE(precision) + \gamma \cdot BCE(specificity)\\&=\alpha \cdot BCE(\frac{|X\bigcap Y|}{|X|}) + \beta \cdot BCE(\frac{|X\bigcap Y|}{|Y|}) + \gamma \cdot BCE(\frac{|\bar X\bigcap \bar Y|}{|\bar X|})\end{aligned} SemScalLoss=α⋅BCE(recall)+β⋅BCE(precision)+γ⋅BCE(specificity)=α⋅BCE(∣X∣∣X⋂Y∣)+β⋅BCE(∣Y∣∣X⋂Y∣)+γ⋅BCE(∣Xˉ∣∣Xˉ⋂Yˉ∣)

由于BCE的计算方式是:torch.nn.functional.binary_cross_entropy,即:

B C E = − l n ( . . . ) BCE=-ln(...) BCE=−ln(...)

当自变量是0到1时,其范围为正无穷到0,与Loss的定义也是符合预期的。本文采用下面的这种方式。

对BCE不太了解的,可以参考博文:<>。

但是,我们可以看到,公式里面的 ∣ . . . ∣ |...| ∣...∣被理解成元素个数,这导致它是离散的。我们需要将其进行连续化。

我们可以仿照Dice、DiceLoss的方式获得连续性。对DiceLoss不熟悉的,可以参考博文:<>。

这里讨论一个通用的计算方式,并且保证连续性,定义:

∣ X ⋂ Y ∣ = ∑ i = 1 N t i m i ∣ Y ∣ = ∑ i = 1 N t i ∣ X ∣ = ∑ i = 1 N m i ∣ X ˉ ⋂ Y ˉ ∣ = ∑ i = 1 N ( 1 − t i ) ( 1 − m i ) ∣ X ˉ ∣ = ∑ i = 1 N ( 1 − m i ) \begin{aligned}|X\bigcap Y|&=\sum_{i=1}^N t_im_i\\|Y|&=\sum_{i=1}^Nt_i\\|X|&=\sum_{i=1}^Nm_i\\|\bar X\bigcap \bar Y|&=\sum_{i=1}^N(1-t_i)(1-m_i)\\|\bar X|&=\sum_{i=1}^N(1-m_i)\end{aligned} ∣X⋂Y∣∣Y∣∣X∣∣Xˉ⋂Yˉ∣∣Xˉ∣=i=1∑Ntimi=i=1∑Nti=i=1∑Nmi=i=1∑N(1−ti)(1−mi)=i=1∑N(1−mi)

其中, m i m_i mi是模型预测值,是经过sigmoid或者softmax之后的结果,取值在0到1之间; t i t_i ti是target标签,是经过one hot编码后的结果,取值非0即1。

通过这种计算方式,以模型输出的概率值进行替代计算,从而使得SemScalLoss获得连续性。

关于 α \alpha α、 β \beta β、 γ \gamma γ的取值,它们是用来平衡recall、precision、specificity之间的关系的。比如:如果增加 α \alpha α,将使得模型更加关注于recall。同时,也不一定需要 γ \gamma γ参数,一切都需要按照实际情况进行设置。

代码实战

多分类问题:采用one-hot编码

python 复制代码
class SemScalLoss(nn.Module):
  """
  多分类SemScalLoss
  """

  def __init__(self):
    super(SemScalLoss, self).__init__()

  def forward(self, pred, target, precision_weights=None, recall_weights=None):
    """
    pred: 模型的输出, 未经过 Softmax, 形状为 [B, C, P] (批次大小、类别数、数据大小)
    target: 标签, 形状为 [B, P], 取值范围为0到C-1
    """
    # Get softmax probabilities
    pred = F.softmax(pred, dim=1)
    loss = torch.tensor(0, device=pred.device, dtype=torch.float32)
    count = 0
    eps = 1e-5
    n_classes = pred.shape[1]

    if precision_weights is not None:
      assert (len(precision_weights) == n_classes), "num recall weights and num class mismatch!"
    if recall_weights is not None:
      assert (len(recall_weights) == n_classes), "num precision weights and num class mismatch!"

    for i in range(0, n_classes):
      p = pred[:, i]
      completion_target = torch.ones_like(target)
      completion_target[target != i] = 0

      if torch.sum(completion_target) > 0:
        count += 1.0
        nominator = torch.sum(p * completion_target)
        loss_class = 0

        if torch.sum(p) > 0:
          precision = nominator / (torch.sum(p))
          loss_precision = F.binary_cross_entropy(precision, torch.ones_like(precision))
          if precision_weights is not None:
            loss_precision = loss_precision * precision_weights[i]
          loss_class += loss_precision

        if torch.sum(completion_target) > 0:
          recall = nominator / (torch.sum(completion_target))
          loss_recall = F.binary_cross_entropy(recall, torch.ones_like(recall))
          if recall_weights is not None:
            loss_recall = loss_recall * recall_weights[i]
          loss_class += loss_recall

        if torch.sum(1 - completion_target) > 0:
          specificity = torch.sum((1 - p) * (1 - completion_target)) / (torch.sum((1 - completion_target)))
          loss_specificity = F.binary_cross_entropy(specificity, torch.ones_like(specificity))
          loss_class += loss_specificity

        loss += loss_class
    return loss / (count + eps)

需要注意的是,这边的代码写法应该与输入的shape、参数的shape有关系的,需要针对于具体的情形进行适配和修改。

相关阅读

相关推荐
玄同7652 小时前
深入理解 SQLAlchemy 的 relationship:让 ORM 关联像 Python 对象一样简单
人工智能·python·sql·conda·fastapi·pip·sqlalchemy
AI营销干货站2 小时前
原圈科技:决胜未来的金融AI市场分析实战教程
大数据·人工智能
Dingdangcat862 小时前
YOLOv26_数字万用表端口连接检测与识别_基于深度学习的自动识别系统
人工智能·深度学习·yolo
新缸中之脑2 小时前
微调 BERT 实现命名实体识别
人工智能·深度学习·bert
向上的车轮2 小时前
飞桨PaddlePaddle:入门指南
人工智能·paddlepaddle
一招定胜负2 小时前
OpenCV实战:DNN风格迁移与CSRT物体追踪
人工智能·opencv·dnn
deng12042 小时前
【yolov1:开启目标检测的全新纪元】
人工智能·yolo·目标检测
宇擎智脑科技3 小时前
A2UI 技术原理深度解析:AI Agent 如何安全生成富交互 UI
人工智能·a2ui
kicikng3 小时前
智能体来了(西南总部)完整拆解:AI Agent 指挥官 + AI调度官架构图
大数据·人工智能·多智能体系统·ai agent指挥官·ai调度官