pytorch 加权CE_loss实现(语义分割中的类不平衡使用)

加权CE_loss和BCE_loss稍有不同

1.标签为long类型,BCE标签为float类型

2.当reduction为mean时计算每个像素点的损失的平均,BCE除以像素数得到平均值,CE除以像素对应的权重之和得到平均值。

参数配置torch.nn.CrossEntropyLoss(weight=None,size_average=None,ignore_index=-100,reduce=None,reduction='mean',label_smoothing=0.0)

增加加权的CE_loss代码实现

python 复制代码
# 总之, CrossEntropyLoss() = softmax + log + NLLLoss() = log_softmax + NLLLoss(), 具体等价应用如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np

class CrossEntropyLoss2d(nn.Module):
    def __init__(self, weight=None):
      super(CrossEntropyLoss2d, self).__init__()
       self.nll_loss = nn.CrossEntropyLoss(weight, reduction='mean')
    def forward(self, preds, targets):
        return self.nll_loss(preds, targets)

语义分割类别计算

python 复制代码
class CE_w_loss(nn.Module):
    def __init__(self,ignore_index=255):
        super(CE_w_loss, self).__init__()
        self.ignore_index = ignore_index
        # self.CE = nn.CrossEntropyLoss(ignore_index=self.ignore_index)
    def forward(self, outputs, targets):
        class_num = outputs.shape[1]
        # print("class_num :",class_num )
        # # 计算每个类别在整个 batch 中的像素数占比
        class_pixel_counts = torch.bincount(targets.flatten(), minlength=class_num)  # 假设有class_num个类别
        class_pixel_proportions = class_pixel_counts.float() / torch.numel(targets)
        # # 根据类别占比计算权重
        class_weights = 1.0 / (torch.log(1.02 + class_pixel_proportions)).double()  # 使用对数变换平衡权重
        # # print("class_weights :",class_weights)
        #
        # 定义交叉熵损失函数,并使用动态计算的类别权重
        criterion = nn.CrossEntropyLoss(ignore_index=self.ignore_index,weight= class_weights)

        # 计算损失
        loss = criterion(outputs, targets)
        print(loss.item())  # 打印损失值
        return loss

    np.random.seed(666)
    pred = np.ones((2, 5, 256,256))
    seg = np.ones((2, 5, 256, 256)) # 灰度
    label = np.ones((2, 256, 256))  # 灰度

    pred = torch.from_numpy(pred)
    seg = torch.from_numpy(seg).int()  # 灰度
    label = torch.from_numpy(label).long()
     ce = CE_w_loss()
    loss = ce(pred, label)
    print("loss:",loss.item())

报错

Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).float() 报错

Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).double() 正确

参考:[1]https://blog.csdn.net/CSDN_of_ding/article/details/111515226

2\] \[3\]

相关推荐
才思喷涌的小书虫3 分钟前
打破 3D 感知瓶颈:OVSeg3R 如何推动开集 3D 实例分割应用落地
人工智能·目标检测·计算机视觉·3d·具身智能·数据标注·图像标注
invicinble8 分钟前
对于后端要和linux打交道要掌握的点
linux·运维·python
喵手10 分钟前
Python爬虫零基础入门【第三章:Requests 静态爬取入门·第4节】列表页→详情页:两段式采集(90%项目都这样)!
爬虫·python·python爬虫实战·python爬虫工程化实战·python爬虫零基础入门·requests静态爬取·两段式采集
言之。11 分钟前
2026 年 1 月 15 日 - 21 日国内外 AI 科技大事及热点深度整理报告
人工智能·科技
zzZ··*12 分钟前
自动登录上海大学校园
python·网络协议·selenium
weisian15113 分钟前
进阶篇-4-数学篇-3--深度解析AI中的向量概念:从生活到代码,一文吃透核心逻辑
人工智能·python·生活·向量
这儿有一堆花13 分钟前
AI视频生成的底层逻辑与技术架构
人工智能·音视频
写代码的【黑咖啡】14 分钟前
Python中的Msgpack:高效二进制序列化库
开发语言·python
Fairy要carry15 分钟前
面试-Encoder-Decoder预训练思路
人工智能
杭州泽沃电子科技有限公司15 分钟前
“不速之客”的威胁:在线监测如何筑起抵御小动物的智能防线
人工智能·在线监测