Rethinking BiSeNet For Real-time Semantic Segmentation细节损失函数学习

更多损失函数学习见:深度学习pytorch之22种损失函数数学公式和代码定义

本博客为文章"Rethinking BiSeNet For Real-time Semantic Segmentation"损失函数学习和使用

论文地址:https://arxiv.org/pdf/2104.13188

代码地址:https://github.com/MichaelFan01/STDC-Seg

概述

细节图像中,细节像素的数量远少于非细节像素,细节预测是一个类别不平衡问题。因为加权交叉熵总是导致粗糙的结果,文章采用二元交叉熵和Dice损失来联合优化细节学习,Dice损失度量了预测图与真实标签之间的重叠度,同时它对前景/背景像素的数量不敏感,这意味着它可以缓解类别不平衡问题。

损失代码

python 复制代码
class DiceLoss(nn.Module):
    def __init__(self, smooth=1):
        super().__init__()
        self.smooth = smooth

    def forward(self, logits, labels):
        logits = torch.flatten(logits, 1)
        labels = torch.flatten(labels, 1)

        intersection = torch.sum(logits * labels, dim=1)
        loss = 1 - ((2 * intersection + self.smooth) / (logits.sum(1) + labels.sum(1) + self.smooth))

        return torch.mean(loss)
class DetailLoss(nn.Module):
    '''Implement detail loss used in paper
       `Rethinking BiSeNet For Real-time Semantic Segmentation`'''
    def __init__(self, dice_loss_coef=1., bce_loss_coef=1., smooth=1):
        super().__init__()
        self.dice_loss_coef = dice_loss_coef
        self.bce_loss_coef = bce_loss_coef
        self.dice_loss_fn = DiceLoss(smooth)
        self.bce_loss_fn = nn.BCEWithLogitsLoss()

    def forward(self, logits, labels):
        loss = self.dice_loss_coef * self.dice_loss_fn(logits, labels) + \
               self.bce_loss_coef * self.bce_loss_fn(logits, labels)

        return loss

公式原理

细节损失函数,由 Dice 损失和二元交叉熵损失组成:
L detail ( p d , g d ) = L dice ( p d , g d ) + L bce ( p d , g d ) \text{L}{\text{detail}}(p_d, g_d) = \text{L}{\text{dice}}(p_d, g_d) + \text{L}_{\text{bce}}(p_d, g_d) Ldetail(pd,gd)=Ldice(pd,gd)+Lbce(pd,gd)

其中:

  • p d ∈ R H × W p_d \in \mathbb{R}^{H \times W} pd∈RH×W为预测的细节图;

  • g d ∈ R H × W g_d \in \mathbb{R}^{H \times W} gd∈RH×W为对应的真实细节图;

  • L bce L_{\text{bce}} Lbce是二元交叉熵损失;

  • L dice L_{\text{dice}} Ldice是 Dice 损失,定义如下。

Dice损失公式:
L dice ( p d , g d ) = 1 − 2 ∑ i = 1 H × W p d i g d i + ϵ ∑ i = 1 H × W ( p d i ) 2 + ∑ i = 1 H × W ( g d i ) 2 + ϵ \text{L}{\text{dice}}(p_d, g_d) = 1 - \frac{2 \sum{i=1}^{H \times W} p_d^i g_d^i + \epsilon}{\sum_{i=1}^{H \times W} (p_d^i)^2 + \sum_{i=1}^{H \times W} (g_d^i)^2 + \epsilon} Ldice(pd,gd)=1−∑i=1H×W(pdi)2+∑i=1H×W(gdi)2+ϵ2∑i=1H×Wpdigdi+ϵ

其中:

  • ϵ \epsilon ϵ 为平滑项,防止分母为零;

  • p d i p_d^i pdi和 g d i g_d^i gdi 分别表示预测图和真实图中第 i i i个像素的值;

  • 分子为预测与真实的逐像素乘积之和的 2 倍;

  • 分母为预测值的平方和与真实值的平方和之和。

使用说明

见博客 DIY损失函数--以自适应边界损失为例

相关推荐
深图智能10 天前
轻量级语义分割算法:演进与创新
算法·计算机视觉·语义分割
xidianjiapei00112 天前
一文读懂深度学习中的损失函数quantifying loss —— 作用、分类和示例代码
人工智能·深度学习·分类·损失函数·交叉熵
杀生丸学AI22 天前
【三维分割】LangSplat: 3D Language Gaussian Splatting(CVPR 2024 highlight)
人工智能·3d·大模型·aigc·svd·语义分割·视频生成
万里守约1 个月前
【论文阅读】SAM-CP:将SAM与组合提示结合起来的多功能分割
论文阅读·图像分割·多模态·语义分割·实例分割·图像大模型
丶21362 个月前
【分类】【损失函数】处理类别不平衡:CEFL 和 CEFL2 损失函数的实现与应用
人工智能·分类·损失函数
余胜辉2 个月前
【深度学习】交叉熵:从理论到实践
人工智能·深度学习·机器学习·损失函数·交叉熵
AICurator3 个月前
SAM2训练自己的数据集
深度学习·语义分割·sam2
chencjiajy3 个月前
机器学习基础:极大似然估计与交叉熵
深度学习·机器学习·损失函数
王亭_6663 个月前
深度学习中损失函数(loss function)介绍
人工智能·pytorch·深度学习·损失函数