【pytorch】pytorch中的高级索引

这里只介绍pytorch的高级索引,是一些奇怪的切片索引

基本版

a[[0, 2], [1, 2]] 等价 a[0, 1] 和 a[2, 2],相当于索引张量的第一行的第二列和第三行的第三列元素;

a[[1, 0, 2], [0]] 等价 a[1, 0] 和 a[0, 0] 和 a[2, 0],相当于索引张量的第二行第一列的元素、张量第一行和第一列的元素以及张量第三行和第一列的元素

python 复制代码
import torch
a = torch.arange(9).view([3, 3])

print(a)
b = a[[0, 2], [1, 2]]

print(b)

c = a[[1, 0, 2], [0]]

print(c)


# ---------output----------
# tensor([[0, 1, 2],
#         [3, 4, 5],
#         [6, 7, 8]])
# tensor([1, 8])
# tensor([3, 0, 6])

# 这里参考了:https://zhuanlan.zhihu.com/p/509591863

高级索引的原则:索引中有: 就代表着改维度全部取,在哪个维度放置索引,就代表想取哪个维度的内容

扩展A:

python 复制代码
import torch
a = torch.arange(30).view([2, 5, 3])  # 假如a代表N, x, y
print(a)

b = torch.tensor([[0, 2],
                  [1, 2]])
print(a[:, b[0, :], :])

# ------output-------
# tensor([[[ 0,  1,  2],
#          [ 3,  4,  5],
#          [ 6,  7,  8],
#          [ 9, 10, 11],
#          [12, 13, 14]],
#
#         [[15, 16, 17],
#          [18, 19, 20],
#          [21, 22, 23],
#          [24, 25, 26],
#          [27, 28, 29]]])
# tensor([[[ 0,  1,  2],
#          [ 6,  7,  8]],
#
#         [[15, 16, 17],
#          [21, 22, 23]]])

# 上述代码含义是对a的所有N, 按b中的第一行取出所有的行

扩展B:

python 复制代码
a = torch.arange(30).view([2, 5, 3])
print(a)

b = torch.tensor([[0, 2],
                  [1, 2]])
print(a[:, b[:, 0], b[:, 1]])

# ----------output--------------
# tensor([[[ 0,  1,  2],
#          [ 3,  4,  5],
#          [ 6,  7,  8],
#          [ 9, 10, 11],
#          [12, 13, 14]],
# 
#         [[15, 16, 17],
#          [18, 19, 20],
#          [21, 22, 23],
#          [24, 25, 26],
#          [27, 28, 29]]])
# tensor([[ 2,  5],
#         [17, 20]])

# 上述代码含义是对a的所有batch, 按b中的元素取出a中的x, y; 取N次

扩展C: (最抽象的一次)

python 复制代码
a = torch.arange(30).view([2, 5, 3])
print(a)

b = torch.tensor([[0, 2],
                  [1, 2]])
print(a[:, b, :])

# ------output-------
# tensor([[[ 0,  1,  2],
#          [ 3,  4,  5],
#          [ 6,  7,  8],
#          [ 9, 10, 11],
#          [12, 13, 14]],
# 
#         [[15, 16, 17],
#          [18, 19, 20],
#          [21, 22, 23],
#          [24, 25, 26],
#          [27, 28, 29]]])
# tensor([[[[ 0,  1,  2],
#           [ 6,  7,  8]],
# 
#          [[ 3,  4,  5],
#           [ 6,  7,  8]]],
# 
# 
#         [[[15, 16, 17],
#           [21, 22, 23]],
# 
#          [[18, 19, 20],
#           [21, 22, 23]]]])

# 上述代码含义是对a的所有batch, 按b中的元素取出a中行; 取N * b[0]次

torch.gather函数

本来想使用torch.gather函数完成上述功能,实验后发现并不直观,还是用高级索引吧。这里放个torch.gather函数单独的内容吧。

python 复制代码
import torch

tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)


index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

#-------------output------------
# tensor([[ 3,  4,  5],
#         [ 6,  7,  8],
#         [ 9, 10, 11]])
# tensor([[9, 7, 5]])
# tensor([[5, 4, 3]])

# torch.gather的理解
# index=[ [x1,x2,x2],
# [y1,y2,y2],
# [z1,z2,z3] ]
# 
# 如果dim=0
# 填入方式
# [ [(x1,0),(x2,1),(x3,2)]
# [(y1,0),(y2,1),(y3,2)]
# [(z1,0),(z2,1),(z3,2)] ]
#
# 如果dim=1
# [ [(0,x1),(0,x2),(0,x3)]
# [(1,y1),(1,y2),(1,y3)]
# [(2,z1),(2,z2),(2,z3)] ]

# 参考: https://zhuanlan.zhihu.com/p/352877584
相关推荐
Luminbox紫创测控1 小时前
汽车自动驾驶的太阳光模拟应用研究
人工智能·自动驾驶·汽车
吴佳浩6 小时前
大模型量化部署终极指南:让700亿参数的AI跑进你的显卡
人工智能·python·gpu
跨境卫士苏苏7 小时前
亚马逊AI广告革命:告别“猜心”,迎接“共创”时代
大数据·人工智能·算法·亚马逊·防关联
珠海西格电力7 小时前
零碳园区工业厂房光伏一体化(BIPV)基础规划
大数据·运维·人工智能·智慧城市·能源
diegoXie7 小时前
Python / R 向量顺序分割与跨步分割
开发语言·python·r语言
七牛云行业应用7 小时前
解决OSError: No space left... 给DeepSeek Agent装上无限云硬盘
python·架构设计·七牛云·deepseek·agent开发
土星云SaturnCloud7 小时前
不止是替代:从机械风扇的可靠性困局,看服务器散热技术新范式
服务器·网络·人工智能·ai
小马爱打代码7 小时前
Spring AI:搭建自定义 MCP Server:获取 QQ 信息
java·人工智能·spring
你们补药再卷啦8 小时前
ai(三)环境资源管理
人工智能·语言模型·电脑
BoBoZz198 小时前
CutWithScalars根据标量利用vtkContourFilter得到等值线
python·vtk·图形渲染·图形处理