【pytorch】torch.gather()函数

dim=0时

python 复制代码
index=[ [x1,x2,x2],
		[y1,y2,y2],
		[z1,z2,z3] ]

如果dim=0
填入方式为:
index=[ [(x1,0),(x2,1),(x3,2)]
		[(y1,0),(y2,1),(y3,2)]
		[(z1,0),(z2,1),(z3,2)] ]
python 复制代码
input = [
    [1, 2, 3, 4],
    [5, 6, 7, 8],
    [9, 10, 11, 12]
] # shape(3,4)
input = torch.tensor(input)
length = torch.LongTensor([
    [2,2,2,2],
    [1,1,1,1],
    [0,0,0,0],
    [0,1,2,0]
])# shape(4,4)
out = torch.gather(input, dim=0, index=length)
print(out)
python 复制代码
tensor([[9, 10, 11, 12],
        [5, 6, 7, 8],
        [1, 2, 3, 4],
        [1, 6, 11, 4]])
python 复制代码
#### dim=0后,根据new_index对input进行索引
new_index=[ [(2,0),(2,1),(2,2),(2,3)],
			[(1,0),(1,1),(1,2),(1,3)],
			[(0,0),(0,1),(0,2),(0,3)],
			[(0,0),(1,1),(2,2),(0,3)] ]
			
可以观察到第四行,行索引变为0,所以当gather函数里的index超过input的唯独时,会从0重新计数。

dim=1时

python 复制代码
input = [
    [1, 2, 3, 4],
    [5, 6, 7, 8],
    [9, 10, 11, 12]
] # shape(3,4)
input = torch.tensor(input)
length = torch.LongTensor([
    [2,2,2,2],
    [1,1,1,1],
    [0,1,2,0]
]) # shape(3,4)
out = torch.gather(input, dim=1, index=length)
print(out)
python 复制代码
tensor([[3, 3, 3, 3],
        [6, 6, 6, 6],
        [9, 10, 11, 9]])
python 复制代码
new_index = [
	[(0,2),(0,2),(0,2),(0,2)],
	[(1,1),(1,1),(1,1),(1,1)],
	[(2,0),(2,1),(2,2)(2,0)]
]
相关推荐
程序猿阿伟几秒前
《无需额外付费的OpenClaw Agent部署指南》
人工智能
DS随心转APP4 分钟前
AI导出鸭:AI 文档排版与一键导出实战指南
人工智能·ai·chatgpt·deepseek·ai导出鸭
geneculture5 分钟前
语(暨各级各类字组)对接外来的词和句以及本土的言和语:言和语的关系及双重形式化彻底解决问题
人工智能·语言学·融智学应用场景·哲学与科学统一性·融智时代(杂志)
凯丨6 分钟前
agentmemory on NAS 完整部署文档(Tailscale + DeepSeek 压缩 + 局域网 viewer)
人工智能
YsyaaabB6 分钟前
LangChain作业二---多语言翻译Prompt
开发语言·python·langchain
weixin_446260857 分钟前
Vortex:高效可编程稀疏注意力机制用于大模型推理服务
人工智能
AI科技星7 分钟前
精细结构常数α的多维度物理比值特性及空间螺旋模型研究
人工智能·线性代数·架构·概率论·学习方法
zhangfeng11338 分钟前
头部AI公司模以OpenAI、DeepSeek为代表型版本迭代训练策略深度解析:重新训练 vs. 增量训练(前瞻性技术推演
人工智能
HappyAcmen8 分钟前
2.PDF长文档完整读取
python·pdf·rag
装不满的克莱因瓶8 分钟前
掌握感知器的学习原理
人工智能·python·神经网络·算法·ai·卷积神经网络