PyTorch学习阶段一:前向传播 - Tensor 的内存模型与高性能算子

阶段一:前向传播 - Tensor 的内存模型与高性能算子

PyTorch 架构级学习系列 - 第 2 篇

本文将深入探讨 PyTorch 的核心:Tensor 的物理存储模型。你将理解为什么某些操作"零开销",而另一些需要复制内存;为什么 transpose 后经常需要 contiguous();以及如何利用这些知识写出高性能的代码。


📚 目录

  1. Tensor 的三要素:Storage、Shape、Stride
  2. 设计动机:为什么要分离 Storage 和 Shape?
  3. 零拷贝的魔法:View 操作的实现原理
  4. 代价与限制:理解 Contiguous
  5. 掌握工具箱:各种维度操作
  6. 从 Tensor 到模型:nn.Module 的设计
  7. 实战练习:手写 MLP 和多头注意力

🔍 Part 1: Tensor 的三要素

💡 核心理念

Tensor = Storage(物理存储)+ (Shape, Stride, Offset)(投影规则)

理解这个分离式设计,你就理解了 PyTorch 性能优化的核心。

1.1 从一个简单例子开始

python 复制代码
import torch

# 创建一个 2x3 的张量
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

print(x.shape)      # torch.Size([2, 3])
print(x.stride())   # (3, 1)
print(x.storage())  # [1, 2, 3, 4, 5, 6](底层是一维数组)

惊人的发现: 虽然我们看到的是 2D 矩阵,但底层存储的是 1D 数组!

这引出了一个关键问题:Tensor 如何将多维的逻辑视图映射到一维的物理存储?

1.2 三要素详解

要素 1:Storage - 物理存储

Storage 是 PyTorch 中真正存储数据的地方,它是一个一维的连续内存块

python 复制代码
x = torch.tensor([1, 2, 3, 4, 5, 6])
print(x.storage())
# 输出:
#  1
#  2
#  3
#  4
#  5
#  6
# [torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 6]

关键特性:

  • Storage 总是一维的
  • 多个 Tensor 可以共享同一个 Storage
  • Storage 存储的是原始数据(字节)
要素 2:Shape - 逻辑形状

Shape(形状) 定义了 Tensor 的维度和每个维度的大小。

python 复制代码
x = torch.arange(12)  # shape: (12,)

# 同一个 Storage,不同的 Shape
y = x.view(3, 4)      # shape: (3, 4) - 3行4列
z = x.view(2, 2, 3)   # shape: (2, 2, 3) - 2个2x3矩阵

重要: Shape 只是"元数据",改变 Shape 不需要移动数据。

要素 3:Stride - 映射关键

Stride(步长) 是理解 Tensor 内存布局的核心概念。

定义: Stride 告诉你,在某个维度上前进一步,需要在 Storage 中跳过多少个元素。

实例解析:二维 Tensor
python 复制代码
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

print(x.shape)      # (2, 3) - 2行3列
print(x.stride())   # (3, 1)

Stride (3, 1) 的含义:

  • 第 0 维(行)的 stride = 3:从一行跳到下一行,需要跳过 3 个元素
  • 第 1 维(列)的 stride = 1:从一列跳到下一列,需要跳过 1 个元素

内存布局可视化:

ini 复制代码
Storage: [1, 2, 3, 4, 5, 6]
          ↑        ↑
         [0,0]   [1,0]
         stride=0×3+0×1=0
                stride=1×3+0×1=3

访问 x[i, j] 的内存地址 = base_offset + i × stride[0] + j × stride[1]
                        = 0 + i × 3 + j × 1

x[0, 0] → storage[0] = 1
x[0, 1] → storage[1] = 2
x[1, 0] → storage[3] = 4
x[1, 2] → storage[5] = 6

1.3 投影公式:多维到一维的映射

Tensor 是通过 (Shape, Stride, Offset) 三元组定义的、对一维 Storage 的多维投影。

形式化定义:

text 复制代码
Storage S: 一维数组 [s₀, s₁, s₂, ..., sₙ₋₁]

Tensor T: 通过投影函数 f 定义的多维视图
  - Shape: (d₀, d₁, d₂, ..., dₖ)
  - Stride: (σ₀, σ₁, σ₂, ..., σₖ)
  - Offset: o

投影函数 f: 多维索引 → 一维索引
  f(i₀, i₁, i₂, ..., iₖ) = o + i₀×σ₀ + i₁×σ₁ + i₂×σ₂ + ... + iₖ×σₖ

因此:
  T[i₀, i₁, ..., iₖ] = S[f(i₀, i₁, ..., iₖ)]

三维示例:

python 复制代码
x = torch.arange(24).view(2, 3, 4)
print(x.stride())  # (12, 4, 1)

# 访问 x[1, 2, 3]:
# storage_index = 1×12 + 2×4 + 3×1 = 12 + 8 + 3 = 23
print(x[1, 2, 3])  # tensor(23)

1.4 多个 Tensor 可以共享同一个 Storage

python 复制代码
x = torch.arange(12)
print(f"x 的 Storage 指针: {x.storage().data_ptr()}")

# 创建多个不同的投影
y = x.view(3, 4)
z = x.view(2, 6)
w = x.view(2, 3, 2)

# 验证:它们共享同一个 Storage
print(f"y 的 Storage 指针: {y.storage().data_ptr()}")
print(f"z 的 Storage 指针: {z.storage().data_ptr()}")
print(f"w 的 Storage 指针: {w.storage().data_ptr()}")
# 所有指针相同!

# 修改任何一个会影响所有
x[0] = 999
print(f"x[0] = {x[0]}")  # 999
print(f"y[0,0] = {y[0, 0]}")  # 999
print(f"z[0,0] = {z[0, 0]}")  # 999

