SemScalLoss,在论文《SurroundOcc: Multi-Camera 3D Occupancy Prediction for Autonomous Driving》中使用到。它可以用来在数据不平衡的情况下,对语义分割中的recall、precision、specificity中进行控制。
本文将详细介绍SemScalLoss的基本概念、思想原理,并提供PyTorch的实现代码,帮助大家去更好的理解和使用。
分类模型指标
对于分类问题,我们通常可以这样描述预测值:
-
TP:预测positive,预测对了
-
TN:预测negative,预测对了
-
FP:预测positive,预测错了
-
FN:预测negative,预测错了
相对应的,recall(召回率)、precision(精确度)、specificity(特异性)是常用来评估分类模型性能的指标。它们的定义如下:
r e c a l l = T P T P + F N p r e c i s i o n = T P T P + F P s p e c i f i c i t y = T N T N + F P \begin{aligned}recall&=\frac{TP}{TP+FN}\\precision&=\frac{TP}{TP+FP}\\ specificity&=\frac{TN}{TN+FP}\end{aligned} recallprecisionspecificity=TP+FNTP=TP+FPTP=TN+FPTN
其中:
-
召回率是正类预测正确的比例,它关注的是模型捕捉到的正样本的能力,召回率越高,模型漏掉的正样本越少。
-
精确度是预测为正的样本中实际为正的比例,它关注的是模型预测的准确性,精确度越高,模型预测为正的样本中实际为负的样本越少。
-
特异性是负类预测正确的比例,它关注的是模型正确识别负样本的能力,特异性越高,模型误将负样本预测为正的能力越低。
可以看出,这三个指标的取值范围都是0到1之间,并且都是越靠近1,效果越好。
用集合的方式进行表示:
r e c a l l = ∣ X ⋂ Y ∣ ∣ X ∣ p r e c i s i o n = ∣ X ⋂ Y ∣ ∣ Y ∣ s p e c i f i c i t y = ∣ X ˉ ⋂ Y ˉ ∣ ∣ X ˉ ∣ \begin{aligned}recall&=\frac{|X\bigcap Y|}{|X|}\\precision&=\frac{|X\bigcap Y|}{|Y|}\\ specificity&=\frac{|\bar X\bigcap \bar Y|}{|\bar X|}\end{aligned} recallprecisionspecificity=∣X∣∣X⋂Y∣=∣Y∣∣X⋂Y∣=∣Xˉ∣∣Xˉ⋂Yˉ∣
其中, ∣ X ⋂ Y ∣ |X\bigcap Y| ∣X⋂Y∣表示交集的元素个数, ∣ X ∣ |X| ∣X∣表示 X X X的元素个数, ∣ Y ∣ |Y| ∣Y∣表示 Y Y Y的元素个数, X ˉ \bar X Xˉ表示非 X X X的元素个数, Y ˉ \bar Y Yˉ表示非 Y Y Y的元素个数。
SemScalLoss的基本概念
我们可以定义一个Loss:
S e m S c a l L o s s = α ⋅ ( 1 − r e c a l l ) + β ⋅ ( 1 − p r e c i s i o n ) + γ ⋅ ( 1 − s p e c i f i c i t y ) = α ⋅ ( 1 − ∣ X ⋂ Y ∣ ∣ X ∣ ) + β ⋅ ( 1 − ∣ X ⋂ Y ∣ ∣ Y ∣ ) + γ ⋅ ( 1 − ∣ X ˉ ⋂ Y ˉ ∣ ∣ X ˉ ∣ ) \begin{aligned}SemScalLoss&=\alpha \cdot (1-recall) + \beta \cdot (1- precision) + \gamma \cdot (1-specificity)\\&=\alpha \cdot (1-\frac{|X\bigcap Y|}{|X|}) + \beta \cdot (1- \frac{|X\bigcap Y|}{|Y|}) + \gamma \cdot (1-\frac{|\bar X\bigcap \bar Y|}{|\bar X|})\end{aligned} SemScalLoss=α⋅(1−recall)+β⋅(1−precision)+γ⋅(1−specificity)=α⋅(1−∣X∣∣X⋂Y∣)+β⋅(1−∣Y∣∣X⋂Y∣)+γ⋅(1−∣Xˉ∣∣Xˉ⋂Yˉ∣)
当然,此时可以看到recall、precision、specificity都是0到1之间的, 1 − 1- 1−操作之后,范围还都是0到1之间。这样会没有太多的区分度。如果为了增加区分度,可以使用下面的这种写法:
S e m S c a l L o s s = α ⋅ B C E ( r e c a l l ) + β ⋅ B C E ( p r e c i s i o n ) + γ ⋅ B C E ( s p e c i f i c i t y ) = α ⋅ B C E ( ∣ X ⋂ Y ∣ ∣ X ∣ ) + β ⋅ B C E ( ∣ X ⋂ Y ∣ ∣ Y ∣ ) + γ ⋅ B C E ( ∣ X ˉ ⋂ Y ˉ ∣ ∣ X ˉ ∣ ) \begin{aligned}SemScalLoss&=\alpha \cdot BCE(recall) + \beta \cdot BCE(precision) + \gamma \cdot BCE(specificity)\\&=\alpha \cdot BCE(\frac{|X\bigcap Y|}{|X|}) + \beta \cdot BCE(\frac{|X\bigcap Y|}{|Y|}) + \gamma \cdot BCE(\frac{|\bar X\bigcap \bar Y|}{|\bar X|})\end{aligned} SemScalLoss=α⋅BCE(recall)+β⋅BCE(precision)+γ⋅BCE(specificity)=α⋅BCE(∣X∣∣X⋂Y∣)+β⋅BCE(∣Y∣∣X⋂Y∣)+γ⋅BCE(∣Xˉ∣∣Xˉ⋂Yˉ∣)
由于BCE的计算方式是:torch.nn.functional.binary_cross_entropy,即:
B C E = − l n ( . . . ) BCE=-ln(...) BCE=−ln(...)
当自变量是0到1时,其范围为正无穷到0,与Loss的定义也是符合预期的。本文采用下面的这种方式。
对BCE不太了解的,可以参考博文:<>。
但是,我们可以看到,公式里面的 ∣ . . . ∣ |...| ∣...∣被理解成元素个数,这导致它是离散的。我们需要将其进行连续化。
我们可以仿照Dice、DiceLoss的方式获得连续性。对DiceLoss不熟悉的,可以参考博文:<>。
这里讨论一个通用的计算方式,并且保证连续性,定义:
∣ X ⋂ Y ∣ = ∑ i = 1 N t i m i ∣ Y ∣ = ∑ i = 1 N t i ∣ X ∣ = ∑ i = 1 N m i ∣ X ˉ ⋂ Y ˉ ∣ = ∑ i = 1 N ( 1 − t i ) ( 1 − m i ) ∣ X ˉ ∣ = ∑ i = 1 N ( 1 − m i ) \begin{aligned}|X\bigcap Y|&=\sum_{i=1}^N t_im_i\\|Y|&=\sum_{i=1}^Nt_i\\|X|&=\sum_{i=1}^Nm_i\\|\bar X\bigcap \bar Y|&=\sum_{i=1}^N(1-t_i)(1-m_i)\\|\bar X|&=\sum_{i=1}^N(1-m_i)\end{aligned} ∣X⋂Y∣∣Y∣∣X∣∣Xˉ⋂Yˉ∣∣Xˉ∣=i=1∑Ntimi=i=1∑Nti=i=1∑Nmi=i=1∑N(1−ti)(1−mi)=i=1∑N(1−mi)
其中, m i m_i mi是模型预测值,是经过sigmoid或者softmax之后的结果,取值在0到1之间; t i t_i ti是target标签,是经过one hot编码后的结果,取值非0即1。
通过这种计算方式,以模型输出的概率值进行替代计算,从而使得SemScalLoss获得连续性。
关于 α \alpha α、 β \beta β、 γ \gamma γ的取值,它们是用来平衡recall、precision、specificity之间的关系的。比如:如果增加 α \alpha α,将使得模型更加关注于recall。同时,也不一定需要 γ \gamma γ参数,一切都需要按照实际情况进行设置。
代码实战
多分类问题:采用one-hot编码
python
class SemScalLoss(nn.Module):
"""
多分类SemScalLoss
"""
def __init__(self):
super(SemScalLoss, self).__init__()
def forward(self, pred, target, precision_weights=None, recall_weights=None):
"""
pred: 模型的输出, 未经过 Softmax, 形状为 [B, C, P] (批次大小、类别数、数据大小)
target: 标签, 形状为 [B, P], 取值范围为0到C-1
"""
# Get softmax probabilities
pred = F.softmax(pred, dim=1)
loss = torch.tensor(0, device=pred.device, dtype=torch.float32)
count = 0
eps = 1e-5
n_classes = pred.shape[1]
if precision_weights is not None:
assert (len(precision_weights) == n_classes), "num recall weights and num class mismatch!"
if recall_weights is not None:
assert (len(recall_weights) == n_classes), "num precision weights and num class mismatch!"
for i in range(0, n_classes):
p = pred[:, i]
completion_target = torch.ones_like(target)
completion_target[target != i] = 0
if torch.sum(completion_target) > 0:
count += 1.0
nominator = torch.sum(p * completion_target)
loss_class = 0
if torch.sum(p) > 0:
precision = nominator / (torch.sum(p))
loss_precision = F.binary_cross_entropy(precision, torch.ones_like(precision))
if precision_weights is not None:
loss_precision = loss_precision * precision_weights[i]
loss_class += loss_precision
if torch.sum(completion_target) > 0:
recall = nominator / (torch.sum(completion_target))
loss_recall = F.binary_cross_entropy(recall, torch.ones_like(recall))
if recall_weights is not None:
loss_recall = loss_recall * recall_weights[i]
loss_class += loss_recall
if torch.sum(1 - completion_target) > 0:
specificity = torch.sum((1 - p) * (1 - completion_target)) / (torch.sum((1 - completion_target)))
loss_specificity = F.binary_cross_entropy(specificity, torch.ones_like(specificity))
loss_class += loss_specificity
loss += loss_class
return loss / (count + eps)
需要注意的是,这边的代码写法应该与输入的shape、参数的shape有关系的,需要针对于具体的情形进行适配和修改。