【PyTorch】5.张量索引操作

目录

[1. 简单行、列索引](#1. 简单行、列索引)

[2. 列表索引](#2. 列表索引)

[3. 范围索引](#3. 范围索引)

[4. 布尔索引](#4. 布尔索引)

[5. 多维索引](#5. 多维索引)


个人主页:Icomi

在深度学习蓬勃发展的当下,PyTorch 是不可或缺的工具。它作为强大的深度学习框架,为构建和训练神经网络提供了高效且灵活的平台。神经网络作为人工智能的核心技术,能够处理复杂的数据模式。通过 PyTorch,我们可以轻松搭建各类神经网络模型,实现从基础到高级的人工智能应用。接下来,就让我们一同走进 PyTorch 的世界,探索神经网络与人工智能的奥秘。本系列为PyTorch入门文章,若各位大佬想持续跟进,欢迎与我交流互关。

大家好,我是一颗米,前面咱们已经深入学习了张量的拼接操作,这让我们在搭建神经网络时能够更灵活地组合数据,就像掌握了一套精妙的 "组合拳",让模型构建更加得心应手。然而,在我们使用张量的过程中,还有一项关键技能,如同在茂密的数据丛林中开辟道路的利器,那就是张量的花式索引操作。

当我们在操作张量时,就好比在管理一个庞大而复杂的仓库,里面存放着各种各样的数据。经常会遇到这样的情况,我们需要从这个 "数据仓库" 里精准地获取特定的数据,或者对某些数据进行修改。这时候,张量的花式索引操作就派上用场了。

想象一下,你置身于一个堆满了各种规格货物(数据)的巨大仓库,**普通的索引方式就像只能按顺序依次寻找货物,效率低下。而花式索引则像是给了你一把神奇的钥匙,**能够让你直接定位到那些分散在不同位置的特定货物,快速地获取或者修改它们。

掌握张量的花式索引操作,是我们必须具备的一项能力。它就像一把万能钥匙,能让我们在处理张量数据时更加高效、精准。无论是构建复杂的神经网络模型,还是对模型输出的数据进行后处理,花式索引都能帮助我们更灵活地操作张量。

接下来,我们就一起深入探索张量的花式索引操作,看看这把神奇的 "钥匙" 究竟有哪些神奇的用法,如何帮助我们在数据的海洋中畅行无阻。

1. 简单行、列索引

准备数据

python 复制代码
import torch

data = torch.randint(0, 10, [4, 5])
print(data)
print('-' * 50)

结果

python 复制代码
tensor([[0, 7, 6, 5, 9],
        [6, 8, 3, 1, 0],
        [6, 3, 8, 7, 3],
        [4, 9, 5, 3, 1]])
--------------------------------------------------
python 复制代码
# 1. 简单行、列索引
def test01():

    print(data[0])
    print(data[:, 0])
    print('-' * 50)

if __name__ == '__main__':
    test01()

程序输出结果

python 复制代码
tensor([0, 7, 6, 5, 9])
tensor([0, 6, 6, 4])
--------------------------------------------------

2. 列表索引

python 复制代码
# 2. 列表索引
def test02():

    # 返回 (0, 1)、(1, 2) 两个位置的元素
    print(data[[0, 1], [1, 2]])
    print('-' * 50)

    # 返回 0、1 行的 1、2 列共4个元素
    print(data[[[0], [1]], [1, 2]])
if __name__ == '__main__':
    test02()

输出结果

python 复制代码
tensor([7, 3])
--------------------------------------------------
tensor([[7, 6],
        [8, 3]])

3. 范围索引

python 复制代码
# 3. 范围索引
def test03():
    # 前3行的前2列数据
    print(data[:3, :2])
    # 第2行到最后的前2列数据
    print(data[2:, :2])
if __name__ == '__main__':
    test03()

结果

python 复制代码
tensor([[0, 7],
        [6, 8],
        [6, 3]])
tensor([[6, 3],
        [4, 9]])

4. 布尔索引

python 复制代码
# 布尔索引
def test():

    # 第三列大于5的行数据
    print(data[data[:, 2] > 5])
    # 第二行大于5的列数据
    print(data[:, data[1] > 5])
if __name__ == '__main__':
    test04()

输出结果

python 复制代码
tensor([[0, 7, 6, 5, 9],
        [6, 3, 8, 7, 3]])
tensor([[0, 7],
        [6, 8],
        [6, 3],
        [4, 9]])

5. 多维索引

python 复制代码
# 多维索引
def test05():

    data = torch.randint(0, 10, [3, 4, 5])
    print(data)
    print('-' * 50)

    print(data[0, :, :])
    print(data[:, 0, :])
    print(data[:, :, 0])


if __name__ == '__main__':
    test05()

输出结果

python 复制代码
tensor([[[2, 4, 1, 2, 3],
         [5, 5, 1, 5, 0],
         [1, 4, 5, 3, 8],
         [7, 1, 1, 9, 9]],

        [[9, 7, 5, 3, 1],
         [8, 8, 6, 0, 1],
         [6, 9, 0, 2, 1],
         [9, 7, 0, 4, 0]],

        [[0, 7, 3, 5, 6],
         [2, 4, 6, 4, 3],
         [2, 0, 3, 7, 9],
         [9, 6, 4, 4, 4]]])
--------------------------------------------------
tensor([[2, 4, 1, 2, 3],
        [5, 5, 1, 5, 0],
        [1, 4, 5, 3, 8],
        [7, 1, 1, 9, 9]])
tensor([[2, 4, 1, 2, 3],
        [9, 7, 5, 3, 1],
        [0, 7, 3, 5, 6]])
tensor([[2, 5, 1, 7],
        [9, 8, 6, 9],
        [0, 2, 2, 9]])
相关推荐
纠结哥_Shrek7 分钟前
pytorch使用SVM实现文本分类
pytorch·支持向量机·分类
英国翰思教育13 分钟前
留学毕业论文如何利用不同问题设计问卷
人工智能·深度学习·学习·算法·学习方法·论文笔记
西猫雷婶22 分钟前
python学opencv|读取图像(四十七)使用cv2.bitwise_not()函数实现图像按位取反运算
开发语言·python·opencv
gaoenyang7605251 小时前
探索高效图像识别:基于OpenCV的形状匹配利器
人工智能·opencv·计算机视觉
背太阳的牧羊人1 小时前
分词器的词表大小以及如果分词器的词表比模型的词表大,那么模型的嵌入矩阵需要被调整以适应新的词表大小。
开发语言·人工智能·python·深度学习·矩阵
人工智能教学实践2 小时前
无人机红外热成像:应急消防的“透视眼”
人工智能·目标跟踪
QQ_7781329742 小时前
ChatGPT高效处理图片技巧使用详解
人工智能·chatgpt
码界筑梦坊2 小时前
基于Django的豆瓣影视剧推荐系统的设计与实现
后端·python·django·毕业设计
fmdpenny3 小时前
前后分离Vue3+Django 之简单的登入
后端·python·django
MYT_flyflyfly3 小时前
计算机视觉之三维重建-单视几何
人工智能·数码相机·计算机视觉