可视化:同一个 Storage 的多种投影

text 复制代码
                Storage [0, 1, 2, 3, 4, 5, ..., 11]
                (一维数组,永远不变)
                        │
        ┌───────────────┼───────────────┐
        │               │               │
        ↓               ↓               ↓
    投影 1           投影 2           投影 3
shape=(3,4)      shape=(2,6)      shape=(2,3,2)
stride=(4,1)     stride=(6,1)     stride=(6,2,1)

小结: 现在你理解了 Tensor 的物理本质:一个一维 Storage + 一套投影规则。

但你可能会问:为什么要这样设计?为什么不直接把数据和形状存在一起?

这就引出了 Part 2...


🎯 Part 2: 设计动机 - 为什么要分离 Storage 和 Shape?

2.1 传统设计的问题

假设我们用传统的方式设计 Tensor:

python 复制代码
# 传统设计:数据和形状紧密绑定
class NaiveTensor:
    def __init__(self, data, shape):
        self.data = data        # 2D 数组 [[1,2,3], [4,5,6]]
        self.shape = shape      # (2, 3)

    def transpose(self):
        # 必须创建新数组!
        new_data = [[self.data[j][i] for j in range(self.shape[0])]
                    for i in range(self.shape[1])]
        return NaiveTensor(new_data, (self.shape[1], self.shape[0]))

问题: 每次转置都要复制数据!

性能测试:

python 复制代码
import time
import torch

# 测试:1000×1000 矩阵转置
x = torch.randn(1000, 1000)

# 方法 1:PyTorch 的 transpose(只改元数据)
start = time.perf_counter()
y = x.transpose(0, 1)
time_pytorch = (time.perf_counter() - start) * 1000

# 方法 2:手动复制(模拟传统设计)
start = time.perf_counter()
y_copy = x.t().contiguous().clone()
time_copy = (time.perf_counter() - start) * 1000

print(f"PyTorch transpose: {time_pytorch:.4f} ms")   # < 0.001 ms
print(f"手动复制:          {time_copy:.4f} ms")      # ~ 10 ms
print(f"加速比:            {time_copy / time_pytorch:.0f}x")
# 快 10000 倍!

2.2 PyTorch 的解决方案:分离式设计

核心思想: 将"物理存储"和"逻辑视图"分离。

text 复制代码
传统设计:
Tensor = 数据 + 形状
  → 改变形状 = 复制数据 ❌

PyTorch 设计:
Tensor = Storage(物理) + (Shape, Stride)(逻辑)
  → 改变形状 = 只改元数据 ✅

优势 1:零拷贝操作

python 复制代码
x = torch.randn(1000, 1000)
# x: shape=(1000, 1000), stride=(1000, 1)
# 含义:1000×1000 的矩阵,Storage 有 100万个元素

# 所有这些操作都是瞬间完成,不复制数据:

y = x.transpose(0, 1)
# y: shape=(1000, 1000), stride=(1, 1000)  ← 只交换了 stride
# 含义:转置矩阵,行列互换
# 访问 y[i,j] = storage[i×1 + j×1000] = x[j,i]
# Storage 完全没动!

z = x[::2, ::2]
# z: shape=(500, 500), stride=(2000, 2)  ← stride 变成 2 倍
# 含义:每隔一行、每隔一列采样(下采样)
# 访问 z[i,j] = storage[i×2000 + j×2] = x[2i, 2j]
# Storage 完全没动!

w = x.unsqueeze(0)
# w: shape=(1, 1000, 1000), stride=(1000000, 1000, 1)
# 含义:在第 0 维插入大小为 1 的维度(增加批次维度)
# 访问 w[0,i,j] = storage[0×1000000 + i×1000 + j×1] = x[i,j]
# Storage 完全没动!

v = x.expand(10, 1000, 1000)
# v: shape=(10, 1000, 1000), stride=(0, 1000, 1)  ← 注意第 0 维 stride=0!
# 含义:虚拟复制 10 份,看起来有 10 个矩阵,但实际内存只有 1 份
# 访问 v[i,j,k] = storage[i×0 + j×1000 + k×1] = x[j,k]
# 无论 i 是多少,都访问同一个位置!
# Storage 完全没动!

优势 2:内存共享

python 复制代码
x = torch.randn(1000, 1000)
# x 占用约 8MB 内存(1000×1000×8字节)

y = x.transpose(0, 1)
# y 不占用额外内存!仍然是 8MB

# 验证:它们共享同一块内存
print(x.storage().data_ptr() == y.storage().data_ptr())  # True
# data_ptr() 返回 Storage 的内存地址,地址相同说明是同一块内存

# 因为共享内存,修改 y 会影响 x
y[0, 0] = 999
print(x[0, 0])  # 999
# 为什么?因为 y[0,0] 和 x[0,0] 访问的是 storage[0],同一个位置!

# 同理,修改 x 也会影响 y
x[0, 1] = 888
print(y[1, 0])  # 888
# 为什么?因为 x[0,1]=storage[1],y[1,0]=storage[1],都是同一个位置!

优势 3:延迟物化(Lazy Materialization)

什么是延迟物化? 就像"拖延症":能不干活就不干活,直到万不得已才真正干活。

python 复制代码
x = torch.randn(1000, 1000)
# Storage: 100万个元素
# shape=(1000, 1000), stride=(1000, 1)

# ===== 第 1 步:转置 =====
y = x.transpose(0, 1)
# ✅ 瞬间完成(< 0.001ms)
# 只改了元数据:shape=(1000, 1000), stride=(1, 1000)
# Storage 还是原来的 100万个元素,位置没变!
print(f"y 和 x 共享内存: {y.storage().data_ptr() == x.storage().data_ptr()}")  # True

