Tensor 的两个组成部分:
-
Storage (存储):实际的数据被保存为一个连续的一维数组(通常是在 CPU 或 GPU 内存中)。
-
View (视图/元数据):描述如何解释这些物理数据。它包含:
-
shape (形状):张量维度
-
stride (步长):想要跳到下一个维度,在 Storage 中需要跳过多少个元素。
-
storage_offset (偏移量):第一个元素的storage中的位置。
-
多维张量的 stride(步长)解读
stride 是一个元组,表示在某个维度上移动 1 个位置,需要在底层一维 storage 中跳过多少个元素。
核心公式
对于一个 n 维张量,元素 [i₁, i₂, ..., iₙ] 在 storage 中的位置:
text
offset = storage_offset
+ i₁ × stride[0]
+ i₂ × stride[1]
+ ...
+ iₙ × stride[n-1]
python
import torch
# 创建一个 3x4 的连续张量
a = torch.arange(12).reshape(3, 4)
print(a)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
print(a.shape) # (3, 4)
print(a.stride) # (4, 1)
print(a.storage_offset) # 0
解读 stride = (4, 1):
| 维度 | stride 值 | 含义 |
|---|---|---|
| 维度0(行) | 4 | 向下移动 1 行,storage 索引跳过 4 个元素 |
| 维度1(列) | 1 | 向右移动 1 列,storage 索引跳过 1 个元素 |
行优先
storage 索引: 0 1 2 3 4 5 6 7 8 9 10 11
storage 数据: [0] [1] [2] [3] [4] [5] [6] [7] [8] [9][10][11]
└───── row0 ─────┘ └───── row1 ─────┘ └───── row2 ─────┘
python
b = a.t() # 转置,变成 4x3
print(b.shape) # (4, 3)
print(b.stride) # (1, 4) ← stride 交换了!
print(b)
# tensor([[ 0, 4, 8],
# [ 1, 5, 9],
# [ 2, 6, 10],
# [ 3, 7, 11]])
解读 stride = (1, 4):
| 维度 | stride | 含义 |
|---|---|---|
| 维度0(行) | 1 | 向下移动 1 行,storage 索引跳过 1 个元素 |
| 维度1(列) | 4 | 向右移动 1 列,storage 索引跳过 4 个元素 |
三维:
python
c = torch.arange(24).reshape(2, 3, 4)
print(c.shape) # (2, 3, 4)
print(c.stride) # (12, 4, 1)
解读 stride = (12, 4, 1):
| 维度 | stride | 含义 |
|---|---|---|
| 维度0(深度/批次) | 12 | 移动 1 个深度,跳过 12 个元素(3×4) |
| 维度1(行) | 4 | 移动 1 行,跳过 4 个元素(1行的长度) |
| 维度2(列) | 1 | 移动 1 列,跳过 1 个元素 |
storage 布局可视化:
text
深度0(第1个 3x4 矩阵):
行0: [0, 1, 2, 3] ← 连续
行1: [4, 5, 6, 7]
行2: [8, 9, 10, 11]
深度1(第2个 3x4 矩阵):
行0: [12, 13, 14, 15]
行1: [16, 17, 18, 19]
行2: [20, 21, 22, 23]
storage 索引: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
1.切片,改变offset:
python
import torch
# 创建一个长度为 6 的张量
a = torch.arange(6) # storage: [0, 1, 2, 3, 4, 5]
# a.shape = (6,)
# a.storage_offset = 0
# a.stride = (1,)
# 对 a 进行切片,取索引 2 到 4(不包括 4)
b = a[2:4] # b 逻辑上是 [2, 3]
# 查看 b 的元数据
print(b.shape) # (2,)
print(b.storage_offset) # 2 ← 关键!b 从 storage 的第 2 个元素开始
print(b.stride) # (1,)
print(b.data_ptr() == a.data_ptr()) # True,指向同一块内存
Storage (a 和 b 共享):
索引: 0 1 2 3 4 5
数据: [0] [1] [2] [3] [4] [5]
↑
└─ b.storage_offset = 2
b 从这里开始,长度 2
b 逻辑视图: [2, 3]
切片本质上是改变storge_offset
2.连续性与view(stride[i]=stride[i+1]×shape[i+1])
判断张量的连续性,就是判断stride
1.最后一个数字是否为1
2.从后往前,每个维度是否等于
3.size=0,空,特判
python
def is_contiguous(tensor):
shape = tensor.shape
stride = tensor.stride()
# 1. 处理标量或空张量
if tensor.ndim == 0:
return True
if 0 in shape:
return True
expected_stride = 1
# 从后往前遍历
for i in range(len(shape) - 1, -1, -1):
# 关键点:如果维度长度为 1,这个维度的 stride 是多少都无所谓
# 它不会改变物理内存的连续分布
if shape[i] != 1:
if stride[i] != expected_stride:
return False
expected_stride *= shape[i]
return True
c = torch.arange(24).reshape(2, 3, 4)
print(c.shape) # (2, 3, 4)
print(c.stride) # (12, 4, 1)
要判断一个张量是否连续,可以简单判断:对于shape[i] > 1的维度i,进行判断stride[i] = stride[i+1]*shape[i+1]
c: [
c[0]->
\[1,3,4,2\], \[3,3,2,1\], \[5,4,2,1\]
c[1]->
\[1,3,4,2\], \[3,3,2,1\], \[5,4,2,1\]
]
比如在判断stride[0] =? stride[1]*shape[1]
可以如此理解:stride[0]:跨到0维度的下一个元素需要跳过12个,比如从c[0]跳到c[1]
stride[1]:跨到1维度的下一个元素需要跳过4个,比如从c[0][1]到c[0][2],
shape[1]:1维度上的元素数量是3(每个c[i]中有三个元素)
若是连续,两者势必相等
3.transpose与permute
本质:修改原数据中的shape与stride。
1.tranpose
python
import torch
# 创建一个连续张量
x = torch.arange(12).reshape(3, 4)
print(f"原始张量 x:")
print(x)
print(f"shape: {x.shape}") # (3, 4)
print(f"stride: {x.stride()}") # (4, 1)
print(f"连续? {x.is_contiguous()}") # True
print(f"storage 地址: {x.data_ptr()}")
print()
# 执行 transpose(交换两个维度)
y = x.transpose(0, 1)
print(f"转置后 y:")
print(y)
print(f"shape: {y.shape}") # (4, 3) - shape 也交换了
print(f"stride: {y.stride()}") # (1, 4) - stride 交换了!
print(f"连续? {y.is_contiguous()}") # False
print(f"storage 地址: {y.data_ptr()}") # 和 x 相同!
print(f"是否共享 storage: {y.data_ptr() == x.data_ptr()}") # True
原始张量 x:
tensor([[ 0, 1, 2, 3],
4, 5, 6, 7\], \[ 8, 9, 10, 11\]\]) shape: (3, 4) stride: (4, 1) 连续? True storage 地址: 140234567890123 转置后 y: tensor(\[\[ 0, 4, 8\], \[ 1, 5, 9\], \[ 2, 6, 10\], \[ 3, 7, 11\]\]) shape: (4, 3) stride: (1, 4) ← 交换了! 连续? False storage 地址: 140234567890123 ← 相同地址! 是否共享 storage: True #### permute: ```python # 3D 张量示例 x = torch.arange(24).reshape(2, 3, 4) print(f"原始 shape: {x.shape}") # (2, 3, 4) print(f"原始 stride: {x.stride()}") # (12, 4, 1) print(f"连续? {x.is_contiguous()}") # True # 重排维度:把维度 (0,1,2) 变成 (2,0,1) y = x.permute(2, 0, 1) print(f"\npermute 后 shape: {y.shape}") # (4, 2, 3) print(f"permute 后 stride: {y.stride()}") # (1, 12, 4) ← 按 permute 规则重排 print(f"连续? {y.is_contiguous()}") # False print(f"共享 storage: {y.data_ptr() == x.data_ptr()}") # True ``` ### 4.cat,stack **`cat` 是在现有轨道上"接火车",而 `stack` 是给火车"加盖一层新轨道"。** 以下是结合底层原理和实际应用的深度解析: #### 1. 核心区别:维度与内存视角 | 特性 | torch.cat (拼接) | torch.stack (堆叠) | |----------|---------------------|------------------| | **核心逻辑** | **沿现有维度**连接 | **沿新维度**连接 | | **维度变化** | 维度数**不变**,指定维度的长度增加 | 维度数 **+1** | | **形状要求** | **非拼接维度**必须完全一致 | **所有维度**必须完全一致 | | **底层操作** | 数据拷贝与连续化 | 增加维度信息,重新索引 | *** ** * ** *** #### 2. torch.cat:底层的"内存搬运工" `torch.cat` 的本质是将多个张量在**物理内存**上(逻辑上)首尾相连。 * **底层原理** : 它不会改变张量的本质结构,而是沿着你指定的轴(`dim`),将输入张量的数据块像"接龙"一样拼在一起。 * **内存视角** :假设你有两个形状为 `(2, 3)` 的张量。如果你沿 `dim=0` 拼接,PyTorch 会申请一块新的连续内存,大小为 `(4, 3)`,然后把第一个张量的数据复制进去,紧接着复制第二个张量的数据。 * **约束来源**:为什么要求非拼接维度必须一致?因为如果维度不对齐(比如一个是 3 列,一个是 4 列),它们在内存中就无法形成整齐的矩形块,破坏了张量的规则结构。 * **代码直观理解**: import torch a = torch.tensor([, ]) # 形状 (2, 2) b = torch.tensor([]) # 形状 (1, 2) # 沿第0维(行)拼接,就像把 b 贴在 a 的下面 result = torch.cat([a, b], dim=0) # 结果形状: (3, 2) -> 维度没变,行数变多了 *** ** * ** *** #### 3. torch.stack:底层的"维度升维器" `torch.stack` 的本质是**创造一个新的维度**,用来索引这些张量。 * **底层原理**: 它不仅仅是数据的组合,更是\*\*元数据(Metadata)\*\*的重构。 * **内存视角** :`stack` 操作会在张量的形状信息(Shape/Stride)中插入一个新的维度。它相当于把多个张量"打包"进一个新的容器里。 * **数据布局** :虽然底层数据在内存中可能依然是连续存储的,但在逻辑上,PyTorch 增加了一个"层"的概念。比如将两个 `(3, 4)` 的张量堆叠,结果变成 `(2, 3, 4)`。那个新增的 `2` 就是新维度的长度,代表"你有2个这样的张量"。 * **约束来源** :因为要把它们整齐地码放在这个新维度里,所以**所有**输入张量的形状必须一模一样,否则无法对齐。 * **代码直观理解**: import torch a = torch.tensor() # 形状 (2,) b = torch.tensor() # 形状 (2,) # 沿新维度堆叠,相当于把它们叠罗汉 result = torch.stack([a, b], dim=0) # 结果形状: (2, 2) -> 维度从1维变成了2维 # 结果: tensor([, # ]) *** ** * ** *** #### 4. 深度对比与避坑指南 为了帮你彻底搞懂,我用一个表格总结它们的"脾气": | 场景 | 应该用谁? | 为什么? | |------------------|----------------|--------------------------------------------------------------------------------------------| | **合并数据集** | `cat` | 比如你有 100 条数据和另外 50 条数据,你想合并成一个 150 条的大列表。 | | **构建批次 (Batch)** | `stack` | 比如你有 3 张大小为 `(3, 224, 224)` 的图片,你想组成一个 Batch 输入模型,变成 `(3, 3, 224, 224)`。 | | **特征融合** | `cat` | 在神经网络层之间,经常把不同通道的特征图拼在一起(如 Inception 结构)。 | | **常见报错** | `RuntimeError` | **Cat 报错** :通常是因为除了拼接维度外,其他维度大小不一样(比如想拼两个矩阵,但一个宽3,一个宽4)。 **Stack 报错**:通常是因为输入的两个张量形状不完全相同。 | #### 总结 * 如果你想\*\*"变长"\*\*(让数据更多),用 **`cat`**。 * 如果你想\*\*"变厚"\*\*(让结构更复杂,增加层级),用 **`stack`**。 理解这一点,你在处理 PyTorch 张量形状变换(Reshape/View)时就会清晰很多,不再容易报 `size mismatch` 的错误了。