深度学习-交叉熵

交叉熵(Cross-Entropy) 是信息论中的一个核心概念,在深度学习中,它是最常用、最重要的损失函数 之一,尤其擅长处理分类问题

简单理解,交叉熵可以用来衡量两个概率分布之间的差异 。在模型训练中,它衡量的就是:模型预测的概率分布 ,与真实的概率分布(通常是One-hot编码的标签)之间的差距。

  • 预测越准,交叉熵损失越小。
  • 预测越离谱,交叉熵损失越大。

1. 从公式理解

对于单个样本,交叉熵损失的公式如下:

\\text{CrossEntropy} = -\\sum_{i} y_i \\log(p_i)

  • ( y_i ) :表示第 ( i ) 个类别的真实标签。在分类任务中,真实类别为1,其他为0。
  • ( p_i ) :表示模型预测样本属于第 ( i ) 个类别的概率,取值范围在0到1之间。
  • ( \log ):是自然对数。

由于 ( y_i ) 只在真实类别(比如第 ( c ) 类)上为1,其他全为0,所以这个公式可以简化为:

\\text{Loss} = -\\log(p_c)

这个简化公式非常直观地说明了交叉熵的工作原理 :损失的大小,完全由模型给正确类别预测的概率 ( p_c ) 决定。

  • 当模型预测正确类别的概率 ( p_c = 1 ) 时,( -\log(1) = 0 ),损失为0。
  • 当 ( p_c = 0.5 ) 时,( -\log(0.5) \approx 0.693 )。
  • 当 ( p_c = 0.1 ) 时,( -\log(0.1) \approx 2.302 )。
  • 当 ( p_c ) 趋近于 0 时,( -\log(p_c) ) 会趋近于正无穷。

2. 一个具体例子

假设你有一个图像分类 任务,图片是一只。分类类别有:猫、狗、鸟。

  • 真实标签 (One-hot 编码)[1, 0, 0] (猫)
  • 模型A的预测 (很准)[0.9, 0.05, 0.05] 。损失 = ( -\log(0.9) \approx 0.105 )。
  • 模型B的预测 (不太准)[0.4, 0.5, 0.1] 。损失 = ( -\log(0.4) \approx 0.916 )。
  • 模型C的预测 (完全错误)[0.05, 0.9, 0.05] 。损失 = ( -\log(0.05) \approx 3.0 )。

可以看到,模型A(预测正确概率高)的损失很小,而模型C(预测错误)的损失非常大。通过反向传播,交叉熵损失函数会驱使模型不断提高对正确类别的预测概率。

3. 为什么在分类任务中如此有效?

交叉熵之所以被广泛使用,主要有三个优势:

  1. 梯度更大,学习更快

    与均方误差(MSE)等损失函数相比,当模型的预测结果与真实标签相差甚远时,交叉熵能提供一个很大的梯度,模型会进行大幅度的修正,从而快速改进。而MSE在初期错误率很高时梯度可能会很小,导致学习缓慢。

  2. 结合Softmax,天然适配多分类

    在神经网络中,最后一层输出的原始数值(logits)通常无法直接视为概率。交叉熵损失函数常常与 Softmax 激活函数配合使用。Softmax能把logits转换成和为1的概率分布,这和交叉熵对输入的预期(概率分布)是天作之合。

4. CrossEntropyLoss vs. BCELoss

在使用PyTorch等框架时,你会遇到几种名称相似但功能不同的交叉熵损失,需要注意区分:

损失函数 适用任务 最后一层激活函数 标签形式 说明
nn.CrossEntropyLoss 多分类 (互斥类别) 无需 (或Linear) 类别索引 (如 1 最常用 ,内部融合了LogSoftmax和NLLLoss,不需要在输出层再加Softmax。
nn.BCELoss 二分类 或多标签 Sigmoid 0/1 概率值 需要手动在输出层加Sigmoid。
nn.BCEWithLogitsLoss 二分类 或多标签 无需 (或Linear) 0/1 数值 推荐,比BCELoss更数值稳定,内部融合了Sigmoid和BCELoss。

代码示例(使用 nn.CrossEntropyLoss 进行多分类):

python 复制代码
import torch
import torch.nn as nn

# 假设有3个类别,2个样本
logits = torch.tensor([[2.0, 1.0, 0.1],      # 模型对样本1的输出
                       [0.5, 2.5, 0.3]])      # 模型对样本2的输出
# 真实标签:样本1属于第0类,样本2属于第1类
labels = torch.tensor([0, 1])

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 计算损失
loss = criterion(logits, labels)
print(f'Cross Entropy Loss: {loss.item()}')   # 输出: 例如 0.774

总结

一句话总结
交叉熵是分类任务中衡量"预测概率分布"与"真实概率分布"差异的标准工具。模型通过最小化交叉熵来让自己对正确类别的预测概率趋近于1,从而学会正确分类。

简单来说,交叉熵是一个聪明的"教练":当你学得差时,它给你严厉的惩罚(大损失,大梯度);当你学得好时,它给你温柔的鼓励(小损失,小梯度),引导模型快速收敛。

相关推荐
AI技术控7 小时前
《Transformers are Inherently Succinct》论文解读:从“能表达什么”到“多紧凑地表达”
人工智能·python·深度学习·机器学习·自然语言处理
Robot_Nav9 小时前
深度学习与强化学习面试八股文知识点汇总
人工智能·深度学习·强化学习
一颗牙牙11 小时前
安装mmcv
开发语言·python·深度学习
paperClub12 小时前
AACR 2026 · AI诊断:深度学习在肿瘤早期检测中的应用
人工智能·深度学习
AI医影跨模态组学13 小时前
NPJ Precis Oncol(IF=8)中国科学院深圳先进技术研究院吴红艳教授等团队:深度可解释放射基因组学解析乳腺MRI肿瘤微环境
人工智能·深度学习·论文·医学·医学影像
大模型最新论文速读14 小时前
05-15 · LLM 最新论文速览
论文阅读·人工智能·深度学习·机器学习·自然语言处理
数智工坊14 小时前
【DINOv2论文阅读】:无需监督的通用视觉特征提取器——机器人VLA模型的“眼睛“基石
论文阅读·人工智能·深度学习·计算机视觉·transformer
一切皆是因缘际会14 小时前
AI低代码开发实战:轻量化部署与多场景落地
人工智能·深度学习·低代码·机器学习·ai·架构
EnCi Zheng15 小时前
09-斯坦福CS336作业 [特殊字符]
人工智能·pytorch·python·深度学习·神经网络
Hali_Botebie15 小时前
【量化】Post-training quantization for vision transformer.
人工智能·深度学习·transformer