# ===== 第 2 步:隔列采样 =====
z = y[:, ::2]
# ✅ 瞬间完成(< 0.001ms)
# 只改了元数据:shape=(1000, 500), stride=(1, 2000)
# Storage 还是原来的 100万个元素,位置没变!
print(f"z 和 x 共享内存: {z.storage().data_ptr() == x.storage().data_ptr()}")  # True

# ===== 第 3 步:增加批次维度 =====
w = z.unsqueeze(0)
# ✅ 瞬间完成(< 0.001ms)
# 只改了元数据:shape=(1, 1000, 500), stride=(500000, 1, 2000)
# Storage 还是原来的 100万个元素,位置没变!
print(f"w 和 x 共享内存: {w.storage().data_ptr() == x.storage().data_ptr()}")  # True

# 此时:做了 3 个操作,但一次数据复制都没有发生!
# w 现在是一个"不连续"的 Tensor,访问它会跳来跳去

# ===== 第 4 步:真正需要连续内存时 =====
w_cont = w.contiguous()
# ❌ 这里才真正复制数据(~ 2ms)
# 创建新的 Storage,按照 w 的逻辑顺序重新排列数据
# w_cont: shape=(1, 1000, 500), stride=(500000, 500, 1)  ← 标准 stride
print(f"w_cont 和 x 共享内存: {w_cont.storage().data_ptr() == x.storage().data_ptr()}")  # False

# 为什么需要 contiguous?
# 因为某些操作(如 view)要求连续的内存布局
# 现在可以安全地 reshape 了:
result = w_cont.view(1, -1)  # 成功!shape=(1, 500000)

类比理解:

想象你有一本 1000 页的书(Storage),你想:

  1. 倒序阅读(transpose)→ 不需要重新打印书,只需要从后往前翻页
  2. 只读偶数页(切片)→ 不需要重新打印书,只需要跳着读
  3. 把书装进袋子(unsqueeze)→ 不需要重新打印书,只是加了个容器

这就是"延迟":你一直在操作原书,没有复制任何一页!

但如果你要: 4. 把选中的页装订成新书 (contiguous)→ 现在才需要复印机!

性能对比:

python 复制代码
import time

x = torch.randn(1000, 1000)

# 测试延迟物化的性能优势
start = time.perf_counter()
y = x.transpose(0, 1)
z = y[:, ::2]
w = z.unsqueeze(0)
time_lazy = (time.perf_counter() - start) * 1000

start = time.perf_counter()
w_cont = w.contiguous()
time_materialize = (time.perf_counter() - start) * 1000

print(f"3 个 View 操作:        {time_lazy:.4f} ms")      # < 0.01 ms
print(f"最终物化(复制数据):   {time_materialize:.4f} ms")  # ~ 2 ms
print(f"物化慢了:              {time_materialize / time_lazy:.0f}x")

关键洞察:

  • View 操作链可以"积累"很多变换,但都不复制数据
  • 只在最后真正需要连续内存时才一次性复制
  • 这样可以避免中间的多次复制,节省时间和内存

2.3 对比总结

操作 传统设计 PyTorch 设计 加速比
transpose() 复制数据(10ms) 只改 stride(< 0.001ms) 10000x
[::2] 切片 复制数据 只改 stride 和 offset 1000x
unsqueeze() 复制并扩展 只改 shape 无限
expand() 复制多份 stride 设为 0 无限

小结: 分离式设计让 PyTorch 可以实现零拷贝的视图操作,性能提升数千倍。

但这引出了新问题:如何通过只改变 Stride 来实现 transpose?Stride 到底是如何工作的?

这就是 Part 3 要解答的...


🪄 Part 3: 零拷贝的魔法 - View 操作的实现原理

3.1 什么是 View?

View 是对同一个 Storage 的不同"解释方式"。

核心原则:

  • ✅ View 操作只改变元数据(Shape、Stride、Offset)
  • ✅ 不移动或复制数据
  • ✅ 多个 View 共享同一个 Storage

3.2 案例 1:Transpose 的零拷贝实现

python 复制代码
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

print("原始:")
print("Shape:", x.shape)        # (2, 3)
print("Stride:", x.stride())    # (3, 1)
print("Storage:", list(x.storage()))  # [1, 2, 3, 4, 5, 6]

# 转置
y = x.transpose(0, 1)

print("\n转置后:")
print("Shape:", y.shape)        # (3, 2)
print("Stride:", y.stride())    # (1, 3) ← 注意 stride 变了!
print("Storage:", list(y.storage()))  # [1, 2, 3, 4, 5, 6] ← Storage 没变!

关键发现: transpose 并没有移动数据,只是交换了 Stride

可视化理解:

ini 复制代码
原始 x (shape=(2,3), stride=(3,1)):
Storage: [1, 2, 3, 4, 5, 6]
投影规则: x[i,j] = storage[i×3 + j×1]

x[0,0] = storage[0×3 + 0×1] = storage[0] = 1
x[0,1] = storage[0×3 + 1×1] = storage[1] = 2
x[1,0] = storage[1×3 + 0×1] = storage[3] = 4

逻辑视图:
[[1, 2, 3],
 [4, 5, 6]]

─────────────────────────────────────

转置后 y (shape=(3,2), stride=(1,3)):
Storage: [1, 2, 3, 4, 5, 6]  ← 完全相同的存储!
投影规则: y[i,j] = storage[i×1 + j×3]

y[0,0] = storage[0×1 + 0×3] = storage[0] = 1
y[0,1] = storage[0×1 + 1×3] = storage[3] = 4
y[1,0] = storage[1×1 + 0×3] = storage[1] = 2

逻辑视图:
[[1, 4],
 [2, 5],
 [3, 6]]

