pytorch 加权CE_loss实现(语义分割中的类不平衡使用)

加权CE_loss和BCE_loss稍有不同

1.标签为long类型,BCE标签为float类型

2.当reduction为mean时计算每个像素点的损失的平均,BCE除以像素数得到平均值,CE除以像素对应的权重之和得到平均值。

参数配置torch.nn.CrossEntropyLoss(weight=None,size_average=None,ignore_index=-100,reduce=None,reduction='mean',label_smoothing=0.0)

增加加权的CE_loss代码实现

python 复制代码
# 总之, CrossEntropyLoss() = softmax + log + NLLLoss() = log_softmax + NLLLoss(), 具体等价应用如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np

class CrossEntropyLoss2d(nn.Module):
    def __init__(self, weight=None):
      super(CrossEntropyLoss2d, self).__init__()
       self.nll_loss = nn.CrossEntropyLoss(weight, reduction='mean')
    def forward(self, preds, targets):
        return self.nll_loss(preds, targets)

语义分割类别计算

python 复制代码
class CE_w_loss(nn.Module):
    def __init__(self,ignore_index=255):
        super(CE_w_loss, self).__init__()
        self.ignore_index = ignore_index
        # self.CE = nn.CrossEntropyLoss(ignore_index=self.ignore_index)
    def forward(self, outputs, targets):
        class_num = outputs.shape[1]
        # print("class_num :",class_num )
        # # 计算每个类别在整个 batch 中的像素数占比
        class_pixel_counts = torch.bincount(targets.flatten(), minlength=class_num)  # 假设有class_num个类别
        class_pixel_proportions = class_pixel_counts.float() / torch.numel(targets)
        # # 根据类别占比计算权重
        class_weights = 1.0 / (torch.log(1.02 + class_pixel_proportions)).double()  # 使用对数变换平衡权重
        # # print("class_weights :",class_weights)
        #
        # 定义交叉熵损失函数,并使用动态计算的类别权重
        criterion = nn.CrossEntropyLoss(ignore_index=self.ignore_index,weight= class_weights)

        # 计算损失
        loss = criterion(outputs, targets)
        print(loss.item())  # 打印损失值
        return loss

    np.random.seed(666)
    pred = np.ones((2, 5, 256,256))
    seg = np.ones((2, 5, 256, 256)) # 灰度
    label = np.ones((2, 256, 256))  # 灰度

    pred = torch.from_numpy(pred)
    seg = torch.from_numpy(seg).int()  # 灰度
    label = torch.from_numpy(label).long()
     ce = CE_w_loss()
    loss = ce(pred, label)
    print("loss:",loss.item())

报错

Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).float() 报错

Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).double() 正确

参考:[1]https://blog.csdn.net/CSDN_of_ding/article/details/111515226

2\] \[3\]

相关推荐
XINVRY-FPGA1 分钟前
XCVU47P-2FSVH2892E Xilinx Virtex UltraScale+ FPGA AMD
c语言·c++·人工智能·嵌入式硬件·阿里云·fpga开发·fpga
___波子 Pro Max.1 小时前
python list去重
python·list
Ai财富密码2 小时前
机器学习 (ML) 基础入门指南
人工智能·神经网络·机器学习·机器人·ml
华科易迅3 小时前
人工智能学习38-VGG训练
人工智能·学习·人工智能学习38-vgg训练
狐凄3 小时前
Python实例题:基于边缘计算的智能物联网系统
python·物联网·边缘计算
m0_537437573 小时前
【深度学习基础与概念】笔记(一)深度学习革命
人工智能·笔记·深度学习
@十八子德月生3 小时前
第十章——8天Python从入门到精通【itheima】-99~101-Python基础综合案例-数据可视化(案例介绍=JSON格式+pyecharts简介)
大数据·python·信息可视化·pycharm·echarts·数据可视化
W说编程3 小时前
算法导论第二十四章 深度学习前沿:从序列建模到创造式AI
c语言·人工智能·python·深度学习·算法·性能优化
动能小子ohhh4 小时前
html实现登录与注册功能案例(不写死且只使用js)
开发语言·前端·javascript·python·html
hao_wujing4 小时前
RNN工作原理和架构
人工智能