Pytorch使用教学4-张量的索引

1 张量的符号索引

张量也是有序序列,我们可以根据每个元素在系统内的顺序位置,来找出特定的元素,也就是索引。

1.1 一维张量的索引

一维张量由零维张量构成

一维张量索引与Python中的索引一样是是从左到右,从0开始的,遵循格式为[start: end: step]

Python 复制代码
t1 = torch.arange(1, 11)
t1
# tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10])

# 取出索引位置是0的元素
t1[0]
# tensor(1)

:张量索引出的结果是零维张量,而不是单独的数。要转化成单独的数还需使用上节介绍的item()方法。

可理解为构成一维张量的是零维张量,而不是单独的数。

张量的step必须大于0

Python 复制代码
# 索引3-10号元素,左闭右开,默认step为1
t1[2: 8]
# tensor([3, 4, 5, 6, 7, 8])

# step=3,隔3个数取一个,左闭右开
t1[2: 8: 2]
# tensor([3, 5, 7])

Python中,step可以为负数,例如:

Python 复制代码
li = [1, 2, 3]
# 列表倒叙排列,取所有数值,从后往前取
li[ ::-1]
# [3, 2, 1]

但在张量中,step必须大于1,否则就会报错。

Python 复制代码
t1 = torch.arange(1, 11)
t1[ ::-1]
# ValueError: step must be greater than zero

1.2 二维张量的索引

二维张量的索引逻辑和一维张量的索引逻辑相同,二维张量可以视为两个一维张量组合而成。

Python 复制代码
t2 = torch.arange(1, 17).reshape(4, 4)
t2
#tensor([[ 1,  2,  3,  4],
#        [ 5,  6,  7,  8],
#        [ 9, 10, 11, 12],
#        [13, 14, 15, 16]])

t2[0,1]也可用t2[0][1]的表示。

Python 复制代码
# 表示索引第一行、第二个(第二列的)元素
t2[0, 1]
# tensor(2)

t2[0][1]
# tensor(2)

但是t2[::2, ::2]t2[::2][ ::2]的索引结果就不同:

Python 复制代码
t2[::2, ::2]
# tensor([[ 1,  3],
#        [ 9, 11]])

t2[::2][::2]
# tensor([[1, 2, 3, 4]])

t2[::2, ::2]二维索引使用逗号隔开时,可以理解为全局索引,取第一行和第三行的第一列和第三列的元素。

t2[::2][::2]二维索引在两个中括号中时,可以理解为先取了第一行和第三行,构成一个新的二维张量,然后在此基础上又间隔2并对所有张量进行索引。

Python 复制代码
tt = t2[::2]
# tensor([[ 1,  2,  3,  4],
#         [ 9, 10, 11, 12]])
tt[::2]
# tensor([[1, 2, 3, 4]])

1.3 三维张量的索引

设三维张量的shapex、y、z,则可理解为它是由x个二维张量构成,每个二维张量由y个一维张量构成,每个一维张量由z个元素构成。

Python 复制代码
t3 = torch.arange(1, 28).reshape(3, 3, 3)
t3
# tensor([[[ 1,  2,  3],
#         [ 4,  5,  6],
#         [ 7,  8,  9]],

#         [[10, 11, 12],
#         [13, 14, 15],
#         [16, 17, 18]],

#         [[19, 20, 21],
#         [22, 23, 24],
#         [25, 26, 27]]])

# 索引第二个矩阵中的第二行、第二个元素
t3[1, 1, 1]
# tensor(14)

# 索引第二个矩阵,行和列都是每隔两个取一个
t3[1, ::2, ::2]
# tensor([[10, 12],
#         [16, 18]])

高维张量的思路与低维一样,就是围绕张量的"形状"进行索引。

2 张量的函数索引

2.1 一维张量的函数索引

PyTorch中,我们还可以使用index_select函数指定index来对张量进行索引,index的类型必须为Tensor

index_select(dim, index)表示在张量的哪个维度进行索引,索引的位值是多少。

Python 复制代码
t1 = torch.arange(1, 11)
indices = torch.tensor([1, 2])
# tensor([1, 2])
t1.index_select(0, indices)
# tensor([2, 3])

对于t1这个一维向量来说,由于只有一个维度,第二个参数取值为0,就代表在第一个维度上进行索引,索引的位置是1和2。

