阶段一:前向传播 - Tensor 的内存模型与高性能算子
PyTorch 架构级学习系列 - 第 2 篇
本文将深入探讨 PyTorch 的核心:Tensor 的物理存储模型。你将理解为什么某些操作"零开销",而另一些需要复制内存;为什么
transpose后经常需要contiguous();以及如何利用这些知识写出高性能的代码。
📚 目录
- Tensor 的三要素:Storage、Shape、Stride
- 设计动机:为什么要分离 Storage 和 Shape?
- 零拷贝的魔法:View 操作的实现原理
- 代价与限制:理解 Contiguous
- 掌握工具箱:各种维度操作
- 从 Tensor 到模型:nn.Module 的设计
- 实战练习:手写 MLP 和多头注意力
🔍 Part 1: Tensor 的三要素
💡 核心理念
Tensor = Storage(物理存储)+ (Shape, Stride, Offset)(投影规则)
理解这个分离式设计,你就理解了 PyTorch 性能优化的核心。
1.1 从一个简单例子开始
python
import torch
# 创建一个 2x3 的张量
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
print(x.shape) # torch.Size([2, 3])
print(x.stride()) # (3, 1)
print(x.storage()) # [1, 2, 3, 4, 5, 6](底层是一维数组)
惊人的发现: 虽然我们看到的是 2D 矩阵,但底层存储的是 1D 数组!
这引出了一个关键问题:Tensor 如何将多维的逻辑视图映射到一维的物理存储?
1.2 三要素详解
要素 1:Storage - 物理存储
Storage 是 PyTorch 中真正存储数据的地方,它是一个一维的连续内存块。
python
x = torch.tensor([1, 2, 3, 4, 5, 6])
print(x.storage())
# 输出:
# 1
# 2
# 3
# 4
# 5
# 6
# [torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 6]
关键特性:
- Storage 总是一维的
- 多个 Tensor 可以共享同一个 Storage
- Storage 存储的是原始数据(字节)
要素 2:Shape - 逻辑形状
Shape(形状) 定义了 Tensor 的维度和每个维度的大小。
python
x = torch.arange(12) # shape: (12,)
# 同一个 Storage,不同的 Shape
y = x.view(3, 4) # shape: (3, 4) - 3行4列
z = x.view(2, 2, 3) # shape: (2, 2, 3) - 2个2x3矩阵
重要: Shape 只是"元数据",改变 Shape 不需要移动数据。
要素 3:Stride - 映射关键
Stride(步长) 是理解 Tensor 内存布局的核心概念。
定义: Stride 告诉你,在某个维度上前进一步,需要在 Storage 中跳过多少个元素。
实例解析:二维 Tensor
python
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
print(x.shape) # (2, 3) - 2行3列
print(x.stride()) # (3, 1)
Stride (3, 1) 的含义:
- 第 0 维(行)的 stride = 3:从一行跳到下一行,需要跳过 3 个元素
- 第 1 维(列)的 stride = 1:从一列跳到下一列,需要跳过 1 个元素
内存布局可视化:
ini
Storage: [1, 2, 3, 4, 5, 6]
↑ ↑
[0,0] [1,0]
stride=0×3+0×1=0
stride=1×3+0×1=3
访问 x[i, j] 的内存地址 = base_offset + i × stride[0] + j × stride[1]
= 0 + i × 3 + j × 1
x[0, 0] → storage[0] = 1
x[0, 1] → storage[1] = 2
x[1, 0] → storage[3] = 4
x[1, 2] → storage[5] = 6
1.3 投影公式:多维到一维的映射
Tensor 是通过 (Shape, Stride, Offset) 三元组定义的、对一维 Storage 的多维投影。
形式化定义:
text
Storage S: 一维数组 [s₀, s₁, s₂, ..., sₙ₋₁]
Tensor T: 通过投影函数 f 定义的多维视图
- Shape: (d₀, d₁, d₂, ..., dₖ)
- Stride: (σ₀, σ₁, σ₂, ..., σₖ)
- Offset: o
投影函数 f: 多维索引 → 一维索引
f(i₀, i₁, i₂, ..., iₖ) = o + i₀×σ₀ + i₁×σ₁ + i₂×σ₂ + ... + iₖ×σₖ
因此:
T[i₀, i₁, ..., iₖ] = S[f(i₀, i₁, ..., iₖ)]
三维示例:
python
x = torch.arange(24).view(2, 3, 4)
print(x.stride()) # (12, 4, 1)
# 访问 x[1, 2, 3]:
# storage_index = 1×12 + 2×4 + 3×1 = 12 + 8 + 3 = 23
print(x[1, 2, 3]) # tensor(23)
1.4 多个 Tensor 可以共享同一个 Storage
python
x = torch.arange(12)
print(f"x 的 Storage 指针: {x.storage().data_ptr()}")
# 创建多个不同的投影
y = x.view(3, 4)
z = x.view(2, 6)
w = x.view(2, 3, 2)
# 验证:它们共享同一个 Storage
print(f"y 的 Storage 指针: {y.storage().data_ptr()}")
print(f"z 的 Storage 指针: {z.storage().data_ptr()}")
print(f"w 的 Storage 指针: {w.storage().data_ptr()}")
# 所有指针相同!
# 修改任何一个会影响所有
x[0] = 999
print(f"x[0] = {x[0]}") # 999
print(f"y[0,0] = {y[0, 0]}") # 999
print(f"z[0,0] = {z[0, 0]}") # 999
可视化:同一个 Storage 的多种投影
text
Storage [0, 1, 2, 3, 4, 5, ..., 11]
(一维数组,永远不变)
│
┌───────────────┼───────────────┐
│ │ │
↓ ↓ ↓
投影 1 投影 2 投影 3
shape=(3,4) shape=(2,6) shape=(2,3,2)
stride=(4,1) stride=(6,1) stride=(6,2,1)
小结: 现在你理解了 Tensor 的物理本质:一个一维 Storage + 一套投影规则。
但你可能会问:为什么要这样设计?为什么不直接把数据和形状存在一起?
这就引出了 Part 2...
🎯 Part 2: 设计动机 - 为什么要分离 Storage 和 Shape?
2.1 传统设计的问题
假设我们用传统的方式设计 Tensor:
python
# 传统设计:数据和形状紧密绑定
class NaiveTensor:
def __init__(self, data, shape):
self.data = data # 2D 数组 [[1,2,3], [4,5,6]]
self.shape = shape # (2, 3)
def transpose(self):
# 必须创建新数组!
new_data = [[self.data[j][i] for j in range(self.shape[0])]
for i in range(self.shape[1])]
return NaiveTensor(new_data, (self.shape[1], self.shape[0]))
问题: 每次转置都要复制数据!
性能测试:
python
import time
import torch
# 测试:1000×1000 矩阵转置
x = torch.randn(1000, 1000)
# 方法 1:PyTorch 的 transpose(只改元数据)
start = time.perf_counter()
y = x.transpose(0, 1)
time_pytorch = (time.perf_counter() - start) * 1000
# 方法 2:手动复制(模拟传统设计)
start = time.perf_counter()
y_copy = x.t().contiguous().clone()
time_copy = (time.perf_counter() - start) * 1000
print(f"PyTorch transpose: {time_pytorch:.4f} ms") # < 0.001 ms
print(f"手动复制: {time_copy:.4f} ms") # ~ 10 ms
print(f"加速比: {time_copy / time_pytorch:.0f}x")
# 快 10000 倍!
2.2 PyTorch 的解决方案:分离式设计
核心思想: 将"物理存储"和"逻辑视图"分离。
text
传统设计:
Tensor = 数据 + 形状
→ 改变形状 = 复制数据 ❌
PyTorch 设计:
Tensor = Storage(物理) + (Shape, Stride)(逻辑)
→ 改变形状 = 只改元数据 ✅
优势 1:零拷贝操作
python
x = torch.randn(1000, 1000)
# x: shape=(1000, 1000), stride=(1000, 1)
# 含义:1000×1000 的矩阵,Storage 有 100万个元素
# 所有这些操作都是瞬间完成,不复制数据:
y = x.transpose(0, 1)
# y: shape=(1000, 1000), stride=(1, 1000) ← 只交换了 stride
# 含义:转置矩阵,行列互换
# 访问 y[i,j] = storage[i×1 + j×1000] = x[j,i]
# Storage 完全没动!
z = x[::2, ::2]
# z: shape=(500, 500), stride=(2000, 2) ← stride 变成 2 倍
# 含义:每隔一行、每隔一列采样(下采样)
# 访问 z[i,j] = storage[i×2000 + j×2] = x[2i, 2j]
# Storage 完全没动!
w = x.unsqueeze(0)
# w: shape=(1, 1000, 1000), stride=(1000000, 1000, 1)
# 含义:在第 0 维插入大小为 1 的维度(增加批次维度)
# 访问 w[0,i,j] = storage[0×1000000 + i×1000 + j×1] = x[i,j]
# Storage 完全没动!
v = x.expand(10, 1000, 1000)
# v: shape=(10, 1000, 1000), stride=(0, 1000, 1) ← 注意第 0 维 stride=0!
# 含义:虚拟复制 10 份,看起来有 10 个矩阵,但实际内存只有 1 份
# 访问 v[i,j,k] = storage[i×0 + j×1000 + k×1] = x[j,k]
# 无论 i 是多少,都访问同一个位置!
# Storage 完全没动!
优势 2:内存共享
python
x = torch.randn(1000, 1000)
# x 占用约 8MB 内存(1000×1000×8字节)
y = x.transpose(0, 1)
# y 不占用额外内存!仍然是 8MB
# 验证:它们共享同一块内存
print(x.storage().data_ptr() == y.storage().data_ptr()) # True
# data_ptr() 返回 Storage 的内存地址,地址相同说明是同一块内存
# 因为共享内存,修改 y 会影响 x
y[0, 0] = 999
print(x[0, 0]) # 999
# 为什么?因为 y[0,0] 和 x[0,0] 访问的是 storage[0],同一个位置!
# 同理,修改 x 也会影响 y
x[0, 1] = 888
print(y[1, 0]) # 888
# 为什么?因为 x[0,1]=storage[1],y[1,0]=storage[1],都是同一个位置!
优势 3:延迟物化(Lazy Materialization)
什么是延迟物化? 就像"拖延症":能不干活就不干活,直到万不得已才真正干活。
python
x = torch.randn(1000, 1000)
# Storage: 100万个元素
# shape=(1000, 1000), stride=(1000, 1)
# ===== 第 1 步:转置 =====
y = x.transpose(0, 1)
# ✅ 瞬间完成(< 0.001ms)
# 只改了元数据:shape=(1000, 1000), stride=(1, 1000)
# Storage 还是原来的 100万个元素,位置没变!
print(f"y 和 x 共享内存: {y.storage().data_ptr() == x.storage().data_ptr()}") # True
# ===== 第 2 步:隔列采样 =====
z = y[:, ::2]
# ✅ 瞬间完成(< 0.001ms)
# 只改了元数据:shape=(1000, 500), stride=(1, 2000)
# Storage 还是原来的 100万个元素,位置没变!
print(f"z 和 x 共享内存: {z.storage().data_ptr() == x.storage().data_ptr()}") # True
# ===== 第 3 步:增加批次维度 =====
w = z.unsqueeze(0)
# ✅ 瞬间完成(< 0.001ms)
# 只改了元数据:shape=(1, 1000, 500), stride=(500000, 1, 2000)
# Storage 还是原来的 100万个元素,位置没变!
print(f"w 和 x 共享内存: {w.storage().data_ptr() == x.storage().data_ptr()}") # True
# 此时:做了 3 个操作,但一次数据复制都没有发生!
# w 现在是一个"不连续"的 Tensor,访问它会跳来跳去
# ===== 第 4 步:真正需要连续内存时 =====
w_cont = w.contiguous()
# ❌ 这里才真正复制数据(~ 2ms)
# 创建新的 Storage,按照 w 的逻辑顺序重新排列数据
# w_cont: shape=(1, 1000, 500), stride=(500000, 500, 1) ← 标准 stride
print(f"w_cont 和 x 共享内存: {w_cont.storage().data_ptr() == x.storage().data_ptr()}") # False
# 为什么需要 contiguous?
# 因为某些操作(如 view)要求连续的内存布局
# 现在可以安全地 reshape 了:
result = w_cont.view(1, -1) # 成功!shape=(1, 500000)
类比理解:
想象你有一本 1000 页的书(Storage),你想:
- 倒序阅读(transpose)→ 不需要重新打印书,只需要从后往前翻页
- 只读偶数页(切片)→ 不需要重新打印书,只需要跳着读
- 把书装进袋子(unsqueeze)→ 不需要重新打印书,只是加了个容器
这就是"延迟":你一直在操作原书,没有复制任何一页!
但如果你要: 4. 把选中的页装订成新书 (contiguous)→ 现在才需要复印机!
性能对比:
python
import time
x = torch.randn(1000, 1000)
# 测试延迟物化的性能优势
start = time.perf_counter()
y = x.transpose(0, 1)
z = y[:, ::2]
w = z.unsqueeze(0)
time_lazy = (time.perf_counter() - start) * 1000
start = time.perf_counter()
w_cont = w.contiguous()
time_materialize = (time.perf_counter() - start) * 1000
print(f"3 个 View 操作: {time_lazy:.4f} ms") # < 0.01 ms
print(f"最终物化(复制数据): {time_materialize:.4f} ms") # ~ 2 ms
print(f"物化慢了: {time_materialize / time_lazy:.0f}x")
关键洞察:
- View 操作链可以"积累"很多变换,但都不复制数据
- 只在最后真正需要连续内存时才一次性复制
- 这样可以避免中间的多次复制,节省时间和内存
2.3 对比总结
| 操作 | 传统设计 | PyTorch 设计 | 加速比 |
|---|---|---|---|
transpose() |
复制数据(10ms) | 只改 stride(< 0.001ms) | 10000x |
[::2] 切片 |
复制数据 | 只改 stride 和 offset | 1000x |
unsqueeze() |
复制并扩展 | 只改 shape | 无限 |
expand() |
复制多份 | stride 设为 0 | 无限 |
小结: 分离式设计让 PyTorch 可以实现零拷贝的视图操作,性能提升数千倍。
但这引出了新问题:如何通过只改变 Stride 来实现 transpose?Stride 到底是如何工作的?
这就是 Part 3 要解答的...
🪄 Part 3: 零拷贝的魔法 - View 操作的实现原理
3.1 什么是 View?
View 是对同一个 Storage 的不同"解释方式"。
核心原则:
- ✅ View 操作只改变元数据(Shape、Stride、Offset)
- ✅ 不移动或复制数据
- ✅ 多个 View 共享同一个 Storage
3.2 案例 1:Transpose 的零拷贝实现
python
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
print("原始:")
print("Shape:", x.shape) # (2, 3)
print("Stride:", x.stride()) # (3, 1)
print("Storage:", list(x.storage())) # [1, 2, 3, 4, 5, 6]
# 转置
y = x.transpose(0, 1)
print("\n转置后:")
print("Shape:", y.shape) # (3, 2)
print("Stride:", y.stride()) # (1, 3) ← 注意 stride 变了!
print("Storage:", list(y.storage())) # [1, 2, 3, 4, 5, 6] ← Storage 没变!
关键发现: transpose 并没有移动数据,只是交换了 Stride!
可视化理解:
ini
原始 x (shape=(2,3), stride=(3,1)):
Storage: [1, 2, 3, 4, 5, 6]
投影规则: x[i,j] = storage[i×3 + j×1]
x[0,0] = storage[0×3 + 0×1] = storage[0] = 1
x[0,1] = storage[0×3 + 1×1] = storage[1] = 2
x[1,0] = storage[1×3 + 0×1] = storage[3] = 4
逻辑视图:
[[1, 2, 3],
[4, 5, 6]]
─────────────────────────────────────
转置后 y (shape=(3,2), stride=(1,3)):
Storage: [1, 2, 3, 4, 5, 6] ← 完全相同的存储!
投影规则: y[i,j] = storage[i×1 + j×3]
y[0,0] = storage[0×1 + 0×3] = storage[0] = 1
y[0,1] = storage[0×1 + 1×3] = storage[3] = 4
y[1,0] = storage[1×1 + 0×3] = storage[1] = 2
逻辑视图:
[[1, 4],
[2, 5],
[3, 6]]
实现原理: 通过改变投影规则,同一个 Storage 被"重新解释"了!
3.3 案例 2:Slice(切片)的零拷贝实现
python
x = torch.arange(12).view(3, 4)
# x: shape=(3, 4), stride=(4, 1), offset=0
# Storage: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
# 逻辑视图:
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
# 切片:取第 1 列(第 2 列,索引从 0 开始)
y = x[:, 1]
# 含义:取所有行的第 1 列
# 结果:[x[0,1], x[1,1], x[2,1]] = [1, 5, 9]
print(y) # tensor([1, 5, 9])
print("Shape:", y.shape) # (3,)
print("Stride:", y.stride()) # (4,) ← 每次跳 4 个元素
print("Storage offset:", y.storage_offset()) # 1 ← 从 storage[1] 开始
# 验证:共享 Storage(没有复制数据)
print(x.storage().data_ptr() == y.storage().data_ptr()) # True
# 验证:修改 y 会影响 x
y[0] = 999
print(x)
# tensor([[ 0, 999, 2, 3], ← x[0,1] 变成了 999
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
实现原理:
三要素的变化:
- Offset 改为 1(从 storage[1] 开始,不是 storage[0])
- Stride 改为 (4,)(每次跳 4 个元素)
- Shape 改为 (3,)(只有 3 个元素)
访问公式:
text
y[i] = storage[offset + i×stride[0]]
= storage[1 + i×4]
y[0] = storage[1 + 0×4] = storage[1] = 1
y[1] = storage[1 + 1×4] = storage[5] = 5
y[2] = storage[1 + 2×4] = storage[9] = 9
可视化:
Storage: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
↑ ↑ ↑
offset +4跳 +4跳
y[0]=1 y[1]=5 y[2]=9
3.4 案例 3:Unsqueeze(增维)的零拷贝实现
python
x = torch.tensor([1, 2, 3]) # shape: (3,)
y = x.unsqueeze(0) # shape: (1, 3)
print("Shape:", y.shape) # (1, 3)
print("Stride:", y.stride()) # (3, 1)
实现原理:
- 在第 0 维插入大小为 1 的维度
- 新维度的 stride 等于原来所有元素的跨度
text
x: shape=(3,), stride=(1,)
[1, 2, 3]
y: shape=(1, 3), stride=(3, 1)
[[1, 2, 3]]
y[0, 0] = storage[0×3 + 0×1] = storage[0] = 1
y[0, 1] = storage[0×3 + 1×1] = storage[1] = 2
3.5 案例 4:Expand(虚拟扩展)的零拷贝实现
python
x = torch.tensor([[1], [2], [3]])
# x: shape=(3, 1), stride=(1, 1), Storage=[1, 2, 3]
# 含义:3 行 1 列的矩阵
# [[1],
# [2],
# [3]]
y = x.expand(3, 4)
# y: shape=(3, 4), stride=(1, 0) ← 关键:第 1 维 stride=0!
# 含义:扩展成 3 行 4 列,但不复制数据
# Storage 还是 [1, 2, 3],只有 3 个元素!
print(y)
# tensor([[1, 1, 1, 1], ← 第 0 行重复 storage[0]
# [2, 2, 2, 2], ← 第 1 行重复 storage[1]
# [3, 3, 3, 3]]) ← 第 2 行重复 storage[2]
print("Stride:", y.stride()) # (1, 0) ← 第 1 维 stride 是 0!
print("共享 Storage:", x.storage().data_ptr() == y.storage().data_ptr()) # True
# 验证:修改 x 会影响 y 的整行
x[0, 0] = 999
print(y)
# tensor([[999, 999, 999, 999], ← 整行都变了!
# [ 2, 2, 2, 2],
# [ 3, 3, 3, 3]])
神奇之处:Stride = (1, 0) 的魔法
text
访问公式:
y[i, j] = storage[i×1 + j×0] = storage[i]
关键:j×0 = 0,无论 j 是多少!
第 0 行(i=0):
y[0, 0] = storage[0×1 + 0×0] = storage[0] = 1
y[0, 1] = storage[0×1 + 1×0] = storage[0] = 1 ← 和上面一样!
y[0, 2] = storage[0×1 + 2×0] = storage[0] = 1 ← 还是一样!
y[0, 3] = storage[0×1 + 3×0] = storage[0] = 1 ← 都访问 storage[0]
第 1 行(i=1):
y[1, 0] = storage[1×1 + 0×0] = storage[1] = 2
y[1, 1] = storage[1×1 + 1×0] = storage[1] = 2 ← 都访问 storage[1]
y[1, 2] = storage[1×1 + 2×0] = storage[1] = 2
y[1, 3] = storage[1×1 + 3×0] = storage[1] = 2
可视化:
Storage: [1, 2, 3] ← 只有 3 个元素
↓ ↓ ↓
逻辑视图:
j=0 j=1 j=2 j=3
i=0 1 1 1 1 ← 都指向 storage[0]
i=1 2 2 2 2 ← 都指向 storage[1]
i=2 3 3 3 3 ← 都指向 storage[2]
内存节省:
- Storage 只有 3 个元素(24 字节)
- 看起来有 3×4=12 个元素
- 节省了 75% 的内存!
3.6 哪些操作是 View?
python
x = torch.randn(4, 4)
# 以下操作都是 View(零拷贝):
y1 = x.view(16) # 改变形状
y2 = x.reshape(2, 8) # 大多数情况是 View
y3 = x.transpose(0, 1) # 转置
y4 = x.permute(1, 0) # 维度重排
y5 = x[:, 2:] # 切片
y6 = x.unsqueeze(0) # 增加维度
y7 = x.squeeze() # 删除大小为 1 的维度
y8 = x.expand(4, 4, 4) # 虚拟扩展
y9 = x[::2, ::2] # 步进切片
# 验证:检查底层指针
print(x.data_ptr() == y1.data_ptr()) # True
3.7 哪些操作会复制数据?
python
x = torch.randn(4, 4)
# 以下操作会复制数据:
y1 = x.clone() # 显式复制
y2 = x.contiguous() # 如果不连续,则复制
y3 = x + 1 # 任何算术运算
y4 = x[x > 0] # 布尔索引
y5 = x[[0, 2]] # 高级索引(非连续选择)
y6 = torch.cat([x, x]) # 拼接
# 验证:不同的内存地址
print(x.data_ptr() == y1.data_ptr()) # False
小结: View 操作通过巧妙地改变投影规则(Stride、Offset),实现了零拷贝。这是 PyTorch 高性能的秘密。
但这引出了新问题:既然 transpose 只改变 Stride,为什么有时候需要调用 contiguous()?什么是"连续"?
这就是 Part 4 要解答的...
🔧 Part 4: 代价与限制 - 理解 Contiguous
4.1 什么是"连续"(Contiguous)?
一个 Tensor 是 contiguous 的,当且仅当它的元素在 Storage 中是按照**标准顺序(C-order,行优先)**排列的。
标准顺序的定义:
python
# 对于 shape = (d₀, d₁, d₂, ..., dₙ)
# 标准 stride(C-order,行优先):
stride[n] = 1
stride[n-1] = d[n]
stride[n-2] = d[n-1] × d[n]
...
stride[0] = d[1] × d[2] × ... × d[n]
示例:
python
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
print(x.shape) # (2, 3)
print(x.stride()) # (3, 1)
# 检查是否连续:
expected_stride = (3, 1) # 3 = 第 1 维的大小, 1 = 最后一维
print(x.is_contiguous()) # True
直观理解:
text
Contiguous 意味着:
- 逻辑顺序:[[1,2,3], [4,5,6]] → 扁平化 → [1,2,3,4,5,6]
- Storage 顺序:[1, 2, 3, 4, 5, 6]
- 完美匹配!✓
4.2 为什么 Transpose 后不连续?
python
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
y = x.transpose(0, 1)
print(y.is_contiguous()) # False
# 检查 y 的逻辑顺序:
print(y)
# tensor([[1, 4],
# [2, 5],
# [3, 6]])
# 逻辑顺序:[[1,4], [2,5], [3,6]] → 扁平化 → [1,4,2,5,3,6]
# Storage 顺序:[1, 2, 3, 4, 5, 6]
# 不匹配!✗
可视化:
text
Storage: [1, 2, 3, 4, 5, 6]
↓ ↓ ↓ ↓
逻辑读取: 1 → 4 2 → 5 ...
↑ ↑ ↑
跳跃访问!不连续!
4.3 为什么某些操作要求连续?
案例:view() 要求输入连续
python
x = torch.randn(3, 4)
y = x.transpose(0, 1) # y 不连续
# 这会报错!
try:
z = y.view(12)
except RuntimeError as e:
print(e)
# RuntimeError: view size is not compatible with input tensor's
# size and stride. Use .reshape(...) or .contiguous().view(...)
为什么 view 要求连续?
view 只改变元数据,不移动数据。如果 Tensor 不连续,就无法通过简单的 stride 计算来重新解释 Storage。
例子:
text
y 的 Storage: [1, 2, 3, 4, 5, 6]
y 的逻辑顺序: [1, 4, 2, 5, 3, 6]
如果要 view 成 (12,):
需要访问顺序: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
但 Storage 只有 6 个元素!
无法用简单的 stride 公式实现 → 必须复制数据
4.4 contiguous() 做了什么?
python
y = x.transpose(0, 1) # 不连续
z = y.contiguous() # 创建新的连续 Storage
print("y 的 Storage 指针:", y.data_ptr())
print("z 的 Storage 指针:", z.data_ptr())
# 不同!z 是新的 Tensor
print("z 是否连续:", z.is_contiguous()) # True
contiguous() 的工作流程:
- 检查 Tensor 是否已经连续
- 如果是,直接返回自己(不复制)
- 如果不是,继续下一步
- 按照逻辑顺序创建新的 Storage
- 复制数据到新 Storage
- 返回新的 Tensor(连续的)
python
# 源码简化版:
def contiguous(tensor):
if tensor.is_contiguous():
return tensor # 已经连续,不复制
# 创建新 Storage,按逻辑顺序复制数据
new_storage = allocate_storage(tensor.numel())
for idx in ndindex(tensor.shape): # 遍历所有索引
new_storage[flat_index(idx)] = tensor[idx]
return Tensor(new_storage, tensor.shape, standard_stride(tensor.shape))
4.5 解决方案对比
方案 1:先 contiguous,再 view
python
y = x.transpose(0, 1) # 不连续
z = y.contiguous().view(12) # OK
方案 2:使用 reshape(推荐)
python
y = x.transpose(0, 1) # 不连续
z = y.reshape(12) # OK,自动处理
reshape vs view:
| 特性 | view |
reshape |
|---|---|---|
| 要求连续 | ✅ 是 | ❌ 否 |
| 总是返回 View | ✅ 是 | ❌ 不一定 |
| 性能 | 最快 | 智能选择 |
| 推荐 | 性能关键路径 | 日常使用 |
4.6 性能影响
python
import time
x = torch.randn(1000, 1000)
y = x.t() # 不连续
# 测试访问速度
start = time.perf_counter()
_ = x.sum()
time_cont = time.perf_counter() - start
start = time.perf_counter()
_ = y.sum()
time_non_cont = time.perf_counter() - start
print(f"连续 Tensor: {time_cont * 1000:.2f} ms")
print(f"不连续 Tensor: {time_non_cont * 1000:.2f} ms")
print(f"慢了: {time_non_cont / time_cont:.2f}x")
# 不连续的 Tensor 通常慢 1.5-3x(缓存不友好)
为什么慢?
- 连续:内存访问是顺序的,CPU 缓存命中率高
- 不连续:内存访问是跳跃的,缓存频繁失效
小结: View 操作的代价是可能产生不连续的 Tensor,某些操作需要先调用 contiguous() 复制数据。
现在你理解了 Tensor 的底层机制。接下来,让我们系统学习 PyTorch 提供的各种维度操作工具。
📐 Part 5: 掌握工具箱 - 各种维度操作
理解了 Stride 的原理后,所有维度操作都只是在改变投影规则。
5.1 Reshape 家族:改变形状
view() - 严格的零拷贝
python
x = torch.arange(12)
# view 要求 Tensor 连续
y = x.view(3, 4) # OK
y = x.view(2, 6) # OK
y = x.view(-1) # OK,-1 表示自动推断
# 元素总数必须匹配
try:
y = x.view(3, 5) # 错误!3×5=15 ≠ 12
except RuntimeError as e:
print(e)
reshape() - 智能的形状变换
python
x = torch.randn(3, 4)
y = x.transpose(0, 1) # y 不连续
# reshape 会自动处理:
# - 如果可以,返回 View(零拷贝)
# - 如果不行,自动调用 contiguous()
z = y.reshape(12) # OK,自动处理了不连续的情况
flatten() - 扁平化
python
x = torch.randn(2, 3, 4)
# 从某个维度开始展平
y = x.flatten(start_dim=1) # shape: (2, 12)
# 等价于:x.reshape(2, -1)
5.2 维度添加与删除
unsqueeze() - 添加大小为 1 的维度
python
x = torch.randn(3, 4) # shape: (3, 4)
y = x.unsqueeze(0) # shape: (1, 3, 4)
z = x.unsqueeze(1) # shape: (3, 1, 4)
w = x.unsqueeze(-1) # shape: (3, 4, 1)
工程用途:
python
# 场景 1:匹配批次维度
image = torch.randn(3, 224, 224) # 单张图片 (C, H, W)
batch = image.unsqueeze(0) # (1, C, H, W)
# 场景 2:为广播做准备
a = torch.randn(5) # (5,)
b = torch.randn(3) # (3,)
a = a.unsqueeze(1) # (5, 1)
b = b.unsqueeze(0) # (1, 3)
c = a * b # (5, 3) - 广播乘法
squeeze() - 删除大小为 1 的维度
python
x = torch.randn(1, 3, 1, 4)
y = x.squeeze() # (3, 4) - 删除所有大小为 1 的维度
z = x.squeeze(0) # (3, 1, 4) - 只删除第 0 维
w = x.squeeze(2) # (1, 3, 4) - 只删除第 2 维
5.3 维度重排
transpose() - 交换两个维度
python
x = torch.randn(2, 3, 4)
y = x.transpose(0, 1) # shape: (3, 2, 4)
z = x.transpose(1, 2) # shape: (2, 4, 3)
原理: 交换对应的 shape 和 stride
python
x: shape=(2, 3, 4), stride=(12, 4, 1)
y: shape=(3, 2, 4), stride=(4, 12, 1) # 交换了前两个
permute() - 任意维度重排
python
x = torch.randn(2, 3, 4, 5)
# 重新排列所有维度
y = x.permute(3, 1, 0, 2) # shape: (5, 3, 2, 4)
# 常见用法:(B, C, H, W) → (B, H, W, C)
images = torch.randn(32, 3, 224, 224) # PyTorch 格式
images_hwc = images.permute(0, 2, 3, 1) # (32, 224, 224, 3)
5.4 Broadcasting - 自动维度扩展
Broadcasting 是 PyTorch 最强大的特性之一。
Broadcasting 规则
从右向左对齐维度,然后:
- 如果两个维度大小相同 → OK
- 如果其中一个是 1 → 可以 broadcast
- 如果其中一个不存在 → 自动添加大小为 1 的维度,然后 broadcast
示例 1:矩阵 + 向量
python
matrix = torch.tensor([[1, 2, 3],
[4, 5, 6]]) # (2, 3)
vector = torch.tensor([10, 20, 30]) # (3,)
result = matrix + vector # (2, 3)
# 对齐过程:
# matrix: (2, 3)
# vector: (3,) → 自动变为 (1, 3) → broadcast 到 (2, 3)
示例 2:复杂的 Broadcasting
python
a = torch.randn(5, 1, 3) # (5, 1, 3)
b = torch.randn(1, 4, 3) # (1, 4, 3)
c = a + b # shape: (5, 4, 3)
# 对齐过程:
# a: (5, 1, 3)
# b: (1, 4, 3)
# ↓ ↓ ↓
# c: (5, 4, 3)
expand() - 显式 Broadcasting
python
x = torch.tensor([[1], [2], [3]]) # shape: (3, 1)
# 虚拟扩展(不复制数据)
y = x.expand(3, 4) # shape: (3, 4)
print(y)
# tensor([[1, 1, 1, 1],
# [2, 2, 2, 2],
# [3, 3, 3, 3]])
# 验证:Stride 的第 1 维是 0
print(y.stride()) # (1, 0)
5.5 操作总结表
| 操作 | 改变的元数据 | 是否 View | 说明 |
|---|---|---|---|
view(shape) |
shape, stride | ✅ | 要求连续 |
reshape(shape) |
shape, stride | 通常 | 自动处理不连续 |
transpose(a, b) |
shape, stride | ✅ | 交换两维 |
permute(dims) |
shape, stride | ✅ | 重排所有维度 |
x[a:b] |
offset, shape, stride | ✅ | 切片 |
x[::step] |
stride | ✅ | 步进 |
unsqueeze(dim) |
shape, stride | ✅ | 插入维度 |
squeeze() |
shape, stride | ✅ | 删除大小为 1 的维度 |
expand(...) |
shape | ✅ | 虚拟扩展(stride 设为 0) |
flatten() |
shape, stride | 通常 | 展平 |
小结: 掌握这些维度操作,你就可以灵活地重塑 Tensor,为计算做准备。
现在你已经理解了 Tensor 的所有核心机制。接下来,让我们看看如何将 Tensor 组织成神经网络模型。
🧱 Part 6: 从 Tensor 到模型 - nn.Module 的设计
理解了 Tensor 之后,我们需要一个容器来组织它们,构建神经网络。
6.1 nn.Module 的基本结构
python
import torch
import torch.nn as nn
class SimpleModule(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__() # 必须调用!
# 可训练参数
self.weight = nn.Parameter(torch.randn(input_dim, output_dim))
self.bias = nn.Parameter(torch.zeros(output_dim))
# 不可训练的状态(如 BatchNorm 的 running_mean)
self.register_buffer('running_sum', torch.zeros(output_dim))
def forward(self, x):
output = torch.matmul(x, self.weight) + self.bias
self.running_sum += output.sum(dim=0)
return output
6.2 Parameter vs Buffer vs Tensor
| 类型 | 用途 | 是否可训练 | 是否保存 | 是否移动到 device |
|---|---|---|---|---|
| Parameter | 模型参数 | ✅ | ✅ | ✅ |
| Buffer | 模型状态 | ❌ | ✅ | ✅ |
| 普通 Tensor | 临时变量 | ❌ | ❌ | ❌ |
实例:
python
class Model(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(10, 5))
self.register_buffer('mean', torch.zeros(5))
self.scale = torch.tensor(2.0)
model = Model()
# 移动到 GPU
model = model.cuda()
print(model.weight.device) # cuda:0
print(model.mean.device) # cuda:0
print(model.scale.device) # cpu ← 普通 Tensor 不会移动!
6.3 自动注册机制
python
class MyModule(nn.Module):
def __init__(self):
super().__init__()
# 自动注册为子模块
self.layer1 = nn.Linear(10, 20)
self.layer2 = nn.Linear(20, 5)
model = MyModule()
# 自动递归收集所有参数
print(len(list(model.parameters()))) # 4 个(2 个 weight + 2 个 bias)
6.4 ModuleList vs Python List
python
# ❌ 错误示例
class BadModule(nn.Module):
def __init__(self):
super().__init__()
self.layers = [nn.Linear(10, 10) for _ in range(3)]
model = BadModule()
print(len(list(model.parameters()))) # 0 ← 参数丢失!
# ✅ 正确示例
class GoodModule(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(10, 10) for _ in range(3)
])
model = GoodModule()
print(len(list(model.parameters()))) # 6
小结: nn.Module 提供了一个优雅的容器,自动管理参数、状态和子模块。
现在,让我们通过实战练习巩固所有知识。
🎯 Part 7: 实战练习
实战 1:手写 MLP(多层感知机)
目标: 不使用 nn.Linear,只用 torch.matmul 实现两层 MLP。
python
import torch
import torch.nn as nn
class ManualMLP(nn.Module):
"""手动实现的多层感知机"""
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
# 第一层权重和偏置
self.w1 = nn.Parameter(
torch.randn(input_dim, hidden_dim) * (2.0 / input_dim) ** 0.5
)
self.b1 = nn.Parameter(torch.zeros(hidden_dim))
# 第二层权重和偏置
self.w2 = nn.Parameter(
torch.randn(hidden_dim, output_dim) * (2.0 / hidden_dim) ** 0.5
)
self.b2 = nn.Parameter(torch.zeros(output_dim))
def forward(self, x):
"""
Args:
x: shape (batch_size, input_dim)
Returns:
output: shape (batch_size, output_dim)
"""
# 第一层:x @ W1 + b1
hidden = torch.matmul(x, self.w1) + self.b1
# 激活函数:ReLU
hidden = torch.clamp(hidden, min=0)
# 第二层:hidden @ W2 + b2
output = torch.matmul(hidden, self.w2) + self.b2
return output
# 测试
model = ManualMLP(784, 128, 10)
x = torch.randn(32, 784) # 模拟 MNIST
output = model(x)
print(output.shape) # torch.Size([32, 10])
实战 2:手写 Multi-Head Attention
目标: 理解 Transformer 的核心,掌握复杂的维度变换。
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class ManualMultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
assert embed_dim % num_heads == 0
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# Q、K、V 的投影矩阵
self.w_q = nn.Parameter(torch.randn(embed_dim, embed_dim) * 0.02)
self.w_k = nn.Parameter(torch.randn(embed_dim, embed_dim) * 0.02)
self.w_v = nn.Parameter(torch.randn(embed_dim, embed_dim) * 0.02)
self.w_o = nn.Parameter(torch.randn(embed_dim, embed_dim) * 0.02)
def forward(self, x):
"""
Args:
x: (batch, seq_len, embed_dim)
Returns:
output: (batch, seq_len, embed_dim)
"""
batch_size, seq_len, embed_dim = x.shape
# 步骤 1:线性投影
Q = torch.matmul(x, self.w_q) # (B, L, E)
K = torch.matmul(x, self.w_k)
V = torch.matmul(x, self.w_v)
# 步骤 2:切分成多头
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
Q = Q.transpose(1, 2) # (B, H, L, D)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
K = K.transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)
V = V.transpose(1, 2)
# 步骤 3:计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) # (B, H, L, L)
scores = scores / (self.head_dim ** 0.5)
# 步骤 4:Softmax
attn_weights = F.softmax(scores, dim=-1)
# 步骤 5:加权求和
attn_output = torch.matmul(attn_weights, V) # (B, H, L, D)
# 步骤 6:合并多头
attn_output = attn_output.transpose(1, 2) # (B, L, H, D)
attn_output = attn_output.contiguous()
attn_output = attn_output.view(batch_size, seq_len, embed_dim)
# 步骤 7:输出投影
output = torch.matmul(attn_output, self.w_o)
return output
# 测试
model = ManualMultiHeadAttention(embed_dim=512, num_heads=8)
x = torch.randn(2, 10, 512)
output = model(x)
print(output.shape) # torch.Size([2, 10, 512])
关键维度变换可视化:
yaml
输入: (B, L, E)
↓ Linear projection
Q, K, V: (B, L, E)
↓ view + transpose
Q, K, V: (B, H, L, D) 其中 H=num_heads, D=E/H
↓ Attention 计算
Output: (B, H, L, D)
↓ transpose + contiguous + view
Output: (B, L, E)
实战 3:手写 LayerNorm
python
def manual_layer_norm(x, eps=1e-5):
"""
Args:
x: shape (batch, ..., features)
Returns:
normalized: shape (batch, ..., features)
"""
# 在最后一维计算统计量
mean = x.mean(dim=-1, keepdim=True) # (B, ..., 1)
var = x.var(dim=-1, keepdim=True, unbiased=False)
# 归一化
x_normalized = (x - mean) / torch.sqrt(var + eps)
return x_normalized
# 测试
x = torch.randn(32, 10, 512)
output = manual_layer_norm(x)
# 验证
print(output[0, 0].mean()) # 接近 0
print(output[0, 0].std()) # 接近 1
🎯 本章总结
核心要点回顾
-
Tensor = Storage(物理)+ (Shape, Stride, Offset)(投影)
- Storage 是一维数组,永远不变
- 投影规则定义如何解释 Storage
- 访问公式:
T[i,j,k] = S[offset + i×stride[0] + j×stride[1] + k×stride[2]]
-
分离式设计的动机
- 零拷贝操作(快 10000 倍)
- 内存共享
- 延迟物化
-
View 操作的原理
- 只改变投影规则,不移动数据
- transpose 交换 stride
- slice 改变 offset
- expand 将 stride 设为 0
-
Contiguous 的必要性
- 某些操作(如 view)要求连续
- contiguous() 复制数据以保证连续
- 不连续的 Tensor 访问慢 1.5-3x
-
维度操作的统一理解
- 所有操作都是在改变投影规则
- 理解 Stride 就理解了一切
-
nn.Module 的设计
- Parameter:可训练参数
- Buffer:模型状态
- 自动递归管理子模块
工程最佳实践
- ✅ 优先使用
reshape()而非view() - ✅ 利用 Broadcasting 代替循环
- ✅ 批量处理数据
- ✅ 使用
ModuleList管理子模块 - ⚠️ 注意
transpose后可能需要contiguous() - ⚠️ 避免在需要梯度的 Tensor 上 in-place 操作
下一步
在下一篇文章中,我们将深入自动微分引擎(Autograd),理解:
- 计算图是如何动态构建的
backward()的实现原理- 如何实现自定义的 autograd.Function
🎯 检查清单
在进入下一阶段前,请确保你能回答:
- Storage、Shape、Stride 的关系是什么?
- 为什么要分离 Storage 和 Shape?
- transpose 如何通过改变 Stride 实现零拷贝?
- 什么是 Contiguous?为什么需要它?
- view 和 reshape 的区别?
- Broadcasting 的规则是什么?
- Parameter 和 Buffer 的区别?
如果你能清晰回答这些问题,恭喜你,你已经掌握了 PyTorch 的核心机制!
下一篇预告: 《阶段二:后向传播 - Autograd 自动微分引擎揭秘》
本文是 PyTorch 架构级学习系列的第 2 篇。