pytorch的张量数据结构以及各种操作函数的底层原理

Tensor 的两个组成部分:

  • Storage (存储):实际的数据被保存为一个连续的一维数组(通常是在 CPU 或 GPU 内存中)。

  • View (视图/元数据):描述如何解释这些物理数据。它包含:

    • shape (形状):张量维度

    • stride (步长):想要跳到下一个维度,在 Storage 中需要跳过多少个元素。

    • storage_offset (偏移量):第一个元素的storage中的位置。

多维张量的 stride(步长)解读

stride 是一个元组,表示在某个维度上移动 1 个位置,需要在底层一维 storage 中跳过多少个元素


核心公式

对于一个 n 维张量,元素 [i₁, i₂, ..., iₙ] 在 storage 中的位置:

text

复制代码
offset = storage_offset 
       + i₁ × stride[0] 
       + i₂ × stride[1] 
       + ... 
       + iₙ × stride[n-1]
python 复制代码
import torch

# 创建一个 3x4 的连续张量
a = torch.arange(12).reshape(3, 4)
print(a)
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]])

print(a.shape)   # (3, 4)
print(a.stride)  # (4, 1)
print(a.storage_offset)  # 0

解读 stride = (4, 1):

维度 stride 值 含义
维度0(行) 4 向下移动 1 行,storage 索引跳过 4 个元素
维度1(列) 1 向右移动 1 列,storage 索引跳过 1 个元素

行优先

复制代码
storage 索引:  0   1   2   3   4   5   6   7   8   9  10  11
storage 数据: [0] [1] [2] [3] [4] [5] [6] [7] [8] [9][10][11]
              └───── row0 ─────┘ └───── row1 ─────┘ └───── row2 ─────┘
python 复制代码
b = a.t()  # 转置,变成 4x3
print(b.shape)   # (4, 3)
print(b.stride)  # (1, 4)  ← stride 交换了!
print(b)
# tensor([[ 0,  4,  8],
#         [ 1,  5,  9],
#         [ 2,  6, 10],
#         [ 3,  7, 11]])

解读 stride = (1, 4):

维度 stride 含义
维度0(行) 1 向下移动 1 行,storage 索引跳过 1 个元素
维度1(列) 4 向右移动 1 列,storage 索引跳过 4 个元素

三维:

python 复制代码
c = torch.arange(24).reshape(2, 3, 4)
print(c.shape)   # (2, 3, 4)
print(c.stride)  # (12, 4, 1)

解读 stride = (12, 4, 1):

维度 stride 含义
维度0(深度/批次) 12 移动 1 个深度,跳过 12 个元素(3×4)
维度1(行) 4 移动 1 行,跳过 4 个元素(1行的长度)
维度2(列) 1 移动 1 列,跳过 1 个元素

storage 布局可视化:

text

复制代码
深度0(第1个 3x4 矩阵):
  行0: [0,  1,  2,  3]   ← 连续
  行1: [4,  5,  6,  7]
  行2: [8,  9, 10, 11]

深度1(第2个 3x4 矩阵):
  行0: [12, 13, 14, 15]
  行1: [16, 17, 18, 19]
  行2: [20, 21, 22, 23]

storage 索引: 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23

1.切片,改变offset:

python 复制代码
import torch

# 创建一个长度为 6 的张量
a = torch.arange(6)  # storage: [0, 1, 2, 3, 4, 5]
                     # a.shape = (6,)
                     # a.storage_offset = 0
                     # a.stride = (1,)

# 对 a 进行切片,取索引 2 到 4(不包括 4)
b = a[2:4]  # b 逻辑上是 [2, 3]

# 查看 b 的元数据
print(b.shape)              # (2,)
print(b.storage_offset)     # 2  ← 关键!b 从 storage 的第 2 个元素开始
print(b.stride)             # (1,)
print(b.data_ptr() == a.data_ptr())  # True,指向同一块内存
复制代码
Storage (a 和 b 共享):
索引:  0    1    2    3    4    5
数据: [0]  [1]  [2]  [3]  [4]  [5]
              ↑
              └─ b.storage_offset = 2
                 b 从这里开始,长度 2
                 b 逻辑视图: [2, 3]

切片本质上是改变storge_offset

2.连续性与view(stride[i]=stride[i+1]×shape[i+1])

判断张量的连续性,就是判断stride

1.最后一个数字是否为1

2.从后往前,每个维度是否等于

3.size=0,空,特判

