pytorch张量的高级索引取值原理解读
代码:
import torch
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
x1 = x[[[0, 1], [1, 0]]]
x2 = x[torch.tensor([[0, 1], [1, 0]])]
print(f"x1:{x1}")
print(f"x2:{x2}")
输出:
x1:tensor([20, 40])
x2:tensor([[[10, 20, 30],
[40, 50, 60]],
[[40, 50, 60],
[10, 20, 30]]])
代码解读:
**张量 x
**是一个 2x3 的张量:
x1
的取值
x1 = x[[[0, 1], [1, 0]]]
-
索引机制 : 这里的索引
[[0, 1], [1, 0]]
是 高级整数索引。- 它取的是第 1 维的具体位置。
-
步骤:
x[[0, 1], [1, 0]]
等价于以下操作:x[0, 1]
-> 20x[1, 0]
-> 40
因此:
x1 = [20, 40]
注:x[[[0, 1], [1, 0]]] 结果同 x[[0, 1], [1, 0]]
x2
的取值
x2 = x[torch.tensor([[0, 1], [1, 0]])]
### 复杂索引,在0维和1维度都取
#x3 = x[torch.tensor([[0, 1], [1, 0]]),torch.tensor([[0, 1], [1, 0]])]
#print(f"x3:{x3}")
#x 3:tensor([[10, 50],
# [50, 10]])
#print(f"x3.shape:{x3.shape}") # x3.shape:torch.Size([2, 2])
-
索引机制 : 这里的索引
torch.tensor([[0, 1], [1, 0]])
是 多维整形张量索引。- 这种索引会在第 0 维上按张量的形状进行广播。
-
广播行为:
- 索引张量的形状是
(2, 2)
。 - PyTorch 会沿第 0 维取出对应的行,并按照索引结果重新排列。
- 索引张量的形状是
-
步骤:
x[0]
->[10, 20, 30]
x[1]
->[40, 50, 60]
根据索引张量
[[0, 1], [1, 0]]
,结果排列为:[[[10, 20, 30], # 对应索引 (0, 0)
[40, 50, 60]], # 对应索引 (0, 1)[[40, 50, 60], # 对应索引 (1, 0)
[10, 20, 30]]] # 对应索引 (1, 1)
总结:
x1
使用的是高级整数索引,按指定的具体位置取值(减少维度)。x2
使用的是多维张量索引,按张量形状广播,生成一个更高维的结果(不减少维度)。