理解pytorch系列:整型索引是怎么实现的

整型索引的匹配规则

在PyTorch中使用整型索引时,需要遵循一些基本规则来确定如何从原始张量中选择数据。整型索引可以是Python中的列表或者数组、NumPy数组,或者是PyTorch的LongTensor。整型索引允许在任何维度上进行复杂的数据选取操作,例如选择特定的行、列或者任意的元素。

以下是PyTorch中整型索引的匹配规则:

  1. 单维度索引:如果你对一个维度使用整型索引(比如通过传递一个整数列表),你将根据列表中的每个整数值得到被索引维度上对应的切片。索引列表中的每个整数指定要选择的数据在该维度上的位置。

    示例:

    python 复制代码
    import torch
    
    x = torch.arange(12).view(3, 4)
    index = torch.tensor([0, 2])
    selected = x[index]  # 选取第一行和第三行
  2. 跨维度的整型索引:如果你使用多个整型索引列表分别对应于多个维度,你将在这些维度上得到一个表格化的选取。每组索引列表定义在对应维度上的位置,交叉点上的元素被选中。

    示例:

    python 复制代码
    rows = torch.tensor([0, 2])
    columns = torch.tensor([1, 3])
    selected = x[rows, columns]  # (0,1) 和 (2,3) 位置上的元素被选中
  3. 张量索引:你也可以使用一个整型张量作为索引。如果索引张量是一个压平(flatten)的一维张量,则选择的是按照索引张量指示的线性索引的元素。如果索引张量是多维的,则返回的张量形状将匹配索引张量的形状。

    示例:

    python 复制代码
    row_indices = torch.tensor([0, 1, 2])
    col_indices = torch.tensor([1, 1, 1])
    selected = x[row_indices, col_indices]  # 选择三个元素,它们处于不同的行但相同列
  4. 广播规则:整型索引同样也受到广播规则的影响。这意味着如果你在不同的维度上使用了不同长度的索引列表,PyTorch会尝试将它们广播到一个共同的形状,然后执行索引操作。

在使用整型索引时,返回的张量总是一个复制,而不是原始数据的视图。这意味着对返回的张量所做的修改不会影响原始张量。在执行整型索引操作时,维度的顺序是非常重要的,因为它们决定了哪些数据将会被选择。

上述内容是PyTorch整型索引的一些基础规则和用例,当然,PyTorch提供的索引能力还包括更高级和复杂的用法,如使用掩码张量或组合不同类型的索引。

整型索引的底层逻辑

PyTorch中的整型索引(也称为高级索引或花式索引)允许使用整数数组来选择数据。整型索引可以在多个维度上非连续地选择数据,并且索引数组不需要与被索引数组的形状相匹配。

对于整型索引的实现,当你提供整数数组或整数张量给PyTorch张量时,底层实现会在C++层处理索引操作。以下是大致的实现步骤:

  1. 分析索引指令:PyTorch检测你提供的索引,并将其识别为整型索引操作。

  2. 内存分配:基于索引操作,PyTorch会分配一个新的内存空间来存储索引后得到的张量。

  3. 数据拷贝:PyTorch会遍历索引张量中的每个元素,并且在原始张量中查找对应位置的元素,然后将这些找到的元素复制到步骤2中分配的内存空间中。

  4. 返回新张量:将拷贝填充后的内存空间包装成一个新的PyTorch张量对象,然后返回该张量。

整型索引操作是一个相对"昂贵"的操作,因为它通常涉及数据的复制而不是简单的视图或内存共享。这意呀着整型索引得到的结果通常都是一个新的张量,与原始张量不共享数据。

以下是在Python中使用PyTorch进行整型索引的简单示例:

python 复制代码
import torch

# 创建一个2维张量
data = torch.tensor([[1, 2], [3, 4], [5, 6]])

# 使用整型索引
indices = torch.tensor([0, 2])
selected_data = data[indices]

print(selected_data)  # 输出张量:[[1, 2], [5, 6]]

在这个示例中,indices张量包含了想要选择的行的索引。在内部,selected_data的创建涉及到遍历indices数组,并且从data张量中抽取出相应的行,然后将这些行组合成一个新的张量。

实际上的C++实现比这个过程要复杂,因为PyTorch的内核会优化这些操作,可能会并行化遍历索引以加快速度。此外,它还需要处理各种边缘情况,并确保在并发环境下的安全性和效率。如果你想了解底层的实现细节,可以查看PyTorch的开源代码,特别是在其GitHub仓库中与张量索引相关的部分。

相关推荐
大耳朵爱学习1 小时前
掌握Transformer之注意力为什么有效
人工智能·深度学习·自然语言处理·大模型·llm·transformer·大语言模型
qq_15321452641 小时前
【2023工业异常检测文献】SimpleNet
图像处理·人工智能·深度学习·神经网络·机器学习·计算机视觉·视觉检测
真正的能量来自内心1 小时前
如何删除EXCELL文件中的空行?
经验分享
B站计算机毕业设计超人4 小时前
计算机毕业设计Python+Flask微博情感分析 微博舆情预测 微博爬虫 微博大数据 舆情分析系统 大数据毕业设计 NLP文本分类 机器学习 深度学习 AI
爬虫·python·深度学习·算法·机器学习·自然语言处理·数据可视化
羊小猪~~4 小时前
深度学习基础案例5--VGG16人脸识别(体验学习的痛苦与乐趣)
人工智能·python·深度学习·学习·算法·机器学习·cnn
重生之我在20年代敲代码5 小时前
strncpy函数的使用和模拟实现
c语言·开发语言·c++·经验分享·笔记
开MINI的工科男6 小时前
深蓝学院-- 量产自动驾驶中的规划控制算法 小鹏
人工智能·机器学习·自动驾驶
AI大模型知识分享7 小时前
Prompt最佳实践|如何用参考文本让ChatGPT答案更精准?
人工智能·深度学习·机器学习·chatgpt·prompt·gpt-3
小言从不摸鱼9 小时前
【AI大模型】ChatGPT模型原理介绍(下)
人工智能·python·深度学习·机器学习·自然语言处理·chatgpt
铁匠匠匠9 小时前
从零开始学数据结构系列之第六章《排序简介》
c语言·数据结构·经验分享·笔记·学习·开源·课程设计