【PyTorch】5.张量索引操作

目录

[1. 简单行、列索引](#1. 简单行、列索引)

[2. 列表索引](#2. 列表索引)

[3. 范围索引](#3. 范围索引)

[4. 布尔索引](#4. 布尔索引)

[5. 多维索引](#5. 多维索引)


个人主页:Icomi

在深度学习蓬勃发展的当下,PyTorch 是不可或缺的工具。它作为强大的深度学习框架,为构建和训练神经网络提供了高效且灵活的平台。神经网络作为人工智能的核心技术,能够处理复杂的数据模式。通过 PyTorch,我们可以轻松搭建各类神经网络模型,实现从基础到高级的人工智能应用。接下来,就让我们一同走进 PyTorch 的世界,探索神经网络与人工智能的奥秘。本系列为PyTorch入门文章,若各位大佬想持续跟进,欢迎与我交流互关。

大家好,我是一颗米,前面咱们已经深入学习了张量的拼接操作,这让我们在搭建神经网络时能够更灵活地组合数据,就像掌握了一套精妙的 "组合拳",让模型构建更加得心应手。然而,在我们使用张量的过程中,还有一项关键技能,如同在茂密的数据丛林中开辟道路的利器,那就是张量的花式索引操作。

当我们在操作张量时,就好比在管理一个庞大而复杂的仓库,里面存放着各种各样的数据。经常会遇到这样的情况,我们需要从这个 "数据仓库" 里精准地获取特定的数据,或者对某些数据进行修改。这时候,张量的花式索引操作就派上用场了。

想象一下,你置身于一个堆满了各种规格货物(数据)的巨大仓库,**普通的索引方式就像只能按顺序依次寻找货物,效率低下。而花式索引则像是给了你一把神奇的钥匙,**能够让你直接定位到那些分散在不同位置的特定货物,快速地获取或者修改它们。

掌握张量的花式索引操作,是我们必须具备的一项能力。它就像一把万能钥匙,能让我们在处理张量数据时更加高效、精准。无论是构建复杂的神经网络模型,还是对模型输出的数据进行后处理,花式索引都能帮助我们更灵活地操作张量。

接下来,我们就一起深入探索张量的花式索引操作,看看这把神奇的 "钥匙" 究竟有哪些神奇的用法,如何帮助我们在数据的海洋中畅行无阻。

1. 简单行、列索引

准备数据

python 复制代码
import torch

data = torch.randint(0, 10, [4, 5])
print(data)
print('-' * 50)

结果

python 复制代码
tensor([[0, 7, 6, 5, 9],
        [6, 8, 3, 1, 0],
        [6, 3, 8, 7, 3],
        [4, 9, 5, 3, 1]])
--------------------------------------------------
python 复制代码
# 1. 简单行、列索引
def test01():

    print(data[0])
    print(data[:, 0])
    print('-' * 50)

if __name__ == '__main__':
    test01()

程序输出结果

python 复制代码
tensor([0, 7, 6, 5, 9])
tensor([0, 6, 6, 4])
--------------------------------------------------

2. 列表索引

python 复制代码
# 2. 列表索引
def test02():

    # 返回 (0, 1)、(1, 2) 两个位置的元素
    print(data[[0, 1], [1, 2]])
    print('-' * 50)

    # 返回 0、1 行的 1、2 列共4个元素
    print(data[[[0], [1]], [1, 2]])
if __name__ == '__main__':
    test02()

输出结果

python 复制代码
tensor([7, 3])
--------------------------------------------------
tensor([[7, 6],
        [8, 3]])

3. 范围索引

python 复制代码
# 3. 范围索引
def test03():
    # 前3行的前2列数据
    print(data[:3, :2])
    # 第2行到最后的前2列数据
    print(data[2:, :2])
if __name__ == '__main__':
    test03()

结果

python 复制代码
tensor([[0, 7],
        [6, 8],
        [6, 3]])
tensor([[6, 3],
        [4, 9]])

4. 布尔索引

python 复制代码
# 布尔索引
def test():

    # 第三列大于5的行数据
    print(data[data[:, 2] > 5])
    # 第二行大于5的列数据
    print(data[:, data[1] > 5])
if __name__ == '__main__':
    test04()

输出结果

python 复制代码
tensor([[0, 7, 6, 5, 9],
        [6, 3, 8, 7, 3]])
tensor([[0, 7],
        [6, 8],
        [6, 3],
        [4, 9]])

5. 多维索引

python 复制代码
# 多维索引
def test05():

    data = torch.randint(0, 10, [3, 4, 5])
    print(data)
    print('-' * 50)

    print(data[0, :, :])
    print(data[:, 0, :])
    print(data[:, :, 0])


if __name__ == '__main__':
    test05()

输出结果

python 复制代码
tensor([[[2, 4, 1, 2, 3],
         [5, 5, 1, 5, 0],
         [1, 4, 5, 3, 8],
         [7, 1, 1, 9, 9]],

        [[9, 7, 5, 3, 1],
         [8, 8, 6, 0, 1],
         [6, 9, 0, 2, 1],
         [9, 7, 0, 4, 0]],

        [[0, 7, 3, 5, 6],
         [2, 4, 6, 4, 3],
         [2, 0, 3, 7, 9],
         [9, 6, 4, 4, 4]]])
--------------------------------------------------
tensor([[2, 4, 1, 2, 3],
        [5, 5, 1, 5, 0],
        [1, 4, 5, 3, 8],
        [7, 1, 1, 9, 9]])
tensor([[2, 4, 1, 2, 3],
        [9, 7, 5, 3, 1],
        [0, 7, 3, 5, 6]])
tensor([[2, 5, 1, 7],
        [9, 8, 6, 9],
        [0, 2, 2, 9]])
相关推荐
GIS数据转换器5 分钟前
当三维地理信息遇上气象预警:电网安全如何实现“先知先觉”?
人工智能·科技·安全·gis·智慧城市·交互
网易易盾5 分钟前
AIGC时代的内容安全:AI检测技术如何应对新型风险挑战?
人工智能·安全·aigc
工头阿乐9 分钟前
PyTorch中的nn.Embedding应用详解
人工智能·pytorch·embedding
alpszero12 分钟前
YOLO11解决方案之物体模糊探索
人工智能·python·opencv·计算机视觉·yolo11
Alessio Micheli15 分钟前
基于几何布朗运动的股价预测模型构建与分析
线性代数·机器学习·概率论
vlln19 分钟前
适应性神经树:当深度学习遇上决策树的“生长法则”
人工智能·深度学习·算法·决策树·机器学习
伊织code26 分钟前
PyTorch API 6 - 编译、fft、fx、函数转换、调试、符号追踪
pytorch·python·ai·api·-·6
奋斗者1号27 分钟前
机器学习之决策树与决策森林:机器学习中的强大工具
人工智能·决策树·机器学习
struggle202529 分钟前
continue通过我们的开源 IDE 扩展和模型、规则、提示、文档和其他构建块中心,创建、共享和使用自定义 AI 代码助手
javascript·ide·python·typescript·开源
多巴胺与内啡肽.37 分钟前
OpenCV进阶操作:风格迁移以及DNN模块解析
人工智能·opencv·dnn