PyTorch gather 方法详解:作用、应用场景与示例解析(中英双语)

PyTorch gather 方法详解:作用、应用场景与示例解析

在深度学习和自然语言处理(NLP)任务中,我们经常需要从高维张量中提取特定索引的数据

PyTorch 提供的 torch.gather 方法可以高效地从张量的指定维度收集数据 ,广泛应用于语言模型(Transformer)、分类任务、强化学习等场景

在本文中,我们将详细介绍:

  • gather 方法的作用
  • 使用 gather 进行索引操作
  • gather 在 NLP 模型中的应用
  • gather 的计算效率与优化

1. torch.gather 的作用

1.1 gather 的基本用法

gather 允许我们在张量的指定维度 上,按照给定的索引提取数据

其基本语法如下:

python 复制代码
torch.gather(input, dim, index)
  • input:输入张量,形状为 (B, L, V)(可以是任意维度)。
  • dim:指定在哪个维度上收集数据(例如 dim=-1 代表在最后一个维度索引)。
  • index:索引张量,形状必须与 inputdim 之外的维度相同

1.2 gather 的核心逻辑

给定 inputindexgather 沿 dim 维度逐元素地获取 input 中指定索引位置的值


2. gather 的基础示例

2.1 从二维张量中提取元素

python 复制代码
import torch

# 定义一个 3x4 的张量
input_tensor = torch.tensor([[10, 20, 30, 40], 
                             [50, 60, 70, 80], 
                             [90, 100, 110, 120]])

# 定义索引张量
index_tensor = torch.tensor([[0, 1], 
                             [2, 3], 
                             [1, 2]])

# 在 `dim=1` 维度上使用 gather
output = torch.gather(input_tensor, dim=1, index=index_tensor)
print(output)

输出:

c 复制代码
tensor([[ 10,  20],
        [ 70,  80],
        [100, 110]])

解释:

  • input_tensor 形状为 (3,4),即 3 行 4 列
  • index_tensor 形状为 (3,2),其中的值指示要从 input_tensordim=1(列) 选取的数据:
    • 第 1 行:取 input_tensor[0,0]input_tensor[0,1],即 [10, 20]
    • 第 2 行:取 input_tensor[1,2]input_tensor[1,3],即 [70, 80]
    • 第 3 行:取 input_tensor[2,1]input_tensor[2,2],即 [100, 110]

3. gather 在 NLP 中的应用

3.1 计算 Token 的对数概率

在语言模型(如 Transformer)中,我们通常需要计算目标 token 的概率 ,即:
P ( y t ) = e logit y t ∑ e logit v P(y_t) = \frac{e^{\text{logit}{y_t}}}{\sum e^{\text{logit}{v}}} P(yt)=∑elogitvelogityt

其中:

  • logits 形状为 (B, L, V),表示 batch 里每个 token 对整个词表(vocabulary)中所有词的 logit 分数。
  • input_ids 形状为 (B, L),表示实际的 token 索引(即每个 token 在词表中的 ID)。

我们使用 gather 取出每个 input_idlogits 中对应的 logit 分值:

python 复制代码
import torch

# 假设 batch_size=2, sequence_length=3, vocab_size=5
logits = torch.tensor([[[2.0, 1.0, 0.5, -1.0, 0.2], 
                         [0.1, -0.5, 2.2, 1.5, 0.0], 
                         [1.1, 3.5, 0.8, -0.2, -1.5]],

                        [[0.0, 2.3, -0.5, 1.0, 0.8], 
                         [-1.2, 1.7, 2.0, 0.3, -0.8], 
                         [2.5, -0.1, -1.2, 0.5, 3.0]]])

input_ids = torch.tensor([[0, 2, 1],  # 对应每个 token 在词表中的索引
                          [1, 3, 4]])

# 取出 input_ids 在 logits 中的 logit 值
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
print(token_logits)

输出:

c 复制代码
tensor([[ 2.0000,  2.2000,  3.5000],
        [ 2.3000,  0.3000,  3.0000]])

解释:

  • logits.gather(dim=-1, index=input_ids.unsqueeze(-1))
    • dim=-1 代表从 Vocab 维度(最后一维)索引数据。
    • input_ids.unsqueeze(-1) 扩展维度 ,让 input_ids 形状变为 (B, L, 1),符合 gather 要求。
    • squeeze(-1) 还原到 (B, L) 形状,使结果是 每个 token 的 logit 值

4. gatherscatter 的对比

