1. 基本作用
torch.gather 的作用是:
从 input 的指定维度 dim 上,按照 index 给出的索引位置取值。
基本语法:
output = torch.gather(input, dim, index)
基本公式 三维举例 dim=1
output[b][k][d] = input[b][index[b][k][d]][d]
其中:
input:原始张量
dim:指定在哪个维度上取值
index:索引张量
output:取出的结果
2. 核心规则
torch.gather 有一个非常重要的规则:
output 的形状和 index 的形状相同。
也就是说:
output.shape == index.shape
index 不只是告诉从哪里取值,它还决定了最终输出张量的形状。
3. 二维例子
假设:
import torch
input = torch.tensor([
[10, 20, 30],
[40, 50, 60]
])
index = torch.tensor([
[0, 2],
[1, 0]
])
output = torch.gather(input, dim=1, index=index)
因为 dim=1,所以是在列方向取值。
取值过程:
output[0][0] = input[0][index[0][0]] = input[0][0] = 10
output[0][1] = input[0][index[0][1]] = input[0][2] = 30
output[1][0] = input[1][index[1][0]] = input[1][1] = 50
output[1][1] = input[1][index[1][1]] = input[1][0] = 40
最终结果:
tensor([
[10, 30],
[50, 40]
])
4. 三维例子
假设:
input.shape = [B, N, D]
含义是:
B:batch size,样本数量
N:每个样本中的 patch 数量
D:每个 patch 的特征维度
如果:
index.shape = [B, K, D]
并且:
output = torch.gather(input, dim=1, index=index)
那么:
output.shape = [B, K, D]
因为 dim=1,所以是在 N 这个维度上取值。
核心公式是:
output[b][k][d] = input[b][index[b][k][d]][d]
解释:
B 维保持对应
D 维保持对应
只有 N 维根据 index[b][k][d] 指定的位置取值
5. 结合 patch 选择代码理解
常见代码:
_, indices = torch.topk(attention_weights, k, dim=1)
selected_patches = torch.gather(
patches,
1,
indices.unsqueeze(-1).expand(-1, -1, D)
)
假设:
patches.shape = [B, N, D]
attention_weights.shape = [B, N]
indices.shape = [B, K]
其中:
B:样本数量
N:patch 数量
D:每个 patch 的特征维度
K:要选出的 patch 数量
torch.topk 得到的是每个样本中分数最高的 K 个 patch 索引:
indices.shape = [B, K]
但是 patches 是三维张量:
patches.shape = [B, N, D]
所以需要先扩展索引:
indices.unsqueeze(-1)
形状变为:
[B, K, 1]
再使用:
expand(-1, -1, D)
形状变为:
[B, K, D]
这样才能和 patches 的三维结构对应起来。
6. 为什么要 expand 到 D 维
因为每个 patch 不是一个数,而是一个 D 维特征向量。
如果某个 patch 的索引是:
indices[b][k] = 3
扩展后变成:
index[b][k] = [3, 3, 3, ..., 3]
长度是 D。
于是:
output[b][k][0] = input[b][3][0]
output[b][k][1] = input[b][3][1]
output[b][k][2] = input[b][3][2]
...
output[b][k][D-1] = input[b][3][D-1]
也就是把第 3 个 patch 的完整 D 维特征全部取出来。
7. 最终效果
对于代码:
selected_patches = torch.gather(
patches,
1,
indices.unsqueeze(-1).expand(-1, -1, D)
)
它的作用是:
从每个样本的 N 个 patch 中,
根据 top-k 得到的索引,
选出 K 个重要 patch,
并保留每个 patch 的完整 D 维特征。
形状变化:
patches: [B, N, D]
indices: [B, K]
expanded index: [B, K, D]
selected_patches: [B, K, D]
8. 记忆方法
可以这样记:
gather = 按照 index,从 input 的某个 dim 维度上取值。
如果:
output = torch.gather(input, dim=1, index=index)
那么就是:
在第 1 维上取值;
其他维度保持对应关系;
output 的形状等于 index 的形状。
对于三维张量:
input.shape = [B, N, D]
index.shape = [B, K, D]
dim = 1
核心公式:
output[b][k][d] = input[b][index[b][k][d]][d]