python 复制代码
def is_contiguous(tensor):
    shape = tensor.shape
    stride = tensor.stride()
    
    # 1. 处理标量或空张量
    if tensor.ndim == 0:
        return True
    if 0 in shape:
        return True

    expected_stride = 1
    # 从后往前遍历
    for i in range(len(shape) - 1, -1, -1):
        # 关键点:如果维度长度为 1,这个维度的 stride 是多少都无所谓
        # 它不会改变物理内存的连续分布
        if shape[i] != 1:
            if stride[i] != expected_stride:
                return False
            expected_stride *= shape[i]
            
    return True

c = torch.arange(24).reshape(2, 3, 4)

print(c.shape) # (2, 3, 4)

print(c.stride) # (12, 4, 1)

要判断一个张量是否连续,可以简单判断:对于shape[i] > 1的维度i,进行判断stride[i] = stride[i+1]*shape[i+1]

c: [

c[0]->

\[1,3,4,2\], \[3,3,2,1\], \[5,4,2,1\]

c[1]->

\[1,3,4,2\], \[3,3,2,1\], \[5,4,2,1\]

]

比如在判断stride[0] =? stride[1]*shape[1]

可以如此理解:stride[0]:跨到0维度的下一个元素需要跳过12个,比如从c[0]跳到c[1]

stride[1]:跨到1维度的下一个元素需要跳过4个,比如从c[0][1]到c[0][2],

shape[1]:1维度上的元素数量是3(每个c[i]中有三个元素)

若是连续,两者势必相等

3.transpose与permute

本质:修改原数据中的shape与stride。

1.tranpose

python 复制代码
import torch

# 创建一个连续张量
x = torch.arange(12).reshape(3, 4)
print(f"原始张量 x:")
print(x)
print(f"shape: {x.shape}")      # (3, 4)
print(f"stride: {x.stride()}")  # (4, 1)
print(f"连续? {x.is_contiguous()}")  # True
print(f"storage 地址: {x.data_ptr()}")
print()

# 执行 transpose(交换两个维度)
y = x.transpose(0, 1)
print(f"转置后 y:")
print(y)
print(f"shape: {y.shape}")      # (4, 3) - shape 也交换了
print(f"stride: {y.stride()}")  # (1, 4) - stride 交换了!
print(f"连续? {y.is_contiguous()}")  # False
print(f"storage 地址: {y.data_ptr()}")  # 和 x 相同!
print(f"是否共享 storage: {y.data_ptr() == x.data_ptr()}")  # True

原始张量 x:

