Pytorch中gather()函数详解和实战示例

在 PyTorch 中,torch.gather() 是一个非常实用的张量操作函数,主要用于根据索引从输入张量中选择特定位置的值。它常用于注意力机制、序列处理等场景。


函数定义

python 复制代码
torch.gather(input, dim, index) → Tensor
  • input:待提取数据的张量。
  • dim:在哪个维度上进行索引选择。
  • index:一个与 input 在除了 dim 维度外相同形状的张量,其值指定了从 input 中提取的索引位置。
  • 返回值:从 input 的指定维度 dim 上根据 index 提取出的新张量。

形象理解

举个简单的例子:

示例 1:二维张量,按列(dim=1)提取

python 复制代码
import torch

input = torch.tensor([[10, 20, 30],
                      [40, 50, 60]])
index = torch.tensor([[2, 1, 0],
                      [0, 1, 2]])

output = torch.gather(input, dim=1, index=index)
print(output)

解释

  • 对于第一行:从 [10, 20, 30] 中提取位置 [2,1,0],结果是 [30, 20, 10]
  • 对于第二行:从 [40, 50, 60] 中提取位置 [0,1,2],结果是 [40, 50, 60]

输出

复制代码
tensor([[30, 20, 10],
        [40, 50, 60]])

示例 2:按行(dim=0)提取

python 复制代码
input = torch.tensor([[1, 2],
                      [3, 4],
                      [5, 6]])

index = torch.tensor([[0, 1],
                      [1, 2],
                      [2, 0]])

output = torch.gather(input, dim=0, index=index)
print(output)

解释

  • 每个位置从第 dim=0 维度提取对应的元素。例如:

    • 第 (0,0) 位置:从 [1,3,5] 中取第 0 行,值为 1
    • 第 (1,0) 位置:从 [1,3,5] 中取第 1 行,值为 3
    • 第 (2,1) 位置:从 [2,4,6] 中取第 0 行,值为 2

输出

复制代码
tensor([[1, 4],
        [3, 6],
        [5, 2]])

应用场景

  1. 注意力机制中的权重选择
  2. 序列解码中的 beam search
  3. 从嵌套表示中根据索引获取嵌套内容

实战场景举例

假设有一个 batch 的 BERT 输出,想从每个句子中提取第 N 个 token(如 [CLS]、某个关键词)的表示向量


假设数据

python 复制代码
import torch
from transformers import BertModel, BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")

sentences = ["I love World", "Transformers are powerful"]
inputs = tokenizer(sentences, padding=True, return_tensors="pt")

# 获取 BERT 输出
outputs = model(**inputs)
last_hidden_state = outputs.last_hidden_state  # (batch_size, seq_len, hidden_size)

print(last_hidden_state.shape)
# torch.Size([2, 5, 768])  假设 padding 后为长度 5,hidden size 为 768

场景 1:提取每个句子的第一个 token(通常是 [CLS])

python 复制代码
cls_embeddings = last_hidden_state[:, 0, :]  # shape: (batch_size, hidden_size)

这个可以直接使用切片完成,不需要 gather


场景 2:提取每个句子中 指定位置的 token 表示(如"love"或"are")

假设我们事先知道每个句子中感兴趣 token 的位置:
python 复制代码
# 每个句子中我们想要提取的 token 索引
# 假设我们想提取第 2 个 token
token_indices = torch.tensor([2, 1])  # shape: (batch_size,)

使用 gather 抽取对应 token 的向量:

python 复制代码
# last_hidden_state: (batch_size, seq_len, hidden_size)
batch_size, seq_len, hidden_size = last_hidden_state.size()

# 将 token_indices 转成 index 用于 gather: shape (batch_size, 1, 1)
token_indices = token_indices.view(-1, 1, 1).expand(-1, 1, hidden_size)  # (batch_size, 1, hidden_size)

# gather on dim=1(seq_len)
token_embeddings = torch.gather(last_hidden_state, dim=1, index=token_indices)  # (batch_size, 1, hidden_size)

# squeeze 掉中间的维度
token_embeddings = token_embeddings.squeeze(1)  # (batch_size, hidden_size)

print(token_embeddings.shape)

小结

操作需求 用法
取所有句子的第一个 token output[:, 0, :]
取所有句子的第 N 个 token output[:, N, :]
取每个句子的指定 token(不同位置) torch.gather()(如上所示)

注意事项

  • index 必须与 input 的 shape 一致,除了在指定的 dim 维度上的大小。
  • index 的值必须小于 inputdim 维度上的长度。

相关推荐
宸津-代码粉碎机1 小时前
LLM 模型部署难题的技术突破:从轻量化到分布式推理的全栈解决方案
java·大数据·人工智能·分布式·python
都叫我大帅哥1 小时前
当数据流经LangChain时,RunnablePassthrough如何成为“最懒却最聪明”的快递员?
python·langchain
都叫我大帅哥1 小时前
机器学习界的“钢铁侠”:支持向量机(SVM)全方位指南
python·机器学习
柴 基4 小时前
Jupyter Notebook 使用指南
ide·python·jupyter
Python×CATIA工业智造5 小时前
Pycaita二次开发基础代码解析:几何体重命名与参数提取技术
python·pycharm·pycatia
你的电影很有趣5 小时前
lesson30:Python迭代三剑客:可迭代对象、迭代器与生成器深度解析
开发语言·python
乌恩大侠6 小时前
自动驾驶的未来:多模态传感器钻机
人工智能·机器学习·自动驾驶
光锥智能7 小时前
AI办公的效率革命,金山办公从未被颠覆
人工智能
GetcharZp7 小时前
爆肝整理!带你快速上手LangChain,轻松集成DeepSeek,打造自己的AI应用
人工智能·llm·deepseek
成成成成成成果7 小时前
揭秘动态测试:软件质量的实战防线
python·功能测试·测试工具·测试用例·可用性测试