pytorch中的gather函数的定义和作用是什么?

在PyTorch中,gather函数是一个用于从张量(tensor)中收集特定索引位置上的元素的函数。它主要用于高级索引和从张量中提取特定信息。

定义(python)

gather函数的基本定义如下:

|---|-------------------------------------------------|
| | torch.gather(input, dim, index, out=None) |

  • input (Tensor): 输入张量。
  • dim (int): 沿其收集元素的维度。
  • index (LongTensor): 索引张量,其形状与input在除了dim维度外的所有维度上都相同。
  • out (Tensor, optional): 输出张量。

作用

gather函数的作用是根据index张量中的索引值,从input张量中沿着指定的dim维度收集元素。这可以用于提取张量中特定位置的值。

举例讲解

假设我们有一个形状为(3, 3)的二维张量input,我们想要沿着第0个维度(即行的维度)收集元素。我们还需要一个索引张量index,它告诉我们从每一行中收集哪个元素。

|---|-------------------------------------------------------------|
| | import torch |
| | |
| | # 创建一个形状为 (3, 3) 的输入张量 |
| | input = torch.tensor([[1, 2, 3], |
| | [4, 5, 6], |
| | [7, 8, 9]]) |
| | |
| | # 创建一个索引张量,它告诉我们在每一行中收集哪个元素 |
| | # 例如,第0行收集第2个元素(值为3),第1行收集第0个元素(值为4),第2行收集第1个元素(值为8) |
| | index = torch.tensor([[2], |
| | [0], |
| | [1]]) |
| | |
| | # 使用 gather 函数 |
| | output = torch.gather(input, dim=0, index=index) |
| | |
| | print(output) |

输出将会是:

tensor:

|---|-------------|
| | [4], |
| | [8]]) |

在这个例子中,gather函数沿着第0个维度(行)收集元素。对于每一行,它都使用index张量中对应的索引值来确定要收集哪个元素。因此,输出张量中的每个元素都是input张量中特定行和列的元素的组合。

注意,index张量的形状是(3, 1),这与input张量在除了第0个维度外的所有维度上的形状相匹配。这是因为我们沿着第0个维度收集元素,所以其他维度的大小必须相同。

相关推荐
ZGi.ai4 分钟前
LangChain做了什么?企业场景中它和专用AI平台的定位区别
人工智能·开源框架·企业ai·- langchain·- ai应用开发
努力努力再努力wz5 分钟前
【Linux网络系列】深入理解 I/O 多路复用:从 select 痛点到 poll 高并发服务器落地,基于 Poll、智能指针与非阻塞 I/O与线程池手写一个高性能 HTTP 服务器!(附源码)
java·linux·运维·服务器·c语言·c++·python
努力努力再努力wz8 分钟前
【Linux网络系列】万字硬核解析网络层核心:IP协议到IP 分片重组、NAT技术及 RIP/OSPF 动态路由全景
java·linux·运维·服务器·数据结构·c++·python
tjc199010059 分钟前
golang如何使用t.Cleanup清理测试_golang t.Cleanup测试清理使用策略
jvm·数据库·python
SteveLaiTVT13 分钟前
从 Curl 开始:不用 SDK,通过 DeepSeek API 手写 Agent Runtime
人工智能
小糖学代码15 分钟前
LLM系列:2.pytorch入门:3.基本优化思想与最小二乘法
人工智能·python·算法·机器学习·ai·数据挖掘·最小二乘法
J_bean17 分钟前
大语言模型 API Token 消耗深度剖析
人工智能·ai·llm·大语言模型·token
醉卧考场君莫笑18 分钟前
规则与传统NLP之任务范式
人工智能·自然语言处理
214396521 分钟前
如何提升SQL数据更新的安全性_使用行级锁与悲观锁机制
jvm·数据库·python
叶子丶苏22 分钟前
第二节_机器学习基本知识点
人工智能·python·机器学习·数据科学