pytorch中torch.gather()简单理解

1.作用

从输入张量中按照指定维度进行索引采集操作,返回值是一个新的张量,形状与 index 张量相同,根据指定的索引从输入张量中采集对应的元素。

2.问题

该函数的主要问题主要在dim维度上,dim=0 表示沿着第一个维度(行)进行索引采集,而 dim=1 表示沿着第二个维度(列)进行索引采集。
简单讲:dim=0,将在行上进行采集,行数不变,在列上取值,如下图中的例子torch.gather(input, dim=0, index=index),当dim=0时,[0,1]中0对应第一行第一列也就是1,1对应着第2行第2列(1在index的坐标为(1,2),dim=0,所以不用看index的行坐标的,只管纵坐标,也就是第2列。而此时值为1代表值input的行【需要+1,下标是从0开始的】,也就是第2行,值为4)也就是4;[1,0]中的1代表第2行第1列也就是3,0也就是第2列的第一行数据(此时index的0的坐标为(2,2),因为dim=0,也就不用看横坐标,也就是第二列。所以此时的0代表源input的行坐标,也就是第一行)也就是2

python 复制代码
input = torch.tensor([[1, 2], [3, 4], [5, 6]])
index = torch.tensor([[0, 1], [1, 0]])

result = torch.gather(input, dim=0, index=index)
result_colum = torch.gather(input, dim=1, index=index)
print("result:",result)
print("result_colum:",result_colum)

结果如下:

相关推荐
姚瑞南1 分钟前
【AI 风向标】四种深度学习算法(CNN、RNN、GAN、RL)的通俗解释
人工智能·深度学习·算法
渡我白衣23 分钟前
深度学习入门(一)——从神经元到损失函数,一步步理解前向传播(上)
人工智能·深度学习·学习
补三补四23 分钟前
SMOTE 算法详解:解决不平衡数据问题的有效工具
人工智能·算法
为java加瓦24 分钟前
前端学AI:如何写好提示词(prompt)
前端·人工智能·prompt
一车小面包25 分钟前
对注意力机制的直观理解
人工智能·深度学习·机器学习
逝水年华QAQ26 分钟前
什么是Edge TTS?
人工智能
ARM+FPGA+AI工业主板定制专家33 分钟前
基于NVIDIA ORIN+FPGA+AI自动驾驶硬件在环注入测试
人工智能·fpga开发·机器人·自动驾驶
AI小云38 分钟前
【Python与AI基础】Python编程基础:模块和包
人工智能·python
用户51914958484540 分钟前
Paytium WordPress插件存储型XSS漏洞深度分析
人工智能·aigc
weixin_433417671 小时前
PyTorch&TensorFlow
人工智能·pytorch·tensorflow