【PyTorch】5.张量索引操作

目录

[1. 简单行、列索引](#1. 简单行、列索引)

[2. 列表索引](#2. 列表索引)

[3. 范围索引](#3. 范围索引)

[4. 布尔索引](#4. 布尔索引)

[5. 多维索引](#5. 多维索引)


个人主页:Icomi

在深度学习蓬勃发展的当下,PyTorch 是不可或缺的工具。它作为强大的深度学习框架,为构建和训练神经网络提供了高效且灵活的平台。神经网络作为人工智能的核心技术,能够处理复杂的数据模式。通过 PyTorch,我们可以轻松搭建各类神经网络模型,实现从基础到高级的人工智能应用。接下来,就让我们一同走进 PyTorch 的世界,探索神经网络与人工智能的奥秘。本系列为PyTorch入门文章,若各位大佬想持续跟进,欢迎与我交流互关。

大家好,我是一颗米,前面咱们已经深入学习了张量的拼接操作,这让我们在搭建神经网络时能够更灵活地组合数据,就像掌握了一套精妙的 "组合拳",让模型构建更加得心应手。然而,在我们使用张量的过程中,还有一项关键技能,如同在茂密的数据丛林中开辟道路的利器,那就是张量的花式索引操作。

当我们在操作张量时,就好比在管理一个庞大而复杂的仓库,里面存放着各种各样的数据。经常会遇到这样的情况,我们需要从这个 "数据仓库" 里精准地获取特定的数据,或者对某些数据进行修改。这时候,张量的花式索引操作就派上用场了。

想象一下,你置身于一个堆满了各种规格货物(数据)的巨大仓库,**普通的索引方式就像只能按顺序依次寻找货物,效率低下。而花式索引则像是给了你一把神奇的钥匙,**能够让你直接定位到那些分散在不同位置的特定货物,快速地获取或者修改它们。

掌握张量的花式索引操作,是我们必须具备的一项能力。它就像一把万能钥匙,能让我们在处理张量数据时更加高效、精准。无论是构建复杂的神经网络模型,还是对模型输出的数据进行后处理,花式索引都能帮助我们更灵活地操作张量。

接下来,我们就一起深入探索张量的花式索引操作,看看这把神奇的 "钥匙" 究竟有哪些神奇的用法,如何帮助我们在数据的海洋中畅行无阻。

1. 简单行、列索引

准备数据

python 复制代码
import torch

data = torch.randint(0, 10, [4, 5])
print(data)
print('-' * 50)

结果

python 复制代码
tensor([[0, 7, 6, 5, 9],
        [6, 8, 3, 1, 0],
        [6, 3, 8, 7, 3],
        [4, 9, 5, 3, 1]])
--------------------------------------------------
python 复制代码
# 1. 简单行、列索引
def test01():

    print(data[0])
    print(data[:, 0])
    print('-' * 50)

if __name__ == '__main__':
    test01()

程序输出结果

python 复制代码
tensor([0, 7, 6, 5, 9])
tensor([0, 6, 6, 4])
--------------------------------------------------

2. 列表索引

python 复制代码
# 2. 列表索引
def test02():

    # 返回 (0, 1)、(1, 2) 两个位置的元素
    print(data[[0, 1], [1, 2]])
    print('-' * 50)

    # 返回 0、1 行的 1、2 列共4个元素
    print(data[[[0], [1]], [1, 2]])
if __name__ == '__main__':
    test02()

输出结果

python 复制代码
tensor([7, 3])
--------------------------------------------------
tensor([[7, 6],
        [8, 3]])

3. 范围索引

python 复制代码
# 3. 范围索引
def test03():
    # 前3行的前2列数据
    print(data[:3, :2])
    # 第2行到最后的前2列数据
    print(data[2:, :2])
if __name__ == '__main__':
    test03()

结果

python 复制代码
tensor([[0, 7],
        [6, 8],
        [6, 3]])
tensor([[6, 3],
        [4, 9]])

4. 布尔索引

python 复制代码
# 布尔索引
def test():

    # 第三列大于5的行数据
    print(data[data[:, 2] > 5])
    # 第二行大于5的列数据
    print(data[:, data[1] > 5])
if __name__ == '__main__':
    test04()

输出结果

python 复制代码
tensor([[0, 7, 6, 5, 9],
        [6, 3, 8, 7, 3]])
tensor([[0, 7],
        [6, 8],
        [6, 3],
        [4, 9]])

5. 多维索引

python 复制代码
# 多维索引
def test05():

    data = torch.randint(0, 10, [3, 4, 5])
    print(data)
    print('-' * 50)

    print(data[0, :, :])
    print(data[:, 0, :])
    print(data[:, :, 0])


if __name__ == '__main__':
    test05()

输出结果

python 复制代码
tensor([[[2, 4, 1, 2, 3],
         [5, 5, 1, 5, 0],
         [1, 4, 5, 3, 8],
         [7, 1, 1, 9, 9]],

        [[9, 7, 5, 3, 1],
         [8, 8, 6, 0, 1],
         [6, 9, 0, 2, 1],
         [9, 7, 0, 4, 0]],

        [[0, 7, 3, 5, 6],
         [2, 4, 6, 4, 3],
         [2, 0, 3, 7, 9],
         [9, 6, 4, 4, 4]]])
--------------------------------------------------
tensor([[2, 4, 1, 2, 3],
        [5, 5, 1, 5, 0],
        [1, 4, 5, 3, 8],
        [7, 1, 1, 9, 9]])
tensor([[2, 4, 1, 2, 3],
        [9, 7, 5, 3, 1],
        [0, 7, 3, 5, 6]])
tensor([[2, 5, 1, 7],
        [9, 8, 6, 9],
        [0, 2, 2, 9]])
相关推荐
秀儿还能再秀26 分钟前
淘宝母婴购物数据可视化分析(基于脱敏公开数据集)
python·数据分析·学习笔记·数据可视化
邵奈一32 分钟前
运行OpenManus项目(使用Conda)
人工智能·大模型·agent·agi
计算机老学长42 分钟前
基于Python的商品销量的数据分析及推荐系统
开发语言·python·数据分析
是理不是里_1 小时前
深度学习与普通神经网络有何区别?
人工智能·深度学习·神经网络
曲幽1 小时前
DeepSeek大语言模型下几个常用术语
人工智能·ai·语言模型·自然语言处理·ollama·deepseek
千益1 小时前
玩转python:系统设计模式在Python项目中的应用
python·设计模式
&白帝&1 小时前
Java @PathVariable获取路径参数
java·开发语言·python
AORO_BEIDOU2 小时前
科普|卫星电话有哪些应用场景?
网络·人工智能·安全·智能手机·信息与通信
dreamczf2 小时前
基于Linux系统的边缘智能终端(RK3568+EtherCAT+PCIe+4G+5G)
linux·人工智能·物联网·5g