深入理解 PyTorch 中的 unsqueeze 操作

深入理解 PyTorch 中的 unsqueeze 操作

一、引言

在 RoPE 旋转位置编码的预计算代码中,有这样两行:

python 复制代码
cos = cos.unsqueeze(0).unsqueeze(0)  # (L, d) → (1, 1, L, d)
sin = sin.unsqueeze(0).unsqueeze(0)  # (L, d) → (1, 1, L, d)

unsqueeze 是 PyTorch 中一个非常基础但极其重要的操作。本文将从基本概念到实际应用,彻底讲清楚它。


二、一句话定义

unsqueeze(dim) 在指定位置插入一个大小为 1 的新维度。

不改变任何数据 ,只改变张量的形状(shape)


三、基础用法

3.1 一维 → 二维

python 复制代码
import torch

x = torch.tensor([1, 2, 3])
print(x.shape)  # torch.Size([3])
python 复制代码
# 在第 0 维插入
a = x.unsqueeze(0)
print(a.shape)   # torch.Size([1, 3])
print(a)         # tensor([[1, 2, 3]])

# 在第 1 维插入
b = x.unsqueeze(1)
print(b.shape)   # torch.Size([3, 1])
print(b)         # tensor([[1],
                 #         [2],
                 #         [3]])

直觉理解:

复制代码
原始 x:          [1, 2, 3]           shape: (3,)

unsqueeze(0):    [[1, 2, 3]]         shape: (1, 3)   → 外面套了一层"行"
unsqueeze(1):    [[1], [2], [3]]     shape: (3, 1)   → 每个元素变成一"列"

3.2 二维 → 三维

python 复制代码
x = torch.tensor([[1, 2],
                   [3, 4]])
print(x.shape)  # torch.Size([2, 2])
python 复制代码
x.unsqueeze(0).shape  # torch.Size([1, 2, 2])  → 最前面加一维
x.unsqueeze(1).shape  # torch.Size([2, 1, 2])  → 中间加一维
x.unsqueeze(2).shape  # torch.Size([2, 2, 1])  → 最后面加一维

可视化维度插入位置:

复制代码
原始 shape: (2, 2)

dim=0 插入:  (①, 2, 2) → (1, 2, 2)
dim=1 插入:  (2, ①, 2) → (2, 1, 2)
dim=2 插入:  (2, 2, ①) → (2, 2, 1)

① 表示新插入的维度,大小为 1

3.3 负数索引

unsqueeze 也支持负数索引,-1 表示最后一个位置:

python 复制代码
x = torch.tensor([1, 2, 3])  # shape: (3,)

x.unsqueeze(-1).shape   # torch.Size([3, 1])   等价于 unsqueeze(1)
x.unsqueeze(-2).shape   # torch.Size([1, 3])   等价于 unsqueeze(0)

规则: 对于结果张量的维度,dim=-1 指最后一维,dim=-2 指倒数第二维,以此类推。


四、连续调用 unsqueeze

回到 RoPE 的代码:

python 复制代码
cos = cos.unsqueeze(0).unsqueeze(0)

逐步追踪形状变化:

复制代码
cos 初始形状:        (L, d)         例如 (4096, 64)
                       ↓
cos.unsqueeze(0):    (1, L, d)      例如 (1, 4096, 64)
                       ↓
.unsqueeze(0):       (1, 1, L, d)   例如 (1, 1, 4096, 64)

第一次 unsqueeze(0): 在最前面加一维
     ┌───┐
     │ 1 │  L    d
     └───┘
     ↑新增

第二次 unsqueeze(0): 再在最前面加一维
     ┌───┬───┐
     │ 1 │ 1 │  L    d
     └───┴───┘
     ↑新增

等价的一步写法:

python 复制代码
cos = cos.unsqueeze(0).unsqueeze(0)

# 等价于
cos = cos.reshape(1, 1, L, d)

# 也等价于
cos = cos[None, None, :, :]

五、为什么要 unsqueeze?------ 广播机制

5.1 问题背景

apply_rotary_emb 中:

python 复制代码
def apply_rotary_emb(x, cos, sin):
    # x 形状:   (B, H, L, d)   例如 (2, 32, 4096, 64)
    # cos 形状: (L, d)          例如 (4096, 64)  ← 只有2维!
  
    x1 = x[..., :d]
    y1 = x1 * cos   # ← 这里会报错或结果不对!

x1 是 4 维 (B, H, L, d)cos 是 2 维 (L, d)维度数量不匹配

