更多损失函数学习见:深度学习pytorch之22种损失函数数学公式和代码定义
本博客为文章"Rethinking BiSeNet For Real-time Semantic Segmentation"损失函数学习和使用
概述
细节图像中,细节像素的数量远少于非细节像素,细节预测是一个类别不平衡问题。因为加权交叉熵总是导致粗糙的结果,文章采用二元交叉熵和Dice损失来联合优化细节学习,Dice损失度量了预测图与真实标签之间的重叠度,同时它对前景/背景像素的数量不敏感,这意味着它可以缓解类别不平衡问题。
损失代码
python
class DiceLoss(nn.Module):
def __init__(self, smooth=1):
super().__init__()
self.smooth = smooth
def forward(self, logits, labels):
logits = torch.flatten(logits, 1)
labels = torch.flatten(labels, 1)
intersection = torch.sum(logits * labels, dim=1)
loss = 1 - ((2 * intersection + self.smooth) / (logits.sum(1) + labels.sum(1) + self.smooth))
return torch.mean(loss)
class DetailLoss(nn.Module):
'''Implement detail loss used in paper
`Rethinking BiSeNet For Real-time Semantic Segmentation`'''
def __init__(self, dice_loss_coef=1., bce_loss_coef=1., smooth=1):
super().__init__()
self.dice_loss_coef = dice_loss_coef
self.bce_loss_coef = bce_loss_coef
self.dice_loss_fn = DiceLoss(smooth)
self.bce_loss_fn = nn.BCEWithLogitsLoss()
def forward(self, logits, labels):
loss = self.dice_loss_coef * self.dice_loss_fn(logits, labels) + \
self.bce_loss_coef * self.bce_loss_fn(logits, labels)
return loss
公式原理
细节损失函数,由 Dice 损失和二元交叉熵损失组成:
L detail ( p d , g d ) = L dice ( p d , g d ) + L bce ( p d , g d ) \text{L}{\text{detail}}(p_d, g_d) = \text{L}{\text{dice}}(p_d, g_d) + \text{L}_{\text{bce}}(p_d, g_d) Ldetail(pd,gd)=Ldice(pd,gd)+Lbce(pd,gd)
其中:
-
p d ∈ R H × W p_d \in \mathbb{R}^{H \times W} pd∈RH×W为预测的细节图;
-
g d ∈ R H × W g_d \in \mathbb{R}^{H \times W} gd∈RH×W为对应的真实细节图;
-
L bce L_{\text{bce}} Lbce是二元交叉熵损失;
-
L dice L_{\text{dice}} Ldice是 Dice 损失,定义如下。
Dice损失公式:
L dice ( p d , g d ) = 1 − 2 ∑ i = 1 H × W p d i g d i + ϵ ∑ i = 1 H × W ( p d i ) 2 + ∑ i = 1 H × W ( g d i ) 2 + ϵ \text{L}{\text{dice}}(p_d, g_d) = 1 - \frac{2 \sum{i=1}^{H \times W} p_d^i g_d^i + \epsilon}{\sum_{i=1}^{H \times W} (p_d^i)^2 + \sum_{i=1}^{H \times W} (g_d^i)^2 + \epsilon} Ldice(pd,gd)=1−∑i=1H×W(pdi)2+∑i=1H×W(gdi)2+ϵ2∑i=1H×Wpdigdi+ϵ
其中:
-
ϵ \epsilon ϵ 为平滑项,防止分母为零;
-
p d i p_d^i pdi和 g d i g_d^i gdi 分别表示预测图和真实图中第 i i i个像素的值;
-
分子为预测与真实的逐像素乘积之和的 2 倍;
-
分母为预测值的平方和与真实值的平方和之和。