Torch.gather

1.官方文档

2.使用要点

  • 输入index的shape等于输出value的shape
  • 输入index的索引值仅替换该index中对应dim的index值
  • 最终输出为替换index后在原tensor中的值

最终输出的shape和index的shape相同

根据dim的值 选择将index[i,j,k]这个结果替换input[i,j,k]里面对应的i or j or k ,并将结果存储到output[i,j,k]

3.实际应用

一维

python 复制代码
import torch
import torch.nn as nn
arr = torch.tensor([1, 2, 3])
index = torch.tensor([0, 1])
result = torch.gather(arr,0, index)
print(result)
"""
tensor([1, 2])
"""

二维

python 复制代码
import torch
arr = torch.tensor([[1, 2, 3],
                [4, 5, 6],
                [7, 8, 9]])
index = torch.tensor([[0, 1],
                  [1, 2]])
result = torch.gather(arr,1, index)
print(result)
"""
dim=0
tensor([[1, 5],
        [4, 8]])
dim=1
tensor([[1, 2],
        [5, 6]])
"""

三维

python 复制代码
import torch
# 创建一个较小的三维张量
tensor_3d = torch.tensor([
    [[1, 2],
     [3, 4]],
    [[5, 6],
     [7, 8]]
], dtype=torch.float32)
# 创建索引张量
index_3d = torch.tensor([
    [[0, 1],
     [1, 0]],
    [[1, 0],
     [0, 1]]
], dtype=torch.long)
# 在 dim = 0 上进行 gather 操作
result_dim0 = tensor_3d.gather(dim=0, index=index_3d)
print("在 dim = 0 上的 gather 结果:")
print(result_dim0)
# 在 dim = 1 上进行 gather 操作
result_dim1 = tensor_3d.gather(dim=1, index=index_3d)
print("在 dim = 1 上的 gather 结果:")
print(result_dim1)
# 在 dim = 2 上进行 gather 操作
result_dim2 = tensor_3d.gather(dim=2, index=index_3d)
print("在 dim = 2 上的 gather 结果:")
print(result_dim2)
"""
在 dim = 0 上的 gather 结果:
tensor([[[1., 6.],
         [7., 4.]],

        [[5., 2.],
         [3., 8.]]])
在 dim = 1 上的 gather 结果:
tensor([[[1., 4.],
         [3., 2.]],

        [[7., 6.],
         [5., 8.]]])
在 dim = 2 上的 gather 结果:
tensor([[[1., 2.],
         [4., 3.]],

        [[6., 5.],
         [7., 8.]]])
"""
相关推荐
liwulin05061 分钟前
【PYTHON-YOLOV8N】yoloface+pytorch+cnn进行面部表情识别
python·yolo·cnn
(●—●)橘子……17 分钟前
记力扣1471.数组中的k个最强值 练习理解
数据结构·python·学习·算法·leetcode
会挠头但不秃18 分钟前
深度学习(5)循环神经网络
人工智能·rnn·深度学习
_OP_CHEN21 分钟前
用极狐 CodeRider-Kilo 开发俄罗斯方块:AI 辅助编程的沉浸式体验
人工智能·vscode·python·ai编程·ai编程插件·coderider-kilo
Wpa.wk23 分钟前
自动化测试 - 文件上传 和 弹窗处理
开发语言·javascript·自动化测试·经验分享·爬虫·python·selenium
_OP_CHEN25 分钟前
【Python基础】(二)从 0 到 1 入门 Python 语法基础:从表达式到运算符的全面指南
开发语言·python
_Li.29 分钟前
机器学习-贝叶斯公式
人工智能·机器学习·概率论
我命由我1234534 分钟前
Python Flask 开发:在 Flask 中返回字符串时,浏览器将其作为 HTML 解析
服务器·开发语言·后端·python·flask·html·学习方法
拾忆,想起36 分钟前
设计模式:软件开发的可复用武功秘籍
开发语言·python·算法·微服务·设计模式·性能优化·服务发现
哥布林学者40 分钟前
吴恩达深度学习课程四:计算机视觉 第二周:经典网络结构 课后习题和代码实践
深度学习·ai