实现原理: 通过改变投影规则,同一个 Storage 被"重新解释"了!

3.3 案例 2:Slice(切片)的零拷贝实现

python 复制代码
x = torch.arange(12).view(3, 4)
# x: shape=(3, 4), stride=(4, 1), offset=0
# Storage: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
# 逻辑视图:
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]])

# 切片:取第 1 列(第 2 列,索引从 0 开始)
y = x[:, 1]
# 含义:取所有行的第 1 列
# 结果:[x[0,1], x[1,1], x[2,1]] = [1, 5, 9]

print(y)  # tensor([1, 5, 9])
print("Shape:", y.shape)   # (3,)
print("Stride:", y.stride())  # (4,) ← 每次跳 4 个元素
print("Storage offset:", y.storage_offset())  # 1 ← 从 storage[1] 开始

# 验证:共享 Storage(没有复制数据)
print(x.storage().data_ptr() == y.storage().data_ptr())  # True

# 验证:修改 y 会影响 x
y[0] = 999
print(x)
# tensor([[  0, 999,   2,   3],  ← x[0,1] 变成了 999
#         [  4,   5,   6,   7],
#         [  8,   9,  10,  11]])

实现原理:

三要素的变化:

  • Offset 改为 1(从 storage[1] 开始,不是 storage[0])
  • Stride 改为 (4,)(每次跳 4 个元素)
  • Shape 改为 (3,)(只有 3 个元素)

访问公式:

text 复制代码
y[i] = storage[offset + i×stride[0]]
     = storage[1 + i×4]

y[0] = storage[1 + 0×4] = storage[1] = 1
y[1] = storage[1 + 1×4] = storage[5] = 5
y[2] = storage[1 + 2×4] = storage[9] = 9

可视化:
Storage: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
            ↑        ↑        ↑
         offset   +4跳      +4跳
         y[0]=1   y[1]=5   y[2]=9

3.4 案例 3:Unsqueeze(增维)的零拷贝实现

python 复制代码
x = torch.tensor([1, 2, 3])  # shape: (3,)

y = x.unsqueeze(0)  # shape: (1, 3)

print("Shape:", y.shape)      # (1, 3)
print("Stride:", y.stride())  # (3, 1)

实现原理:

  • 在第 0 维插入大小为 1 的维度
  • 新维度的 stride 等于原来所有元素的跨度
text 复制代码
x: shape=(3,), stride=(1,)
   [1, 2, 3]

y: shape=(1, 3), stride=(3, 1)
   [[1, 2, 3]]

y[0, 0] = storage[0×3 + 0×1] = storage[0] = 1
y[0, 1] = storage[0×3 + 1×1] = storage[1] = 2

3.5 案例 4:Expand(虚拟扩展)的零拷贝实现

python 复制代码
x = torch.tensor([[1], [2], [3]])
# x: shape=(3, 1), stride=(1, 1), Storage=[1, 2, 3]
# 含义:3 行 1 列的矩阵
# [[1],
#  [2],
#  [3]]

y = x.expand(3, 4)
# y: shape=(3, 4), stride=(1, 0)  ← 关键:第 1 维 stride=0!
# 含义:扩展成 3 行 4 列,但不复制数据
# Storage 还是 [1, 2, 3],只有 3 个元素!

print(y)
# tensor([[1, 1, 1, 1],  ← 第 0 行重复 storage[0]
#         [2, 2, 2, 2],  ← 第 1 行重复 storage[1]
#         [3, 3, 3, 3]]) ← 第 2 行重复 storage[2]

print("Stride:", y.stride())  # (1, 0) ← 第 1 维 stride 是 0!
print("共享 Storage:", x.storage().data_ptr() == y.storage().data_ptr())  # True

# 验证:修改 x 会影响 y 的整行
x[0, 0] = 999
print(y)
# tensor([[999, 999, 999, 999],  ← 整行都变了!
#         [  2,   2,   2,   2],
#         [  3,   3,   3,   3]])

神奇之处:Stride = (1, 0) 的魔法

text 复制代码
访问公式:
y[i, j] = storage[i×1 + j×0] = storage[i]

关键:j×0 = 0,无论 j 是多少!

第 0 行(i=0):
y[0, 0] = storage[0×1 + 0×0] = storage[0] = 1
y[0, 1] = storage[0×1 + 1×0] = storage[0] = 1  ← 和上面一样!
y[0, 2] = storage[0×1 + 2×0] = storage[0] = 1  ← 还是一样!
y[0, 3] = storage[0×1 + 3×0] = storage[0] = 1  ← 都访问 storage[0]

第 1 行(i=1):
y[1, 0] = storage[1×1 + 0×0] = storage[1] = 2
y[1, 1] = storage[1×1 + 1×0] = storage[1] = 2  ← 都访问 storage[1]
y[1, 2] = storage[1×1 + 2×0] = storage[1] = 2
y[1, 3] = storage[1×1 + 3×0] = storage[1] = 2

可视化:
Storage: [1, 2, 3]  ← 只有 3 个元素
          ↓  ↓  ↓
逻辑视图:
     j=0  j=1  j=2  j=3
i=0   1    1    1    1    ← 都指向 storage[0]
i=1   2    2    2    2    ← 都指向 storage[1]
i=2   3    3    3    3    ← 都指向 storage[2]

内存节省:

  • Storage 只有 3 个元素(24 字节)
  • 看起来有 3×4=12 个元素
  • 节省了 75% 的内存!

3.6 哪些操作是 View?

python 复制代码
x = torch.randn(4, 4)

