pytorch索引操作函数介绍

PyTorch 提供了一系列强大的索引操作函数,用于高效地操作张量数据。以下是常用的 PyTorch 索引操作函数及其用途:


1. 基础索引操作

这些操作类似于 NumPy 的基本索引方式。

1.1 方括号索引 ([])

  • 支持 切片布尔索引高级索引

  • 示例:

    x = torch.tensor([[10, 20, 30], [40, 50, 60]])
    print(x[0, 1]) # 取第 0 行第 1 列,输出 20
    print(x[:, 1]) # 取所有行的第 1 列,输出 tensor([20, 50])
    print(x[1]) # 取第 1 行,输出 tensor([40, 50, 60])

2. 索引操作函数

2.1 torch.index_select

按指定维度的索引提取元素。

  • 用法

    torch.index_select(input, dim, index, *, out=None)

  • 参数

    • input:输入张量。
    • dim:指定的维度。
    • index:索引张量,必须是一维张量。
  • 示例

    x = torch.tensor([[10, 20, 30], [40, 50, 60]])
    index = torch.tensor([0, 2])
    result = torch.index_select(x, dim=1, index=index)
    print(result) # 输出 tensor([[10, 30], [40, 60]])

2.2 torch.gather

根据索引张量,从输入张量中收集数据。

  • 用法

    torch.gather(input, dim, index, *, sparse_grad=False, out=None)

  • 参数

    • input:输入张量。
    • dim:指定的维度。
    • index:索引张量,必须与输入张量形状兼容。
  • 示例

    x = torch.tensor([[10, 20, 30], [40, 50, 60]])
    index = torch.tensor([[0, 2], [1, 0]])
    result = torch.gather(x, dim=1, index=index)
    print(result) # 输出 tensor([[10, 30], [50, 40]])

2.3 torch.scatter

将源张量的值按索引写入目标张量。

  • 用法

    torch.scatter(input, dim, index, src, *, reduce=None)

示例

复制代码
x = torch.zeros(2, 3)
index = torch.tensor([[0, 1, 2], [0, 1, 2]])
src = torch.tensor([[10., 20., 30.], [40., 50., 60.]])
result = torch.scatter(x, dim=1, index=index, src=src)
print(result)  # 输出 tensor([[10., 20., 30.], [40., 50., 60.]])

2.4 torch.take

按照扁平化索引从张量中提取元素。

  • 用法

    torch.take(input, index)

示例

复制代码
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
index = torch.tensor([0, 2, 5])
result = torch.take(x, index)
print(result)  # 输出 tensor([10, 30, 60])

torch.gather和 torch.scatter区别

特性 torch.gather torch.scatter
作用 从指定位置提取数据 向指定位置写入数据
索引意义 定义要从 input 中提取数据的位置 定义要向 input 中写入数据的位置
输出形状 index 形状相同 input 形状相同
操作方向 input 提取值 src 的值写入到 input
使用场景 数据提取(如选择性采样) 数据写入(如更新某些特定位置的值)

2.5 torch.masked_select

根据布尔掩码从张量中提取元素。

  • 用法

    torch.masked_select(input, mask, *, out=None)

示例

复制代码
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
mask = x > 30
result = torch.masked_select(x, mask)
print(result)  # 输出 tensor([40, 50, 60])

2.6 torch.nonzero

返回所有非零元素的索引。

  • 用法

    torch.nonzero(input, *, as_tuple=False)

示例

复制代码
x = torch.tensor([[0, 1, 0], [2, 0, 3]])
result = torch.nonzero(x)
print(result)  # 输出 tensor([[0, 1], [1, 0], [1, 2]])

2.7 torch.where

根据条件选择值。

  • 用法

    torch.where(condition, x, y)

示例

复制代码
x = torch.tensor([1, 2, 3])
y = torch.tensor([10, 20, 30])
condition = x > 1
result = torch.where(condition, x, y)
print(result)  # 输出 tensor([10,  2,  3])

2.8 torch.advanced_indexing

  • PyTorch 支持 布尔张量索引整形张量索引

  • 示例:

    x = torch.tensor([[10, 20, 30], [40, 50, 60]])
    index = torch.tensor([[0, 1], [1, 0]])
    result = x[index]
    print(result)

输出:

复制代码
tensor([[[10, 20, 30],
         [40, 50, 60]],

        [[40, 50, 60],
         [10, 20, 30]]])

注:第0维度操作

3. 总结

函数 功能
torch.index_select 按指定维度和索引提取元素
torch.gather 根据索引张量从输入中收集元素
torch.scatter 按索引将源张量的值写入目标张量
torch.take 按扁平化索引从张量中提取元素
torch.masked_select 根据布尔掩码从张量中提取元素
torch.nonzero 找出非零元素的索引
torch.where 根据条件选择元素

这些索引操作函数可以帮助简化张量操作,提高代码效率。

相关推荐
一 铭1 小时前
AI领域新趋势:从提示(Prompt)工程到上下文(Context)工程
人工智能·语言模型·大模型·llm·prompt
云泽野3 小时前
【Java|集合类】list遍历的6种方式
java·python·list
麻雀无能为力4 小时前
CAU数据挖掘实验 表分析数据插件
人工智能·数据挖掘·中国农业大学
时序之心4 小时前
时空数据挖掘五大革新方向详解篇!
人工智能·数据挖掘·论文·时间序列
IMPYLH5 小时前
Python 的内置函数 reversed
笔记·python
.30-06Springfield5 小时前
人工智能概念之七:集成学习思想(Bagging、Boosting、Stacking)
人工智能·算法·机器学习·集成学习
说私域6 小时前
基于开源AI智能名片链动2+1模式S2B2C商城小程序的超级文化符号构建路径研究
人工智能·小程序·开源
永洪科技6 小时前
永洪科技荣获商业智能品牌影响力奖,全力打造”AI+决策”引擎
大数据·人工智能·科技·数据分析·数据可视化·bi
shangyingying_16 小时前
关于小波降噪、小波增强、小波去雾的原理区分
人工智能·深度学习·计算机视觉
小赖同学啊7 小时前
物联网数据安全区块链服务
开发语言·python·区块链