深度学习| 交叉熵损失函数(包含代码实现)

前言:因为我深度学习主要用于图像分割,所以交叉熵损失函数主要侧重在图像分割。

交叉熵损失函数

介绍

交叉熵损失函数(Cross-Entropy Loss)是深度学习中常用的一种损失函数,特别是处理分类问题。该函数起源于信息论中的交叉熵概念,用于衡量两个概率分布间的差异,可以衡量预估概率分布和真实样品对应概率分布之间的差异。

从这个概念上理解会让人感觉很抽象,建议直接从公式来进行理解。

公式

二分类
H ( y , y ^ ) = − 1 N ∑ i = 1 N [ y i l o g ( y ^ i ) + ( 1 − y i ) l o g ( 1 − y ^ i ) ] H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^N[y_ilog(\widehat{y}_i)+(1-y_i)log(1-\widehat{y}_i)] H(y,y )=−N1i=1∑N[yilog(y i)+(1−yi)log(1−y i)]

其中y是真实标签, y ^ \widehat{y} y 是预测值,N是样本的数量。每个样本都会计算一个损失,然后对所有样本的损失求平均。

对于图像来说,这里的N可以看作是图像像素点的个数, y ^ \widehat{y} y 是预测每个像素点的值,y是每个像素点标签的值,一张图像的交叉熵其实就是计算每个像素点预测值和标签插值的平均。

多分类

多分类就是二分类的延申,理解的原理都是一样的。
H ( y , y ^ ) = − 1 N ∑ i = 1 N [ y i 1 l o g ( y ^ i 1 ) + y i 2 l o g ( y ^ i 2 ) + . . . + y i m l o g ( y ^ i m ) ] H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^N[y_{i1}log(\widehat{y}{i1})+y{i2}log(\widehat{y}{i2})+...+y{im}log(\widehat{y}{im})] H(y,y )=−N1i=1∑N[yi1log(y i1)+yi2log(y i2)+...+yimlog(y im)]
H ( y , y ^ ) = − 1 N ∑ i = 1 N ∑ j = 1 m [ y i j l o g ( y ^ i j ) ] H(y,\widehat{y})=- \frac{1}{N} \sum
{i=1}^N\sum_{j=1}^m[y_{ij}log(\widehat{y}_{ij})] H(y,y )=−N1i=1∑Nj=1∑m[yijlog(y ij)]

这里的 y ^ \widehat{y} y 和y是one-hot编码目标向量,例如 y i = [ y i 1 , y i 2 , . . . , y i m ] y_i=[y_{i1},y_{i2},...,y_{im}] yi=[yi1,yi2,...,yim]。

交叉熵函数存在什么问题

之前我的一篇博客提过使用交叉熵函数面对类别不均衡的时候会出现问题,导致结果会偏向更常见的类别,对少类别的识别非常差。

产生这点的原因是因为交叉熵的特点就是"平等"地看待每一个样本,无论什么类别权重都是一样的。所以当正负样本不均衡时,大量简单的负样本会占据主导地位,少量的难样本和正样本就会分不出来。

带权重的交叉熵函数

由于交叉熵函数在应对类别不均衡会出现问题,于是就有了带权重的交叉熵函数。

带权重的交叉熵函数(Weighted cross entropy,WCE)会在计算交叉熵函数的时候,给不同类别前面加入一个权重。

公式
H ( y , y ^ ) = − 1 N ∑ i = 1 N ∑ j = 1 m [ w j y i j l o g ( y ^ i j ) ] H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^N\sum_{j=1}^m[w_{j} y_{ij}log(\widehat{y}_{ij})] H(y,y )=−N1i=1∑Nj=1∑m[wjyijlog(y ij)]

其中 w j w_j wj表示对j类别的权重,用于增大在预测图上占比例小的类别,公式如下:
w j = N − ∑ 1 N y ^ i j ∑ 1 N y ^ i j w_j= \frac{N-\sum_{1}^N\widehat{y}{ij}}{\sum{1}^N\widehat{y}_{ij}} wj=∑1Ny ijN−∑1Ny ij

补充:除了带权重的交叉熵函数能解决样本类别不均衡,还有DiceLoss和FocalLoss能用来解决。

代码

PyTorch的话有自带的库能解决:

python 复制代码
import torch
import torch.nn as nn

class CrossEntropyLoss2d(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(CrossEntropyLoss2d, self).__init__()
        self.nll_loss = nn.CrossEntropyLoss(weight, size_average)

    def forward(self, preds, targets):
        return self.nll_loss(preds, targets)

nn.CrossEntropyLoss(weight, size_average)

weight:可以指定一个一维的Tensor,用来设置每个类别的权重。用C表示类别的个数,Tensor的长度应该为C。

size_average:bool类型数据,默认情况下为True,此时损失是每个minibatch的平均;如果设置成False,则对每个minibatch求和。

相关推荐
eqwaak02 分钟前
量子计算与AI音乐——解锁无限可能的音色宇宙
人工智能·爬虫·python·自动化·量子计算
Blossom.1187 分钟前
量子计算与经典计算的融合与未来
人工智能·深度学习·机器学习·计算机视觉·量子计算
跳跳糖炒酸奶30 分钟前
第四章、Isaacsim在GUI中构建机器人(1): 添加简单对象
人工智能·python·ubuntu·机器人
猿饵块36 分钟前
机器人--ros2--IMU
人工智能
硅谷秋水37 分钟前
MoLe-VLA:通过混合层实现的动态跳层视觉-语言-动作模型实现高效机器人操作
人工智能·深度学习·机器学习·计算机视觉·语言模型·机器人
LS_learner38 分钟前
小智机器人关键函数解析,Application::OutputAudio()处理音频数据的输出的函数
人工智能·嵌入式硬件
2301_764441331 小时前
基于神经网络的肾脏疾病预测模型
人工智能·深度学习·神经网络
子燕若水1 小时前
用gpt-4o 生成图的教程和常用提示词
人工智能
weixin_442424031 小时前
Opencv计算机视觉编程攻略-第七节 提取直线、轮廓和区域
人工智能·opencv·计算机视觉
x-cmd1 小时前
[250401] OpenAI 向免费用户开放 GPT-4o 图像生成功能 | Neovim 0.11 新特性解读
人工智能·gpt·文生图·openai·命令行·neovim