# 以下操作都是 View(零拷贝):
y1 = x.view(16)           # 改变形状
y2 = x.reshape(2, 8)      # 大多数情况是 View
y3 = x.transpose(0, 1)    # 转置
y4 = x.permute(1, 0)      # 维度重排
y5 = x[:, 2:]             # 切片
y6 = x.unsqueeze(0)       # 增加维度
y7 = x.squeeze()          # 删除大小为 1 的维度
y8 = x.expand(4, 4, 4)    # 虚拟扩展
y9 = x[::2, ::2]          # 步进切片

# 验证:检查底层指针
print(x.data_ptr() == y1.data_ptr())  # True

3.7 哪些操作会复制数据?

python 复制代码
x = torch.randn(4, 4)

# 以下操作会复制数据:
y1 = x.clone()            # 显式复制
y2 = x.contiguous()       # 如果不连续,则复制
y3 = x + 1                # 任何算术运算
y4 = x[x > 0]             # 布尔索引
y5 = x[[0, 2]]            # 高级索引(非连续选择)
y6 = torch.cat([x, x])    # 拼接

# 验证:不同的内存地址
print(x.data_ptr() == y1.data_ptr())  # False

小结: View 操作通过巧妙地改变投影规则(Stride、Offset),实现了零拷贝。这是 PyTorch 高性能的秘密。

但这引出了新问题:既然 transpose 只改变 Stride,为什么有时候需要调用 contiguous()?什么是"连续"?

这就是 Part 4 要解答的...


🔧 Part 4: 代价与限制 - 理解 Contiguous

4.1 什么是"连续"(Contiguous)?

一个 Tensor 是 contiguous 的,当且仅当它的元素在 Storage 中是按照**标准顺序(C-order,行优先)**排列的。

标准顺序的定义:

python 复制代码
# 对于 shape = (d₀, d₁, d₂, ..., dₙ)
# 标准 stride(C-order,行优先):

stride[n] = 1
stride[n-1] = d[n]
stride[n-2] = d[n-1] × d[n]
...
stride[0] = d[1] × d[2] × ... × d[n]

示例:

python 复制代码
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

print(x.shape)   # (2, 3)
print(x.stride())  # (3, 1)

# 检查是否连续:
expected_stride = (3, 1)  # 3 = 第 1 维的大小, 1 = 最后一维
print(x.is_contiguous())  # True

直观理解:

text 复制代码
Contiguous 意味着:
- 逻辑顺序:[[1,2,3], [4,5,6]] → 扁平化 → [1,2,3,4,5,6]
- Storage 顺序:[1, 2, 3, 4, 5, 6]
- 完美匹配!✓

4.2 为什么 Transpose 后不连续?

python 复制代码
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

y = x.transpose(0, 1)

print(y.is_contiguous())  # False

# 检查 y 的逻辑顺序:
print(y)
# tensor([[1, 4],
#         [2, 5],
#         [3, 6]])

# 逻辑顺序:[[1,4], [2,5], [3,6]] → 扁平化 → [1,4,2,5,3,6]
# Storage 顺序:[1, 2, 3, 4, 5, 6]
# 不匹配!✗

可视化:

text 复制代码
Storage: [1, 2, 3, 4, 5, 6]
         ↓     ↓  ↓     ↓
逻辑读取: 1  → 4  2  → 5  ...
          ↑     ↑     ↑
        跳跃访问!不连续!

4.3 为什么某些操作要求连续?

案例:view() 要求输入连续

python 复制代码
x = torch.randn(3, 4)
y = x.transpose(0, 1)  # y 不连续

# 这会报错!
try:
    z = y.view(12)
except RuntimeError as e:
    print(e)
    # RuntimeError: view size is not compatible with input tensor's
    # size and stride. Use .reshape(...) or .contiguous().view(...)

为什么 view 要求连续?

view 只改变元数据,不移动数据。如果 Tensor 不连续,就无法通过简单的 stride 计算来重新解释 Storage。

例子:

text 复制代码
y 的 Storage: [1, 2, 3, 4, 5, 6]
y 的逻辑顺序: [1, 4, 2, 5, 3, 6]

如果要 view 成 (12,):
需要访问顺序: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
但 Storage 只有 6 个元素!

无法用简单的 stride 公式实现 → 必须复制数据

4.4 contiguous() 做了什么?

python 复制代码
y = x.transpose(0, 1)  # 不连续
z = y.contiguous()     # 创建新的连续 Storage

print("y 的 Storage 指针:", y.data_ptr())
print("z 的 Storage 指针:", z.data_ptr())
# 不同!z 是新的 Tensor

print("z 是否连续:", z.is_contiguous())  # True

contiguous() 的工作流程:

  1. 检查 Tensor 是否已经连续
    • 如果是,直接返回自己(不复制)
    • 如果不是,继续下一步
  2. 按照逻辑顺序创建新的 Storage
  3. 复制数据到新 Storage
  4. 返回新的 Tensor(连续的)
python 复制代码
# 源码简化版:
def contiguous(tensor):
    if tensor.is_contiguous():
        return tensor  # 已经连续,不复制

    # 创建新 Storage,按逻辑顺序复制数据
    new_storage = allocate_storage(tensor.numel())
    for idx in ndindex(tensor.shape):  # 遍历所有索引
        new_storage[flat_index(idx)] = tensor[idx]

    return Tensor(new_storage, tensor.shape, standard_stride(tensor.shape))

4.5 解决方案对比

方案 1:先 contiguous,再 view

python 复制代码
y = x.transpose(0, 1)  # 不连续
z = y.contiguous().view(12)  # OK

方案 2:使用 reshape(推荐)

python 复制代码
y = x.transpose(0, 1)  # 不连续
z = y.reshape(12)      # OK,自动处理

reshape vs view

特性 view reshape
要求连续 ✅ 是 ❌ 否
总是返回 View ✅ 是 ❌ 不一定
性能 最快 智能选择
推荐 性能关键路径 日常使用

