目录
[🧩 先给你一个总的路线图(你先不用背,只用看懂结构)](#🧩 先给你一个总的路线图(你先不用背,只用看懂结构))
[🎯 第一部分:形状变换类操作(reshape/view/squeeze/unsqueeze)](#🎯 第一部分:形状变换类操作(reshape/view/squeeze/unsqueeze))
[① reshape ------ 任意改变形状(必要时复制)](#① reshape —— 任意改变形状(必要时复制))
[② view ------ 更严格的 reshape(内存必须连续)](#② view —— 更严格的 reshape(内存必须连续))
[③ unsqueeze ------ 增加一个维度](#③ unsqueeze —— 增加一个维度)
[④ squeeze ------ 去掉大小为 1 的维度](#④ squeeze —— 去掉大小为 1 的维度)
[🚀 第二部分:广播扩展类操作](#🚀 第二部分:广播扩展类操作)
[🎯 1. expand ------ 广播的核心工具](#🎯 1. expand —— 广播的核心工具)
[🎯 1.1 expand 中的 -1 是什么意思?](#🎯 1.1 expand 中的 -1 是什么意思?)
[🎯 2. expand_as ------ 让 A 扩展成与 B 的 shape 相同](#🎯 2. expand_as —— 让 A 扩展成与 B 的 shape 相同)
[🎯 3. broadcast_to ------ 与 NumPy 同名,效果同 expand](#🎯 3. broadcast_to —— 与 NumPy 同名,效果同 expand)
[🎯 4. 为什么广播要配合 unsqueeze 一起使用?](#🎯 4. 为什么广播要配合 unsqueeze 一起使用?)
[✔ 小测试(重要)](#✔ 小测试(重要))
[【判断 1】](#【判断 1】)
[【判断 2】](#【判断 2】)
[🌌 第三部分:维度重排(transpose / permute)](#🌌 第三部分:维度重排(transpose / permute))
[Part 3-1:transpose 基础直觉](#Part 3-1:transpose 基础直觉)
[📌案例 1:交换 seq 和 hidden](#📌案例 1:交换 seq 和 hidden)
[Part 3-2:permute ------ 彻底自由](#Part 3-2:permute —— 彻底自由)
[Part 3-3:真实应用案例(非常关键)](#Part 3-3:真实应用案例(非常关键))
[Step 1:线性层后维度不变](#Step 1:线性层后维度不变)
[Step 2:reshape 把 hidden 拆成 (num_heads, head_dim)](#Step 2:reshape 把 hidden 拆成 (num_heads, head_dim))
[Step 3:permute,让头排到 batch 后面](#Step 3:permute,让头排到 batch 后面)
[Part 3-4:一个直觉测试(你试试看)](#Part 3-4:一个直觉测试(你试试看))
[Part 3-5:mT:矩阵转置(专门给二维矩阵用的)](#Part 3-5:mT:矩阵转置(专门给二维矩阵用的))
[Part 3-6:mH:矩阵共轭转置(矩阵 Hermitian transpose)](#Part 3-6:mH:矩阵共轭转置(矩阵 Hermitian transpose))
[🔥 小练习(你来做)](#🔥 小练习(你来做))
[🌟 第四部分:拼接与复制(cat、stack、repeat、tile)](#🌟 第四部分:拼接与复制(cat、stack、repeat、tile))
[🔹 1. cat:按某个维度拼接(不增加新维度)](#🔹 1. cat:按某个维度拼接(不增加新维度))
[🔹 2. stack:把多个张量叠成一个"新维度"](#🔹 2. stack:把多个张量叠成一个“新维度”)
[🔹 3. repeat:把张量按次数重复(复制数据)](#🔹 3. repeat:把张量按次数重复(复制数据))
[🔹 4. tile:与 repeat 类似,但更 numpy 风格](#🔹 4. tile:与 repeat 类似,但更 numpy 风格)
[🎯 小练习(由你来做)](#🎯 小练习(由你来做))
[1、flatten ------ 把多维压成一维](#1、flatten —— 把多维压成一维)
[2、unflatten ------ 和 flatten 反向,给一维数据重新分块](#2、unflatten —— 和 flatten 反向,给一维数据重新分块)
[3、swapdims / swapaxes ------ 交换两个维度](#3、swapdims / swapaxes —— 交换两个维度)
[4、moveaxis ------ 一次把一个维度移动到目标位置](#4、moveaxis —— 一次把一个维度移动到目标位置)
🧩 先给你一个总的路线图(你先不用背,只用看懂结构)
PyTorch 的维度变形大致分成 5 类:
-
形状变换 (不改变元素顺序)
reshape、view、squeeze、unsqueeze
-
虚拟扩展(广播)
expand、expand_as、broadcast_to
-
维度重排(换顺序)
transpose、permute、mT、mH
-
拼接与切片(加减维度)
cat、stack、repeat、tile
-
高级手段(不常用但很有用)
flatten、unflatten、swapdims、moveaxis
为了不把你淹没,我们先从第一个类别:形状变换讲起。等你确认后,我们再继续往下推进。
🎯 第一部分:形状变换类操作(reshape/view/squeeze/unsqueeze)
这些操作的特点是:
不改变数据内容,只改变维度结构。
你来感受一下每个的"直觉"。
① reshape ------ 任意改变形状(必要时复制)
直觉:把一根绳子卷成不同形状,但绳子顺序不变。
例子:
python
x = torch.arange(12) # [12]
y = x.reshape(3, 4) # [3, 4]
问题给你想想:
如果 reshape 的新形状里有一个 -1,它的意思是什么?(-1表示的自动推断该维度大小,使总元素数量不变)
② view ------ 更严格的 reshape(内存必须连续)
直觉:
reshape 是万能插座,
view 是廉价插座,有要求才能用。(所以日常场景中更推荐使用reshape)
python
x = torch.arange(12)
y = x.view(3, 4)
问题给你判断:
下面这种能不能用 view?
python
x = torch.arange(12).reshape(3, 4).transpose(0, 1)
y = x.view(12)
思考下原因,我再等你回答。(答案是不行,因为交换完维度之后数值不再连续故不能使用view)
③ unsqueeze ------ 增加一个维度
直觉:
你把一排数字装进一个盒子里 → 维度 +1。
x = torch.tensor([1, 2, 3]) # [3]
y = x.unsqueeze(0) # [1, 3]
z = x.unsqueeze(1) # [3, 1]
你来判断一个实际例子:
假设 tokens.shape = [batch, seq],
要把它变成 [batch, seq, 1],你应该用哪一个 unsqueeze?(答案是unsqueeze(2))
④ squeeze ------ 去掉大小为 1 的维度
直觉:
把空盒子压扁。
x = torch.randn(1, 3, 1, 5)
y = x.squeeze() # 去掉所有为1的维度 → [3, 5]
问题给你:
如果我只想去掉第 2 个维度(也就是 size=1 的那个),应该怎么写?(答案是squeeze(2))
🚀 第二部分:广播扩展类操作
操作包括:
-
expand -
expand_as -
broadcast_to(PyTorch 2.0+) -
相关的
unsqueeze(常与 expand 组合)
这些操作有两个核心特征:
-
不会复制数据,只创建"伪扩展"视图
-
只能扩张 size=1 的维度
简单说:
你不能凭空把 3 扩张成 5,但你可以把 1 扩成任何数字。
🎯 1. expand ------ 广播的核心工具
语法:
python
x.expand(new_size_1, new_size_2, ...)
规则:
-
原维度是 1 → 可以扩成任意数
-
原维度不是 1 → 只能保持不变(写 -1)
例子:
python
x = torch.zeros(2, 1) # [2, 1]
y = x.expand(2, 3) # [2, 3]
或者写为
y = x.expand(-1,3)
这里 没有 复制 3 个数据,而是让 [1] 伪装成 [3]。
🎯 1.1 expand 中的 -1 是什么意思?
你的基础已经很稳了,这里加一点:
-1 表示"保持原维度不变"。
例子:
x = torch.zeros(4, 1, 5) # [4, 1, 5]
y = x.expand(-1, 3, -1) # [4, 3, 5]
🎯 2. expand_as ------ 让 A 扩展成与 B 的 shape 相同
例子:
mask = torch.zeros(2, 1, 1)
target = torch.randn(2, 3, 5)
mask2 = mask.expand_as(target) # shape → [2, 3, 5]
这是很多注意力mask的常用写法。
🎯 3. broadcast_to ------ 与 NumPy 同名,效果同 expand
等价于:
x.broadcast_to((...))
例子:
x = torch.tensor([1,2,3]) # [3]
y = torch.broadcast_to(x, (2,3)) # [2,3]
它比 expand 更直观,但需要 PyTorch 新版本。
🎯 4. 为什么广播要配合 unsqueeze 一起使用?
因为广播只能扩张 size = 1 的维度。
但很多时候你需要先造出这个维度,让 expand 能执行。
例子:
你有 [batch, seq],需要广播成 [batch, seq, seq] 做 attention mask。
只靠 expand 不够。
步骤是:
mask = (seq == pad)
mask = mask.unsqueeze(1) # [batch, 1, seq]
mask = mask.expand(-1, seq, -1) # [batch, seq, seq]
注意:
如果不 unsqueeze,多出来的 seq 维度根本没法扩展。
✔ 小测试(重要)
你来判断这两行各会输出什么形状:
【判断 1】
python
x = torch.randn(2, 3, 1)
y = x.expand(2, 3, 5)
y.shape = ?
答案:(2,3,5)
【判断 2】
python
x = torch.randn(2, 4)
y = x.unsqueeze(1).expand(-1, 4, -1)
y.shape = ?
答案:(2,4,4)
那我们顺着前面的劲头,继续闯进 第三部分:维度重排(transpose / permute) 。
这玩意儿看似只是"换个顺序",但在深度学习里,它的地位几乎等于"维度战场的传送门"。Transformer、BERT、CNN,多头注意力......全都离不开它。
我会继续保持我们之前那种"边讲边练"的方式,让你的直觉一步步变得硬核。
🌌 第三部分:维度重排(transpose / permute)
在 PyTorch 里,主要有两个手段:
-
tensor.transpose(dim0, dim1):交换两个维度
-
tensor.permute(dims):根据你给出的列表完全重新排列全部维度
你可以把它们理解成:
-
transpose = 小规模换座位
-
permute = 整个班级重新排布
Part 3-1:transpose 基础直觉
假设我们有:
x = torch.randn(2, 3, 5) # (batch, seq_len, hidden_size)
你知道这个东西在 NLP 中几乎到处都是。
📌案例 1:交换 seq 和 hidden
y = x.transpose(1, 2)
print(y.shape)
维度变化:
(2, 3 , 5 ) → (2, 5 , 3)
维度 1 和 2 换了个位置。
你可以脑补成"把句子由 3 个时间步 × 每步 5 个特征"变成"5 个特征 × 每个特征有 3 个时间步"。
Part 3-2:permute ------ 彻底自由
permute 接受一个"新顺序",你完全控制每个维度的新位置。
x = torch.randn(2, 3, 5)
y = x.permute(2, 0, 1)
print(y.shape)
变化:
(2, 3, 5)
→ (5, 2, 3)
你甚至可以做到 transpose 做不到的三维以上的全面换位。
Transformer 里的 Q、K、V 分头,就是用 permute 完成的。
Part 3-3:真实应用案例(非常关键)
🚀案例:多头注意力的标准维度变化
你一定见过这种张量:
x = (batch, seq_len, hidden_size)
要变成多头注意力输入:
(batch, num_heads, seq_len, head_dim)
假设:
batch=2, seq=3, hidden=12, num_heads=4 -> head_dim=3
我们来一步步变换:
Step 1:线性层后维度不变
x -> (2, 3, 12)
Step 2:reshape 把 hidden 拆成 (num_heads, head_dim)
y = x.reshape(2, 3, 4, 3)
# (batch, seq_len, num_heads, head_dim)
Step 3:permute,让头排到 batch 后面
这是注意力的核心:
y = y.permute(0, 2, 1, 3)
print(y.shape)
结果:
(2, 4, 3, 3)
你已经能读懂这个维度了------
2 个 batch,每个 batch 有 4 个 head,每个 head 有 3 个 token,每个 token 的特征是 3 维。
Transformer 的"多头"的物理意义就在这一步显现。
Part 3-4:一个直觉测试(你试试看)
你一定已经对维度敏感多了,我来给你一个稍微有点花的练习,让你的直觉再上一层:
假设:
x = torch.randn(2, 4, 3, 3) # (batch, head, seq, dim)
我想要把它变回:
(2, 3, 12) # (batch, seq, hidden)
其中 hidden = 4(head) * 3(dim) = 12
想一想:
你需要先 permute 还是先 reshape?
顺序弄错,结果就完全不对。
答案是:步骤:
-
permute(0,2,1,3)
-
reshape 为 (2,3,-1)
这是 多头注意力把 4D 张量再折回 3D 时的典型套路。
Part 3-5:mT:矩阵转置(专门给二维矩阵用的)
你已经知道 permute 是"任意维度排列",transpose 是"交换两个维度"。
现在,我们扩展这两类操作的"矩阵专属版本":
-
mT
-
mH
这些是在矩阵或注意力结构中极常出现的。
当你有二维 Tensor:
python
x.shape = (M, N)
mT 相当于:
python
x.mT == x.transpose(-1, -2)
举例:
python
x: (4, 8)
x.mT → (8, 4)
重点:
-
只能用于二维(matrix)
-
是最常用最直观的"转置"
在注意力中,K 的转置经常写成:
python
K.mT
比 transpose(-1,-2) 更短更安全。
Part 3-6:mH:矩阵共轭转置(矩阵 Hermitian transpose)
多用于复数张量。
和 mT 很像,但是带复数共轭:
python
x.mH == x.transpose(-1, -2).conj()
如果张量是实数,那么 mH == mT。
注意力里一般不用它(一般不涉及复数)。
🔥 小练习(你来做)
保持我们的学习节奏,我给你一个小 tensor,你来写代码表达"我要做什么变换",我检查你是否真正理解。
假设:
python
x.shape = (batch=2, heads=4, seq=3, dim=5)
现在我想把 x 变成:
python
(batch=2, seq=3, heads=4, dim=5)
只问一句:
你会用 permute 怎么写?
答案是permute(0,2,1,3)
你会用 transpose怎么写?
答案是transpose(1,2)
好,那我们就按既定路线进入 第四部分:拼接与切片(cat、stack、repeat、tile) 。
保持学习节奏:我讲一点,你验证一点,不会一股脑塞给你。
🌟 第四部分:拼接与复制(cat、stack、repeat、tile)
这一部分的关键是:
它们都在"增加维度"或"扩展维度",但方式不一样。
为了不堆砌概念,我们一个个来。
🔹 1. cat:按某个维度拼接(不增加新维度)
cat 的本质是:
把两个张量沿着某个已有维度接在一起。
例子:
python
x: (2, 3)
y: (2, 3)
torch.cat([x, y], dim=0) → (4, 3)
torch.cat([x, y], dim=1) → (2, 6)
要点:
-
维度数不变
-
被拼接的维度必须对齐
🔹 2. stack:把多个张量叠成一个"新维度"
stack 的本质是:
一定会增加一个新维度。
例子:
python
x: (3,)
y: (3,)
torch.stack([x, y], dim=0) → (2, 3)
torch.stack([x, y], dim=1) → (3, 2)
和 cat 的区别:
-
cat:不加维度
-
stack:加一个新维度
🔹 3. repeat:把张量按次数重复(复制数据)
repeat 会直接复制数据:
x: (2, 3)
x.repeat(2, 1) → (4, 3)
x.repeat(1, 2) → (2, 6)
写 repeat(n1, n2, ...) 的意思是:
每个维度复制几次。
注意:
-
repeat 会真实复制数据(成本更高)
-
expand 只是广播,不复制
🔹 4. tile:与 repeat 类似,但更 numpy 风格
PyTorch 也有 tile,它等价于 repeat,但写法偏向 numpy:
x: (2, 3)
x.tile(2, 1) → (4, 3)
tile 是 repeat 的更灵活版本,本质一样。
🎯 小练习(由你来做)
给定:
a = torch.randn(2, 3)
b = torch.randn(2, 3)
我给你 3 个目标形状,你来选"应该用 cat / stack / repeat / tile 中哪个操作"。
只需要写方法名,不用写代码。
目标 1:
(4, 3)
答案是:torch.cat([a,b],dim=0)或者a.repeat(2,1) b.repeat(2,1)
目标 2:
(2, 2, 3)
答案是:torch.stack([a,b],dim=0)
目标 3:
(2, 3, 3)
答案是:torch.unsqueeze(-1).repeat(1,1, 3)
第五部分:高级张量操作(不常用但很有用)。
1、flatten ------ 把多维压成一维
你可以把它理解成把一块多层千层饼压成一条长面条。
简单记法:
flatten = 展平
示例:
python
import torch
x = torch.rand(2, 3, 4)
y = torch.flatten(x)
print(y.shape) # torch.Size([24])
你也可以只 flatten 部分维度,比如从第 1 维 flatten 到最后一维:
python
y = torch.flatten(x, start_dim=1)
print(y.shape) # torch.Size([2, 12])
使用场景:
做全连接层前的数据准备------几乎所有 CNN 分类器都会在最后 flatten。
2、unflatten ------ 和 flatten 反向,给一维数据重新分块
你可以把一根长面条重新折叠成带维度的形状。
示例:
python
x = torch.rand(10)
y = torch.unflatten(x, dim=0, sizes=(2, 5))
print(y.shape) # torch.Size([2, 5])
更自由的:
python
x = torch.rand(24)
y = torch.unflatten(x, dim=0, sizes=(2, 3, 4))
# torch.Size([2, 3, 4])
使用场景:
模型输出维度重建、Transformer 中按 head 分组、图像 patch 重建。
3、swapdims / swapaxes ------ 交换两个维度
像交换 Rubik's cube 的两条轴,把数据沿不同方向"旋转"。
示例:
python
x = torch.rand(2, 3, 4)
y = x.swapdims(1, 2) # 或 swapaxes
print(y.shape) # torch.Size([2, 4, 3])
使用场景:
序列维度、batch 维度、channel 维度之间互换。例如:
RNN 想要 (sequence, batch, features),CNN 想要 (batch, channel, height, width)。
4、moveaxis ------ 一次把一个维度移动到目标位置
swap 是"换位置",moveaxis 是"搬家"。
示例:
python
x = torch.rand(2, 3, 4)
y = torch.moveaxis(x, 0, -1)
print(y.shape) # torch.Size([3, 4, 2])
如果想把某几个维度移动也行:
python
x = torch.rand(2, 3, 4, 5)
y = torch.moveaxis(x, (0, 2), (2, 0))
print(y.shape) # torch.Size([4, 3, 2, 5])
使用场景:
图像数据格式转换(NHWC ↔ NCHW)、头维度移动、融合不同网络结构时调维度。
小总结(用一句话记住它们)
flatten:把饼压扁
unflatten:把面条折回来
swapdims:两轴互换
moveaxis:一轴搬家