深度学习| 交叉熵损失函数(包含代码实现)

前言:因为我深度学习主要用于图像分割,所以交叉熵损失函数主要侧重在图像分割。

交叉熵损失函数

介绍

交叉熵损失函数(Cross-Entropy Loss)是深度学习中常用的一种损失函数,特别是处理分类问题。该函数起源于信息论中的交叉熵概念,用于衡量两个概率分布间的差异,可以衡量预估概率分布和真实样品对应概率分布之间的差异。

从这个概念上理解会让人感觉很抽象,建议直接从公式来进行理解。

公式

二分类
H ( y , y ^ ) = − 1 N ∑ i = 1 N y i l o g ( y \^ i ) + ( 1 − y i ) l o g ( 1 − y \^ i ) H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^Ny_ilog(\\widehat{y}_i)+(1-y_i)log(1-\\widehat{y}_i) H(y,y )=−N1i=1∑Nyilog(y i)+(1−yi)log(1−y i)

其中y是真实标签, y ^ \widehat{y} y 是预测值,N是样本的数量。每个样本都会计算一个损失,然后对所有样本的损失求平均。

对于图像来说,这里的N可以看作是图像像素点的个数, y ^ \widehat{y} y 是预测每个像素点的值,y是每个像素点标签的值,一张图像的交叉熵其实就是计算每个像素点预测值和标签插值的平均。

多分类

多分类就是二分类的延申,理解的原理都是一样的。
H ( y , y ^ ) = − 1 N ∑ i = 1 N y i 1 l o g ( y \^ i 1 ) + y i 2 l o g ( y \^ i 2 ) + . . . + y i m l o g ( y \^ i m ) H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^Ny_{i1}log(\\widehat{y}_{i1})+y_{i2}log(\\widehat{y}_{i2})+...+y_{im}log(\\widehat{y}_{im}) H(y,y )=−N1i=1∑Nyi1log(y i1)+yi2log(y i2)+...+yimlog(y im)
H ( y , y ^ ) = − 1 N ∑ i = 1 N ∑ j = 1 m y i j l o g ( y \^ i j ) H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^N\sum_{j=1}^my_{ij}log(\\widehat{y}_{ij}) H(y,y )=−N1i=1∑Nj=1∑myijlog(y ij)

这里的 y ^ \widehat{y} y 和y是one-hot编码目标向量,例如 y i = y i 1 , y i 2 , . . . , y i m y_i=y_{i1},y_{i2},...,y_{im} yi=yi1,yi2,...,yim

交叉熵函数存在什么问题

之前我的一篇博客提过使用交叉熵函数面对类别不均衡的时候会出现问题,导致结果会偏向更常见的类别,对少类别的识别非常差。

产生这点的原因是因为交叉熵的特点就是"平等"地看待每一个样本,无论什么类别权重都是一样的。所以当正负样本不均衡时,大量简单的负样本会占据主导地位,少量的难样本和正样本就会分不出来。

带权重的交叉熵函数

由于交叉熵函数在应对类别不均衡会出现问题,于是就有了带权重的交叉熵函数。

带权重的交叉熵函数(Weighted cross entropy,WCE)会在计算交叉熵函数的时候,给不同类别前面加入一个权重。

公式
H ( y , y ^ ) = − 1 N ∑ i = 1 N ∑ j = 1 m w j y i j l o g ( y \^ i j ) H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^N\sum_{j=1}^mw_{j} y_{ij}log(\\widehat{y}_{ij}) H(y,y )=−N1i=1∑Nj=1∑mwjyijlog(y ij)

其中 w j w_j wj表示对j类别的权重,用于增大在预测图上占比例小的类别,公式如下:
w j = N − ∑ 1 N y ^ i j ∑ 1 N y ^ i j w_j= \frac{N-\sum_{1}^N\widehat{y}{ij}}{\sum{1}^N\widehat{y}_{ij}} wj=∑1Ny ijN−∑1Ny ij

补充:除了带权重的交叉熵函数能解决样本类别不均衡,还有DiceLoss和FocalLoss能用来解决。

代码

PyTorch的话有自带的库能解决:

python 复制代码
import torch
import torch.nn as nn

class CrossEntropyLoss2d(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(CrossEntropyLoss2d, self).__init__()
        self.nll_loss = nn.CrossEntropyLoss(weight, size_average)

    def forward(self, preds, targets):
        return self.nll_loss(preds, targets)

nn.CrossEntropyLoss(weight, size_average)

weight:可以指定一个一维的Tensor,用来设置每个类别的权重。用C表示类别的个数,Tensor的长度应该为C。

size_average:bool类型数据,默认情况下为True,此时损失是每个minibatch的平均;如果设置成False,则对每个minibatch求和。

相关推荐
Raink老师4 小时前
【AI面试临阵磨枪-79】实时数据 RAG:订单、商家、物流、天气、动态库存
人工智能·面试·职场和发展
脑极体5 小时前
点亮星河AI+鸿蒙,一座艺术场馆的日神觉醒
人工智能·华为·harmonyos
Cosolar5 小时前
Chroma向量库面试学习指南
数据库·人工智能·面试·职场和发展·数据库架构
BUG指挥官5 小时前
Claude Code的自动化编程
人工智能
意图共鸣5 小时前
意图共鸣科技《认知智能白皮书》——感知与执行分离:认知架构(CA)如何重塑大模型底层结构
人工智能·架构
等一个人的@5 小时前
让数据自己开口:数睿通智库新增智能问数模块
人工智能·自然语言处理
ZGi.ai5 小时前
人工审查节点:让自动化工作流多一步人工把关
运维·人工智能·自动化·人机协同·智能体工作流·人工审查
王莎莎-MinerU6 小时前
MinerU 深度技术解析:从架构原理到生产部署的全面指南
css·人工智能·自然语言处理·架构·ocr·个人开发
盘古信息IMS6 小时前
盘古信息IMS V6 8.0重磅发布:以薪火AI数智平台点燃离散制造数智化引擎
大数据·人工智能·制造
weilaieqi16 小时前
从音响制造到AI家庭娱乐生态:不见不散AI智能K歌音响亮相第二十届深圳国际金融博览会
人工智能·制造·娱乐