:这里取出的是位置,而不是取出[1:2]区间内左闭右开的元素。

2.2 二维张量的函数索引

Python 复制代码
t2 = torch.arange(12).reshape(4, 3)
t2
# tensor([[ 0,  1,  2],
#         [ 3,  4,  5],
#         [ 6,  7,  8],
#         [ 9, 10, 11]])

t2.shape
# torch.Size([4, 3])

indices = torch.tensor([1, 2])
t2.index_select(0,indices)
# tensor([[3, 4, 5],
#         [6, 7, 8]])

此时dim参数取值为0,代表在shape的第一个维度上进行索引。

Python 复制代码
t2 = torch.arange(12).reshape(4, 3)
indices = torch.tensor([1, 1])
t2.index_select(1, indices)
# tensor([[ 1,  1],
#        [ 4,  4],
#        [ 7,  7],
#        [10, 10]])

此时dim参数取值为1,代表在shape的第二个维度上进行索引。index参数的值为[1,1],就代表取出第二个维度上为1的元素2次。

下面可以再次理解:

Python 复制代码
t2 = torch.arange(12).reshape(4, 3)
t2
# tensor([[ 0,  1,  2],
#         [ 3,  4,  5],
#         [ 6,  7,  8],
#         [ 9, 10, 11]])

t2.shape
# torch.Size([4, 3])

indices = torch.tensor([2, 2, 2])
t2.index_select(1, indices)
# tensor([[ 2,  2,  2],
#         [ 5,  5,  5],
#         [ 8,  8,  8],
#         [11, 11, 11]])

取出第二个维度上为2的元素3次。

高维张量函数索引的思路与低维一样,都是在shape的维度上进行操作。

PyTorch中很多函数都采用的是第几维的思路,后面会介绍给大家,大家还需勤加练习,适应这种思路。同时使用函数式索引,在习惯后对代码可读性会有很大提升。

Pytorch张量操作大全:

Pytorch使用教学1-Tensor的创建
Pytorch使用教学2-Tensor的维度
Pytorch使用教学3-特殊张量的创建与类型转化
Pytorch使用教学4-张量的索引
Pytorch使用教学5-视图view与reshape的区别
Pytorch使用教学6-张量的分割与合并
Pytorch使用教学7-张量的广播
Pytorch使用教学8-张量的科学运算
Pytorch使用教学9-张量的线性代数运算
Pytorch使用教学10-张量操作方法大总结

有关Pytorch建模相关的AI干货请扫码关注公众号「AI有温度」阅读获取

相关推荐
查理零世30 分钟前
保姆级讲解 python之zip()方法实现矩阵行列转置
python·算法·矩阵
刀客12341 分钟前
python3+TensorFlow 2.x(四)反向传播
人工智能·python·tensorflow
SpikeKing1 小时前
LLM - 大模型 ScallingLaws 的设计 100B 预训练方案(PLM) 教程(5)
人工智能·llm·预训练·scalinglaws·100b·deepnorm·egs
小枫@码1 小时前
免费GPU算力,不花钱部署DeepSeek-R1
人工智能·语言模型
liruiqiang051 小时前
机器学习 - 初学者需要弄懂的一些线性代数的概念
人工智能·线性代数·机器学习·线性回归
Icomi_1 小时前
【外文原版书阅读】《机器学习前置知识》1.线性代数的重要性,初识向量以及向量加法
c语言·c++·人工智能·深度学习·神经网络·机器学习·计算机视觉
微学AI1 小时前
GPU算力平台|在GPU算力平台部署可图大模型Kolors的应用实战教程
人工智能·大模型·llm·gpu算力
西猫雷婶1 小时前
python学opencv|读取图像(四十六)使用cv2.bitwise_or()函数实现图像按位或运算
人工智能·opencv·计算机视觉
IT古董1 小时前
【深度学习】常见模型-生成对抗网络(Generative Adversarial Network, GAN)
人工智能·深度学习·生成对抗网络
Jackilina_Stone1 小时前
【论文阅读笔记】“万字”关于深度学习的图像和视频阴影检测、去除和生成的综述笔记 | 2024.9.3
论文阅读·人工智能·笔记·深度学习·ai