1.官方文档
2.使用要点
- 输入index的shape等于输出value的shape
- 输入index的索引值仅替换该index中对应dim的index值
- 最终输出为替换index后在原tensor中的值
最终输出的shape和index的shape相同
根据dim的值 选择将index[i,j,k]这个结果替换input[i,j,k]里面对应的i or j or k ,并将结果存储到output[i,j,k]
3.实际应用
一维
python
import torch
import torch.nn as nn
arr = torch.tensor([1, 2, 3])
index = torch.tensor([0, 1])
result = torch.gather(arr,0, index)
print(result)
"""
tensor([1, 2])
"""
二维
python
import torch
arr = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
index = torch.tensor([[0, 1],
[1, 2]])
result = torch.gather(arr,1, index)
print(result)
"""
dim=0
tensor([[1, 5],
[4, 8]])
dim=1
tensor([[1, 2],
[5, 6]])
"""
三维
python
import torch
# 创建一个较小的三维张量
tensor_3d = torch.tensor([
[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]
], dtype=torch.float32)
# 创建索引张量
index_3d = torch.tensor([
[[0, 1],
[1, 0]],
[[1, 0],
[0, 1]]
], dtype=torch.long)
# 在 dim = 0 上进行 gather 操作
result_dim0 = tensor_3d.gather(dim=0, index=index_3d)
print("在 dim = 0 上的 gather 结果:")
print(result_dim0)
# 在 dim = 1 上进行 gather 操作
result_dim1 = tensor_3d.gather(dim=1, index=index_3d)
print("在 dim = 1 上的 gather 结果:")
print(result_dim1)
# 在 dim = 2 上进行 gather 操作
result_dim2 = tensor_3d.gather(dim=2, index=index_3d)
print("在 dim = 2 上的 gather 结果:")
print(result_dim2)
"""
在 dim = 0 上的 gather 结果:
tensor([[[1., 6.],
[7., 4.]],
[[5., 2.],
[3., 8.]]])
在 dim = 1 上的 gather 结果:
tensor([[[1., 4.],
[3., 2.]],
[[7., 6.],
[5., 8.]]])
在 dim = 2 上的 gather 结果:
tensor([[[1., 2.],
[4., 3.]],
[[6., 5.],
[7., 8.]]])
"""