pytorch的张量数据结构以及各种操作函数的底层原理

Tensor 的两个组成部分:

  • Storage (存储):实际的数据被保存为一个连续的一维数组(通常是在 CPU 或 GPU 内存中)。

  • View (视图/元数据):描述如何解释这些物理数据。它包含:

    • shape (形状):张量维度

    • stride (步长):想要跳到下一个维度,在 Storage 中需要跳过多少个元素。

    • storage_offset (偏移量):第一个元素的storage中的位置。

多维张量的 stride(步长)解读

stride 是一个元组,表示在某个维度上移动 1 个位置,需要在底层一维 storage 中跳过多少个元素


核心公式

对于一个 n 维张量,元素 [i₁, i₂, ..., iₙ] 在 storage 中的位置:

text

复制代码
offset = storage_offset 
       + i₁ × stride[0] 
       + i₂ × stride[1] 
       + ... 
       + iₙ × stride[n-1]
python 复制代码
import torch

# 创建一个 3x4 的连续张量
a = torch.arange(12).reshape(3, 4)
print(a)
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]])

print(a.shape)   # (3, 4)
print(a.stride)  # (4, 1)
print(a.storage_offset)  # 0

解读 stride = (4, 1):

维度 stride 值 含义
维度0(行) 4 向下移动 1 行,storage 索引跳过 4 个元素
维度1(列) 1 向右移动 1 列,storage 索引跳过 1 个元素

行优先

复制代码
storage 索引:  0   1   2   3   4   5   6   7   8   9  10  11
storage 数据: [0] [1] [2] [3] [4] [5] [6] [7] [8] [9][10][11]
              └───── row0 ─────┘ └───── row1 ─────┘ └───── row2 ─────┘
python 复制代码
b = a.t()  # 转置,变成 4x3
print(b.shape)   # (4, 3)
print(b.stride)  # (1, 4)  ← stride 交换了!
print(b)
# tensor([[ 0,  4,  8],
#         [ 1,  5,  9],
#         [ 2,  6, 10],
#         [ 3,  7, 11]])

解读 stride = (1, 4):

维度 stride 含义
维度0(行) 1 向下移动 1 行,storage 索引跳过 1 个元素
维度1(列) 4 向右移动 1 列,storage 索引跳过 4 个元素

三维:

python 复制代码
c = torch.arange(24).reshape(2, 3, 4)
print(c.shape)   # (2, 3, 4)
print(c.stride)  # (12, 4, 1)

解读 stride = (12, 4, 1):

维度 stride 含义
维度0(深度/批次) 12 移动 1 个深度,跳过 12 个元素(3×4)
维度1(行) 4 移动 1 行,跳过 4 个元素(1行的长度)
维度2(列) 1 移动 1 列,跳过 1 个元素

storage 布局可视化:

text

复制代码
深度0(第1个 3x4 矩阵):
  行0: [0,  1,  2,  3]   ← 连续
  行1: [4,  5,  6,  7]
  行2: [8,  9, 10, 11]

深度1(第2个 3x4 矩阵):
  行0: [12, 13, 14, 15]
  行1: [16, 17, 18, 19]
  行2: [20, 21, 22, 23]

storage 索引: 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23

1.切片,改变offset:

python 复制代码
import torch

# 创建一个长度为 6 的张量
a = torch.arange(6)  # storage: [0, 1, 2, 3, 4, 5]
                     # a.shape = (6,)
                     # a.storage_offset = 0
                     # a.stride = (1,)

# 对 a 进行切片,取索引 2 到 4(不包括 4)
b = a[2:4]  # b 逻辑上是 [2, 3]

# 查看 b 的元数据
print(b.shape)              # (2,)
print(b.storage_offset)     # 2  ← 关键!b 从 storage 的第 2 个元素开始
print(b.stride)             # (1,)
print(b.data_ptr() == a.data_ptr())  # True,指向同一块内存
复制代码
Storage (a 和 b 共享):
索引:  0    1    2    3    4    5
数据: [0]  [1]  [2]  [3]  [4]  [5]
              ↑
              └─ b.storage_offset = 2
                 b 从这里开始,长度 2
                 b 逻辑视图: [2, 3]

切片本质上是改变storge_offset

2.连续性与view(stridei=stridei+1×shapei+1)

判断张量的连续性,就是判断stride

1.最后一个数字是否为1

2.从后往前,每个维度是否等于

3.size=0,空,特判