除了 gather(用于提取数据),PyTorch 还提供 scatter(用于写入数据)。

4.1 scatter 基本用法

python 复制代码
import torch

# 初始化 3x3 零张量
x = torch.zeros(3, 3)

# 指定索引
index = torch.tensor([[0, 2], 
                      [1, 1], 
                      [2, 0]])

# 指定填充值
updates = torch.tensor([[5, 8], 
                         [3, 7], 
                         [6, 2]])

# 在 dim=1 维度上 scatter
x.scatter_(dim=1, index=index, src=updates)
print(x)

输出:

c 复制代码
tensor([[5., 0., 8.],
        [0., 7., 0.],
        [2., 0., 6.]])

scatter_() 用于替换 index 位置的值,而 gather() 用于提取 index 位置的值


5. 总结

  • torch.gather(input, dim, index) 提取 input dim 维度上指定 index 位置的值
  • 常用于 NLP 任务中计算 token 对数概率、分类任务中提取预测分数
  • 通过 gather 提取 logits 对应 input_ids,可高效计算 对数概率损失函数
  • gather从索引获取数据 ,而 scatter根据索引写入数据

🚀 掌握 gather 让你在深度学习项目中更高效地处理索引操作! 🚀

深入理解 gather(dim=-1):作用、计算过程与 dim=-2 的对比

在 PyTorch 中,torch.gather 是一个强大的索引操作函数,它可以根据提供的 index 张量,从 input 张量的指定维度(dim )中提取相应的数据。

在 NLP(自然语言处理)任务中,我们常用 gather(dim=-1) 来从 logits 中获取 输入 token(input_ids)对应的 logits 值,用于计算损失或评估模型表现。


1. gather(dim=-1) 的作用

