PyTorch中Tensor的存储结构

PyTorch中Tensor的存储结构

Tensor数据的类型

Tensor 中数据主要有下面两种类型:

  • meta data:元数据,也就是描述数据特征的数据,例如 shape、dtype、device、stride等等
  • raw data:数据本身,我们可以通过 tensor.data_ptr() 获取到数据存储的内存位置

参考下面案例

python 复制代码
def tensor_struct():
    #  meta_data / raw_data
    nd_array = np.array([[1, 2, 3], [4, 5, 6]])
    # tensor = torch.tensor(nd_array) # deep copy
    tensor = torch.from_numpy(nd_array)

    # raw data
    print(f"pytorch data: \n{tensor}")
    # print("pytorch raw data: \n", tensor.storage())
    print(f"numpy raw data_ptr: {nd_array.ctypes.data}")
    print(f"pytroch raw data_ptr: {tensor.data_ptr()}")  # raw_data

    print(f"numpy data id: {id(nd_array)}", )
    print(f"pytorch data id: {id(tensor)}")

    tensor2 = tensor.reshape(1, 6)
    # 观察可以看到 tensor 及 tensor2 的 id 是不同的, 但是 data_ptr 却相同
    # tensor2 的 row_data 没有变化, meta_data 发生了变化 -> tensor2 是 tensor 的一个 view
    print(f"tensor id: {id(tensor)}")
    print(f"tensor2 id: ", id(tensor2))
    print(f"tensor pointer addr: {tensor.data_ptr()}")
    print(f"tensor2 pointer addr: {tensor2.data_ptr()}")

视图

首先了解一下 Pytorch 中下面的两个概念:

  • stride() :获取张量(Tensor)的步幅信息。步幅(Stride)描述了张量在内存中相邻元素之间的距离(以元素个数为单位),对于多维张量而言,它是一个表示各维度间跳跃关系的元组
  • data_ptr():获取张量(Tensor)底层数据在内存中的起始地址。这个地址是一个整数值,通常表示为一个C语言指针类型(在Python环境中表现为Python整数)

参考下面案例

python 复制代码
# 理解 tensor 的步长
def stride_demo():
    tensor = torch.randn(2, 3, 5)
    # stride 就是 tensor 中某一个维度上, 相邻元素之间的步长(以元素个数为单位)
    # 对于 shape 为 2,3,5 的 tensor
    # 在第0维上, 两个元素之间的步长为 3*5 = 15
    # 在第1维上, 两个元素之间的步长为 5*1 = 5
    # 在第2维上, 由于是最后一个维度了, 两个相邻元素间步长就是1了
    tensor_stride = tensor.stride()
    print(f"tensor_stride: {tensor_stride}")
    print(f"tensor.stride(0): {tensor.stride(0)}")
    print(f"tensor.stride(1): {tensor.stride(1)}")
    print(f"tensor.stride(2): {tensor.stride(2)}")

实际上PyTorch获取指定索引位置的数据时,本质上是通过data_ptr()的位置获取多维数组的起始点,然后依据 stride() 计算指定维度走一步需要移动的位置,最终计算出当前索引的数据。

对于一个 shape 为 [2, 3, 5]的 tensor,那么它的 stride 应当为:

  • 第0维:stride[0] 应当为后面两维的乘积,也就是 5*3 = 15
  • 第1维:stride[1] 应当为后面一维的维度,也就是 5
  • 第1维:stride[2] 上面每一个数值都是连续的,也就是1

因此,stride也就是 [15, 5, 1]

连续型与破坏连续性

Tensor中的连续性

如果 Tensor 的 stride 满足前面的定义,那么在读取数据时可以认为是连续的,在做类似矩阵乘法时读取数据的效率就会比较高。

但是有一些操作是会破坏这种连续性的

参考下面案例

python 复制代码
def contiguous_demo():
    data0 = torch.randint(0, 10, (2, 5))
    data1 = data0.transpose(1, 0)
    data2 = data0.reshape(5, 2)
    print(f"data0: {data0}")
    # data1 和 data2 的 shape 相同, 但是对应位置上的值是不同的
    # data0: [ [3, 5, 5, 9, 2], [8, 7, 4, 9, 7] ]
    # data1: [ [3, 8], [5, 7], [5, 4], [9, 9], [2, 7] ]
    # data2: [ [3, 5], [5, 9], [2, 8], [7, 4], [9, 7] ]
    print(f"data1: {data1}")
    print(f"data2: {data2}")

    # data0、data1、data2 中 的data_ptr() 都是是相同的,说明 row_data 是没有变化的
    # transpose 以及 reshape 操作虽然数据不同,但转换以后 raw_data 是没有变化的
    print(f"data0 data_ptr: {data0.data_ptr()}")
    print(f"data1 data_ptr: {data1.data_ptr()}")
    print(f"data2 data_ptr: {data2.data_ptr()}")

    # transpose 以及 reshape 的区别在于两个操作以后 tensor 的 stride 发生了变化
    # 根据之前的例子对于一个 (5, 2) 的 tensor, stride 取值应当是 (2, 1)
    # 可以看到, reshape 以后是满足这个性质的
    # ------------------------ transpose 导致的不连续现象 -------------------------
    # tensor 在 transpose 操作之后, 读取数据的方式发生了改变, 不能像之前一样 "挨个" 读取数据
    # 从而发生了数据 "不连续" 的现象 !!!
    # 也就是说 transpose 操作本质上仍然是获取的是一个 view,但是会导致数据的不连续
    # ------------------------ transpose 导致的不连续现象 -------------------------
    print(f"data0 stride: {data0.stride()}")  # (5, 1)
    print(f"data1 stride: {data1.stride()}")  # (1, 5)
    print(f"data2 stride: {data2.stride()}")  # (2, 1)

    print(f"data0 is_contiguous: {data0.is_contiguous()}")  # True
    print(f"data1 is_contiguous: {data1.is_contiguous()}")  # False
    print(f"data2 is_contiguous: {data2.is_contiguous()}")  # True

