pytorch索引操作函数介绍

PyTorch 提供了一系列强大的索引操作函数,用于高效地操作张量数据。以下是常用的 PyTorch 索引操作函数及其用途:


1. 基础索引操作

这些操作类似于 NumPy 的基本索引方式。

1.1 方括号索引 ([])

  • 支持 切片布尔索引高级索引

  • 示例:

    x = torch.tensor([[10, 20, 30], [40, 50, 60]])
    print(x[0, 1]) # 取第 0 行第 1 列,输出 20
    print(x[:, 1]) # 取所有行的第 1 列,输出 tensor([20, 50])
    print(x[1]) # 取第 1 行,输出 tensor([40, 50, 60])

2. 索引操作函数

2.1 torch.index_select

按指定维度的索引提取元素。

  • 用法

    torch.index_select(input, dim, index, *, out=None)

  • 参数

    • input:输入张量。
    • dim:指定的维度。
    • index:索引张量,必须是一维张量。
  • 示例

    x = torch.tensor([[10, 20, 30], [40, 50, 60]])
    index = torch.tensor([0, 2])
    result = torch.index_select(x, dim=1, index=index)
    print(result) # 输出 tensor([[10, 30], [40, 60]])

2.2 torch.gather

根据索引张量,从输入张量中收集数据。

  • 用法

    torch.gather(input, dim, index, *, sparse_grad=False, out=None)

  • 参数

    • input:输入张量。
    • dim:指定的维度。
    • index:索引张量,必须与输入张量形状兼容。
  • 示例

    x = torch.tensor([[10, 20, 30], [40, 50, 60]])
    index = torch.tensor([[0, 2], [1, 0]])
    result = torch.gather(x, dim=1, index=index)
    print(result) # 输出 tensor([[10, 30], [50, 40]])

2.3 torch.scatter

将源张量的值按索引写入目标张量。

  • 用法

    torch.scatter(input, dim, index, src, *, reduce=None)

示例

x = torch.zeros(2, 3)
index = torch.tensor([[0, 1, 2], [0, 1, 2]])
src = torch.tensor([[10., 20., 30.], [40., 50., 60.]])
result = torch.scatter(x, dim=1, index=index, src=src)
print(result)  # 输出 tensor([[10., 20., 30.], [40., 50., 60.]])

2.4 torch.take

按照扁平化索引从张量中提取元素。

  • 用法

    torch.take(input, index)

示例

x = torch.tensor([[10, 20, 30], [40, 50, 60]])
index = torch.tensor([0, 2, 5])
result = torch.take(x, index)
print(result)  # 输出 tensor([10, 30, 60])

torch.gather和 torch.scatter区别

特性 torch.gather torch.scatter
作用 从指定位置提取数据 向指定位置写入数据
索引意义 定义要从 input 中提取数据的位置 定义要向 input 中写入数据的位置
输出形状 index 形状相同 input 形状相同
操作方向 input 提取值 src 的值写入到 input
使用场景 数据提取(如选择性采样) 数据写入(如更新某些特定位置的值)

2.5 torch.masked_select

根据布尔掩码从张量中提取元素。

  • 用法

    torch.masked_select(input, mask, *, out=None)

示例

x = torch.tensor([[10, 20, 30], [40, 50, 60]])
mask = x > 30
result = torch.masked_select(x, mask)
print(result)  # 输出 tensor([40, 50, 60])

2.6 torch.nonzero

返回所有非零元素的索引。

  • 用法

    torch.nonzero(input, *, as_tuple=False)

示例

x = torch.tensor([[0, 1, 0], [2, 0, 3]])
result = torch.nonzero(x)
print(result)  # 输出 tensor([[0, 1], [1, 0], [1, 2]])

2.7 torch.where

根据条件选择值。

  • 用法

    torch.where(condition, x, y)

示例

x = torch.tensor([1, 2, 3])
y = torch.tensor([10, 20, 30])
condition = x > 1
result = torch.where(condition, x, y)
print(result)  # 输出 tensor([10,  2,  3])

2.8 torch.advanced_indexing

  • PyTorch 支持 布尔张量索引整形张量索引

  • 示例:

    x = torch.tensor([[10, 20, 30], [40, 50, 60]])
    index = torch.tensor([[0, 1], [1, 0]])
    result = x[index]
    print(result)

输出:

tensor([[[10, 20, 30],
         [40, 50, 60]],

        [[40, 50, 60],
         [10, 20, 30]]])

注:第0维度操作

3. 总结

函数 功能
torch.index_select 按指定维度和索引提取元素
torch.gather 根据索引张量从输入中收集元素
torch.scatter 按索引将源张量的值写入目标张量
torch.take 按扁平化索引从张量中提取元素
torch.masked_select 根据布尔掩码从张量中提取元素
torch.nonzero 找出非零元素的索引
torch.where 根据条件选择元素

这些索引操作函数可以帮助简化张量操作,提高代码效率。

相关推荐
AI小欧同学7 分钟前
【AIGC-ChatGPT进阶提示词指令】AI美食助手的设计与实现:Lisp风格系统提示词分析
人工智能·chatgpt·aigc
Wishell20159 分钟前
小白学Pytorch
pytorch
灵魂画师向阳24 分钟前
【CSDN首发】Stable Diffusion从零到精通学习路线分享
人工智能·学习·计算机视觉·ai作画·stable diffusion·midjourney
Elastic 中国社区官方博客27 分钟前
在不到 5 分钟的时间内将威胁情报 PDF 添加为 AI 助手的自定义知识
大数据·人工智能·安全·elasticsearch·搜索引擎·pdf·全文检索
香菜的开发日记28 分钟前
快速学习 pytest 基础知识
自动化测试·python·pytest
背太阳的牧羊人37 分钟前
grouped.get_group((‘B‘, ‘A‘))选择分组
python·pandas
埃菲尔铁塔_CV算法1 小时前
BOOST 在计算机视觉方面的应用及具体代码分析(二)
c++·人工智能·算法·机器学习·计算机视觉
m0_748233361 小时前
用JAVA实现人工智能:采用框架Spring AI Java
java·人工智能·spring
穆姬姗1 小时前
【Python】论文长截图、页面分割、水印去除、整合PDF
开发语言·python·pdf
刘大猫262 小时前
《docker基础篇:4.Docker镜像》包括是什么、分层的镜像、UnionFS(联合文件系统)、docker镜像的加载原理、为什么docker镜像要采用这种
人工智能·算法·计算机视觉