计算Dice损失的函数

计算Dice损失的函数

python 复制代码
def Dice_loss(inputs, target, beta=1, smooth = 1e-5):
    n,c, h, w = inputs.size()    #
    nt,ht, wt, ct = target.size()  #nt,
    
    if h != ht and w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
    temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
    temp_target = target.view(n, -1, ct)

    #--------------------------------------------#
    #   计算dice loss
    #--------------------------------------------#
    tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
    fp = torch.sum(temp_inputs                       , axis=[0,1]) - tp
    fn = torch.sum(temp_target[...,:-1]              , axis=[0,1]) - tp

    score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
    dice_loss = 1 - torch.mean(score)
    return dice_loss

这段代码是用于计算二分类问题的混淆矩阵(Confusion Matrix)中的True Positives(TP),False Positives(FP)和False Negatives(FN)。在混淆矩阵中,TP表示模型正确预测为正类的数量,FP表示模型错误地预测为正类的数量,FN表示实际为正类但模型没有预测为正类的数量。

让我们分解这段代码来理解每个部分的作用:

  1. temp_target[..., :-1] * temp_inputs

    • temp_target[..., :-1] 获取 temp_target 张量中除了最后一维之外的所有元素。:-1 是一个切片操作,它表示从开始到倒数第二个元素。
    • temp_inputs 是模型的预测输出。
    • 这两个张量进行元素相乘,只有当 temp_target 的最后一维等于 1 时,才会乘以 temp_inputs 对应的位置的值。这模拟了只有当预测和真实标签都为正类(1)时,才认为是真正的正类检测。
  2. torch.sum(..., axis=[0,1])

    • 这是一个求和操作,计算在指定维度上(这里是第0维和第1维)的总和。
    • axis=[0,1] 表示在第0维和第1维上进行求和。通常,第0维代表批量大小(batch size),第1维代表序列长度(sequence length)。
    • 这样做的效果是将所有正类预测的和(TP)汇总起来,无论它们在批量中的哪个位置或序列中。
  3. tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])

    • 最终,tp 保存了所有正类预测的数量。
  4. fp = torch.sum(temp_inputs, axis=[0,1]) - tp

    • torch.sum(temp_inputs, axis=[0,1]) 计算了所有预测为正类的数量,无论它们是否真的是正类。
    • 然后从中减去 tp,得到假正类的数量(FP),即模型错误地预测为正类的数量。
  5. fn = torch.sum(temp_target[...,:-1], axis=[0,1]) - tp

    • torch.sum(temp_target[...,:-1], axis=[0,1]) 计算了实际为正类的数量,无论模型是否预测它们为正类。
    • 然后从中减去 tp,得到假负类的数量(FN),即实际为正类但模型没有预测为正类的数量。

综上所述,这段代码通过计算TP、FP和FN,来评估模型在二分类任务中的性能。这些值是计算精确度(Precision)、召回率(Recall)和F1得分的关键。

相关推荐
小袁拒绝摆烂1 小时前
OpenCV-python灰度变化和直方图修正类型
python·opencv·计算机视觉
gogoMark4 小时前
口播视频怎么剪!利用AI提高口播视频剪辑效率并增强”网感”
人工智能·音视频
Dxy12393102164 小时前
Python 条件语句详解
开发语言·python
龙泉寺天下行走4 小时前
Python 翻译词典小程序
python·oracle·小程序
2201_754918414 小时前
OpenCV 特征检测全面解析与实战应用
人工智能·opencv·计算机视觉
践行见远5 小时前
django之视图
python·django·drf
love530love6 小时前
Windows避坑部署CosyVoice多语言大语言模型
人工智能·windows·python·语言模型·自然语言处理·pycharm
985小水博一枚呀6 小时前
【AI大模型学习路线】第二阶段之RAG基础与架构——第七章(【项目实战】基于RAG的PDF文档助手)技术方案与架构设计?
人工智能·学习·语言模型·架构·大模型
白熊1887 小时前
【图像生成大模型】Wan2.1:下一代开源大规模视频生成模型
人工智能·计算机视觉·开源·文生图·音视频
weixin_514548897 小时前
一种开源的高斯泼溅实现库——gsplat: An Open-Source Library for Gaussian Splatting
人工智能·计算机视觉·3d