pytorch索引操作函数介绍

PyTorch 提供了一系列强大的索引操作函数,用于高效地操作张量数据。以下是常用的 PyTorch 索引操作函数及其用途:


1. 基础索引操作

这些操作类似于 NumPy 的基本索引方式。

1.1 方括号索引 ([])

  • 支持 切片布尔索引高级索引

  • 示例:

    x = torch.tensor([[10, 20, 30], [40, 50, 60]])
    print(x[0, 1]) # 取第 0 行第 1 列,输出 20
    print(x[:, 1]) # 取所有行的第 1 列,输出 tensor([20, 50])
    print(x[1]) # 取第 1 行,输出 tensor([40, 50, 60])

2. 索引操作函数

2.1 torch.index_select

按指定维度的索引提取元素。

  • 用法

    torch.index_select(input, dim, index, *, out=None)

  • 参数

    • input:输入张量。
    • dim:指定的维度。
    • index:索引张量,必须是一维张量。
  • 示例

    x = torch.tensor([[10, 20, 30], [40, 50, 60]])
    index = torch.tensor([0, 2])
    result = torch.index_select(x, dim=1, index=index)
    print(result) # 输出 tensor([[10, 30], [40, 60]])

2.2 torch.gather

根据索引张量,从输入张量中收集数据。

  • 用法

    torch.gather(input, dim, index, *, sparse_grad=False, out=None)

  • 参数

    • input:输入张量。
    • dim:指定的维度。
    • index:索引张量,必须与输入张量形状兼容。
  • 示例

    x = torch.tensor([[10, 20, 30], [40, 50, 60]])
    index = torch.tensor([[0, 2], [1, 0]])
    result = torch.gather(x, dim=1, index=index)
    print(result) # 输出 tensor([[10, 30], [50, 40]])

2.3 torch.scatter

将源张量的值按索引写入目标张量。

  • 用法

    torch.scatter(input, dim, index, src, *, reduce=None)

示例

复制代码
x = torch.zeros(2, 3)
index = torch.tensor([[0, 1, 2], [0, 1, 2]])
src = torch.tensor([[10., 20., 30.], [40., 50., 60.]])
result = torch.scatter(x, dim=1, index=index, src=src)
print(result)  # 输出 tensor([[10., 20., 30.], [40., 50., 60.]])

2.4 torch.take

按照扁平化索引从张量中提取元素。

  • 用法

    torch.take(input, index)

示例

复制代码
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
index = torch.tensor([0, 2, 5])
result = torch.take(x, index)
print(result)  # 输出 tensor([10, 30, 60])

torch.gather和 torch.scatter区别

特性 torch.gather torch.scatter
作用 从指定位置提取数据 向指定位置写入数据
索引意义 定义要从 input 中提取数据的位置 定义要向 input 中写入数据的位置
输出形状 index 形状相同 input 形状相同
操作方向 input 提取值 src 的值写入到 input
使用场景 数据提取(如选择性采样) 数据写入(如更新某些特定位置的值)

2.5 torch.masked_select

根据布尔掩码从张量中提取元素。

  • 用法

    torch.masked_select(input, mask, *, out=None)

示例

复制代码
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
mask = x > 30
result = torch.masked_select(x, mask)
print(result)  # 输出 tensor([40, 50, 60])

2.6 torch.nonzero

返回所有非零元素的索引。

  • 用法

    torch.nonzero(input, *, as_tuple=False)

示例

复制代码
x = torch.tensor([[0, 1, 0], [2, 0, 3]])
result = torch.nonzero(x)
print(result)  # 输出 tensor([[0, 1], [1, 0], [1, 2]])

2.7 torch.where

根据条件选择值。

  • 用法

    torch.where(condition, x, y)

示例

复制代码
x = torch.tensor([1, 2, 3])
y = torch.tensor([10, 20, 30])
condition = x > 1
result = torch.where(condition, x, y)
print(result)  # 输出 tensor([10,  2,  3])

2.8 torch.advanced_indexing

  • PyTorch 支持 布尔张量索引整形张量索引

  • 示例:

    x = torch.tensor([[10, 20, 30], [40, 50, 60]])
    index = torch.tensor([[0, 1], [1, 0]])
    result = x[index]
    print(result)

输出:

复制代码
tensor([[[10, 20, 30],
         [40, 50, 60]],

        [[40, 50, 60],
         [10, 20, 30]]])

注:第0维度操作

3. 总结

函数 功能
torch.index_select 按指定维度和索引提取元素
torch.gather 根据索引张量从输入中收集元素
torch.scatter 按索引将源张量的值写入目标张量
torch.take 按扁平化索引从张量中提取元素
torch.masked_select 根据布尔掩码从张量中提取元素
torch.nonzero 找出非零元素的索引
torch.where 根据条件选择元素

这些索引操作函数可以帮助简化张量操作,提高代码效率。

相关推荐
zzywxc78713 分钟前
AI 正在深度重构软件开发的底层逻辑和全生命周期,从技术演进、流程重构和未来趋势三个维度进行系统性分析
java·大数据·开发语言·人工智能·spring
超龄超能程序猿16 分钟前
(1)机器学习小白入门 YOLOv:从概念到实践
人工智能·机器学习
大熊背25 分钟前
图像处理专业书籍以及网络资源总结
人工智能·算法·microsoft
3gying30 分钟前
chromedriver
python
江理不变情32 分钟前
图像质量对比感悟
c++·人工智能
DES 仿真实践家1 小时前
【Day 11-N22】Python类(3)——Python的继承性、多继承、方法重写
开发语言·笔记·python
张较瘦_3 小时前
[论文阅读] 人工智能 + 软件工程 | 需求获取访谈中LLM生成跟进问题研究:来龙去脉与创新突破
论文阅读·人工智能
一 铭3 小时前
AI领域新趋势:从提示(Prompt)工程到上下文(Context)工程
人工智能·语言模型·大模型·llm·prompt
云泽野6 小时前
【Java|集合类】list遍历的6种方式
java·python·list
麻雀无能为力7 小时前
CAU数据挖掘实验 表分析数据插件
人工智能·数据挖掘·中国农业大学