4.6 性能影响

python 复制代码
import time

x = torch.randn(1000, 1000)
y = x.t()  # 不连续

# 测试访问速度
start = time.perf_counter()
_ = x.sum()
time_cont = time.perf_counter() - start

start = time.perf_counter()
_ = y.sum()
time_non_cont = time.perf_counter() - start

print(f"连续 Tensor:    {time_cont * 1000:.2f} ms")
print(f"不连续 Tensor:  {time_non_cont * 1000:.2f} ms")
print(f"慢了:          {time_non_cont / time_cont:.2f}x")
# 不连续的 Tensor 通常慢 1.5-3x(缓存不友好)

为什么慢?

  • 连续:内存访问是顺序的,CPU 缓存命中率高
  • 不连续:内存访问是跳跃的,缓存频繁失效

小结: View 操作的代价是可能产生不连续的 Tensor,某些操作需要先调用 contiguous() 复制数据。

现在你理解了 Tensor 的底层机制。接下来,让我们系统学习 PyTorch 提供的各种维度操作工具。


📐 Part 5: 掌握工具箱 - 各种维度操作

理解了 Stride 的原理后,所有维度操作都只是在改变投影规则。

5.1 Reshape 家族:改变形状

view() - 严格的零拷贝
python 复制代码
x = torch.arange(12)

# view 要求 Tensor 连续
y = x.view(3, 4)  # OK
y = x.view(2, 6)  # OK
y = x.view(-1)    # OK,-1 表示自动推断

# 元素总数必须匹配
try:
    y = x.view(3, 5)  # 错误!3×5=15 ≠ 12
except RuntimeError as e:
    print(e)
reshape() - 智能的形状变换
python 复制代码
x = torch.randn(3, 4)
y = x.transpose(0, 1)  # y 不连续

# reshape 会自动处理:
# - 如果可以,返回 View(零拷贝)
# - 如果不行,自动调用 contiguous()
z = y.reshape(12)  # OK,自动处理了不连续的情况
flatten() - 扁平化
python 复制代码
x = torch.randn(2, 3, 4)

# 从某个维度开始展平
y = x.flatten(start_dim=1)  # shape: (2, 12)
# 等价于:x.reshape(2, -1)

5.2 维度添加与删除

unsqueeze() - 添加大小为 1 的维度
python 复制代码
x = torch.randn(3, 4)  # shape: (3, 4)

y = x.unsqueeze(0)     # shape: (1, 3, 4)
z = x.unsqueeze(1)     # shape: (3, 1, 4)
w = x.unsqueeze(-1)    # shape: (3, 4, 1)

工程用途:

python 复制代码
# 场景 1:匹配批次维度
image = torch.randn(3, 224, 224)    # 单张图片 (C, H, W)
batch = image.unsqueeze(0)          # (1, C, H, W)

# 场景 2:为广播做准备
a = torch.randn(5)      # (5,)
b = torch.randn(3)      # (3,)

a = a.unsqueeze(1)      # (5, 1)
b = b.unsqueeze(0)      # (1, 3)
c = a * b               # (5, 3) - 广播乘法
squeeze() - 删除大小为 1 的维度
python 复制代码
x = torch.randn(1, 3, 1, 4)

y = x.squeeze()        # (3, 4) - 删除所有大小为 1 的维度
z = x.squeeze(0)       # (3, 1, 4) - 只删除第 0 维
w = x.squeeze(2)       # (1, 3, 4) - 只删除第 2 维

5.3 维度重排

transpose() - 交换两个维度
python 复制代码
x = torch.randn(2, 3, 4)

y = x.transpose(0, 1)   # shape: (3, 2, 4)
z = x.transpose(1, 2)   # shape: (2, 4, 3)

原理: 交换对应的 shape 和 stride

python 复制代码
x: shape=(2, 3, 4), stride=(12, 4, 1)
y: shape=(3, 2, 4), stride=(4, 12, 1)  # 交换了前两个
permute() - 任意维度重排
python 复制代码
x = torch.randn(2, 3, 4, 5)

# 重新排列所有维度
y = x.permute(3, 1, 0, 2)  # shape: (5, 3, 2, 4)

# 常见用法:(B, C, H, W) → (B, H, W, C)
images = torch.randn(32, 3, 224, 224)  # PyTorch 格式
images_hwc = images.permute(0, 2, 3, 1)  # (32, 224, 224, 3)

5.4 Broadcasting - 自动维度扩展

Broadcasting 是 PyTorch 最强大的特性之一。

Broadcasting 规则

从右向左对齐维度,然后:

  1. 如果两个维度大小相同 → OK
  2. 如果其中一个是 1 → 可以 broadcast
  3. 如果其中一个不存在 → 自动添加大小为 1 的维度,然后 broadcast
示例 1:矩阵 + 向量
python 复制代码
matrix = torch.tensor([[1, 2, 3],
                       [4, 5, 6]])  # (2, 3)
vector = torch.tensor([10, 20, 30]) # (3,)

result = matrix + vector  # (2, 3)

# 对齐过程:
#   matrix: (2, 3)
#   vector:    (3,) → 自动变为 (1, 3) → broadcast 到 (2, 3)
示例 2:复杂的 Broadcasting
python 复制代码
a = torch.randn(5, 1, 3)   # (5, 1, 3)
b = torch.randn(1, 4, 3)   # (1, 4, 3)

c = a + b  # shape: (5, 4, 3)

# 对齐过程:
#   a: (5, 1, 3)
#   b: (1, 4, 3)
#      ↓  ↓  ↓
#   c: (5, 4, 3)
expand() - 显式 Broadcasting
python 复制代码
x = torch.tensor([[1], [2], [3]])  # shape: (3, 1)