虽然 PyTorch 的广播机制可以从右往左自动对齐,但这可能导致语义错误。显式 unsqueeze 更安全、更清晰。

5.2 广播机制回顾

PyTorch 广播的核心规则:

从最后一维开始向前对齐,每一维要么相等,要么其中一个为 1。大小为 1 的维度会被"复制"扩展。

复制代码
x1:  (B, H, L, d) = (2, 32, 4096, 64)
cos: (1, 1, L, d) = (1,  1, 4096, 64)   ← unsqueeze 后

对齐过程:

复制代码
维度:    dim0    dim1    dim2    dim3
x1:       2      32     4096     64
cos:      1       1     4096     64
          ↓       ↓       ↓       ↓
结果:     2      32     4096     64     ✅ 每维要么相等,要么为1

cos 的 dim0 和 dim1 大小为 1,会自动广播(复制)到 2 和 32:

复制代码
cos 广播过程:

(1, 1, 4096, 64)
     ↓ 沿 dim0 复制 2 份
(2, 1, 4096, 64)
     ↓ 沿 dim1 复制 32 份
(2, 32, 4096, 64)   ← 与 x1 形状一致,可以逐元素相乘

5.3 广播的物理意义

复制代码
x 的 4 个维度:  (B,    H,    L,    d)
                 ↑     ↑     ↑     ↑
               批次  注意力头  位置  维度组

cos 的含义: 每个"位置"的每个"维度组"有一个 cos 值
            与"批次"和"注意力头"无关 → 这两维设为 1,广播共享

所有 batch、所有 head 共享同一套 cos/sin,因为位置编码只取决于位置和维度,与具体内容无关。


六、如果不 unsqueeze 会怎样?

情况 1:PyTorch 自动广播(2维 vs 4维)

python 复制代码
cos_2d = cos  # shape: (4096, 64)
x1     = ...  # shape: (2, 32, 4096, 64)

result = x1 * cos_2d  # PyTorch 会自动从右对齐

PyTorch 的自动对齐过程:

复制代码
x1:      (2, 32, 4096, 64)
cos_2d:              (4096, 64)    ← 从右对齐
         ↓
等价于:  ( 1,  1, 4096, 64)        ← 自动在前面补 1

结果:在这个特定例子中碰巧是对的! 但这依赖隐式行为,不够清晰。

情况 2:隐式广播出错的例子

假设不小心把 cos 的形状搞成 (4096, 1)

python 复制代码
cos_wrong = cos[:, 0:1]  # shape: (4096, 1)
x1 = ...                 # shape: (2, 32, 4096, 64)

result = x1 * cos_wrong  # 不会报错!但语义完全错误
复制代码
x1:        (2, 32, 4096, 64)
cos_wrong:          (4096,  1)   ← 从右对齐
等价于:    (1,  1, 4096,  1)
广播结果:  (2, 32, 4096, 64)    ← 形状对了,但值全错了!

显式 unsqueeze 让代码意图一目了然,减少这类隐蔽的 bug。


七、unsqueeze vs 其他等价操作

方法 代码 结果形状
unsqueeze cos.unsqueeze(0).unsqueeze(0) (1,1,L,d)
reshape cos.reshape(1, 1, L, d) (1,1,L,d)
view cos.view(1, 1, L, d) (1,1,L,d)
None 索引 cos[None, None, :, :] (1,1,L,d)
expand_dims(NumPy风格) 不直接支持 ---
python 复制代码
# 这四种写法完全等价
cos_a = cos.unsqueeze(0).unsqueeze(0)
cos_b = cos.reshape(1, 1, *cos.shape)
cos_c = cos.view(1, 1, *cos.shape)
cos_d = cos[None, None, :, :]

# 验证
assert torch.equal(cos_a, cos_b)
assert torch.equal(cos_a, cos_c)
assert torch.equal(cos_a, cos_d)  # 全部相等

unsqueeze 的优势: 语义最清晰------"在某个位置插入一个维度",不需要知道其他维度的具体大小。


八、unsqueeze 的逆操作:squeeze

操作 功能 示例
unsqueeze(dim) 在 dim 处插入大小为 1 的维度 (L,d)(1,L,d)
squeeze(dim) 移除 dim 处大小为 1 的维度 (1,L,d)(L,d)
squeeze() 移除所有大小为 1 的维度 (1,1,L,d)(L,d)
python 复制代码
x = torch.zeros(1, 1, 4096, 64)