tensor([[ 0, 1, 2, 3],

4, 5, 6, 7\], \[ 8, 9, 10, 11\]\]) shape: (3, 4) stride: (4, 1) 连续? True storage 地址: 140234567890123 转置后 y: tensor(\[\[ 0, 4, 8\], \[ 1, 5, 9\], \[ 2, 6, 10\], \[ 3, 7, 11\]\]) shape: (4, 3) stride: (1, 4) ← 交换了! 连续? False storage 地址: 140234567890123 ← 相同地址! 是否共享 storage: True #### permute: ```python # 3D 张量示例 x = torch.arange(24).reshape(2, 3, 4) print(f"原始 shape: {x.shape}") # (2, 3, 4) print(f"原始 stride: {x.stride()}") # (12, 4, 1) print(f"连续? {x.is_contiguous()}") # True # 重排维度:把维度 (0,1,2) 变成 (2,0,1) y = x.permute(2, 0, 1) print(f"\npermute 后 shape: {y.shape}") # (4, 2, 3) print(f"permute 后 stride: {y.stride()}") # (1, 12, 4) ← 按 permute 规则重排 print(f"连续? {y.is_contiguous()}") # False print(f"共享 storage: {y.data_ptr() == x.data_ptr()}") # True ``` ### 4.cat,stack **`cat` 是在现有轨道上"接火车",而 `stack` 是给火车"加盖一层新轨道"。** 以下是结合底层原理和实际应用的深度解析: #### 1. 核心区别:维度与内存视角 | 特性 | torch.cat (拼接) | torch.stack (堆叠) | |----------|---------------------|------------------| | **核心逻辑** | **沿现有维度**连接 | **沿新维度**连接 | | **维度变化** | 维度数**不变**,指定维度的长度增加 | 维度数 **+1** | | **形状要求** | **非拼接维度**必须完全一致 | **所有维度**必须完全一致 | | **底层操作** | 数据拷贝与连续化 | 增加维度信息,重新索引 | *** ** * ** *** #### 2. torch.cat:底层的"内存搬运工" `torch.cat` 的本质是将多个张量在**物理内存**上(逻辑上)首尾相连。 * **底层原理** : 它不会改变张量的本质结构,而是沿着你指定的轴(`dim`),将输入张量的数据块像"接龙"一样拼在一起。 * **内存视角** :假设你有两个形状为 `(2, 3)` 的张量。如果你沿 `dim=0` 拼接,PyTorch 会申请一块新的连续内存,大小为 `(4, 3)`,然后把第一个张量的数据复制进去,紧接着复制第二个张量的数据。 * **约束来源**:为什么要求非拼接维度必须一致?因为如果维度不对齐(比如一个是 3 列,一个是 4 列),它们在内存中就无法形成整齐的矩形块,破坏了张量的规则结构。 * **代码直观理解**: import torch a = torch.tensor([, ]) # 形状 (2, 2) b = torch.tensor([]) # 形状 (1, 2) # 沿第0维(行)拼接,就像把 b 贴在 a 的下面 result = torch.cat([a, b], dim=0) # 结果形状: (3, 2) -> 维度没变,行数变多了 *** ** * ** *** #### 3. torch.stack:底层的"维度升维器" `torch.stack` 的本质是**创造一个新的维度**,用来索引这些张量。 * **底层原理**: 它不仅仅是数据的组合,更是\*\*元数据(Metadata)\*\*的重构。 * **内存视角** :`stack` 操作会在张量的形状信息(Shape/Stride)中插入一个新的维度。它相当于把多个张量"打包"进一个新的容器里。 * **数据布局** :虽然底层数据在内存中可能依然是连续存储的,但在逻辑上,PyTorch 增加了一个"层"的概念。比如将两个 `(3, 4)` 的张量堆叠,结果变成 `(2, 3, 4)`。那个新增的 `2` 就是新维度的长度,代表"你有2个这样的张量"。 * **约束来源** :因为要把它们整齐地码放在这个新维度里,所以**所有**输入张量的形状必须一模一样,否则无法对齐。 * **代码直观理解**: import torch a = torch.tensor() # 形状 (2,) b = torch.tensor() # 形状 (2,) # 沿新维度堆叠,相当于把它们叠罗汉 result = torch.stack([a, b], dim=0) # 结果形状: (2, 2) -> 维度从1维变成了2维 # 结果: tensor([, # ]) *** ** * ** *** #### 4. 深度对比与避坑指南 为了帮你彻底搞懂,我用一个表格总结它们的"脾气": | 场景 | 应该用谁? | 为什么? | |------------------|----------------|--------------------------------------------------------------------------------------------| | **合并数据集** | `cat` | 比如你有 100 条数据和另外 50 条数据,你想合并成一个 150 条的大列表。 | | **构建批次 (Batch)** | `stack` | 比如你有 3 张大小为 `(3, 224, 224)` 的图片,你想组成一个 Batch 输入模型,变成 `(3, 3, 224, 224)`。 | | **特征融合** | `cat` | 在神经网络层之间,经常把不同通道的特征图拼在一起(如 Inception 结构)。 | | **常见报错** | `RuntimeError` | **Cat 报错** :通常是因为除了拼接维度外,其他维度大小不一样(比如想拼两个矩阵,但一个宽3,一个宽4)。 **Stack 报错**:通常是因为输入的两个张量形状不完全相同。 | #### 总结 * 如果你想\*\*"变长"\*\*(让数据更多),用 **`cat`**。 * 如果你想\*\*"变厚"\*\*(让结构更复杂,增加层级),用 **`stack`**。 理解这一点,你在处理 PyTorch 张量形状变换(Reshape/View)时就会清晰很多,不再容易报 `size mismatch` 的错误了。

相关推荐
盘古开天16662 小时前
Gemma4本地部署,零成本打造私有 AI 助手
人工智能·本地部署·智能体·gemma4·ai私有助理
浔川python社2 小时前
张雪机车:以热爱为轮,让中国摩托驰骋世界之巅
python
zl_dfq2 小时前
Python学习5 之【字符串】
python·学习
夜影风2 小时前
算力租赁产业链全景分析:解构AI时代的“算力电厂”
人工智能·算力租赁
MediaTea2 小时前
AI 术语通俗词典:矩阵乘法
人工智能·线性代数·矩阵
NHuan^_^2 小时前
SpringBoot3 整合 SpringAI 实现ai助手(记忆)
java·人工智能·spring boot
Binary_ey2 小时前
光刻技术第22期 | 贝叶斯压缩感知光源优化的优化技术及对比分析
人工智能·深度学习·机器学习
奔跑草-2 小时前
【AI日报】每日AI最新消息2026-04-07
人工智能·大模型·github·开源软件
rainy雨2 小时前
免费且好用的精益工具在哪里?2026年精益工具清单整理
大数据·人工智能·信息可视化·数据挖掘·数据分析·精益工程