# 虚拟扩展(不复制数据)
y = x.expand(3, 4)  # shape: (3, 4)

print(y)
# tensor([[1, 1, 1, 1],
#         [2, 2, 2, 2],
#         [3, 3, 3, 3]])

# 验证:Stride 的第 1 维是 0
print(y.stride())  # (1, 0)

5.5 操作总结表

操作 改变的元数据 是否 View 说明
view(shape) shape, stride 要求连续
reshape(shape) shape, stride 通常 自动处理不连续
transpose(a, b) shape, stride 交换两维
permute(dims) shape, stride 重排所有维度
x[a:b] offset, shape, stride 切片
x[::step] stride 步进
unsqueeze(dim) shape, stride 插入维度
squeeze() shape, stride 删除大小为 1 的维度
expand(...) shape 虚拟扩展(stride 设为 0)
flatten() shape, stride 通常 展平

小结: 掌握这些维度操作,你就可以灵活地重塑 Tensor,为计算做准备。

现在你已经理解了 Tensor 的所有核心机制。接下来,让我们看看如何将 Tensor 组织成神经网络模型。


🧱 Part 6: 从 Tensor 到模型 - nn.Module 的设计

理解了 Tensor 之后,我们需要一个容器来组织它们,构建神经网络。

6.1 nn.Module 的基本结构

python 复制代码
import torch
import torch.nn as nn

class SimpleModule(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()  # 必须调用!

        # 可训练参数
        self.weight = nn.Parameter(torch.randn(input_dim, output_dim))
        self.bias = nn.Parameter(torch.zeros(output_dim))

        # 不可训练的状态(如 BatchNorm 的 running_mean)
        self.register_buffer('running_sum', torch.zeros(output_dim))

    def forward(self, x):
        output = torch.matmul(x, self.weight) + self.bias
        self.running_sum += output.sum(dim=0)
        return output

6.2 Parameter vs Buffer vs Tensor

类型 用途 是否可训练 是否保存 是否移动到 device
Parameter 模型参数
Buffer 模型状态
普通 Tensor 临时变量

实例:

python 复制代码
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(10, 5))
        self.register_buffer('mean', torch.zeros(5))
        self.scale = torch.tensor(2.0)

model = Model()

# 移动到 GPU
model = model.cuda()
print(model.weight.device)  # cuda:0
print(model.mean.device)    # cuda:0
print(model.scale.device)   # cpu ← 普通 Tensor 不会移动!

6.3 自动注册机制

python 复制代码
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        # 自动注册为子模块
        self.layer1 = nn.Linear(10, 20)
        self.layer2 = nn.Linear(20, 5)

model = MyModule()

# 自动递归收集所有参数
print(len(list(model.parameters())))  # 4 个(2 个 weight + 2 个 bias)

6.4 ModuleList vs Python List

python 复制代码
# ❌ 错误示例
class BadModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = [nn.Linear(10, 10) for _ in range(3)]

model = BadModule()
print(len(list(model.parameters())))  # 0 ← 参数丢失!

# ✅ 正确示例
class GoodModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(10, 10) for _ in range(3)
        ])

model = GoodModule()
print(len(list(model.parameters())))  # 6

小结: nn.Module 提供了一个优雅的容器,自动管理参数、状态和子模块。

现在,让我们通过实战练习巩固所有知识。


🎯 Part 7: 实战练习

实战 1:手写 MLP(多层感知机)

目标: 不使用 nn.Linear,只用 torch.matmul 实现两层 MLP。

python 复制代码
import torch
import torch.nn as nn

class ManualMLP(nn.Module):
    """手动实现的多层感知机"""

    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()

        # 第一层权重和偏置
        self.w1 = nn.Parameter(
            torch.randn(input_dim, hidden_dim) * (2.0 / input_dim) ** 0.5
        )
        self.b1 = nn.Parameter(torch.zeros(hidden_dim))

        # 第二层权重和偏置
        self.w2 = nn.Parameter(
            torch.randn(hidden_dim, output_dim) * (2.0 / hidden_dim) ** 0.5
        )
        self.b2 = nn.Parameter(torch.zeros(output_dim))

    def forward(self, x):
        """
        Args:
            x: shape (batch_size, input_dim)
        Returns:
            output: shape (batch_size, output_dim)
        """
        # 第一层:x @ W1 + b1
        hidden = torch.matmul(x, self.w1) + self.b1

        # 激活函数:ReLU
        hidden = torch.clamp(hidden, min=0)

        # 第二层:hidden @ W2 + b2
        output = torch.matmul(hidden, self.w2) + self.b2

        return output

# 测试
model = ManualMLP(784, 128, 10)
x = torch.randn(32, 784)  # 模拟 MNIST

output = model(x)
print(output.shape)  # torch.Size([32, 10])

实战 2:手写 Multi-Head Attention

