Pytorch中的CrossEntropyLoss

CrossEntropyLoss 的输入要求

在 PyTorch 中,CrossEntropyLoss 有以下要求:

  1. 预测值(logits) 的形状为 (N, C, ...),其中:

• N 是样本数(或批次大小)。

• C 是类别数。

• ... 是额外的维度(例如序列长度、图像的高度和宽度等)。

  1. 标签(targets) 的形状为 (N, ...),表示每个样本对应的分类标签。标签是整数索引,范围为 [0, C-1]。

其中重点为:PyTorch 的 CrossEntropyLoss要求输入张量的第二个维度必须是类别的个数,无论是 1D 数据、序列数据还是高维数据,这个要求都是一致的。第二维度始终对应分类任务中的类别数 (num_classes),这是 CrossEntropyLoss 的固定设计。

为什么第二维度必须是类别数?

CrossEntropyLoss 的计算方式基于每个样本的预测概率分布和真实类别标签:

  1. 对于每个样本或位置,CrossEntropyLoss 期望提供一个类别分布的 logits(未经过 softmax 的分值),这个分布存储在输入张量的第二维度。

  2. 损失函数会沿着第二维度(类别维度)计算每个样本的交叉熵损失。

换句话说,第二维度的每个值代表每个类别的 logits,这些 logits 会通过内部的 log_softmax 转换成对数概率,用于交叉熵计算。

相关推荐
Hy行者勇哥25 分钟前
多源数据抽取与推送模块架构设计
人工智能·个人开发
星空的资源小屋31 分钟前
Text Grab,一款OCR 截图文字识别工具
python·django·ocr·scikit-learn
寒秋丶32 分钟前
Milvus:Json字段详解(十)
数据库·人工智能·python·ai·milvus·向量数据库·rag
长桥夜波1 小时前
机器学习日报07
人工智能·机器学习
长桥夜波1 小时前
机器学习日报11
人工智能·机器学习
一个处女座的程序猿4 小时前
LLMs之SLMs:《Small Language Models are the Future of Agentic AI》的翻译与解读
人工智能·自然语言处理·小语言模型·slms
自由随风飘4 小时前
python 题目练习1~5
开发语言·python
fl1768317 小时前
基于python的天气预报系统设计和可视化数据分析源码+报告
开发语言·python·数据分析
档案宝档案管理7 小时前
档案宝:企业合同档案管理的“安全保险箱”与“效率加速器”
大数据·数据库·人工智能·安全·档案·档案管理
闲人编程7 小时前
Python与区块链:如何用Web3.py与以太坊交互
python·安全·区块链·web3.py·以太坊·codecapsule