目录
[1. 简单行、列索引](#1. 简单行、列索引)
[2. 列表索引](#2. 列表索引)
[3. 范围索引](#3. 范围索引)
[4. 布尔索引](#4. 布尔索引)
[5. 多维索引](#5. 多维索引)
个人主页:Icomi
在深度学习蓬勃发展的当下,PyTorch 是不可或缺的工具。它作为强大的深度学习框架,为构建和训练神经网络提供了高效且灵活的平台。神经网络作为人工智能的核心技术,能够处理复杂的数据模式。通过 PyTorch,我们可以轻松搭建各类神经网络模型,实现从基础到高级的人工智能应用。接下来,就让我们一同走进 PyTorch 的世界,探索神经网络与人工智能的奥秘。本系列为PyTorch入门文章,若各位大佬想持续跟进,欢迎与我交流互关。
大家好,我是一颗米,前面咱们已经深入学习了张量的拼接操作,这让我们在搭建神经网络时能够更灵活地组合数据,就像掌握了一套精妙的 "组合拳",让模型构建更加得心应手。然而,在我们使用张量的过程中,还有一项关键技能,如同在茂密的数据丛林中开辟道路的利器,那就是张量的花式索引操作。
当我们在操作张量时,就好比在管理一个庞大而复杂的仓库,里面存放着各种各样的数据。经常会遇到这样的情况,我们需要从这个 "数据仓库" 里精准地获取特定的数据,或者对某些数据进行修改。这时候,张量的花式索引操作就派上用场了。
想象一下,你置身于一个堆满了各种规格货物(数据)的巨大仓库,**普通的索引方式就像只能按顺序依次寻找货物,效率低下。而花式索引则像是给了你一把神奇的钥匙,**能够让你直接定位到那些分散在不同位置的特定货物,快速地获取或者修改它们。
掌握张量的花式索引操作,是我们必须具备的一项能力。它就像一把万能钥匙,能让我们在处理张量数据时更加高效、精准。无论是构建复杂的神经网络模型,还是对模型输出的数据进行后处理,花式索引都能帮助我们更灵活地操作张量。
接下来,我们就一起深入探索张量的花式索引操作,看看这把神奇的 "钥匙" 究竟有哪些神奇的用法,如何帮助我们在数据的海洋中畅行无阻。
1. 简单行、列索引
准备数据
python
import torch
data = torch.randint(0, 10, [4, 5])
print(data)
print('-' * 50)
结果
python
tensor([[0, 7, 6, 5, 9],
[6, 8, 3, 1, 0],
[6, 3, 8, 7, 3],
[4, 9, 5, 3, 1]])
--------------------------------------------------
python
# 1. 简单行、列索引
def test01():
print(data[0])
print(data[:, 0])
print('-' * 50)
if __name__ == '__main__':
test01()
程序输出结果
python
tensor([0, 7, 6, 5, 9])
tensor([0, 6, 6, 4])
--------------------------------------------------
2. 列表索引
python
# 2. 列表索引
def test02():
# 返回 (0, 1)、(1, 2) 两个位置的元素
print(data[[0, 1], [1, 2]])
print('-' * 50)
# 返回 0、1 行的 1、2 列共4个元素
print(data[[[0], [1]], [1, 2]])
if __name__ == '__main__':
test02()
输出结果
python
tensor([7, 3])
--------------------------------------------------
tensor([[7, 6],
[8, 3]])
3. 范围索引
python
# 3. 范围索引
def test03():
# 前3行的前2列数据
print(data[:3, :2])
# 第2行到最后的前2列数据
print(data[2:, :2])
if __name__ == '__main__':
test03()
结果
python
tensor([[0, 7],
[6, 8],
[6, 3]])
tensor([[6, 3],
[4, 9]])
4. 布尔索引
python
# 布尔索引
def test():
# 第三列大于5的行数据
print(data[data[:, 2] > 5])
# 第二行大于5的列数据
print(data[:, data[1] > 5])
if __name__ == '__main__':
test04()
输出结果
python
tensor([[0, 7, 6, 5, 9],
[6, 3, 8, 7, 3]])
tensor([[0, 7],
[6, 8],
[6, 3],
[4, 9]])
5. 多维索引
python
# 多维索引
def test05():
data = torch.randint(0, 10, [3, 4, 5])
print(data)
print('-' * 50)
print(data[0, :, :])
print(data[:, 0, :])
print(data[:, :, 0])
if __name__ == '__main__':
test05()
输出结果
python
tensor([[[2, 4, 1, 2, 3],
[5, 5, 1, 5, 0],
[1, 4, 5, 3, 8],
[7, 1, 1, 9, 9]],
[[9, 7, 5, 3, 1],
[8, 8, 6, 0, 1],
[6, 9, 0, 2, 1],
[9, 7, 0, 4, 0]],
[[0, 7, 3, 5, 6],
[2, 4, 6, 4, 3],
[2, 0, 3, 7, 9],
[9, 6, 4, 4, 4]]])
--------------------------------------------------
tensor([[2, 4, 1, 2, 3],
[5, 5, 1, 5, 0],
[1, 4, 5, 3, 8],
[7, 1, 1, 9, 9]])
tensor([[2, 4, 1, 2, 3],
[9, 7, 5, 3, 1],
[0, 7, 3, 5, 6]])
tensor([[2, 5, 1, 7],
[9, 8, 6, 9],
[0, 2, 2, 9]])