目标: 理解 Transformer 的核心,掌握复杂的维度变换。

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class ManualMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Q、K、V 的投影矩阵
        self.w_q = nn.Parameter(torch.randn(embed_dim, embed_dim) * 0.02)
        self.w_k = nn.Parameter(torch.randn(embed_dim, embed_dim) * 0.02)
        self.w_v = nn.Parameter(torch.randn(embed_dim, embed_dim) * 0.02)
        self.w_o = nn.Parameter(torch.randn(embed_dim, embed_dim) * 0.02)

    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, embed_dim)
        Returns:
            output: (batch, seq_len, embed_dim)
        """
        batch_size, seq_len, embed_dim = x.shape

        # 步骤 1:线性投影
        Q = torch.matmul(x, self.w_q)  # (B, L, E)
        K = torch.matmul(x, self.w_k)
        V = torch.matmul(x, self.w_v)

        # 步骤 2:切分成多头
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
        Q = Q.transpose(1, 2)  # (B, H, L, D)

        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
        K = K.transpose(1, 2)

        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)
        V = V.transpose(1, 2)

        # 步骤 3:计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1))  # (B, H, L, L)
        scores = scores / (self.head_dim ** 0.5)

        # 步骤 4:Softmax
        attn_weights = F.softmax(scores, dim=-1)

        # 步骤 5:加权求和
        attn_output = torch.matmul(attn_weights, V)  # (B, H, L, D)

        # 步骤 6:合并多头
        attn_output = attn_output.transpose(1, 2)  # (B, L, H, D)
        attn_output = attn_output.contiguous()
        attn_output = attn_output.view(batch_size, seq_len, embed_dim)

        # 步骤 7:输出投影
        output = torch.matmul(attn_output, self.w_o)

        return output

# 测试
model = ManualMultiHeadAttention(embed_dim=512, num_heads=8)
x = torch.randn(2, 10, 512)

output = model(x)
print(output.shape)  # torch.Size([2, 10, 512])

关键维度变换可视化:

yaml 复制代码
输入:     (B, L, E)
   ↓ Linear projection
Q, K, V: (B, L, E)
   ↓ view + transpose
Q, K, V: (B, H, L, D)  其中 H=num_heads, D=E/H
   ↓ Attention 计算
Output:  (B, H, L, D)
   ↓ transpose + contiguous + view
Output:  (B, L, E)

实战 3:手写 LayerNorm

python 复制代码
def manual_layer_norm(x, eps=1e-5):
    """
    Args:
        x: shape (batch, ..., features)
    Returns:
        normalized: shape (batch, ..., features)
    """
    # 在最后一维计算统计量
    mean = x.mean(dim=-1, keepdim=True)  # (B, ..., 1)
    var = x.var(dim=-1, keepdim=True, unbiased=False)

    # 归一化
    x_normalized = (x - mean) / torch.sqrt(var + eps)

    return x_normalized

# 测试
x = torch.randn(32, 10, 512)
output = manual_layer_norm(x)

# 验证
print(output[0, 0].mean())  # 接近 0
print(output[0, 0].std())   # 接近 1

🎯 本章总结

核心要点回顾

  1. Tensor = Storage(物理)+ (Shape, Stride, Offset)(投影)

    • Storage 是一维数组,永远不变
    • 投影规则定义如何解释 Storage
    • 访问公式:T[i,j,k] = S[offset + i×stride[0] + j×stride[1] + k×stride[2]]
  2. 分离式设计的动机

    • 零拷贝操作(快 10000 倍)
    • 内存共享
    • 延迟物化
  3. View 操作的原理

    • 只改变投影规则,不移动数据
    • transpose 交换 stride
    • slice 改变 offset
    • expand 将 stride 设为 0
  4. Contiguous 的必要性

    • 某些操作(如 view)要求连续
    • contiguous() 复制数据以保证连续
    • 不连续的 Tensor 访问慢 1.5-3x
  5. 维度操作的统一理解

    • 所有操作都是在改变投影规则
    • 理解 Stride 就理解了一切
  6. nn.Module 的设计

    • Parameter:可训练参数
    • Buffer:模型状态
    • 自动递归管理子模块

工程最佳实践

  1. ✅ 优先使用 reshape() 而非 view()
  2. ✅ 利用 Broadcasting 代替循环
  3. ✅ 批量处理数据
  4. ✅ 使用 ModuleList 管理子模块
  5. ⚠️ 注意 transpose 后可能需要 contiguous()
  6. ⚠️ 避免在需要梯度的 Tensor 上 in-place 操作

下一步

在下一篇文章中,我们将深入自动微分引擎(Autograd),理解:

  • 计算图是如何动态构建的
  • backward() 的实现原理
  • 如何实现自定义的 autograd.Function

🎯 检查清单

在进入下一阶段前,请确保你能回答:

  • Storage、Shape、Stride 的关系是什么?
  • 为什么要分离 Storage 和 Shape?
  • transpose 如何通过改变 Stride 实现零拷贝?
  • 什么是 Contiguous?为什么需要它?
  • view 和 reshape 的区别?
  • Broadcasting 的规则是什么?
  • Parameter 和 Buffer 的区别?

如果你能清晰回答这些问题,恭喜你,你已经掌握了 PyTorch 的核心机制!


下一篇预告: 《阶段二:后向传播 - Autograd 自动微分引擎揭秘》


本文是 PyTorch 架构级学习系列的第 2 篇。

相关推荐
CoderLiu2 小时前
Agent 沙箱架构深度解析:从 Pattern 选型到生产级框架设计
前端·人工智能·后端
神奇小汤圆2 小时前
Java内存模型(JMM)与 volatile 底层实现全解析
后端
宸津-代码粉碎机2 小时前
SpringBoot 任务执行链路追踪实战:TraceID 透传全解析,实现从调度到执行的全链路可观测
开发语言·人工智能·spring boot·后端·python
FelixBitSoul2 小时前
拒对着 Docker 进度条发呆:深度优化 AI 应用的构建与模型加载
后端
IT_陈寒2 小时前
SpringBoot项目启动速度提升300%?这5个隐藏配置太关键了!
前端·人工智能·后端
DigitalOcean2 小时前
DigitalOcean 亮相 NVIDIA GTC 2026:为智能体时代打造 AI 工厂
aigc
段小二2 小时前
为什么 Claude 不用 RAG?——理解 RAG 的真实边界,再用 Spring AI 落地三种架构(Java 架构师的 AI 工程笔记 06)
后端
Mr.45672 小时前
Spring Boot 3 + EasyExcel 3.x 实战:构建高效、可靠的Excel导入导出服务
spring boot·后端·excel