深度学习| 交叉熵损失函数(包含代码实现)

前言:因为我深度学习主要用于图像分割,所以交叉熵损失函数主要侧重在图像分割。

交叉熵损失函数

介绍

交叉熵损失函数(Cross-Entropy Loss)是深度学习中常用的一种损失函数,特别是处理分类问题。该函数起源于信息论中的交叉熵概念,用于衡量两个概率分布间的差异,可以衡量预估概率分布和真实样品对应概率分布之间的差异。

从这个概念上理解会让人感觉很抽象,建议直接从公式来进行理解。

公式

二分类
H ( y , y ^ ) = − 1 N ∑ i = 1 N [ y i l o g ( y ^ i ) + ( 1 − y i ) l o g ( 1 − y ^ i ) ] H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^N[y_ilog(\widehat{y}_i)+(1-y_i)log(1-\widehat{y}_i)] H(y,y )=−N1i=1∑N[yilog(y i)+(1−yi)log(1−y i)]

其中y是真实标签, y ^ \widehat{y} y 是预测值,N是样本的数量。每个样本都会计算一个损失,然后对所有样本的损失求平均。

对于图像来说,这里的N可以看作是图像像素点的个数, y ^ \widehat{y} y 是预测每个像素点的值,y是每个像素点标签的值,一张图像的交叉熵其实就是计算每个像素点预测值和标签插值的平均。

多分类

多分类就是二分类的延申,理解的原理都是一样的。
H ( y , y ^ ) = − 1 N ∑ i = 1 N [ y i 1 l o g ( y ^ i 1 ) + y i 2 l o g ( y ^ i 2 ) + . . . + y i m l o g ( y ^ i m ) ] H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^N[y_{i1}log(\widehat{y}{i1})+y{i2}log(\widehat{y}{i2})+...+y{im}log(\widehat{y}{im})] H(y,y )=−N1i=1∑N[yi1log(y i1)+yi2log(y i2)+...+yimlog(y im)]
H ( y , y ^ ) = − 1 N ∑ i = 1 N ∑ j = 1 m [ y i j l o g ( y ^ i j ) ] H(y,\widehat{y})=- \frac{1}{N} \sum
{i=1}^N\sum_{j=1}^m[y_{ij}log(\widehat{y}_{ij})] H(y,y )=−N1i=1∑Nj=1∑m[yijlog(y ij)]

这里的 y ^ \widehat{y} y 和y是one-hot编码目标向量,例如 y i = [ y i 1 , y i 2 , . . . , y i m ] y_i=[y_{i1},y_{i2},...,y_{im}] yi=[yi1,yi2,...,yim]。

交叉熵函数存在什么问题

之前我的一篇博客提过使用交叉熵函数面对类别不均衡的时候会出现问题,导致结果会偏向更常见的类别,对少类别的识别非常差。

产生这点的原因是因为交叉熵的特点就是"平等"地看待每一个样本,无论什么类别权重都是一样的。所以当正负样本不均衡时,大量简单的负样本会占据主导地位,少量的难样本和正样本就会分不出来。

带权重的交叉熵函数

由于交叉熵函数在应对类别不均衡会出现问题,于是就有了带权重的交叉熵函数。

带权重的交叉熵函数(Weighted cross entropy,WCE)会在计算交叉熵函数的时候,给不同类别前面加入一个权重。

公式
H ( y , y ^ ) = − 1 N ∑ i = 1 N ∑ j = 1 m [ w j y i j l o g ( y ^ i j ) ] H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^N\sum_{j=1}^m[w_{j} y_{ij}log(\widehat{y}_{ij})] H(y,y )=−N1i=1∑Nj=1∑m[wjyijlog(y ij)]

其中 w j w_j wj表示对j类别的权重,用于增大在预测图上占比例小的类别,公式如下:
w j = N − ∑ 1 N y ^ i j ∑ 1 N y ^ i j w_j= \frac{N-\sum_{1}^N\widehat{y}{ij}}{\sum{1}^N\widehat{y}_{ij}} wj=∑1Ny ijN−∑1Ny ij

补充:除了带权重的交叉熵函数能解决样本类别不均衡,还有DiceLoss和FocalLoss能用来解决。

代码

PyTorch的话有自带的库能解决:

python 复制代码
import torch
import torch.nn as nn

class CrossEntropyLoss2d(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(CrossEntropyLoss2d, self).__init__()
        self.nll_loss = nn.CrossEntropyLoss(weight, size_average)

    def forward(self, preds, targets):
        return self.nll_loss(preds, targets)

nn.CrossEntropyLoss(weight, size_average)

weight:可以指定一个一维的Tensor,用来设置每个类别的权重。用C表示类别的个数,Tensor的长度应该为C。

size_average:bool类型数据,默认情况下为True,此时损失是每个minibatch的平均;如果设置成False,则对每个minibatch求和。

相关推荐
正在走向自律11 分钟前
Trae上手指南:AI编程从0到1的奇妙跃迁
人工智能
MILI元宇宙15 分钟前
DeepSeek R1开源模型的技术突破与AI产业格局的重构
人工智能·重构·开源
江苏泊苏系统集成有限公司1 小时前
半导体晶圆制造洁净厂房的微振控制方案-江苏泊苏系统集成有限公司
人工智能·深度学习·目标检测·机器学习·创业创新·制造·远程工作
猿小猴子2 小时前
主流 AI IDE 之一的 Windsurf 介绍
ide·人工智能
智联视频超融合平台3 小时前
无人机+AI视频联网:精准狙击,让‘罪恶之花’无处藏身
人工智能·网络协议·安全·系统安全·音视频·无人机
AiTEN_Robotics3 小时前
智能仓储落地:机器人如何通过自动化减少仓库操作失误?
人工智能·机器人·自动化
江湖有缘4 小时前
华为云Flexus+DeepSeek征文 | 初探华为云ModelArts Studio:部署DeepSeek-V3/R1商用服务的详细步骤
人工智能·华为云·modelarts
Vizio<4 小时前
基于FashionMnist数据集的自监督学习(生成式自监督学习AE算法)
人工智能·笔记·深度学习·神经网络·自监督学习
梅一一4 小时前
5款AI对决:Gemini学术封神,但日常办公我选它
大数据·人工智能·数据可视化
kyle~4 小时前
Pytorch---ImageFolder
人工智能·pytorch·python