【pyTorch】关于PyTorch的高级索引机制理解

python 复制代码
y_hat[[0, 1], y]

看起来简短,其实包含了 PyTorch 的高级索引(advanced indexing) 机制,

常用于分类任务中,从模型输出中取出正确类别对应的预测概率(或分数)

我们来一步步拆解 👇


🧩 一、先看变量定义

python 复制代码
y = torch.tensor([0, 2])
y_hat = torch.tensor([
    [0.1, 0.3, 0.6],
    [0.3, 0.2, 0.5]
])

可以理解为:

样本编号 模型预测概率(3类) 正确类别
第 0 个样本 [0.1, 0.3, 0.6] 0
第 1 个样本 [0.3, 0.2, 0.5] 2

🧮 二、代码含义

python 复制代码
y_hat[[0, 1], y]

这是一个行列双索引操作,等价于:

python 复制代码
torch.tensor([y_hat[0, y[0]], y_hat[1, y[1]]])

也就是:

复制代码
= [y_hat[0, 0], y_hat[1, 2]]

📊 三、一步步替换求值

从矩阵中取出对应元素:

索引 取出的元素
y_hat[0, 0] 第 0 行第 0 列 0.1
y_hat[1, 2] 第 1 行第 2 列 0.5

所以:

python 复制代码
y_hat[[0, 1], y] = tensor([0.1, 0.5])

🧠 四、这个操作在深度学习中的用途

这个技巧非常常见!尤其是在 交叉熵损失 (cross-entropy) 计算中。

假设 y_hat 是每个样本对所有类别的预测概率分布,y 是正确标签,

那这行代码可以提取出:

每个样本 "对应正确类别" 的预测概率。

例如:

python 复制代码
y_hat[[0, 1], y]
# → tensor([预测第0个样本属于真实类别的概率, 
#           预测第1个样本属于真实类别的概率])

然后就可以算交叉熵:

python 复制代码
loss = -torch.log(y_hat[[0, 1], y])

✅ 五、总结一句话

代码 含义
y_hat[[0, 1], y] 按行索引 [0, 1] 和对应列索引 y 同时取元素
= [y_hat[0, y[0]], y_hat[1, y[1]]] 等价的展开写法
结果 取出每个样本的正确类别预测值(常用于分类任务)

💡 口诀记忆:

y_hat[行索引, 列索引]

会"对齐"地从每一行中取出对应列的元素。

常用来拿到"每个样本在真实类别上的预测概率"。


具体理解

🧩 数据回顾

python 复制代码
y = torch.tensor([0, 2])
y_hat = torch.tensor([
    [0.1, 0.3, 0.6],
    [0.3, 0.2, 0.5]
])

🎨 图示说明

复制代码
          类别0    类别1    类别2
样本0 →   0.1      0.3      0.6
样本1 →   0.3      0.2      0.5

以及标签:

复制代码
y = [0, 2]

表示:

  • 样本 0 的真实类别是 0
  • 样本 1 的真实类别是 2

🔍 执行这句代码:

python 复制代码
y_hat[[0, 1], y]

等价于:

复制代码
取第 0 行的第 y[0]=0 列 → 0.1
取第 1 行的第 y[1]=2 列 → 0.5

✅ 可视化标注结果

复制代码
          类别0    类别1    类别2
样本0 →  [0.1]*    0.3      0.6
样本1 →   0.3      0.2    [0.5]*

星号 * 表示被选中的元素

最终输出:

复制代码
tensor([0.1, 0.5])

🧠 用途回顾

在分类任务里:

python 复制代码
loss = -torch.log(y_hat[[0, 1], y])

就是取出模型对真实标签类别的预测概率,再取负对数计算交叉熵损失。

相关推荐
延凡科技2 小时前
无人机低空智能巡飞巡检平台:全域感知与智能决策的低空作业中枢
大数据·人工智能·科技·安全·无人机·能源
2501_941329722 小时前
YOLOv8-SEAMHead改进实战:书籍检测与识别系统优化方案
人工智能·yolo·目标跟踪
晓翔仔3 小时前
【深度实战】Agentic AI 安全攻防指南:基于 CSA 红队测试手册的 12 类风险完整解析
人工智能·安全·ai·ai安全
百家方案4 小时前
2026年数据治理整体解决方案 - 全1066页下载
大数据·人工智能·数据治理
北京耐用通信4 小时前
工业自动化中耐达讯自动化Profibus光纤链路模块连接RFID读写器的应用
人工智能·科技·物联网·自动化·信息与通信
Hgfdsaqwr4 小时前
Django全栈开发入门:构建一个博客系统
jvm·数据库·python
开发者小天5 小时前
python中For Loop的用法
java·服务器·python
老百姓懂点AI5 小时前
[RAG实战] 向量数据库选型与优化:智能体来了(西南总部)AI agent指挥官的长短期记忆架构设计
python
小韩博5 小时前
一篇文章讲清AI核心概念之(LLM、Agent、MCP、Skills) -- 从解决问题的角度来说明
人工智能
沃达德软件6 小时前
人工智能治安管控系统
图像处理·人工智能·深度学习·目标检测·计算机视觉·目标跟踪·视觉检测