目录

pytorch索引操作函数介绍

PyTorch 提供了一系列强大的索引操作函数,用于高效地操作张量数据。以下是常用的 PyTorch 索引操作函数及其用途:


1. 基础索引操作

这些操作类似于 NumPy 的基本索引方式。

1.1 方括号索引 ([])

  • 支持 切片布尔索引高级索引

  • 示例:

    x = torch.tensor([[10, 20, 30], [40, 50, 60]])
    print(x[0, 1]) # 取第 0 行第 1 列,输出 20
    print(x[:, 1]) # 取所有行的第 1 列,输出 tensor([20, 50])
    print(x[1]) # 取第 1 行,输出 tensor([40, 50, 60])

2. 索引操作函数

2.1 torch.index_select

按指定维度的索引提取元素。

  • 用法

    torch.index_select(input, dim, index, *, out=None)

  • 参数

    • input:输入张量。
    • dim:指定的维度。
    • index:索引张量,必须是一维张量。
  • 示例

    x = torch.tensor([[10, 20, 30], [40, 50, 60]])
    index = torch.tensor([0, 2])
    result = torch.index_select(x, dim=1, index=index)
    print(result) # 输出 tensor([[10, 30], [40, 60]])

2.2 torch.gather

根据索引张量,从输入张量中收集数据。

  • 用法

    torch.gather(input, dim, index, *, sparse_grad=False, out=None)

  • 参数

    • input:输入张量。
    • dim:指定的维度。
    • index:索引张量,必须与输入张量形状兼容。
  • 示例

    x = torch.tensor([[10, 20, 30], [40, 50, 60]])
    index = torch.tensor([[0, 2], [1, 0]])
    result = torch.gather(x, dim=1, index=index)
    print(result) # 输出 tensor([[10, 30], [50, 40]])

2.3 torch.scatter

将源张量的值按索引写入目标张量。

  • 用法

    torch.scatter(input, dim, index, src, *, reduce=None)

示例

复制代码
x = torch.zeros(2, 3)
index = torch.tensor([[0, 1, 2], [0, 1, 2]])
src = torch.tensor([[10., 20., 30.], [40., 50., 60.]])
result = torch.scatter(x, dim=1, index=index, src=src)
print(result)  # 输出 tensor([[10., 20., 30.], [40., 50., 60.]])

2.4 torch.take

按照扁平化索引从张量中提取元素。

  • 用法

    torch.take(input, index)

示例

复制代码
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
index = torch.tensor([0, 2, 5])
result = torch.take(x, index)
print(result)  # 输出 tensor([10, 30, 60])

torch.gather和 torch.scatter区别

特性 torch.gather torch.scatter
作用 从指定位置提取数据 向指定位置写入数据
索引意义 定义要从 input 中提取数据的位置 定义要向 input 中写入数据的位置
输出形状 index 形状相同 input 形状相同
操作方向 input 提取值 src 的值写入到 input
使用场景 数据提取(如选择性采样) 数据写入(如更新某些特定位置的值)

2.5 torch.masked_select

根据布尔掩码从张量中提取元素。

  • 用法

    torch.masked_select(input, mask, *, out=None)

示例

复制代码
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
mask = x > 30
result = torch.masked_select(x, mask)
print(result)  # 输出 tensor([40, 50, 60])

2.6 torch.nonzero

返回所有非零元素的索引。

  • 用法

    torch.nonzero(input, *, as_tuple=False)

示例

复制代码
x = torch.tensor([[0, 1, 0], [2, 0, 3]])
result = torch.nonzero(x)
print(result)  # 输出 tensor([[0, 1], [1, 0], [1, 2]])

2.7 torch.where

根据条件选择值。

  • 用法

    torch.where(condition, x, y)

示例

复制代码
x = torch.tensor([1, 2, 3])
y = torch.tensor([10, 20, 30])
condition = x > 1
result = torch.where(condition, x, y)
print(result)  # 输出 tensor([10,  2,  3])

2.8 torch.advanced_indexing

  • PyTorch 支持 布尔张量索引整形张量索引

  • 示例:

    x = torch.tensor([[10, 20, 30], [40, 50, 60]])
    index = torch.tensor([[0, 1], [1, 0]])
    result = x[index]
    print(result)

输出:

复制代码
tensor([[[10, 20, 30],
         [40, 50, 60]],

        [[40, 50, 60],
         [10, 20, 30]]])

注:第0维度操作

3. 总结

函数 功能
torch.index_select 按指定维度和索引提取元素
torch.gather 根据索引张量从输入中收集元素
torch.scatter 按索引将源张量的值写入目标张量
torch.take 按扁平化索引从张量中提取元素
torch.masked_select 根据布尔掩码从张量中提取元素
torch.nonzero 找出非零元素的索引
torch.where 根据条件选择元素

这些索引操作函数可以帮助简化张量操作,提高代码效率。

本文是转载文章,点击查看原文
如有侵权,请联系 xyy@jishuzhan.net 删除
相关推荐
蹦蹦跳跳真可爱58936 分钟前
Python----计算机视觉处理(Opencv:二值化,阈值法,反阈值法,截断阈值法,OTSU阈值法)
人工智能·python·opencv·计算机视觉
愚戏师2 小时前
Python :数据模型
开发语言·python
袁袁袁袁满2 小时前
Blackbox.Ai体验:AI编程插件如何提升开发效率
人工智能·ai编程·ai插件·chatgpt-4o·deepseek-r1满血版·免费大模型·gemini pro
摸鱼仙人~3 小时前
预训练微调类型分类
人工智能·自然语言处理·分类
申耀的科技观察3 小时前
【观察】拓展大模型应用交付领域“新赛道”,亚信科技为高质量发展“加速度”...
大数据·人工智能·科技
哈喽小疯车3 小时前
Python 实现大文件的高并发下载
python
lboyj4 小时前
新能源汽车电控系统的大尺寸PCB需求:猎板PCB的技术突围
大数据·网络·人工智能
HABuo4 小时前
【YOLOv8】YOLOv8改进系列(5)----替换主干网络之EfficientFormerV2
人工智能·深度学习·yolo·目标检测·计算机视觉
訾博ZiBo5 小时前
AI日报 - 2025年3月16日
人工智能
(initial)5 小时前
大型语言模型与强化学习的融合:迈向通用人工智能的新范式——基于基础复现的实验平台构建
人工智能·强化学习