python 复制代码
def is_contiguous(tensor):
    shape = tensor.shape
    stride = tensor.stride()
    
    # 1. 处理标量或空张量
    if tensor.ndim == 0:
        return True
    if 0 in shape:
        return True

    expected_stride = 1
    # 从后往前遍历
    for i in range(len(shape) - 1, -1, -1):
        # 关键点:如果维度长度为 1,这个维度的 stride 是多少都无所谓
        # 它不会改变物理内存的连续分布
        if shape[i] != 1:
            if stride[i] != expected_stride:
                return False
            expected_stride *= shape[i]
            
    return True

c = torch.arange(24).reshape(2, 3, 4)

print(c.shape) # (2, 3, 4)

print(c.stride) # (12, 4, 1)

要判断一个张量是否连续,可以简单判断:对于shapei > 1的维度i,进行判断stridei = stridei+1*shapei+1

c: [

c0->

\[1,3,4,2,

3,3,2,1,

5,4,2,1\]

c1->

\[1,3,4,2,

3,3,2,1,

5,4,2,1\]

]

比如在判断stride0 =? stride1*shape1

可以如此理解:stride0:跨到0维度的下一个元素需要跳过12个,比如从c0跳到c1

stride1:跨到1维度的下一个元素需要跳过4个,比如从c01到c02,

shape1:1维度上的元素数量是3(每个ci中有三个元素)

若是连续,两者势必相等

3.transpose与permute

本质:修改原数据中的shape与stride。

1.tranpose

python 复制代码
import torch

# 创建一个连续张量
x = torch.arange(12).reshape(3, 4)
print(f"原始张量 x:")
print(x)
print(f"shape: {x.shape}")      # (3, 4)
print(f"stride: {x.stride()}")  # (4, 1)
print(f"连续? {x.is_contiguous()}")  # True
print(f"storage 地址: {x.data_ptr()}")
print()

# 执行 transpose(交换两个维度)
y = x.transpose(0, 1)
print(f"转置后 y:")
print(y)
print(f"shape: {y.shape}")      # (4, 3) - shape 也交换了
print(f"stride: {y.stride()}")  # (1, 4) - stride 交换了!
print(f"连续? {y.is_contiguous()}")  # False
print(f"storage 地址: {y.data_ptr()}")  # 和 x 相同!
print(f"是否共享 storage: {y.data_ptr() == x.data_ptr()}")  # True

原始张量 x:

tensor(\[ 0, 1, 2, 3,

4, 5, 6, 7,

8, 9, 10, 11])

shape: (3, 4)

stride: (4, 1)

连续? True

storage 地址: 140234567890123

转置后 y:

tensor(\[ 0, 4, 8,

1, 5, 9,

2, 6, 10,

3, 7, 11])

shape: (4, 3)

stride: (1, 4) ← 交换了!

连续? False

storage 地址: 140234567890123 ← 相同地址!

是否共享 storage: True

permute:

python 复制代码
# 3D 张量示例
x = torch.arange(24).reshape(2, 3, 4)
print(f"原始 shape: {x.shape}")    # (2, 3, 4)
print(f"原始 stride: {x.stride()}") # (12, 4, 1)
print(f"连续? {x.is_contiguous()}") # True

# 重排维度:把维度 (0,1,2) 变成 (2,0,1)
y = x.permute(2, 0, 1)
print(f"\npermute 后 shape: {y.shape}")    # (4, 2, 3)
print(f"permute 后 stride: {y.stride()}") # (1, 12, 4)  ← 按 permute 规则重排
print(f"连续? {y.is_contiguous()}")        # False
print(f"共享 storage: {y.data_ptr() == x.data_ptr()}")  # True

4.cat,stack

cat 是在现有轨道上"接火车",而 stack 是给火车"加盖一层新轨道"。

以下是结合底层原理和实际应用的深度解析:

1. 核心区别:维度与内存视角

特性 torch.cat (拼接) torch.stack (堆叠)
核心逻辑 沿现有维度连接 沿新维度连接
维度变化 维度数不变,指定维度的长度增加 维度数 +1
形状要求 非拼接维度必须完全一致 所有维度必须完全一致
底层操作 数据拷贝与连续化 增加维度信息,重新索引

2. torch.cat:底层的"内存搬运工"

