Pytorch中的CrossEntropyLoss

CrossEntropyLoss 的输入要求

在 PyTorch 中,CrossEntropyLoss 有以下要求:

  1. 预测值(logits) 的形状为 (N, C, ...),其中:

• N 是样本数(或批次大小)。

• C 是类别数。

• ... 是额外的维度(例如序列长度、图像的高度和宽度等)。

  1. 标签(targets) 的形状为 (N, ...),表示每个样本对应的分类标签。标签是整数索引,范围为 [0, C-1]。

其中重点为:PyTorch 的 CrossEntropyLoss要求输入张量的第二个维度必须是类别的个数,无论是 1D 数据、序列数据还是高维数据,这个要求都是一致的。第二维度始终对应分类任务中的类别数 (num_classes),这是 CrossEntropyLoss 的固定设计。

为什么第二维度必须是类别数?

CrossEntropyLoss 的计算方式基于每个样本的预测概率分布和真实类别标签:

  1. 对于每个样本或位置,CrossEntropyLoss 期望提供一个类别分布的 logits(未经过 softmax 的分值),这个分布存储在输入张量的第二维度。

  2. 损失函数会沿着第二维度(类别维度)计算每个样本的交叉熵损失。

换句话说,第二维度的每个值代表每个类别的 logits,这些 logits 会通过内部的 log_softmax 转换成对数概率,用于交叉熵计算。

相关推荐
汽车仪器仪表相关领域1 分钟前
半自动精准检测,降本增效之选——NHD-1050半自动远、近光检测仪项目实战分享
服务器·人工智能·功能测试·安全·可用性测试
码农很忙1 分钟前
2026年GEO服务商深度探析:AI时代品牌“算法战”的突围路径
人工智能
min1811234562 分钟前
产品开发跨职能流程图在线生成工具
人工智能·microsoft·信息可视化·架构·机器人·流程图
hit56实验室5 分钟前
如何调整vad参数
人工智能
柠檬叶子C6 分钟前
【Python】解决 No module named ‘imp‘ 问题 | Python3 中废弃的 imp 模块
开发语言·python
我想吃烤肉肉7 分钟前
wait_until=“domcontentloaded“ 解释
开发语言·前端·javascript·爬虫·python
退休钓鱼选手10 分钟前
BehaviorTree行为树-机器人及自动驾驶
人工智能·自动驾驶
xiao5kou4chang6kai410 分钟前
贯通LLM应用→数据分析→自动化编程→文献及知识管理→科研写作与绘图→构建本地LLM、Agent→多模型圆桌会议→N8N自动化工作流深度应用
人工智能·自动化·llm·科研绘图·n8n
weixin1997010801610 分钟前
废旧物资 item_search - 按关键字搜索商品列表接口对接全攻略:从入门到精通
数据库·python
海棠AI实验室11 分钟前
第二章 从脚本到工程:进阶学习的 5 个方法论(可维护性/可复现/可评估/可扩展/可交付)
python·数据