1.1 dim=-1 的含义

  • dim=-1 代表最后一个维度 (即词表维度)。

  • logits.shape = (batch_size, sequence_length, vocab_size) 这样一个张量中:

    • dim=0:表示 batch 维度(不同样本)。
    • dim=1:表示 sequence 维度(句子中的不同 token)。
    • dim=2(即 dim=-1:表示词汇表(vocab),即每个 token 对所有单词的 logits 评分。
  • gather(dim=-1, index=input_ids.unsqueeze(-1)) 的作用:

    • dim=-1(词表维度)上提取 input_ids 对应的 logits 值
    • 这样,每个 token 只保留它对应的 logits,而不是整个词表的所有 logits。

2. 代码示例与计算过程

2.1 示例:计算 Token Logits

python 复制代码
import torch

# 假设 batch_size=2, sequence_length=3, vocab_size=5
logits = torch.tensor([
    [[2.0, 1.0, 0.5, -1.0, 0.2], 
     [0.1, -0.5, 2.2, 1.5, 0.0], 
     [1.1, 3.5, 0.8, -0.2, -1.5]],

    [[0.0, 2.3, -0.5, 1.0, 0.8], 
     [-1.2, 1.7, 2.0, 0.3, -0.8], 
     [2.5, -0.1, -1.2, 0.5, 3.0]]
])

input_ids = torch.tensor([
    [0, 2, 1],  # 第一个样本的 token 索引
    [1, 3, 4]   # 第二个样本的 token 索引
])

# 扩展维度,使 input_ids 形状变为 (batch_size, sequence_length, 1)
expanded_index = input_ids.unsqueeze(-1)

# 使用 gather 从 logits 中提取相应的 token logits
token_logits = logits.gather(dim=-1, index=expanded_index).squeeze(-1)
print(token_logits)

输出:

c 复制代码
tensor([[2.0000, 2.2000, 3.5000],
        [2.3000, 0.3000, 3.0000]])

2.2 gather(dim=-1) 计算过程解析

对于 logits.shape = (2, 3, 5)

  • dim=-1 代表最后一维 ,即 vocab_size=5 维度。
  • input_ids.shape = (2, 3),表示每个 batch 的 token 在词表中的索引。

让我们手动解析 gather(dim=-1) 的计算步骤:

Batch Token 索引 (dim=1) Index (dim=-1) Extracted Logit (gather 结果)
第 1 个样本 第 1 个 token input_ids[0,0] = 0 logits[0,0,0] = 2.0
第 2 个 token input_ids[0,1] = 2 logits[0,1,2] = 2.2
第 3 个 token input_ids[0,2] = 1 logits[0,2,1] = 3.5
第 2 个样本 第 1 个 token input_ids[1,0] = 1 logits[1,0,1] = 2.3
第 2 个 token input_ids[1,1] = 3 logits[1,1,3] = 0.3
第 3 个 token input_ids[1,2] = 4 logits[1,2,4] = 3.0

这正是 gather(dim=-1) 提取的值。


3. dim=-1 vs. dim=-2 的区别

3.1 什么是 dim=-2

如果改成 gather(dim=-2),它会尝试在 sequence 维度(dim=1 上进行索引,这会导致错误的行为。

因为 input_ids 只包含 token 在词表中的索引,而不是 token 在句子中的索引。

3.2 如果错误地使用 dim=-2

python 复制代码
wrong_gather = logits.gather(dim=-2, index=input_ids.unsqueeze(-1))
print(wrong_gather.shape)
  • dim=-2 代表第二个维度(sequence 维度)
  • 这意味着 PyTorch 会尝试从 logits 中选取整个 token 级别的数据,而不是单独的 token logits

❌ 结果错误,因为 input_ids 里的索引根本不适用于 dim=-2

3.3 为什么 dim=-1 才是正确的?

  • input_ids 里的索引指向 词表索引(vocab index) ,所以应该沿着词表维度(dim=-1)索引数据。
  • dim=-1 选择的是 单个 token 对应的 logits,不会影响整个句子结构。

4. 结论

dim 含义 是否正确
dim=-1 (最后一维) 提取每个 token 在词表中的 logits 正确
dim=-2 (倒数第二维) 尝试索引整个句子级别的数据 错误

核心要点

dim=-1(最后一维)用于 获取输入 token 对应的 logits 值 ,常用于 NLP 任务。

gather(dim=-1, index=input_ids.unsqueeze(-1))input_ids 选择 logits 里的正确位置。

dim=-2 会错误地索引整个 token 级别的数据,而不是单个 token logits。

🚀 正确理解 gather(dim=-1),能够帮助你高效地提取模型输出,用于计算损失、评估模型! 🚀

Understanding torch.gather: Purpose, Use Cases, and Implementation in NLP

In deep learning, particularly in Natural Language Processing (NLP) and reinforcement learning , we often need to extract specific values from high-dimensional tensors using given indices. torch.gather is a powerful PyTorch function that efficiently retrieves data along a specified dimension based on an index tensor.

This article will cover:

  • The purpose of torch.gather
  • How gather works with examples
  • Practical applications in NLP and deep learning
  • Performance considerations and comparisons with scatter

1. What is torch.gather?

1.1 Basic Syntax

python 复制代码
torch.gather(input, dim, index)
  • input: The source tensor from which values are gathered.
  • dim: The dimension along which to index values.
  • index: A tensor containing the indices of elements to extract.

1.2 How gather Works

  • It retrieves values from input at positions specified by index along the dim dimension.
  • The index tensor must have the same shape as input, except for the dim dimension.

2. Basic Examples of torch.gather

2.1 Extracting Elements from a 2D Tensor

python 复制代码
import torch

# Define a 3x4 tensor
input_tensor = torch.tensor([[10, 20, 30, 40], 
                             [50, 60, 70, 80], 
                             [90, 100, 110, 120]])

# Define the index tensor
index_tensor = torch.tensor([[0, 1], 
                             [2, 3], 
                             [1, 2]])

# Gather values along dimension 1 (columns)
output = torch.gather(input_tensor, dim=1, index=index_tensor)
print(output)

Output:

c 复制代码
tensor([[ 10,  20],
        [ 70,  80],
        [100, 110]])

Explanation:

  • dim=1 means we are indexing columns.
  • index_tensor[i, j] determines which element to select from input_tensor[i].

3. gather in NLP: Extracting Token Logits

3.1 Why Use gather in NLP?

In Transformer-based language models (GPT, BERT, etc.), we often need to compute the log probability of specific tokens. Given:

  • logits: The model's output scores for each token.
  • input_ids: The actual token indices.

We use gather to efficiently retrieve the logits corresponding to each token.

3.2 Extracting Token Logits for Loss Calculation

python 复制代码
import torch

# Simulated logits for batch_size=2, sequence_length=3, vocab_size=5
logits = torch.tensor([[[2.0, 1.0, 0.5, -1.0, 0.2], 
                         [0.1, -0.5, 2.2, 1.5, 0.0], 
                         [1.1, 3.5, 0.8, -0.2, -1.5]],

                        [[0.0, 2.3, -0.5, 1.0, 0.8], 
                         [-1.2, 1.7, 2.0, 0.3, -0.8], 
                         [2.5, -0.1, -1.2, 0.5, 3.0]]])

input_ids = torch.tensor([[0, 2, 1],  # Token indices
                          [1, 3, 4]])

# Extracting logits corresponding to input_ids
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
print(token_logits)

Output:

c 复制代码
tensor([[ 2.0000,  2.2000,  3.5000],
        [ 2.3000,  0.3000,  3.0000]])

3.3 Explanation

  1. input_ids.unsqueeze(-1) converts shape (B, L)(B, L, 1), making it compatible with gather.
  2. gather(dim=-1, index=input_ids.unsqueeze(-1)) retrieves the logits corresponding to input_ids.
  3. squeeze(-1) removes the unnecessary last dimension.

This operation is efficient and memory-friendly compared to iterating through tokens manually.


4. gather vs. scatter

While gather retrieves values from an input tensor using an index, scatter does the opposite: it writes values to an output tensor at specific indices.

4.1 Using scatter to Modify a Tensor

python 复制代码
import torch

# Initialize a 3x3 zero tensor
x = torch.zeros(3, 3)

# Define index positions
index = torch.tensor([[0, 2], 
                      [1, 1], 
                      [2, 0]])

# Define values to write
updates = torch.tensor([[5, 8], 
                         [3, 7], 
                         [6, 2]])

# Use scatter_ to update x
x.scatter_(dim=1, index=index, src=updates)
print(x)

Output:

c 复制代码
tensor([[5., 0., 8.],
        [0., 7., 0.],
        [2., 0., 6.]])

4.2 Key Difference

  • gather: Extracts values from specific indices.
  • scatter: Writes values to specific indices.

5. Performance and Memory Efficiency

5.1 Why gather is Efficient?

  • Vectorized indexing : Instead of looping through individual indices, gather efficiently extracts multiple values in parallel.
  • Lower memory footprint : Since gather does not require additional tensor allocations, it is more memory-efficient than manually indexing with loops.
  • Optimized for GPU : PyTorch internally optimizes gather to run efficiently on CUDA devices.

5.2 Performance Benchmark

python 复制代码
import time

x = torch.randn(1000, 1000)
index = torch.randint(0, 1000, (1000, 500))

start = time.time()
_ = x.gather(dim=1, index=index)
end = time.time()
print(f"gather execution time: {end - start:.6f} s")

Results (Example):

c 复制代码
gather execution time: 0.002341 s

This is much faster than manually iterating over indices.


6. Summary

  • torch.gather(input, dim, index) efficiently extracts values from a tensor using an index tensor.
  • Common use cases:
    • Extracting token logits for NLP tasks (e.g., loss computation in Transformer models).
    • Indexing probability distributions in reinforcement learning.
    • Selecting specific elements from multi-dimensional tensors.
  • gather is memory-efficient, parallelized, and optimized for GPU acceleration.
  • Comparison with scatter:
    • gather extracts values from an input tensor.
    • scatter writes values into an output tensor.

🚀 Mastering torch.gather will help you write more efficient deep learning models! 🚀

后记

2025年2月21日19点14分于上海,在GPT4o大模型辅助下完成。

相关推荐
AWS官方合作商4 分钟前
Amazon Lex:AI对话引擎重构企业服务新范式
人工智能·ai·机器人·aws
workflower9 分钟前
Prompt Engineering的重要性
大数据·人工智能·设计模式·prompt·软件工程·需求分析·ai编程
学长学姐我该怎么办21 分钟前
年前集训总结python
python
curemoon27 分钟前
理解都远正态分布中指数项的精度矩阵(协方差逆矩阵)
人工智能·算法·矩阵
量化投资技术28 分钟前
【量化科普】Sharpe Ratio,夏普比率
python·量化交易·量化·量化投资·qmt·miniqmt
yanglamei196230 分钟前
基于Python+Django+Vue的旅游景区推荐系统系统设计与实现源代码+数据库+使用说明
vue.js·python·django
虚假程序设计32 分钟前
python用 PythonNet 从 Python 调用 WPF 类库 UI 用XAML
python·ui·wpf
胡桃不是夹子1 小时前
CPU安装pytorch(别点进来)
人工智能·pytorch·python
Fansv5871 小时前
深度学习-6.用于计算机视觉的深度学习
人工智能·深度学习·计算机视觉
xjxijd2 小时前
AI 为金融领域带来了什么突破?
人工智能·其他