Rethinking BiSeNet For Real-time Semantic Segmentation细节损失函数学习

更多损失函数学习见:深度学习pytorch之22种损失函数数学公式和代码定义

本博客为文章"Rethinking BiSeNet For Real-time Semantic Segmentation"损失函数学习和使用

论文地址:https://arxiv.org/pdf/2104.13188

代码地址:https://github.com/MichaelFan01/STDC-Seg

概述

细节图像中,细节像素的数量远少于非细节像素,细节预测是一个类别不平衡问题。因为加权交叉熵总是导致粗糙的结果,文章采用二元交叉熵和Dice损失来联合优化细节学习,Dice损失度量了预测图与真实标签之间的重叠度,同时它对前景/背景像素的数量不敏感,这意味着它可以缓解类别不平衡问题。

损失代码

python 复制代码
class DiceLoss(nn.Module):
    def __init__(self, smooth=1):
        super().__init__()
        self.smooth = smooth

    def forward(self, logits, labels):
        logits = torch.flatten(logits, 1)
        labels = torch.flatten(labels, 1)

        intersection = torch.sum(logits * labels, dim=1)
        loss = 1 - ((2 * intersection + self.smooth) / (logits.sum(1) + labels.sum(1) + self.smooth))

        return torch.mean(loss)
class DetailLoss(nn.Module):
    '''Implement detail loss used in paper
       `Rethinking BiSeNet For Real-time Semantic Segmentation`'''
    def __init__(self, dice_loss_coef=1., bce_loss_coef=1., smooth=1):
        super().__init__()
        self.dice_loss_coef = dice_loss_coef
        self.bce_loss_coef = bce_loss_coef
        self.dice_loss_fn = DiceLoss(smooth)
        self.bce_loss_fn = nn.BCEWithLogitsLoss()

    def forward(self, logits, labels):
        loss = self.dice_loss_coef * self.dice_loss_fn(logits, labels) + \
               self.bce_loss_coef * self.bce_loss_fn(logits, labels)

        return loss

公式原理

细节损失函数,由 Dice 损失和二元交叉熵损失组成:
L detail ( p d , g d ) = L dice ( p d , g d ) + L bce ( p d , g d ) \text{L}{\text{detail}}(p_d, g_d) = \text{L}{\text{dice}}(p_d, g_d) + \text{L}_{\text{bce}}(p_d, g_d) Ldetail(pd,gd)=Ldice(pd,gd)+Lbce(pd,gd)

其中:

  • p d ∈ R H × W p_d \in \mathbb{R}^{H \times W} pd∈RH×W为预测的细节图;

  • g d ∈ R H × W g_d \in \mathbb{R}^{H \times W} gd∈RH×W为对应的真实细节图;

  • L bce L_{\text{bce}} Lbce是二元交叉熵损失;

  • L dice L_{\text{dice}} Ldice是 Dice 损失,定义如下。

Dice损失公式:
L dice ( p d , g d ) = 1 − 2 ∑ i = 1 H × W p d i g d i + ϵ ∑ i = 1 H × W ( p d i ) 2 + ∑ i = 1 H × W ( g d i ) 2 + ϵ \text{L}{\text{dice}}(p_d, g_d) = 1 - \frac{2 \sum{i=1}^{H \times W} p_d^i g_d^i + \epsilon}{\sum_{i=1}^{H \times W} (p_d^i)^2 + \sum_{i=1}^{H \times W} (g_d^i)^2 + \epsilon} Ldice(pd,gd)=1−∑i=1H×W(pdi)2+∑i=1H×W(gdi)2+ϵ2∑i=1H×Wpdigdi+ϵ

其中:

  • ϵ \epsilon ϵ 为平滑项,防止分母为零;

  • p d i p_d^i pdi和 g d i g_d^i gdi 分别表示预测图和真实图中第 i i i个像素的值;

  • 分子为预测与真实的逐像素乘积之和的 2 倍;

  • 分母为预测值的平方和与真实值的平方和之和。

使用说明

见博客 DIY损失函数--以自适应边界损失为例

相关推荐
aitoolhub6 天前
AI在线设计中的Prompt技巧:如何让输出更精准
人工智能·计算机视觉·prompt·aigc·语义分割·设计语言
这张生成的图像能检测吗6 天前
(论文速读)CCASeg:基于卷积交叉注意的语义分割多尺度上下文解码
人工智能·深度学习·计算机视觉·语义分割
一瞬祈望9 天前
⭐ 深度学习入门体系(第 7 篇): 什么是损失函数?
人工智能·深度学习·cnn·损失函数
这张生成的图像能检测吗18 天前
(论文速读)RoShuNet:一个轻量级的基于卷积神经网络的可见图像特征提取器
人工智能·深度学习·计算机视觉·语义分割·目标追踪·分类模型
oliveray19 天前
基于 OpenVINO 优化的 GroundingDINO + EfficientSAM 视频分割追踪
人工智能·目标检测·语义分割·openvino
怎么全是重名25 天前
Stacked U-Nets: A No-Frills Approach to Natural Image Segmentation
深度学习·神经网络·计算机视觉·语义分割
donkey_19931 个月前
ShiftwiseConv: Small Convolutional Kernel with Large Kernel Effect
人工智能·深度学习·目标检测·计算机视觉·语义分割·实例分割
猛码Memmat1 个月前
SAM 3: Segment Anything with Concepts
计算机视觉·sam·语义分割
最晚的py1 个月前
正规方程法
损失函数·正规方程法
最晚的py1 个月前
机器学习--损失函数
人工智能·python·机器学习·损失函数