可以看到 transpose 操作会与原始的 tensor 共享同一份 raw_data,但是会使得原来读取最后一个维度数据时发生不连续的现象,因此使得数据变得 "不连续" 了。

常见的破坏连续性的算子

主要有 transpose、permute、T 等等

参考下面案例

python 复制代码
def discontinuous_operator():
    data0 = torch.randint(0, 10, (2, 3, 4))
    # transpose 指定交换 第0轴 和 第1轴
    data1 = data0.transpose(0, 1)
    # permute 指的是: 原来第0轴 -> 第2轴, 原来第1轴 -> 第0轴, 原来第2轴 -> 第1轴
    data2 = data0.permute(2, 0, 1)
    data3 = data0.T

    print(f"data0.shape: {data0.shape}")  # [2, 3, 4]
    print(f"data1.shape: {data1.shape}")  # [3, 2, 4]
    print(f"data2.shape: {data2.shape}")  # [4, 2, 3]
    print(f"data3.shape: {data3.shape}")  # [4, 3, 2]

    print(f"data0 stride: {data0.stride()}")  # (12, 4, 1)
    print(f"data1 stride: {data1.stride()}")  # (4, 12, 1)
    print(f"data2 stride: {data2.stride()}")  # (1, 12, 4)
    print(f"data3 stride: {data3.stride()}")  # (1, 4, 12)
contiguous() 方法

既然有些算子会破坏Tensor的连续性,那么有没有什么方法可以避免呢?

我们可以使用 Tensor 中提供的 contiguous()方法使得 Tensor 变为连续的,本质上也就是新开辟了一个数据存储空间,然后把原来的数据挪到新空间下。

参考下面案例

python 复制代码
def contiguous_method():
    data0 = torch.randint(0, 10, (2, 5))
    # 这时候 data1 只是 data0 的一个 view
    data1 = data0.transpose(0, 1)
    # 此时创建了一个新的数据空间, data1 已经不是 data0 的一个 view了, 两者的 raw_data 已经不同了
    data1 = data1.contiguous()

    print(f"data1 shape: {data1.shape}")
    print(f"data1 stride: {data1.stride()}")

    # 可以看到此时 data0 与 data1 的 data_ptr 已经不同了
    print(f"data0 data_ptr: {data0.data_ptr()}")
    print(f"data1 data_ptr: {data1.data_ptr()}")

我们可以看到,对于一个不连续的 Tensor 调用 contiguous()方法后,Tensor重新变为连续的了,但是 raw_data 也发生了改变。

reshape vs view

在大部分情况下,reshape 和 view 的作用都是相同的,但是在处理不连续的 Tensor 时,两个算子处理上有所差异:

  • view:直接报错 _view size is not compatible with input tensor's size and stride_
  • reshape:会新开辟一个空间存储,将原有数据copy到新的存储空间当中。

参考下面案例

python 复制代码
def view_discontinuous():
    data0 = torch.randint(0, 10, (2, 5))
    data1 = data0.transpose(0, 1)
    # 直接报错: view size is not compatible with input tensor's size and stride
    data2 = data1.view(2, 5)
    print(f"data2: {data2}")


def reshape_discontinuous():
    data0 = torch.randint(0, 10, (2, 5))
    data1 = data0.transpose(0, 1)
    # 此时程序可以跑通
    data2 = data1.reshape(2, 5)

    print(f"data0: {data0}")
    print(f"data1: {data1}")
    print(f"data2: {data2}")

    # 可以看到 data0 和 data1 共享一份 raw_data, 但是 data2 的 raw_data 发生了改变
    # 也就是说: reshape 一个不连续的 tensor, 会新创建一个空间, 将原来的数据 copy 到新的空间
    print(f"data0 data_ptr: {data0.data_ptr()}")
    print(f"data1 data_ptr: {data1.data_ptr()}")
    print(f"data2 data_ptr: {data2.data_ptr()}")
相关推荐
陈广亮22 分钟前
构建具有长期记忆的 AI Agent:从设计模式到生产实践
人工智能
会写代码的柯基犬31 分钟前
DeepSeek vs Kimi vs Qwen —— AI 生成俄罗斯方块代码效果横评
人工智能·llm
Mintopia1 小时前
OpenClaw 是什么?为什么节后热度如此之高?
人工智能
爱可生开源社区1 小时前
DBA 的未来?八位行业先锋的年度圆桌讨论
人工智能·dba
叁两4 小时前
用opencode打造全自动公众号写作流水线,AI 代笔太香了!
前端·人工智能·agent
前端付豪4 小时前
LangChain记忆:通过Memory记住上次的对话细节
人工智能·python·langchain
strayCat232554 小时前
Clawdbot 源码解读 7: 扩展机制
人工智能·开源
程序员打怪兽4 小时前
详解Visual Transformer (ViT)网络模型
深度学习
王鑫星4 小时前
SWE-bench 首次突破 80%:Claude Opus 4.5 发布,Anthropic 的野心不止于写代码
人工智能