交叉熵(Cross-Entropy) 是信息论中的一个核心概念,在深度学习中,它是最常用、最重要的损失函数 之一,尤其擅长处理分类问题。
简单理解,交叉熵可以用来衡量两个概率分布之间的差异 。在模型训练中,它衡量的就是:模型预测的概率分布 ,与真实的概率分布(通常是One-hot编码的标签)之间的差距。
- 预测越准,交叉熵损失越小。
- 预测越离谱,交叉熵损失越大。
1. 从公式理解
对于单个样本,交叉熵损失的公式如下:
\\text{CrossEntropy} = -\\sum_{i} y_i \\log(p_i)
- ( y_i ) :表示第 ( i ) 个类别的真实标签。在分类任务中,真实类别为1,其他为0。
- ( p_i ) :表示模型预测样本属于第 ( i ) 个类别的概率,取值范围在0到1之间。
- ( \log ):是自然对数。
由于 ( y_i ) 只在真实类别(比如第 ( c ) 类)上为1,其他全为0,所以这个公式可以简化为:
\\text{Loss} = -\\log(p_c)
这个简化公式非常直观地说明了交叉熵的工作原理 :损失的大小,完全由模型给正确类别预测的概率 ( p_c ) 决定。
- 当模型预测正确类别的概率 ( p_c = 1 ) 时,( -\log(1) = 0 ),损失为0。
- 当 ( p_c = 0.5 ) 时,( -\log(0.5) \approx 0.693 )。
- 当 ( p_c = 0.1 ) 时,( -\log(0.1) \approx 2.302 )。
- 当 ( p_c ) 趋近于 0 时,( -\log(p_c) ) 会趋近于正无穷。
2. 一个具体例子
假设你有一个图像分类 任务,图片是一只猫。分类类别有:猫、狗、鸟。
- 真实标签 (One-hot 编码) :
[1, 0, 0](猫) - 模型A的预测 (很准) :
[0.9, 0.05, 0.05]。损失 = ( -\log(0.9) \approx 0.105 )。 - 模型B的预测 (不太准) :
[0.4, 0.5, 0.1]。损失 = ( -\log(0.4) \approx 0.916 )。 - 模型C的预测 (完全错误) :
[0.05, 0.9, 0.05]。损失 = ( -\log(0.05) \approx 3.0 )。
可以看到,模型A(预测正确概率高)的损失很小,而模型C(预测错误)的损失非常大。通过反向传播,交叉熵损失函数会驱使模型不断提高对正确类别的预测概率。
3. 为什么在分类任务中如此有效?
交叉熵之所以被广泛使用,主要有三个优势:
-
梯度更大,学习更快
与均方误差(MSE)等损失函数相比,当模型的预测结果与真实标签相差甚远时,交叉熵能提供一个很大的梯度,模型会进行大幅度的修正,从而快速改进。而MSE在初期错误率很高时梯度可能会很小,导致学习缓慢。
-
结合Softmax,天然适配多分类
在神经网络中,最后一层输出的原始数值(logits)通常无法直接视为概率。交叉熵损失函数常常与 Softmax 激活函数配合使用。Softmax能把logits转换成和为1的概率分布,这和交叉熵对输入的预期(概率分布)是天作之合。
4. CrossEntropyLoss vs. BCELoss
在使用PyTorch等框架时,你会遇到几种名称相似但功能不同的交叉熵损失,需要注意区分:
| 损失函数 | 适用任务 | 最后一层激活函数 | 标签形式 | 说明 |
|---|---|---|---|---|
nn.CrossEntropyLoss |
多分类 (互斥类别) | 无需 (或Linear) | 类别索引 (如 1) |
最常用 ,内部融合了LogSoftmax和NLLLoss,不需要在输出层再加Softmax。 |
nn.BCELoss |
二分类 或多标签 | Sigmoid | 0/1 概率值 | 需要手动在输出层加Sigmoid。 |
nn.BCEWithLogitsLoss |
二分类 或多标签 | 无需 (或Linear) | 0/1 数值 | 推荐,比BCELoss更数值稳定,内部融合了Sigmoid和BCELoss。 |
代码示例(使用 nn.CrossEntropyLoss 进行多分类):
python
import torch
import torch.nn as nn
# 假设有3个类别,2个样本
logits = torch.tensor([[2.0, 1.0, 0.1], # 模型对样本1的输出
[0.5, 2.5, 0.3]]) # 模型对样本2的输出
# 真实标签:样本1属于第0类,样本2属于第1类
labels = torch.tensor([0, 1])
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 计算损失
loss = criterion(logits, labels)
print(f'Cross Entropy Loss: {loss.item()}') # 输出: 例如 0.774
总结
| 一句话总结 |
|---|
| 交叉熵是分类任务中衡量"预测概率分布"与"真实概率分布"差异的标准工具。模型通过最小化交叉熵来让自己对正确类别的预测概率趋近于1,从而学会正确分类。 |
简单来说,交叉熵是一个聪明的"教练":当你学得差时,它给你严厉的惩罚(大损失,大梯度);当你学得好时,它给你温柔的鼓励(小损失,小梯度),引导模型快速收敛。