深度学习 - softmax交叉熵损失计算

示例代码

python 复制代码
import torch
from torch import nn

# 多分类交叉熵损失,使用nn.CrossEntropyLoss()实现。nn.CrossEntropyLoss()=softmax + 损失计算
def test1():
    # 设置真实值: 可以是热编码后的结果也可以不进行热编码
    # y_true = torch.tensor([[0, 1, 0], [0, 0, 1]], dtype=torch.float32)
    # 注意的类型必须是64位整型数据
    y_true = torch.tensor([1, 2], dtype=torch.int64)
    y_pred = torch.tensor([[0.2, 0.6, 0.2], [0.1, 0.8, 0.1]], dtype=torch.float32)
    # 实例化交叉熵损失
    loss = nn.CrossEntropyLoss()
    # 计算损失结果
    my_loss = loss(y_pred, y_true).numpy()
    print('loss:', my_loss)

输入数据

python 复制代码
y_true = torch.tensor([1, 2], dtype=torch.int64)
y_pred = torch.tensor([[0.2, 0.6, 0.2], [0.1, 0.8, 0.1]], dtype=torch.float32)
  • y_true:真实标签,包含两个样本,分别属于类别 1 和类别 2。
  • y_pred:预测的概率分布,包含两个样本,每个样本有三个类别的预测值。

Step 1: Softmax 变换

Softmax 函数将原始的预测值转换为概率分布。Softmax 的公式如下:

Softmax ( x i ) = e x i ∑ j e x j \text{Softmax}(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}} Softmax(xi)=∑jexjexi

对于第一个样本 y_pred = [0.2, 0.6, 0.2]

  1. 计算指数:

e 0.2 ≈ 1.221 , e 0.6 ≈ 1.822 , e 0.2 ≈ 1.221 e^{0.2} \approx 1.221, \quad e^{0.6} \approx 1.822, \quad e^{0.2} \approx 1.221 e0.2≈1.221,e0.6≈1.822,e0.2≈1.221

  1. 计算 Softmax 分母:

∑ j e x j = 1.221 + 1.822 + 1.221 = 4.264 \sum_{j} e^{x_j} = 1.221 + 1.822 + 1.221 = 4.264 j∑exj=1.221+1.822+1.221=4.264

  1. 计算 Softmax 分子并得到结果:

Softmax ( 0.2 ) = 1.221 4.264 ≈ 0.286 \text{Softmax}(0.2) = \frac{1.221}{4.264} \approx 0.286 Softmax(0.2)=4.2641.221≈0.286

Softmax ( 0.6 ) = 1.822 4.264 ≈ 0.427 \text{Softmax}(0.6) = \frac{1.822}{4.264} \approx 0.427 Softmax(0.6)=4.2641.822≈0.427

Softmax ( 0.2 ) = 1.221 4.264 ≈ 0.286 \text{Softmax}(0.2) = \frac{1.221}{4.264} \approx 0.286 Softmax(0.2)=4.2641.221≈0.286

Softmax 结果为 [[0.286, 0.427, 0.286]]

对于第二个样本 y_pred = [0.1, 0.8, 0.1]

  1. 计算指数:

e 0.1 ≈ 1.105 , e 0.8 ≈ 2.225 , e 0.1 ≈ 1.105 e^{0.1} \approx 1.105, \quad e^{0.8} \approx 2.225, \quad e^{0.1} \approx 1.105 e0.1≈1.105,e0.8≈2.225,e0.1≈1.105

  1. 计算 Softmax 分母:

∑ j e x j = 1.105 + 2.225 + 1.105 = 4.435 \sum_{j} e^{x_j} = 1.105 + 2.225 + 1.105 = 4.435 j∑exj=1.105+2.225+1.105=4.435

  1. 计算 Softmax 分子并得到结果:

Softmax ( 0.1 ) = 1.105 4.435 ≈ 0.249 \text{Softmax}(0.1) = \frac{1.105}{4.435} \approx 0.249 Softmax(0.1)=4.4351.105≈0.249

Softmax ( 0.8 ) = 2.225 4.435 ≈ 0.502 \text{Softmax}(0.8) = \frac{2.225}{4.435} \approx 0.502 Softmax(0.8)=4.4352.225≈0.502

