pytorch索引操作函数介绍

PyTorch 提供了一系列强大的索引操作函数,用于高效地操作张量数据。以下是常用的 PyTorch 索引操作函数及其用途:


1. 基础索引操作

这些操作类似于 NumPy 的基本索引方式。

1.1 方括号索引 ([])

  • 支持 切片布尔索引高级索引

  • 示例:

    x = torch.tensor([[10, 20, 30], [40, 50, 60]])
    print(x[0, 1]) # 取第 0 行第 1 列,输出 20
    print(x[:, 1]) # 取所有行的第 1 列,输出 tensor([20, 50])
    print(x[1]) # 取第 1 行,输出 tensor([40, 50, 60])

2. 索引操作函数

2.1 torch.index_select

按指定维度的索引提取元素。

  • 用法

    torch.index_select(input, dim, index, *, out=None)

  • 参数

    • input:输入张量。
    • dim:指定的维度。
    • index:索引张量,必须是一维张量。
  • 示例

    x = torch.tensor([[10, 20, 30], [40, 50, 60]])
    index = torch.tensor([0, 2])
    result = torch.index_select(x, dim=1, index=index)
    print(result) # 输出 tensor([[10, 30], [40, 60]])

2.2 torch.gather

根据索引张量,从输入张量中收集数据。

  • 用法

    torch.gather(input, dim, index, *, sparse_grad=False, out=None)

  • 参数

    • input:输入张量。
    • dim:指定的维度。
    • index:索引张量,必须与输入张量形状兼容。
  • 示例

    x = torch.tensor([[10, 20, 30], [40, 50, 60]])
    index = torch.tensor([[0, 2], [1, 0]])
    result = torch.gather(x, dim=1, index=index)
    print(result) # 输出 tensor([[10, 30], [50, 40]])

2.3 torch.scatter

将源张量的值按索引写入目标张量。

  • 用法

    torch.scatter(input, dim, index, src, *, reduce=None)

示例

复制代码
x = torch.zeros(2, 3)
index = torch.tensor([[0, 1, 2], [0, 1, 2]])
src = torch.tensor([[10., 20., 30.], [40., 50., 60.]])
result = torch.scatter(x, dim=1, index=index, src=src)
print(result)  # 输出 tensor([[10., 20., 30.], [40., 50., 60.]])

2.4 torch.take

按照扁平化索引从张量中提取元素。

  • 用法

    torch.take(input, index)

示例

复制代码
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
index = torch.tensor([0, 2, 5])
result = torch.take(x, index)
print(result)  # 输出 tensor([10, 30, 60])

torch.gather和 torch.scatter区别

特性 torch.gather torch.scatter
作用 从指定位置提取数据 向指定位置写入数据
索引意义 定义要从 input 中提取数据的位置 定义要向 input 中写入数据的位置
输出形状 index 形状相同 input 形状相同
操作方向 input 提取值 src 的值写入到 input
使用场景 数据提取(如选择性采样) 数据写入(如更新某些特定位置的值)

2.5 torch.masked_select

根据布尔掩码从张量中提取元素。

  • 用法

    torch.masked_select(input, mask, *, out=None)

示例

复制代码
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
mask = x > 30
result = torch.masked_select(x, mask)
print(result)  # 输出 tensor([40, 50, 60])

2.6 torch.nonzero

返回所有非零元素的索引。

  • 用法

    torch.nonzero(input, *, as_tuple=False)

示例

复制代码
x = torch.tensor([[0, 1, 0], [2, 0, 3]])
result = torch.nonzero(x)
print(result)  # 输出 tensor([[0, 1], [1, 0], [1, 2]])

2.7 torch.where

根据条件选择值。

  • 用法

    torch.where(condition, x, y)

示例

复制代码
x = torch.tensor([1, 2, 3])
y = torch.tensor([10, 20, 30])
condition = x > 1
result = torch.where(condition, x, y)
print(result)  # 输出 tensor([10,  2,  3])

2.8 torch.advanced_indexing

  • PyTorch 支持 布尔张量索引整形张量索引

  • 示例:

    x = torch.tensor([[10, 20, 30], [40, 50, 60]])
    index = torch.tensor([[0, 1], [1, 0]])
    result = x[index]
    print(result)

输出:

复制代码
tensor([[[10, 20, 30],
         [40, 50, 60]],

        [[40, 50, 60],
         [10, 20, 30]]])

注:第0维度操作

3. 总结

函数 功能
torch.index_select 按指定维度和索引提取元素
torch.gather 根据索引张量从输入中收集元素
torch.scatter 按索引将源张量的值写入目标张量
torch.take 按扁平化索引从张量中提取元素
torch.masked_select 根据布尔掩码从张量中提取元素
torch.nonzero 找出非零元素的索引
torch.where 根据条件选择元素

这些索引操作函数可以帮助简化张量操作,提高代码效率。

相关推荐
胖祥12 小时前
onnx之优化器
人工智能·深度学习
AI服务老曹12 小时前
源码交付与低代码重构:企业级 AI 视频管理平台的二次开发实战
人工智能·低代码·重构
L-影12 小时前
下篇:一棵树能长成多少种样子?——AI中决策树的类型与作用,以及它凭什么活了六十年还没过气
人工智能·算法·决策树·ai
jovi_AI电报12 小时前
你还把 ChatGPT 当白月光,别人已经让它出来上班了
人工智能
Z.风止12 小时前
Large Model-learning(2)
开发语言·笔记·python·leetcode
蓝天守卫者联盟112 小时前
玩具喷涂废气治理厂家:行业现状、技术路径与选型指南
大数据·运维·人工智能·python
m0_7381207212 小时前
我的创作纪念日0328
java·网络·windows·python·web安全·php
智慧化智能化数字化方案12 小时前
架构进阶——解读企业数字化转型L1-L5数据架构设计方法论及案例【附全文阅读】
人工智能·企业数字化转型·l1-l5数据架构设计方法论
无代码专家12 小时前
通过轻流 AI OA 系统实现行政成本优化——生产管理落地方案
运维·人工智能·云计算
red1giant_star12 小时前
浅析文件类漏洞原理与分类——含payload合集与检测与防护思路
python·安全