python
y_hat[[0, 1], y]
看起来简短,其实包含了 PyTorch 的高级索引(advanced indexing) 机制,
常用于分类任务中,从模型输出中取出正确类别对应的预测概率(或分数)。
我们来一步步拆解 👇
🧩 一、先看变量定义
python
y = torch.tensor([0, 2])
y_hat = torch.tensor([
[0.1, 0.3, 0.6],
[0.3, 0.2, 0.5]
])
可以理解为:
样本编号 | 模型预测概率(3类) | 正确类别 |
---|---|---|
第 0 个样本 | [0.1, 0.3, 0.6] |
0 |
第 1 个样本 | [0.3, 0.2, 0.5] |
2 |
🧮 二、代码含义
python
y_hat[[0, 1], y]
这是一个行列双索引操作,等价于:
python
torch.tensor([y_hat[0, y[0]], y_hat[1, y[1]]])
也就是:
= [y_hat[0, 0], y_hat[1, 2]]
📊 三、一步步替换求值
从矩阵中取出对应元素:
索引 | 取出的元素 | 值 |
---|---|---|
y_hat[0, 0] | 第 0 行第 0 列 | 0.1 |
y_hat[1, 2] | 第 1 行第 2 列 | 0.5 |
所以:
python
y_hat[[0, 1], y] = tensor([0.1, 0.5])
🧠 四、这个操作在深度学习中的用途
这个技巧非常常见!尤其是在 交叉熵损失 (cross-entropy) 计算中。
假设 y_hat
是每个样本对所有类别的预测概率分布,y
是正确标签,
那这行代码可以提取出:
每个样本 "对应正确类别" 的预测概率。
例如:
python
y_hat[[0, 1], y]
# → tensor([预测第0个样本属于真实类别的概率,
# 预测第1个样本属于真实类别的概率])
然后就可以算交叉熵:
python
loss = -torch.log(y_hat[[0, 1], y])
✅ 五、总结一句话
代码 | 含义 |
---|---|
y_hat[[0, 1], y] |
按行索引 [0, 1] 和对应列索引 y 同时取元素 |
= [y_hat[0, y[0]], y_hat[1, y[1]]] |
等价的展开写法 |
结果 | 取出每个样本的正确类别预测值(常用于分类任务) |
💡 口诀记忆:
y_hat[行索引, 列索引]
会"对齐"地从每一行中取出对应列的元素。
常用来拿到"每个样本在真实类别上的预测概率"。
具体理解
🧩 数据回顾
python
y = torch.tensor([0, 2])
y_hat = torch.tensor([
[0.1, 0.3, 0.6],
[0.3, 0.2, 0.5]
])
🎨 图示说明
类别0 类别1 类别2
样本0 → 0.1 0.3 0.6
样本1 → 0.3 0.2 0.5
以及标签:
y = [0, 2]
表示:
- 样本 0 的真实类别是 0
- 样本 1 的真实类别是 2
🔍 执行这句代码:
python
y_hat[[0, 1], y]
等价于:
取第 0 行的第 y[0]=0 列 → 0.1
取第 1 行的第 y[1]=2 列 → 0.5
✅ 可视化标注结果
类别0 类别1 类别2
样本0 → [0.1]* 0.3 0.6
样本1 → 0.3 0.2 [0.5]*
星号 * 表示被选中的元素
最终输出:
tensor([0.1, 0.5])
🧠 用途回顾
在分类任务里:
python
loss = -torch.log(y_hat[[0, 1], y])
就是取出模型对真实标签类别的预测概率,再取负对数计算交叉熵损失。