Softmax ( 0.1 ) = 1.105 4.435 ≈ 0.249 \text{Softmax}(0.1) = \frac{1.105}{4.435} \approx 0.249 Softmax(0.1)=4.4351.105≈0.249

Softmax 结果为 [[0.249, 0.502, 0.249]]

Step 2: 计算交叉熵损失

交叉熵损失的公式为:

CrossEntropyLoss ( p , y ) = − ∑ i = 1 N y i log ⁡ ( p i ) \text{CrossEntropyLoss}(p, y) = -\sum_{i=1}^{N} y_i \log(p_i) CrossEntropyLoss(p,y)=−i=1∑Nyilog(pi)

对于第一个样本,真实标签为 1(y_true = 1),Softmax 后的预测概率分布为 [0.286, 0.427, 0.286]

CrossEntropyLoss = − [ 0 ⋅ log ⁡ ( 0.286 ) + 1 ⋅ log ⁡ ( 0.427 ) + 0 ⋅ log ⁡ ( 0.286 ) ] \text{CrossEntropyLoss} = - [0 \cdot \log(0.286) + 1 \cdot \log(0.427) + 0 \cdot \log(0.286)] CrossEntropyLoss=−[0⋅log(0.286)+1⋅log(0.427)+0⋅log(0.286)]

由于 (0 \cdot \log(0.286) = 0),忽略后我们得到:

log ⁡ ( 0.427 ) ≈ 0.851 \log(0.427) \approx 0.851 log(0.427)≈0.851

对于第二个样本,真实标签为 2(y_true = 2),Softmax 后的预测概率分布为 [0.249, 0.502, 0.249]

CrossEntropyLoss = − [ 0 ⋅ log ⁡ ( 0.249 ) + 0 ⋅ log ⁡ ( 0.502 ) + 1 ⋅ log ⁡ ( 0.249 ) ] \text{CrossEntropyLoss} = - [0 \cdot \log(0.249) + 0 \cdot \log(0.502) + 1 \cdot \log(0.249)] CrossEntropyLoss=−[0⋅log(0.249)+0⋅log(0.502)+1⋅log(0.249)]

由于 (0 \cdot \log(0.249) = 0) 和 (0 \cdot \log(0.502) = 0),忽略后我们得到:

log ⁡ ( 0.249 ) ≈ 1.390 \log(0.249) \approx 1.390 log(0.249)≈1.390

Step 3: 平均损失

计算平均损失:

平均损失 = 0.851 + 1.390 2 ≈ 2.241 2 ≈ 1.1205 \text{平均损失} = \frac{0.851 + 1.390}{2} \approx \frac{2.241}{2} \approx 1.1205 平均损失=20.851+1.390≈22.241≈1.1205

因此,最终的交叉熵损失 my_loss 约为 1.1205。

相关推荐
何双新1 小时前
Odoo AI 智能查询系统
前端·人工智能·python
生命是有光的3 小时前
【机器学习】机器学习算法
人工智能·机器学习
Blossom.1183 小时前
把 AI 塞进「自行车码表」——基于 MEMS 的 3D 地形预测码表
人工智能·python·深度学习·opencv·机器学习·计算机视觉·3d
小鹿的工作手帐6 小时前
有鹿机器人:为城市描绘清洁新图景的智能使者
人工智能·科技·机器人
TechubNews7 小时前
香港数字资产交易市场蓬勃发展,监管与创新并驾齐驱
人工智能·区块链
DogDaoDao8 小时前
用PyTorch实现多类图像分类:从原理到实际操作
图像处理·人工智能·pytorch·python·深度学习·分类·图像分类
小和尚同志8 小时前
450 star 的神级提示词管理工具 AI-Gist,让提示词不再吃灰
人工智能·aigc
这张生成的图像能检测吗9 小时前
(论文速读)Prompt Depth Anything:让深度估计进入“提示时代“
深度学习·计算机视觉·深度估计
金井PRATHAMA10 小时前
大脑的藏宝图——神经科学如何为自然语言处理(NLP)的深度语义理解绘制新航线
人工智能·自然语言处理
大学生毕业题目10 小时前
毕业项目推荐:28-基于yolov8/yolov5/yolo11的电塔危险物品检测识别系统(Python+卷积神经网络)
人工智能·python·yolo·cnn·pyqt·电塔·危险物品