深入理解 PyTorch 中的 unsqueeze 操作
一、引言
在 RoPE 旋转位置编码的预计算代码中,有这样两行:
python
cos = cos.unsqueeze(0).unsqueeze(0) # (L, d) → (1, 1, L, d)
sin = sin.unsqueeze(0).unsqueeze(0) # (L, d) → (1, 1, L, d)
unsqueeze 是 PyTorch 中一个非常基础但极其重要的操作。本文将从基本概念到实际应用,彻底讲清楚它。
二、一句话定义
unsqueeze(dim)在指定位置插入一个大小为 1 的新维度。
它不改变任何数据 ,只改变张量的形状(shape)。
三、基础用法
3.1 一维 → 二维
python
import torch
x = torch.tensor([1, 2, 3])
print(x.shape) # torch.Size([3])
python
# 在第 0 维插入
a = x.unsqueeze(0)
print(a.shape) # torch.Size([1, 3])
print(a) # tensor([[1, 2, 3]])
# 在第 1 维插入
b = x.unsqueeze(1)
print(b.shape) # torch.Size([3, 1])
print(b) # tensor([[1],
# [2],
# [3]])
直觉理解:
原始 x: [1, 2, 3] shape: (3,)
unsqueeze(0): [[1, 2, 3]] shape: (1, 3) → 外面套了一层"行"
unsqueeze(1): [[1], [2], [3]] shape: (3, 1) → 每个元素变成一"列"
3.2 二维 → 三维
python
x = torch.tensor([[1, 2],
[3, 4]])
print(x.shape) # torch.Size([2, 2])
python
x.unsqueeze(0).shape # torch.Size([1, 2, 2]) → 最前面加一维
x.unsqueeze(1).shape # torch.Size([2, 1, 2]) → 中间加一维
x.unsqueeze(2).shape # torch.Size([2, 2, 1]) → 最后面加一维
可视化维度插入位置:
原始 shape: (2, 2)
dim=0 插入: (①, 2, 2) → (1, 2, 2)
dim=1 插入: (2, ①, 2) → (2, 1, 2)
dim=2 插入: (2, 2, ①) → (2, 2, 1)
① 表示新插入的维度,大小为 1
3.3 负数索引
unsqueeze 也支持负数索引,-1 表示最后一个位置:
python
x = torch.tensor([1, 2, 3]) # shape: (3,)
x.unsqueeze(-1).shape # torch.Size([3, 1]) 等价于 unsqueeze(1)
x.unsqueeze(-2).shape # torch.Size([1, 3]) 等价于 unsqueeze(0)
规则: 对于结果张量的维度,dim=-1 指最后一维,dim=-2 指倒数第二维,以此类推。
四、连续调用 unsqueeze
回到 RoPE 的代码:
python
cos = cos.unsqueeze(0).unsqueeze(0)
逐步追踪形状变化:
cos 初始形状: (L, d) 例如 (4096, 64)
↓
cos.unsqueeze(0): (1, L, d) 例如 (1, 4096, 64)
↓
.unsqueeze(0): (1, 1, L, d) 例如 (1, 1, 4096, 64)
第一次 unsqueeze(0): 在最前面加一维
┌───┐
│ 1 │ L d
└───┘
↑新增
第二次 unsqueeze(0): 再在最前面加一维
┌───┬───┐
│ 1 │ 1 │ L d
└───┴───┘
↑新增
等价的一步写法:
python
cos = cos.unsqueeze(0).unsqueeze(0)
# 等价于
cos = cos.reshape(1, 1, L, d)
# 也等价于
cos = cos[None, None, :, :]
五、为什么要 unsqueeze?------ 广播机制
5.1 问题背景
在 apply_rotary_emb 中:
python
def apply_rotary_emb(x, cos, sin):
# x 形状: (B, H, L, d) 例如 (2, 32, 4096, 64)
# cos 形状: (L, d) 例如 (4096, 64) ← 只有2维!
x1 = x[..., :d]
y1 = x1 * cos # ← 这里会报错或结果不对!
x1 是 4 维 (B, H, L, d),cos 是 2 维 (L, d),维度数量不匹配。
虽然 PyTorch 的广播机制可以从右往左自动对齐,但这可能导致语义错误。显式 unsqueeze 更安全、更清晰。
5.2 广播机制回顾
PyTorch 广播的核心规则:
从最后一维开始向前对齐,每一维要么相等,要么其中一个为 1。大小为 1 的维度会被"复制"扩展。
x1: (B, H, L, d) = (2, 32, 4096, 64)
cos: (1, 1, L, d) = (1, 1, 4096, 64) ← unsqueeze 后
对齐过程:
维度: dim0 dim1 dim2 dim3
x1: 2 32 4096 64
cos: 1 1 4096 64
↓ ↓ ↓ ↓
结果: 2 32 4096 64 ✅ 每维要么相等,要么为1
cos 的 dim0 和 dim1 大小为 1,会自动广播(复制)到 2 和 32:
cos 广播过程:
(1, 1, 4096, 64)
↓ 沿 dim0 复制 2 份
(2, 1, 4096, 64)
↓ 沿 dim1 复制 32 份
(2, 32, 4096, 64) ← 与 x1 形状一致,可以逐元素相乘
5.3 广播的物理意义
x 的 4 个维度: (B, H, L, d)
↑ ↑ ↑ ↑
批次 注意力头 位置 维度组
cos 的含义: 每个"位置"的每个"维度组"有一个 cos 值
与"批次"和"注意力头"无关 → 这两维设为 1,广播共享
所有 batch、所有 head 共享同一套 cos/sin,因为位置编码只取决于位置和维度,与具体内容无关。
六、如果不 unsqueeze 会怎样?
情况 1:PyTorch 自动广播(2维 vs 4维)
python
cos_2d = cos # shape: (4096, 64)
x1 = ... # shape: (2, 32, 4096, 64)
result = x1 * cos_2d # PyTorch 会自动从右对齐
PyTorch 的自动对齐过程:
x1: (2, 32, 4096, 64)
cos_2d: (4096, 64) ← 从右对齐
↓
等价于: ( 1, 1, 4096, 64) ← 自动在前面补 1
结果:在这个特定例子中碰巧是对的! 但这依赖隐式行为,不够清晰。
情况 2:隐式广播出错的例子
假设不小心把 cos 的形状搞成 (4096, 1):
python
cos_wrong = cos[:, 0:1] # shape: (4096, 1)
x1 = ... # shape: (2, 32, 4096, 64)
result = x1 * cos_wrong # 不会报错!但语义完全错误
x1: (2, 32, 4096, 64)
cos_wrong: (4096, 1) ← 从右对齐
等价于: (1, 1, 4096, 1)
广播结果: (2, 32, 4096, 64) ← 形状对了,但值全错了!
显式 unsqueeze 让代码意图一目了然,减少这类隐蔽的 bug。
七、unsqueeze vs 其他等价操作
| 方法 | 代码 | 结果形状 |
|---|---|---|
unsqueeze |
cos.unsqueeze(0).unsqueeze(0) |
(1,1,L,d) |
reshape |
cos.reshape(1, 1, L, d) |
(1,1,L,d) |
view |
cos.view(1, 1, L, d) |
(1,1,L,d) |
None 索引 |
cos[None, None, :, :] |
(1,1,L,d) |
expand_dims(NumPy风格) |
不直接支持 | --- |
python
# 这四种写法完全等价
cos_a = cos.unsqueeze(0).unsqueeze(0)
cos_b = cos.reshape(1, 1, *cos.shape)
cos_c = cos.view(1, 1, *cos.shape)
cos_d = cos[None, None, :, :]
# 验证
assert torch.equal(cos_a, cos_b)
assert torch.equal(cos_a, cos_c)
assert torch.equal(cos_a, cos_d) # 全部相等
unsqueeze 的优势: 语义最清晰------"在某个位置插入一个维度",不需要知道其他维度的具体大小。
八、unsqueeze 的逆操作:squeeze
| 操作 | 功能 | 示例 |
|---|---|---|
unsqueeze(dim) |
在 dim 处插入大小为 1 的维度 | (L,d) → (1,L,d) |
squeeze(dim) |
移除 dim 处大小为 1 的维度 | (1,L,d) → (L,d) |
squeeze() |
移除所有大小为 1 的维度 | (1,1,L,d) → (L,d) |
python
x = torch.zeros(1, 1, 4096, 64)
x.squeeze(0).shape # (1, 4096, 64) 只移除 dim0
x.squeeze().shape # (4096, 64) 移除所有大小为1的维度
九、unsqueeze 在深度学习中的常见场景
场景 1:单样本推理时添加 batch 维度
python
# 模型期望输入: (B, C, H, W)
# 单张图片: (C, H, W) = (3, 224, 224)
image = load_image() # (3, 224, 224)
batch = image.unsqueeze(0) # (1, 3, 224, 224) ← 添加batch维
output = model(batch)
场景 2:计算两两距离矩阵
python
# 计算 N 个点之间的距离
points = torch.randn(N, 3) # N 个三维点
# 需要 (N, 1, 3) - (1, N, 3) → 广播得到 (N, N, 3)
diff = points.unsqueeze(1) - points.unsqueeze(0)
dist = diff.norm(dim=-1) # (N, N) 距离矩阵
场景 3:通道注意力(SE-Net)
python
# 特征图: (B, C, H, W)
# 通道权重: (B, C)
weights = channel_attention(x) # (B, C)
weights = weights.unsqueeze(-1).unsqueeze(-1) # (B, C, 1, 1)
x = x * weights # 广播到 (B, C, H, W)
场景 4:RoPE 位置编码(本文场景)
python
cos = angles.cos() # (L, d)
cos = cos.unsqueeze(0).unsqueeze(0) # (1, 1, L, d)
# 广播到 (B, H, L, d),与 Q/K 逐元素相乘
十、完整回顾 RoPE 中的 unsqueeze
python
def precompute_freqs(dim, max_seq_len, base=10000):
freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).float()
angles = torch.outer(t, freqs) # (L, d)
cos = angles.cos() # (L, d)
sin = angles.sin() # (L, d)
# (L, d)
# ↓ unsqueeze(0): 插入 batch 维
# (1, L, d)
# ↓ unsqueeze(0): 插入 head 维
# (1, 1, L, d)
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
return cos, sin
# 使用时:
# x: (B, H, L, d) = (2, 32, 4096, 64) ← Q 或 K
# cos: (1, 1, L, d) = (1, 1, 4096, 64)
#
# x * cos → 广播机制自动将 cos 复制到每个 batch 和 head
#
# 含义:位置编码与"是哪个样本""是哪个头"无关
# 只与"在哪个位置""是哪个维度组"有关
十一、总结
| 要点 | 说明 |
|---|---|
| 功能 | 在指定位置插入一个大小为 1 的维度 |
| 数据 | 不改变任何数据,只改变形状 |
| 目的 | 为广播(broadcasting)做准备 |
| 语法 | tensor.unsqueeze(dim) 或 tensor[None] |
| 逆操作 | squeeze(dim) 移除大小为 1 的维度 |
| RoPE 中 | 将 (L, d) 扩展为 (1, 1, L, d),使 cos/sin 能广播到 (B, H, L, d) |
一句话总结:
unsqueeze不改变数据,只是给张量"升维",让它能通过广播机制与更高维度的张量进行运算------在 RoPE 中,它让位置编码能无缝应用到每个 batch 的每个注意力头上。
后记
2026年5月15日于上海,在claude opus 4.6辅助下完成。