torch.gather 用法笔记

1. 基本作用

torch.gather 的作用是:

复制代码
从 input 的指定维度 dim 上,按照 index 给出的索引位置取值。

基本语法:

复制代码
output = torch.gather(input, dim, index)

基本公式 三维举例 dim=1

output[b][k][d] = input[b][index[b][k][d]][d]

其中:

复制代码
input:原始张量
dim:指定在哪个维度上取值
index:索引张量
output:取出的结果

2. 核心规则

torch.gather 有一个非常重要的规则:

复制代码
output 的形状和 index 的形状相同。

也就是说:

复制代码
output.shape == index.shape

index 不只是告诉从哪里取值,它还决定了最终输出张量的形状。


3. 二维例子

假设:

复制代码
import torch

input = torch.tensor([
    [10, 20, 30],
    [40, 50, 60]
])

index = torch.tensor([
    [0, 2],
    [1, 0]
])

output = torch.gather(input, dim=1, index=index)

因为 dim=1,所以是在列方向取值。

取值过程:

复制代码
output[0][0] = input[0][index[0][0]] = input[0][0] = 10
output[0][1] = input[0][index[0][1]] = input[0][2] = 30

output[1][0] = input[1][index[1][0]] = input[1][1] = 50
output[1][1] = input[1][index[1][1]] = input[1][0] = 40

最终结果:

复制代码
tensor([
    [10, 30],
    [50, 40]
])

4. 三维例子

假设:

复制代码
input.shape = [B, N, D]

含义是:

复制代码
B:batch size,样本数量
N:每个样本中的 patch 数量
D:每个 patch 的特征维度

如果:

复制代码
index.shape = [B, K, D]

并且:

复制代码
output = torch.gather(input, dim=1, index=index)

那么:

复制代码
output.shape = [B, K, D]

因为 dim=1,所以是在 N 这个维度上取值。

核心公式是:

复制代码
output[b][k][d] = input[b][index[b][k][d]][d]

解释:

复制代码
B 维保持对应
D 维保持对应
只有 N 维根据 index[b][k][d] 指定的位置取值

5. 结合 patch 选择代码理解

常见代码:

复制代码
_, indices = torch.topk(attention_weights, k, dim=1)

selected_patches = torch.gather(
    patches,
    1,
    indices.unsqueeze(-1).expand(-1, -1, D)
)

假设:

复制代码
patches.shape = [B, N, D]
attention_weights.shape = [B, N]
indices.shape = [B, K]

其中:

复制代码
B:样本数量
N:patch 数量
D:每个 patch 的特征维度
K:要选出的 patch 数量

torch.topk 得到的是每个样本中分数最高的 K 个 patch 索引:

复制代码
indices.shape = [B, K]

但是 patches 是三维张量:

复制代码
patches.shape = [B, N, D]

所以需要先扩展索引:

复制代码
indices.unsqueeze(-1)

形状变为:

复制代码
[B, K, 1]

再使用:

复制代码
expand(-1, -1, D)

形状变为:

复制代码
[B, K, D]

这样才能和 patches 的三维结构对应起来。


6. 为什么要 expand 到 D 维

因为每个 patch 不是一个数,而是一个 D 维特征向量。

如果某个 patch 的索引是:

复制代码
indices[b][k] = 3

扩展后变成:

复制代码
index[b][k] = [3, 3, 3, ..., 3]

长度是 D

于是:

复制代码
output[b][k][0] = input[b][3][0]
output[b][k][1] = input[b][3][1]
output[b][k][2] = input[b][3][2]
...
output[b][k][D-1] = input[b][3][D-1]

也就是把第 3 个 patch 的完整 D 维特征全部取出来。


7. 最终效果

对于代码:

复制代码
selected_patches = torch.gather(
    patches,
    1,
    indices.unsqueeze(-1).expand(-1, -1, D)
)

它的作用是:

复制代码
从每个样本的 N 个 patch 中,
根据 top-k 得到的索引,
选出 K 个重要 patch,
并保留每个 patch 的完整 D 维特征。

形状变化:

复制代码
patches:          [B, N, D]
indices:          [B, K]
expanded index:   [B, K, D]
selected_patches: [B, K, D]

8. 记忆方法

可以这样记:

复制代码
gather = 按照 index,从 input 的某个 dim 维度上取值。

如果:

复制代码
output = torch.gather(input, dim=1, index=index)

那么就是:

复制代码
在第 1 维上取值;
其他维度保持对应关系;
output 的形状等于 index 的形状。

对于三维张量:

复制代码
input.shape = [B, N, D]
index.shape = [B, K, D]
dim = 1

核心公式:

复制代码
output[b][k][d] = input[b][index[b][k][d]][d]
相关推荐
Lyn_Li19 分钟前
Kaggle Top 5 | 198只股票、200条数据的金融预测——BattleFin高分方案从零复现
python·kaggle·比赛复盘·金融预测
Lihua奏4 小时前
从单核到多核:CPU为什么不能再只靠提频变快
深度学习
拾年2755 小时前
大模型的"聪明"从哪来?聊聊 AI 数据集的那些事儿
人工智能·深度学习·机器学习
小九九的爸爸5 小时前
前端想要入门Agent开发,要具备哪些Python基础?
python·agent·ai编程
阿耶同学6 小时前
手把手教你用 LangGraph 搭建三层嵌套 Agent 架构
python·程序员
花酒锄作田1 天前
Pydantic校验配置文件
python
hboot1 天前
AI工程师第四课 - 深度学习入门
pytorch·python·神经网络
ZhengEnCi1 天前
P2M-Matplotlib折线图完全指南-从数据可视化到趋势分析的Python绘图利器
python·matlab·数据可视化
ZhengEnCi1 天前
P2L-Matplotlib饼图完全指南-从数据可视化到图表定制的Python绘图利器
python·matlab
曲幽1 天前
你的REST接口还在“过度投喂”数据吗?——FastAPI + GraphQL实战避坑指南
python·fastapi·web·graphql·route·cors·rest·strawberry