torch.cat 的本质是将多个张量在物理内存上(逻辑上)首尾相连。

  • 底层原理 : 它不会改变张量的本质结构,而是沿着你指定的轴(dim),将输入张量的数据块像"接龙"一样拼在一起。

    • 内存视角 :假设你有两个形状为 (2, 3) 的张量。如果你沿 dim=0 拼接,PyTorch 会申请一块新的连续内存,大小为 (4, 3),然后把第一个张量的数据复制进去,紧接着复制第二个张量的数据。
    • 约束来源:为什么要求非拼接维度必须一致?因为如果维度不对齐(比如一个是 3 列,一个是 4 列),它们在内存中就无法形成整齐的矩形块,破坏了张量的规则结构。
  • 代码直观理解

    复制代码
    import torch
    a = torch.tensor([, ]) # 形状 (2, 2)
    b = torch.tensor([])         # 形状 (1, 2)
    
    # 沿第0维(行)拼接,就像把 b 贴在 a 的下面
    result = torch.cat([a, b], dim=0)  
    # 结果形状: (3, 2) -> 维度没变,行数变多了

3. torch.stack:底层的"维度升维器"

torch.stack 的本质是创造一个新的维度,用来索引这些张量。

  • 底层原理: 它不仅仅是数据的组合,更是**元数据(Metadata)**的重构。

    • 内存视角stack 操作会在张量的形状信息(Shape/Stride)中插入一个新的维度。它相当于把多个张量"打包"进一个新的容器里。
    • 数据布局 :虽然底层数据在内存中可能依然是连续存储的,但在逻辑上,PyTorch 增加了一个"层"的概念。比如将两个 (3, 4) 的张量堆叠,结果变成 (2, 3, 4)。那个新增的 2 就是新维度的长度,代表"你有2个这样的张量"。
    • 约束来源 :因为要把它们整齐地码放在这个新维度里,所以所有输入张量的形状必须一模一样,否则无法对齐。
  • 代码直观理解

    复制代码
    import torch
    a = torch.tensor() # 形状 (2,)
    b = torch.tensor() # 形状 (2,)
    
    # 沿新维度堆叠,相当于把它们叠罗汉
    result = torch.stack([a, b], dim=0)
    # 结果形状: (2, 2) -> 维度从1维变成了2维
    # 结果: tensor([,
    #              ])

4. 深度对比与避坑指南

为了帮你彻底搞懂,我用一个表格总结它们的"脾气":

场景 应该用谁? 为什么?
合并数据集 cat 比如你有 100 条数据和另外 50 条数据,你想合并成一个 150 条的大列表。
构建批次 (Batch) stack 比如你有 3 张大小为 (3, 224, 224) 的图片,你想组成一个 Batch 输入模型,变成 (3, 3, 224, 224)
特征融合 cat 在神经网络层之间,经常把不同通道的特征图拼在一起(如 Inception 结构)。
常见报错 RuntimeError Cat 报错 :通常是因为除了拼接维度外,其他维度大小不一样(比如想拼两个矩阵,但一个宽3,一个宽4)。 Stack 报错:通常是因为输入的两个张量形状不完全相同。

总结

  • 如果你想**"变长"**(让数据更多),用 cat
  • 如果你想**"变厚"**(让结构更复杂,增加层级),用 stack

理解这一点,你在处理 PyTorch 张量形状变换(Reshape/View)时就会清晰很多,不再容易报 size mismatch 的错误了。

相关推荐
A.说学逗唱的Coke1 小时前
【大模型专题】向量数据库深度解析:从原理到实战,构建企业级 AI 知识检索底座
数据库·人工智能
果丁智能1 小时前
智能锁赋能网约房民宿数字化管控:身份核验+远程授权,筑牢安全防线、降本增效
网络·数据库·人工智能·安全·智能家居
V搜xhliang02461 小时前
AI智能体的数据安全与合规实践
人工智能·学习·数据分析·自动化·ai编程
大貔貅喝啤酒1 小时前
Python Requests库教程
自动化测试·python·requests库
PPIO派欧云1 小时前
PPIO登上贵州新闻联播,深化AI算力生态建设
人工智能
hai3152475432 小时前
一种通过空间几何转换进行软件编程计算的方式与现有计算的对比
人工智能·深度学习·数学建模·硬件架构·几何学·图论·拓扑学
猿饵块2 小时前
LibreOffice---文档制作
人工智能
硅谷秋水2 小时前
HARBOR:一个面向具身智体机器人强化学习的驾驭框架
人工智能·深度学习·机器学习·机器人
Mr..Jackey2 小时前
瑞佑 RUI Builder 图形化 UI 设计工具
arm开发·人工智能·单片机·ui·人机交互·ra8889·lcd控制芯片
copyer_xyf2 小时前
LangChain 调用 LLM
后端·python·agent