在 PyTorch 中,torch.gather()
是一个非常实用的张量操作函数,主要用于根据索引从输入张量中选择特定位置的值。它常用于注意力机制、序列处理等场景。
函数定义
python
torch.gather(input, dim, index) → Tensor
input
:待提取数据的张量。dim
:在哪个维度上进行索引选择。index
:一个与input
在除了dim
维度外相同形状的张量,其值指定了从input
中提取的索引位置。- 返回值:从
input
的指定维度dim
上根据index
提取出的新张量。
形象理解
举个简单的例子:
示例 1:二维张量,按列(dim=1)提取
python
import torch
input = torch.tensor([[10, 20, 30],
[40, 50, 60]])
index = torch.tensor([[2, 1, 0],
[0, 1, 2]])
output = torch.gather(input, dim=1, index=index)
print(output)
解释:
- 对于第一行:从
[10, 20, 30]
中提取位置[2,1,0]
,结果是[30, 20, 10]
- 对于第二行:从
[40, 50, 60]
中提取位置[0,1,2]
,结果是[40, 50, 60]
输出:
tensor([[30, 20, 10],
[40, 50, 60]])
示例 2:按行(dim=0)提取
python
input = torch.tensor([[1, 2],
[3, 4],
[5, 6]])
index = torch.tensor([[0, 1],
[1, 2],
[2, 0]])
output = torch.gather(input, dim=0, index=index)
print(output)
解释:
-
每个位置从第
dim=0
维度提取对应的元素。例如:- 第 (0,0) 位置:从 [1,3,5] 中取第 0 行,值为 1
- 第 (1,0) 位置:从 [1,3,5] 中取第 1 行,值为 3
- 第 (2,1) 位置:从 [2,4,6] 中取第 0 行,值为 2
输出:
tensor([[1, 4],
[3, 6],
[5, 2]])
应用场景
- 注意力机制中的权重选择
- 序列解码中的 beam search
- 从嵌套表示中根据索引获取嵌套内容
实战场景举例
假设有一个 batch 的 BERT 输出,想从每个句子中提取第 N 个 token(如 [CLS]、某个关键词)的表示向量。
假设数据
python
import torch
from transformers import BertModel, BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")
sentences = ["I love World", "Transformers are powerful"]
inputs = tokenizer(sentences, padding=True, return_tensors="pt")
# 获取 BERT 输出
outputs = model(**inputs)
last_hidden_state = outputs.last_hidden_state # (batch_size, seq_len, hidden_size)
print(last_hidden_state.shape)
# torch.Size([2, 5, 768]) 假设 padding 后为长度 5,hidden size 为 768
场景 1:提取每个句子的第一个 token(通常是 [CLS])
python
cls_embeddings = last_hidden_state[:, 0, :] # shape: (batch_size, hidden_size)
这个可以直接使用切片完成,不需要 gather
。
场景 2:提取每个句子中 指定位置的 token 表示(如"love"或"are")
假设我们事先知道每个句子中感兴趣 token 的位置:
python
# 每个句子中我们想要提取的 token 索引
# 假设我们想提取第 2 个 token
token_indices = torch.tensor([2, 1]) # shape: (batch_size,)
使用 gather
抽取对应 token 的向量:
python
# last_hidden_state: (batch_size, seq_len, hidden_size)
batch_size, seq_len, hidden_size = last_hidden_state.size()
# 将 token_indices 转成 index 用于 gather: shape (batch_size, 1, 1)
token_indices = token_indices.view(-1, 1, 1).expand(-1, 1, hidden_size) # (batch_size, 1, hidden_size)
# gather on dim=1(seq_len)
token_embeddings = torch.gather(last_hidden_state, dim=1, index=token_indices) # (batch_size, 1, hidden_size)
# squeeze 掉中间的维度
token_embeddings = token_embeddings.squeeze(1) # (batch_size, hidden_size)
print(token_embeddings.shape)
小结
操作需求 | 用法 |
---|---|
取所有句子的第一个 token | output[:, 0, :] |
取所有句子的第 N 个 token |
output[:, N, :] |
取每个句子的指定 token(不同位置) | torch.gather() (如上所示) |
注意事项
index
必须与input
的 shape 一致,除了在指定的dim
维度上的大小。index
的值必须小于input
在dim
维度上的长度。