x.squeeze(0).shape     # (1, 4096, 64)   只移除 dim0
x.squeeze().shape      # (4096, 64)       移除所有大小为1的维度

九、unsqueeze 在深度学习中的常见场景

场景 1:单样本推理时添加 batch 维度

python 复制代码
# 模型期望输入: (B, C, H, W)
# 单张图片:     (C, H, W) = (3, 224, 224)

image = load_image()              # (3, 224, 224)
batch = image.unsqueeze(0)        # (1, 3, 224, 224)  ← 添加batch维
output = model(batch)

场景 2:计算两两距离矩阵

python 复制代码
# 计算 N 个点之间的距离
points = torch.randn(N, 3)       # N 个三维点

# 需要 (N, 1, 3) - (1, N, 3) → 广播得到 (N, N, 3)
diff = points.unsqueeze(1) - points.unsqueeze(0)
dist = diff.norm(dim=-1)          # (N, N) 距离矩阵

场景 3:通道注意力(SE-Net)

python 复制代码
# 特征图: (B, C, H, W)
# 通道权重: (B, C)

weights = channel_attention(x)    # (B, C)
weights = weights.unsqueeze(-1).unsqueeze(-1)  # (B, C, 1, 1)
x = x * weights                  # 广播到 (B, C, H, W)

场景 4:RoPE 位置编码(本文场景)

python 复制代码
cos = angles.cos()                              # (L, d)
cos = cos.unsqueeze(0).unsqueeze(0)             # (1, 1, L, d)
# 广播到 (B, H, L, d),与 Q/K 逐元素相乘

十、完整回顾 RoPE 中的 unsqueeze

python 复制代码
def precompute_freqs(dim, max_seq_len, base=10000):
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(max_seq_len).float()
    angles = torch.outer(t, freqs)    # (L, d)
  
    cos = angles.cos()                # (L, d)
    sin = angles.sin()                # (L, d)
  
    #          (L, d)
    #            ↓  unsqueeze(0): 插入 batch 维
    #         (1, L, d)
    #            ↓  unsqueeze(0): 插入 head 维
    #        (1, 1, L, d)
    cos = cos.unsqueeze(0).unsqueeze(0)
    sin = sin.unsqueeze(0).unsqueeze(0)
  
    return cos, sin

# 使用时:
# x:   (B, H, L, d) = (2, 32, 4096, 64)   ← Q 或 K
# cos: (1, 1, L, d) = (1,  1, 4096, 64)
# 
# x * cos → 广播机制自动将 cos 复制到每个 batch 和 head
#
# 含义:位置编码与"是哪个样本""是哪个头"无关
#       只与"在哪个位置""是哪个维度组"有关

十一、总结

要点 说明
功能 在指定位置插入一个大小为 1 的维度
数据 不改变任何数据,只改变形状
目的 为广播(broadcasting)做准备
语法 tensor.unsqueeze(dim)tensor[None]
逆操作 squeeze(dim) 移除大小为 1 的维度
RoPE 中 (L, d) 扩展为 (1, 1, L, d),使 cos/sin 能广播到 (B, H, L, d)

一句话总结:unsqueeze 不改变数据,只是给张量"升维",让它能通过广播机制与更高维度的张量进行运算------在 RoPE 中,它让位置编码能无缝应用到每个 batch 的每个注意力头上。

后记

2026年5月15日于上海,在claude opus 4.6辅助下完成。

相关推荐
FreakStudio1 小时前
硬件版【Cursor】?aily blockly IDE尝鲜封神,实战硬伤尽显
python·单片机·嵌入式·大学生·面向对象·并行计算·电子diy·电子计算机
秦歌6662 小时前
DeepAgents框架详解和文件后端
人工智能·langchain
测试员周周3 小时前
【Appium 系列】第06节-页面对象实现 — LoginPage 实战
开发语言·前端·人工智能·python·功能测试·appium·测试用例
霸道流氓气质3 小时前
基于 Milvus Lite 的 Spring AI RAG 向量库实践方案与示例
人工智能·spring·milvus
ar01233 小时前
AR巡检平台:构筑智能巡检新模式的数字化引擎
人工智能·ar
语音之家3 小时前
【预讲会征集】ACL 2026 论文预讲会
人工智能·论文·acl
碳基硅坊3 小时前
电商场景下的商品自动识别与辅助上架
人工智能
2301_783848654 小时前
优化文本分类中堆叠模型的网格搜索性能:避免训练卡顿的实战指南
jvm·数据库·python