【Pytorch】分类问题交叉熵

1️⃣ 为什么分类问题不用 MSE(均方误差)?

表格

复制

场景 标签 预测 MSE 损失
分类 [0,0,1] [0.3,0.3,0.4] (0.4-1)²=0.36
分类 [0,0,1] [0.1,0.2,0.7] (0.7-1)²=0.09

看起来合理,但:

  1. sigmoid/softmax 输出在 0/1 附近梯度几乎为零梯度消失

  2. MSE 把"概率差"当数值差" → 不符合概率直觉;

  3. 收敛慢,还容易卡在鞍点。


2️⃣ 交叉熵(Cross Entropy)思想

一句话 :衡量「真实分布 p」与「预测分布 q」之间的信息差距

公式(离散版):

CE(p,q) = − Σ p(i) log q(i)

  • p 是 one-hot 标签(比如 [0,0,1])

  • q 是 softmax 输出(比如 [0.1,0.2,0.7])

因为 p 只有一个 1,其余为 0,所以求和只剩一项

CE = − log q(正确类)

直观

  • 若 q(正确类)=0.7 → CE ≈ 0.36

  • 若 q(正确类)=0.98 → CE ≈ 0.02
    预测越准,损失越小 ,且梯度不饱和(后面会算给你看)。


3️⃣ 手推一条二分类例子

表格

复制

样本 真实 y 预测 p
1 0.8
0 0.1

二元交叉熵(BCE):

L = − [y log p + (1−y) log(1−p)]

样本 1(猫):

L = − [1·log0.8 + 0·log0.2] = −log0.8 ≈ 0.223

样本 2(狗):

L = − [0·log0.1 + 1·log0.9] = −log0.9 ≈ 0.105

平均损失 ≈ 0.164,预测越离谱,值越大


4️⃣ PyTorch 一行代码算完

Python

复制代码
import torch.nn.functional as F

logits = torch.tensor([[1.0, 2.0, 0.5]])   # 模型输出(未归一化)
target = torch.tensor([1])                  # 正确类别索引

loss = F.cross_entropy(logits, target)
print(loss.item())          #  tensor(0.8309)

内部干了啥

  1. softmax(logits) → 概率

  2. log(softmax) → 对数概率

  3. -log q(正确类) → 损失


5️⃣ 数值稳定性技巧

不要手写:

Python

复制

复制代码
prob = F.softmax(logits)
log_prob = torch.log(prob)
loss = F.nll_loss(log_prob, target)

推荐直接用:

Python

复制代码
loss = F.cross_entropy(logits, target)

内部实现 log-sum-exp 技巧 ,避免 log(softmax) 造成数值溢出。


6️⃣ 对比实验(直观感受)

表格

复制

方法 损失曲线 梯度大小 收敛速度
MSE 平坦区早 极小
CE 无平坦区 稳定

7️⃣ 小结口诀(背下来)

分类用 CE,回归用 MSE
CE = −log q(对类)
PyTorch:F.cross_entropy(logits, target)
别手写 softmax+log!


8️⃣ 课后 5 分钟动手

  1. F.cross_entropy 算一条三分类样本。

  2. logits 乘 10 再算一次,观察损失变化。

  3. 对比 F.mse_lossF.cross_entropy 的梯度大小(.grad)。

相关推荐
前端摸鱼匠1 分钟前
面试题3:自注意力机制(Self-Attention)的计算流程是什么?
人工智能·ai·面试·职场和发展
出门吃三碗饭5 分钟前
CARLA: 如何在 CARLA 中回放自动驾驶场景
人工智能·机器学习·自动驾驶
Axis tech6 分钟前
第二届人形机器人半程马拉松即将于4月开赛,对比去年技术进步有哪些?
人工智能·机器人
志栋智能6 分钟前
超自动化巡检,如何成为业务稳定的“压舱石”?
大数据·运维·网络·人工智能·自动化
lifallen8 分钟前
从零推导一个现代 ReAct Agent框架
人工智能·算法·语言模型
我的offer在哪里8 分钟前
腾讯 Ardot 深度博客:AI 重构 UI/UX 全链路,从 “描述即界面” 到设计工业化的腾讯范式
人工智能·ui·重构
AEIC学术交流中心8 分钟前
【快速EI检索 | IEEE出版】第六届信号图像处理与通信国际学术会议(ICSIPC 2026)
图像处理·人工智能
康世行10 分钟前
IDEA集成AI辅助工具推荐(好用不卡顿)
java·人工智能·intellij-idea
柯儿的天空10 分钟前
【OpenClaw 全面解析:从零到精通】第007篇:流量枢纽——OpenClaw Gateway 网关深度解析
人工智能·gpt·ai作画·gateway·aigc·ai编程·ai写作
人道领域11 分钟前
2026年Q1大模型深度复盘:OpenAI,Gemini2.0,字节跳动,与“多模态Agent”元年
人工智能·ai·google·chatgpt·gemini