深度学习-交叉熵

交叉熵(Cross-Entropy) 是信息论中的一个核心概念,在深度学习中,它是最常用、最重要的损失函数 之一,尤其擅长处理分类问题

简单理解,交叉熵可以用来衡量两个概率分布之间的差异 。在模型训练中,它衡量的就是:模型预测的概率分布 ,与真实的概率分布(通常是One-hot编码的标签)之间的差距。

  • 预测越准,交叉熵损失越小。
  • 预测越离谱,交叉熵损失越大。

1. 从公式理解

对于单个样本,交叉熵损失的公式如下:

\\text{CrossEntropy} = -\\sum_{i} y_i \\log(p_i)

  • ( y_i ) :表示第 ( i ) 个类别的真实标签。在分类任务中,真实类别为1,其他为0。
  • ( p_i ) :表示模型预测样本属于第 ( i ) 个类别的概率,取值范围在0到1之间。
  • ( \log ):是自然对数。

由于 ( y_i ) 只在真实类别(比如第 ( c ) 类)上为1,其他全为0,所以这个公式可以简化为:

\\text{Loss} = -\\log(p_c)

这个简化公式非常直观地说明了交叉熵的工作原理 :损失的大小,完全由模型给正确类别预测的概率 ( p_c ) 决定。

  • 当模型预测正确类别的概率 ( p_c = 1 ) 时,( -\log(1) = 0 ),损失为0。
  • 当 ( p_c = 0.5 ) 时,( -\log(0.5) \approx 0.693 )。
  • 当 ( p_c = 0.1 ) 时,( -\log(0.1) \approx 2.302 )。
  • 当 ( p_c ) 趋近于 0 时,( -\log(p_c) ) 会趋近于正无穷。

2. 一个具体例子

假设你有一个图像分类 任务,图片是一只。分类类别有:猫、狗、鸟。

  • 真实标签 (One-hot 编码)[1, 0, 0] (猫)
  • 模型A的预测 (很准)[0.9, 0.05, 0.05] 。损失 = ( -\log(0.9) \approx 0.105 )。
  • 模型B的预测 (不太准)[0.4, 0.5, 0.1] 。损失 = ( -\log(0.4) \approx 0.916 )。
  • 模型C的预测 (完全错误)[0.05, 0.9, 0.05] 。损失 = ( -\log(0.05) \approx 3.0 )。

可以看到,模型A(预测正确概率高)的损失很小,而模型C(预测错误)的损失非常大。通过反向传播,交叉熵损失函数会驱使模型不断提高对正确类别的预测概率。

3. 为什么在分类任务中如此有效?

交叉熵之所以被广泛使用,主要有三个优势:

  1. 梯度更大,学习更快

    与均方误差(MSE)等损失函数相比,当模型的预测结果与真实标签相差甚远时,交叉熵能提供一个很大的梯度,模型会进行大幅度的修正,从而快速改进。而MSE在初期错误率很高时梯度可能会很小,导致学习缓慢。

  2. 结合Softmax,天然适配多分类

    在神经网络中,最后一层输出的原始数值(logits)通常无法直接视为概率。交叉熵损失函数常常与 Softmax 激活函数配合使用。Softmax能把logits转换成和为1的概率分布,这和交叉熵对输入的预期(概率分布)是天作之合。

4. CrossEntropyLoss vs. BCELoss

在使用PyTorch等框架时,你会遇到几种名称相似但功能不同的交叉熵损失,需要注意区分:

损失函数 适用任务 最后一层激活函数 标签形式 说明
nn.CrossEntropyLoss 多分类 (互斥类别) 无需 (或Linear) 类别索引 (如 1 最常用 ,内部融合了LogSoftmax和NLLLoss,不需要在输出层再加Softmax。
nn.BCELoss 二分类 或多标签 Sigmoid 0/1 概率值 需要手动在输出层加Sigmoid。
nn.BCEWithLogitsLoss 二分类 或多标签 无需 (或Linear) 0/1 数值 推荐,比BCELoss更数值稳定,内部融合了Sigmoid和BCELoss。

代码示例(使用 nn.CrossEntropyLoss 进行多分类):

python 复制代码
import torch
import torch.nn as nn

# 假设有3个类别,2个样本
logits = torch.tensor([[2.0, 1.0, 0.1],      # 模型对样本1的输出
                       [0.5, 2.5, 0.3]])      # 模型对样本2的输出
# 真实标签:样本1属于第0类,样本2属于第1类
labels = torch.tensor([0, 1])

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 计算损失
loss = criterion(logits, labels)
print(f'Cross Entropy Loss: {loss.item()}')   # 输出: 例如 0.774

总结

一句话总结
交叉熵是分类任务中衡量"预测概率分布"与"真实概率分布"差异的标准工具。模型通过最小化交叉熵来让自己对正确类别的预测概率趋近于1,从而学会正确分类。

简单来说,交叉熵是一个聪明的"教练":当你学得差时,它给你严厉的惩罚(大损失,大梯度);当你学得好时,它给你温柔的鼓励(小损失,小梯度),引导模型快速收敛。

相关推荐
Element_南笙2 小时前
VGG网络-深度学习经典架构解析
网络·深度学习·架构
陶陶然Yay4 小时前
神经网络卷积层梯度公式推导
人工智能·深度学习·神经网络
隔壁大炮6 小时前
Day06-08.CNN概述介绍
人工智能·pytorch·深度学习·算法·计算机视觉·cnn·numpy
β添砖java6 小时前
深度学习(8)过拟合、欠拟合
人工智能·深度学习
QiZhang | UESTC7 小时前
从基础 RoPE 到 YaRN:源码学习路线揭秘
pytorch·深度学习·学习
HackTorjan8 小时前
深度解析雪花算法及其高性能优化策略
人工智能·深度学习·算法·性能优化·dreamweaver
STLearner9 小时前
AI论文速读 | QuitoBench:支付宝高质量开源时间序列预测基准测试集
大数据·论文阅读·人工智能·深度学习·学习·机器学习·开源
aidesignplus10 小时前
从平方到线性:Mamba如何挑战Transformer的长序列效率瓶颈?
人工智能·python·深度学习·vim·transformer
AI医影跨模态组学11 小时前
Ann Oncol(IF=65.4)广东省人民医院放射科刘再毅&阿里巴巴达摩院等团队:基于非增强CT与深度学习的结直肠癌检测
人工智能·深度学习·论文·医学影像