torch.gather 用法笔记

1. 基本作用

torch.gather 的作用是:

复制代码
从 input 的指定维度 dim 上,按照 index 给出的索引位置取值。

基本语法:

复制代码
output = torch.gather(input, dim, index)

基本公式 三维举例 dim=1

output[b][k][d] = input[b][index[b][k][d]][d]

其中:

复制代码
input:原始张量
dim:指定在哪个维度上取值
index:索引张量
output:取出的结果

2. 核心规则

torch.gather 有一个非常重要的规则:

复制代码
output 的形状和 index 的形状相同。

也就是说:

复制代码
output.shape == index.shape

index 不只是告诉从哪里取值,它还决定了最终输出张量的形状。


3. 二维例子

假设:

复制代码
import torch

input = torch.tensor([
    [10, 20, 30],
    [40, 50, 60]
])

index = torch.tensor([
    [0, 2],
    [1, 0]
])

output = torch.gather(input, dim=1, index=index)

因为 dim=1,所以是在列方向取值。

取值过程:

复制代码
output[0][0] = input[0][index[0][0]] = input[0][0] = 10
output[0][1] = input[0][index[0][1]] = input[0][2] = 30

output[1][0] = input[1][index[1][0]] = input[1][1] = 50
output[1][1] = input[1][index[1][1]] = input[1][0] = 40

最终结果:

复制代码
tensor([
    [10, 30],
    [50, 40]
])

4. 三维例子

假设:

复制代码
input.shape = [B, N, D]

含义是:

复制代码
B:batch size,样本数量
N:每个样本中的 patch 数量
D:每个 patch 的特征维度

如果:

复制代码
index.shape = [B, K, D]

并且:

复制代码
output = torch.gather(input, dim=1, index=index)

那么:

复制代码
output.shape = [B, K, D]

因为 dim=1,所以是在 N 这个维度上取值。

核心公式是:

复制代码
output[b][k][d] = input[b][index[b][k][d]][d]

解释:

复制代码
B 维保持对应
D 维保持对应
只有 N 维根据 index[b][k][d] 指定的位置取值

5. 结合 patch 选择代码理解

常见代码:

复制代码
_, indices = torch.topk(attention_weights, k, dim=1)

selected_patches = torch.gather(
    patches,
    1,
    indices.unsqueeze(-1).expand(-1, -1, D)
)

假设:

复制代码
patches.shape = [B, N, D]
attention_weights.shape = [B, N]
indices.shape = [B, K]

其中:

复制代码
B:样本数量
N:patch 数量
D:每个 patch 的特征维度
K:要选出的 patch 数量

torch.topk 得到的是每个样本中分数最高的 K 个 patch 索引:

复制代码
indices.shape = [B, K]

但是 patches 是三维张量:

复制代码
patches.shape = [B, N, D]

所以需要先扩展索引:

复制代码
indices.unsqueeze(-1)

形状变为:

复制代码
[B, K, 1]

再使用:

复制代码
expand(-1, -1, D)

形状变为:

复制代码
[B, K, D]

这样才能和 patches 的三维结构对应起来。


6. 为什么要 expand 到 D 维

因为每个 patch 不是一个数,而是一个 D 维特征向量。

如果某个 patch 的索引是:

复制代码
indices[b][k] = 3

扩展后变成:

复制代码
index[b][k] = [3, 3, 3, ..., 3]

长度是 D

于是:

复制代码
output[b][k][0] = input[b][3][0]
output[b][k][1] = input[b][3][1]
output[b][k][2] = input[b][3][2]
...
output[b][k][D-1] = input[b][3][D-1]

也就是把第 3 个 patch 的完整 D 维特征全部取出来。


7. 最终效果

对于代码:

复制代码
selected_patches = torch.gather(
    patches,
    1,
    indices.unsqueeze(-1).expand(-1, -1, D)
)

它的作用是:

复制代码
从每个样本的 N 个 patch 中,
根据 top-k 得到的索引,
选出 K 个重要 patch,
并保留每个 patch 的完整 D 维特征。

形状变化:

复制代码
patches:          [B, N, D]
indices:          [B, K]
expanded index:   [B, K, D]
selected_patches: [B, K, D]

8. 记忆方法

可以这样记:

复制代码
gather = 按照 index,从 input 的某个 dim 维度上取值。

如果:

复制代码
output = torch.gather(input, dim=1, index=index)

那么就是:

复制代码
在第 1 维上取值;
其他维度保持对应关系;
output 的形状等于 index 的形状。

对于三维张量:

复制代码
input.shape = [B, N, D]
index.shape = [B, K, D]
dim = 1

核心公式:

复制代码
output[b][k][d] = input[b][index[b][k][d]][d]
相关推荐
程序员小嬛2 小时前
2026年因果推断与多目标优化结合的前沿思路
人工智能·深度学习·神经网络·transformer·论文笔记
杨超越luckly2 小时前
Agent应用指南:利用GET请求获取赛力斯汽车门店位置信息
python·html·汽车·可视化·门店
花月C2 小时前
Agent上下文三级压缩
python·prompt·ai编程
专注搞钱2 小时前
用Python写了个SPC自动分析工具,效率提升10倍
开发语言·python
yijianace2 小时前
Python爬虫实战:ThreadPoolExecutor多线程采集书籍信息与图片下载
开发语言·爬虫·python
人邮异步社区2 小时前
请问如何系统地学习深度学习所需的数学基础?
人工智能·深度学习·学习
mightbxg2 小时前
【学习一下】余弦相似度+Sigmoid+交叉熵组合
深度学习·学习·机器学习
郝亚军2 小时前
win11安装python3.12.7和pycharm
ide·python·pycharm
资深流水灯工程师2 小时前
PyCharm 虚拟环境完整配置指南(PySide6 开发专用)
ide·python·pycharm