深度学习| 交叉熵损失函数(包含代码实现)

前言:因为我深度学习主要用于图像分割,所以交叉熵损失函数主要侧重在图像分割。

交叉熵损失函数

介绍

交叉熵损失函数(Cross-Entropy Loss)是深度学习中常用的一种损失函数,特别是处理分类问题。该函数起源于信息论中的交叉熵概念,用于衡量两个概率分布间的差异,可以衡量预估概率分布和真实样品对应概率分布之间的差异。

从这个概念上理解会让人感觉很抽象,建议直接从公式来进行理解。

公式

二分类
H ( y , y ^ ) = − 1 N ∑ i = 1 N [ y i l o g ( y ^ i ) + ( 1 − y i ) l o g ( 1 − y ^ i ) ] H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^N[y_ilog(\widehat{y}_i)+(1-y_i)log(1-\widehat{y}_i)] H(y,y )=−N1i=1∑N[yilog(y i)+(1−yi)log(1−y i)]

其中y是真实标签, y ^ \widehat{y} y 是预测值,N是样本的数量。每个样本都会计算一个损失,然后对所有样本的损失求平均。

对于图像来说,这里的N可以看作是图像像素点的个数, y ^ \widehat{y} y 是预测每个像素点的值,y是每个像素点标签的值,一张图像的交叉熵其实就是计算每个像素点预测值和标签插值的平均。

多分类

多分类就是二分类的延申,理解的原理都是一样的。
H ( y , y ^ ) = − 1 N ∑ i = 1 N [ y i 1 l o g ( y ^ i 1 ) + y i 2 l o g ( y ^ i 2 ) + . . . + y i m l o g ( y ^ i m ) ] H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^N[y_{i1}log(\widehat{y}{i1})+y{i2}log(\widehat{y}{i2})+...+y{im}log(\widehat{y}{im})] H(y,y )=−N1i=1∑N[yi1log(y i1)+yi2log(y i2)+...+yimlog(y im)]
H ( y , y ^ ) = − 1 N ∑ i = 1 N ∑ j = 1 m [ y i j l o g ( y ^ i j ) ] H(y,\widehat{y})=- \frac{1}{N} \sum
{i=1}^N\sum_{j=1}^m[y_{ij}log(\widehat{y}_{ij})] H(y,y )=−N1i=1∑Nj=1∑m[yijlog(y ij)]

这里的 y ^ \widehat{y} y 和y是one-hot编码目标向量,例如 y i = [ y i 1 , y i 2 , . . . , y i m ] y_i=[y_{i1},y_{i2},...,y_{im}] yi=[yi1,yi2,...,yim]。

交叉熵函数存在什么问题

之前我的一篇博客提过使用交叉熵函数面对类别不均衡的时候会出现问题,导致结果会偏向更常见的类别,对少类别的识别非常差。

产生这点的原因是因为交叉熵的特点就是"平等"地看待每一个样本,无论什么类别权重都是一样的。所以当正负样本不均衡时,大量简单的负样本会占据主导地位,少量的难样本和正样本就会分不出来。

带权重的交叉熵函数

由于交叉熵函数在应对类别不均衡会出现问题,于是就有了带权重的交叉熵函数。

带权重的交叉熵函数(Weighted cross entropy,WCE)会在计算交叉熵函数的时候,给不同类别前面加入一个权重。

公式
H ( y , y ^ ) = − 1 N ∑ i = 1 N ∑ j = 1 m [ w j y i j l o g ( y ^ i j ) ] H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^N\sum_{j=1}^m[w_{j} y_{ij}log(\widehat{y}_{ij})] H(y,y )=−N1i=1∑Nj=1∑m[wjyijlog(y ij)]

其中 w j w_j wj表示对j类别的权重,用于增大在预测图上占比例小的类别,公式如下:
w j = N − ∑ 1 N y ^ i j ∑ 1 N y ^ i j w_j= \frac{N-\sum_{1}^N\widehat{y}{ij}}{\sum{1}^N\widehat{y}_{ij}} wj=∑1Ny ijN−∑1Ny ij

补充:除了带权重的交叉熵函数能解决样本类别不均衡,还有DiceLoss和FocalLoss能用来解决。

代码

PyTorch的话有自带的库能解决:

python 复制代码
import torch
import torch.nn as nn

class CrossEntropyLoss2d(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(CrossEntropyLoss2d, self).__init__()
        self.nll_loss = nn.CrossEntropyLoss(weight, size_average)

    def forward(self, preds, targets):
        return self.nll_loss(preds, targets)

nn.CrossEntropyLoss(weight, size_average)

weight:可以指定一个一维的Tensor,用来设置每个类别的权重。用C表示类别的个数,Tensor的长度应该为C。

size_average:bool类型数据,默认情况下为True,此时损失是每个minibatch的平均;如果设置成False,则对每个minibatch求和。

相关推荐
泰迪智能科技0124 分钟前
高校深度学习视觉应用平台产品介绍
人工智能·深度学习
盛派网络小助手1 小时前
微信 SDK 更新 Sample,NCF 文档和模板更新,更多更新日志,欢迎解锁
开发语言·人工智能·后端·架构·c#
Eric.Lee20211 小时前
Paddle OCR 中英文检测识别 - python 实现
人工智能·opencv·计算机视觉·ocr检测
cd_farsight1 小时前
nlp初学者怎么入门?需要学习哪些?
人工智能·自然语言处理
AI明说1 小时前
评估大语言模型在药物基因组学问答任务中的表现:PGxQA
人工智能·语言模型·自然语言处理·数智药师·数智药学
Focus_Liu1 小时前
NLP-UIE(Universal Information Extraction)
人工智能·自然语言处理
PowerBI学谦2 小时前
使用copilot轻松将电子邮件转为高效会议
人工智能·copilot
audyxiao0012 小时前
AI一周重要会议和活动概览
人工智能·计算机视觉·数据挖掘·多模态
Jeremy_lf2 小时前
【生成模型之三】ControlNet & Latent Diffusion Models论文详解
人工智能·深度学习·stable diffusion·aigc·扩散模型