深度学习| 交叉熵损失函数(包含代码实现)

前言:因为我深度学习主要用于图像分割,所以交叉熵损失函数主要侧重在图像分割。

交叉熵损失函数

介绍

交叉熵损失函数(Cross-Entropy Loss)是深度学习中常用的一种损失函数,特别是处理分类问题。该函数起源于信息论中的交叉熵概念,用于衡量两个概率分布间的差异,可以衡量预估概率分布和真实样品对应概率分布之间的差异。

从这个概念上理解会让人感觉很抽象,建议直接从公式来进行理解。

公式

二分类
H ( y , y ^ ) = − 1 N ∑ i = 1 N y i l o g ( y \^ i ) + ( 1 − y i ) l o g ( 1 − y \^ i ) H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^Ny_ilog(\\widehat{y}_i)+(1-y_i)log(1-\\widehat{y}_i) H(y,y )=−N1i=1∑Nyilog(y i)+(1−yi)log(1−y i)

其中y是真实标签, y ^ \widehat{y} y 是预测值,N是样本的数量。每个样本都会计算一个损失,然后对所有样本的损失求平均。

对于图像来说,这里的N可以看作是图像像素点的个数, y ^ \widehat{y} y 是预测每个像素点的值,y是每个像素点标签的值,一张图像的交叉熵其实就是计算每个像素点预测值和标签插值的平均。

多分类

多分类就是二分类的延申,理解的原理都是一样的。
H ( y , y ^ ) = − 1 N ∑ i = 1 N y i 1 l o g ( y \^ i 1 ) + y i 2 l o g ( y \^ i 2 ) + . . . + y i m l o g ( y \^ i m ) H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^Ny_{i1}log(\\widehat{y}_{i1})+y_{i2}log(\\widehat{y}_{i2})+...+y_{im}log(\\widehat{y}_{im}) H(y,y )=−N1i=1∑Nyi1log(y i1)+yi2log(y i2)+...+yimlog(y im)
H ( y , y ^ ) = − 1 N ∑ i = 1 N ∑ j = 1 m y i j l o g ( y \^ i j ) H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^N\sum_{j=1}^my_{ij}log(\\widehat{y}_{ij}) H(y,y )=−N1i=1∑Nj=1∑myijlog(y ij)

这里的 y ^ \widehat{y} y 和y是one-hot编码目标向量,例如 y i = y i 1 , y i 2 , . . . , y i m y_i=y_{i1},y_{i2},...,y_{im} yi=yi1,yi2,...,yim

交叉熵函数存在什么问题

之前我的一篇博客提过使用交叉熵函数面对类别不均衡的时候会出现问题,导致结果会偏向更常见的类别,对少类别的识别非常差。

产生这点的原因是因为交叉熵的特点就是"平等"地看待每一个样本,无论什么类别权重都是一样的。所以当正负样本不均衡时,大量简单的负样本会占据主导地位,少量的难样本和正样本就会分不出来。

带权重的交叉熵函数

由于交叉熵函数在应对类别不均衡会出现问题,于是就有了带权重的交叉熵函数。

带权重的交叉熵函数(Weighted cross entropy,WCE)会在计算交叉熵函数的时候,给不同类别前面加入一个权重。

公式
H ( y , y ^ ) = − 1 N ∑ i = 1 N ∑ j = 1 m w j y i j l o g ( y \^ i j ) H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^N\sum_{j=1}^mw_{j} y_{ij}log(\\widehat{y}_{ij}) H(y,y )=−N1i=1∑Nj=1∑mwjyijlog(y ij)

其中 w j w_j wj表示对j类别的权重,用于增大在预测图上占比例小的类别,公式如下:
w j = N − ∑ 1 N y ^ i j ∑ 1 N y ^ i j w_j= \frac{N-\sum_{1}^N\widehat{y}{ij}}{\sum{1}^N\widehat{y}_{ij}} wj=∑1Ny ijN−∑1Ny ij

补充:除了带权重的交叉熵函数能解决样本类别不均衡,还有DiceLoss和FocalLoss能用来解决。

代码

PyTorch的话有自带的库能解决:

python 复制代码
import torch
import torch.nn as nn

class CrossEntropyLoss2d(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(CrossEntropyLoss2d, self).__init__()
        self.nll_loss = nn.CrossEntropyLoss(weight, size_average)

    def forward(self, preds, targets):
        return self.nll_loss(preds, targets)

nn.CrossEntropyLoss(weight, size_average)

weight:可以指定一个一维的Tensor,用来设置每个类别的权重。用C表示类别的个数,Tensor的长度应该为C。

size_average:bool类型数据,默认情况下为True,此时损失是每个minibatch的平均;如果设置成False,则对每个minibatch求和。

相关推荐
Lei活在当下1 天前
【AI手记系列-2026/6/18】iSparto & Harness,Caveman 以及AI时代的生存指南
人工智能·llm·openai
冬奇Lab1 天前
每日一个开源项目(第134篇):Zvec - 阿里开源的嵌入式向量数据库,向量搜索界的 SQLite
数据库·人工智能·llm
冬奇Lab1 天前
Agent 系列(22):Context Engineering 深度——三种上下文管理策略的量化对比
人工智能·agent
hboot1 天前
AI工程师第二课 - 数据处理
人工智能·python·数据分析
程序员cxuan1 天前
DeepSeek 杀入多模态,识图功能正式上线!
人工智能·后端·程序员
米小虾1 天前
告别单打独斗:2026年多Agent协作架构实战指南
人工智能·agent
IT_陈寒1 天前
SpringBoot这个自动配置坑我跳了三次
前端·人工智能·后端
Larcher1 天前
AI Loop:让AI像人一样自主完成任务的核心机制
javascript·人工智能·设计模式
牧艺1 天前
从零到协同:构建类飞书在线文档系统的五个技术重难点
前端·人工智能