PyTorch 的维度变形一站式入门

目录

[🧩 先给你一个总的路线图(你先不用背,只用看懂结构)](#🧩 先给你一个总的路线图(你先不用背,只用看懂结构))

[🎯 第一部分:形状变换类操作(reshape/view/squeeze/unsqueeze)](#🎯 第一部分:形状变换类操作(reshape/view/squeeze/unsqueeze))

[① reshape ------ 任意改变形状(必要时复制)](#① reshape —— 任意改变形状(必要时复制))

[② view ------ 更严格的 reshape(内存必须连续)](#② view —— 更严格的 reshape(内存必须连续))

[③ unsqueeze ------ 增加一个维度](#③ unsqueeze —— 增加一个维度)

[④ squeeze ------ 去掉大小为 1 的维度](#④ squeeze —— 去掉大小为 1 的维度)

[🚀 第二部分:广播扩展类操作](#🚀 第二部分:广播扩展类操作)

[🎯 1. expand ------ 广播的核心工具](#🎯 1. expand —— 广播的核心工具)

[🎯 1.1 expand 中的 -1 是什么意思?](#🎯 1.1 expand 中的 -1 是什么意思?)

[🎯 2. expand_as ------ 让 A 扩展成与 B 的 shape 相同](#🎯 2. expand_as —— 让 A 扩展成与 B 的 shape 相同)

[🎯 3. broadcast_to ------ 与 NumPy 同名,效果同 expand](#🎯 3. broadcast_to —— 与 NumPy 同名,效果同 expand)

[🎯 4. 为什么广播要配合 unsqueeze 一起使用?](#🎯 4. 为什么广播要配合 unsqueeze 一起使用?)

[✔ 小测试(重要)](#✔ 小测试(重要))

[【判断 1】](#【判断 1】)

[【判断 2】](#【判断 2】)

[🌌 第三部分:维度重排(transpose / permute)](#🌌 第三部分:维度重排(transpose / permute))

[Part 3-1:transpose 基础直觉](#Part 3-1:transpose 基础直觉)

[📌案例 1:交换 seq 和 hidden](#📌案例 1:交换 seq 和 hidden)

[Part 3-2:permute ------ 彻底自由](#Part 3-2:permute —— 彻底自由)

[Part 3-3:真实应用案例(非常关键)](#Part 3-3:真实应用案例(非常关键))

🚀案例:多头注意力的标准维度变化

[Step 1:线性层后维度不变](#Step 1:线性层后维度不变)

[Step 2:reshape 把 hidden 拆成 (num_heads, head_dim)](#Step 2:reshape 把 hidden 拆成 (num_heads, head_dim))

[Step 3:permute,让头排到 batch 后面](#Step 3:permute,让头排到 batch 后面)

[Part 3-4:一个直觉测试(你试试看)](#Part 3-4:一个直觉测试(你试试看))

[Part 3-5:mT:矩阵转置(专门给二维矩阵用的)](#Part 3-5:mT:矩阵转置(专门给二维矩阵用的))

[Part 3-6:mH:矩阵共轭转置(矩阵 Hermitian transpose)](#Part 3-6:mH:矩阵共轭转置(矩阵 Hermitian transpose))

[🔥 小练习(你来做)](#🔥 小练习(你来做))

[🌟 第四部分:拼接与复制(cat、stack、repeat、tile)](#🌟 第四部分:拼接与复制(cat、stack、repeat、tile))

[🔹 1. cat:按某个维度拼接(不增加新维度)](#🔹 1. cat:按某个维度拼接(不增加新维度))

[🔹 2. stack:把多个张量叠成一个"新维度"](#🔹 2. stack:把多个张量叠成一个“新维度”)

[🔹 3. repeat:把张量按次数重复(复制数据)](#🔹 3. repeat:把张量按次数重复(复制数据))

[🔹 4. tile:与 repeat 类似,但更 numpy 风格](#🔹 4. tile:与 repeat 类似,但更 numpy 风格)

[🎯 小练习(由你来做)](#🎯 小练习(由你来做))

第五部分:高级张量操作(不常用但很有用)。

[1、flatten ------ 把多维压成一维](#1、flatten —— 把多维压成一维)

[2、unflatten ------ 和 flatten 反向,给一维数据重新分块](#2、unflatten —— 和 flatten 反向,给一维数据重新分块)

[3、swapdims / swapaxes ------ 交换两个维度](#3、swapdims / swapaxes —— 交换两个维度)

[4、moveaxis ------ 一次把一个维度移动到目标位置](#4、moveaxis —— 一次把一个维度移动到目标位置)

小总结(用一句话记住它们)


🧩 先给你一个总的路线图(你先不用背,只用看懂结构)

PyTorch 的维度变形大致分成 5 类:

  1. 形状变换 (不改变元素顺序)

    reshape、view、squeeze、unsqueeze

  2. 虚拟扩展(广播)

    expand、expand_as、broadcast_to

  3. 维度重排(换顺序)

    transpose、permute、mT、mH

  4. 拼接与切片(加减维度)

    cat、stack、repeat、tile

  5. 高级手段(不常用但很有用)

    flatten、unflatten、swapdims、moveaxis

为了不把你淹没,我们先从第一个类别:形状变换讲起。等你确认后,我们再继续往下推进。


🎯 第一部分:形状变换类操作(reshape/view/squeeze/unsqueeze)

这些操作的特点是:
不改变数据内容,只改变维度结构。

你来感受一下每个的"直觉"。


① reshape ------ 任意改变形状(必要时复制)

直觉:把一根绳子卷成不同形状,但绳子顺序不变。

例子:

python 复制代码
x = torch.arange(12)      # [12]
y = x.reshape(3, 4)       # [3, 4]

问题给你想想:

如果 reshape 的新形状里有一个 -1,它的意思是什么?(-1表示的自动推断该维度大小,使总元素数量不变)


② view ------ 更严格的 reshape(内存必须连续)

直觉:

reshape 是万能插座,

view 是廉价插座,有要求才能用。(所以日常场景中更推荐使用reshape)

python 复制代码
x = torch.arange(12)
y = x.view(3, 4)

问题给你判断:

下面这种能不能用 view?

python 复制代码
x = torch.arange(12).reshape(3, 4).transpose(0, 1)
y = x.view(12)

思考下原因,我再等你回答。(答案是不行,因为交换完维度之后数值不再连续故不能使用view)


③ unsqueeze ------ 增加一个维度

直觉:

你把一排数字装进一个盒子里 → 维度 +1。

复制代码
x = torch.tensor([1, 2, 3])  # [3]
y = x.unsqueeze(0)           # [1, 3]
z = x.unsqueeze(1)           # [3, 1]

你来判断一个实际例子:

假设 tokens.shape = [batch, seq]

要把它变成 [batch, seq, 1],你应该用哪一个 unsqueeze?(答案是unsqueeze(2))


④ squeeze ------ 去掉大小为 1 的维度

直觉:

把空盒子压扁。

复制代码
x = torch.randn(1, 3, 1, 5)
y = x.squeeze()     # 去掉所有为1的维度 → [3, 5]

问题给你:

如果我只想去掉第 2 个维度(也就是 size=1 的那个),应该怎么写?(答案是squeeze(2))


🚀 第二部分:广播扩展类操作

操作包括:

  • expand

  • expand_as

  • broadcast_to(PyTorch 2.0+)

  • 相关的 unsqueeze(常与 expand 组合)

这些操作有两个核心特征:

  1. 不会复制数据,只创建"伪扩展"视图

  2. 只能扩张 size=1 的维度

简单说:

你不能凭空把 3 扩张成 5,但你可以把 1 扩成任何数字。


🎯 1. expand ------ 广播的核心工具

语法:

python 复制代码
x.expand(new_size_1, new_size_2, ...)

规则:

  • 原维度是 1 → 可以扩成任意数

  • 原维度不是 1 → 只能保持不变(写 -1)

例子:

python 复制代码
x = torch.zeros(2, 1)      # [2, 1]
y = x.expand(2, 3)         # [2, 3]
或者写为
y = x.expand(-1,3)

这里 没有 复制 3 个数据,而是让 [1] 伪装成 [3]


🎯 1.1 expand 中的 -1 是什么意思?

你的基础已经很稳了,这里加一点:

-1 表示"保持原维度不变"

例子:

复制代码
x = torch.zeros(4, 1, 5)   # [4, 1, 5]
y = x.expand(-1, 3, -1)    # [4, 3, 5]

🎯 2. expand_as ------ 让 A 扩展成与 B 的 shape 相同

例子:

复制代码
mask = torch.zeros(2, 1, 1)
target = torch.randn(2, 3, 5)

mask2 = mask.expand_as(target)   # shape → [2, 3, 5]

这是很多注意力mask的常用写法。


🎯 3. broadcast_to ------ 与 NumPy 同名,效果同 expand

等价于:

复制代码
x.broadcast_to((...))

例子:

复制代码
x = torch.tensor([1,2,3])   # [3]
y = torch.broadcast_to(x, (2,3))   # [2,3]

它比 expand 更直观,但需要 PyTorch 新版本。


🎯 4. 为什么广播要配合 unsqueeze 一起使用?

因为广播只能扩张 size = 1 的维度。

但很多时候你需要先造出这个维度,让 expand 能执行。

例子:

你有 [batch, seq],需要广播成 [batch, seq, seq] 做 attention mask。

只靠 expand 不够。

步骤是:

复制代码
mask = (seq == pad)
mask = mask.unsqueeze(1)          # [batch, 1, seq]
mask = mask.expand(-1, seq, -1)   # [batch, seq, seq]

注意:

如果不 unsqueeze,多出来的 seq 维度根本没法扩展。


✔ 小测试(重要)

你来判断这两行各会输出什么形状:


【判断 1】

python 复制代码
x = torch.randn(2, 3, 1)
y = x.expand(2, 3, 5)
y.shape = ?

答案:(2,3,5)


【判断 2】

python 复制代码
x = torch.randn(2, 4)
y = x.unsqueeze(1).expand(-1, 4, -1)
y.shape = ?

答案:(2,4,4)


那我们顺着前面的劲头,继续闯进 第三部分:维度重排(transpose / permute)

这玩意儿看似只是"换个顺序",但在深度学习里,它的地位几乎等于"维度战场的传送门"。Transformer、BERT、CNN,多头注意力......全都离不开它。

我会继续保持我们之前那种"边讲边练"的方式,让你的直觉一步步变得硬核。


🌌 第三部分:维度重排(transpose / permute)

在 PyTorch 里,主要有两个手段:

  • tensor.transpose(dim0, dim1):交换两个维度

  • tensor.permute(dims):根据你给出的列表完全重新排列全部维度

你可以把它们理解成:

  • transpose = 小规模换座位

  • permute = 整个班级重新排布


Part 3-1:transpose 基础直觉

假设我们有:

复制代码
x = torch.randn(2, 3, 5)   # (batch, seq_len, hidden_size)

你知道这个东西在 NLP 中几乎到处都是。

📌案例 1:交换 seq 和 hidden

复制代码
y = x.transpose(1, 2)
print(y.shape)

维度变化:

(2, 3 , 5 ) → (2, 5 , 3)

维度 1 和 2 换了个位置。

你可以脑补成"把句子由 3 个时间步 × 每步 5 个特征"变成"5 个特征 × 每个特征有 3 个时间步"。


Part 3-2:permute ------ 彻底自由

permute 接受一个"新顺序",你完全控制每个维度的新位置。

复制代码
x = torch.randn(2, 3, 5)
y = x.permute(2, 0, 1)
print(y.shape)

变化:

(2, 3, 5)

→ (5, 2, 3)

你甚至可以做到 transpose 做不到的三维以上的全面换位。

Transformer 里的 Q、K、V 分头,就是用 permute 完成的。


Part 3-3:真实应用案例(非常关键)

🚀案例:多头注意力的标准维度变化

你一定见过这种张量:

复制代码
x = (batch, seq_len, hidden_size)

要变成多头注意力输入:

复制代码
(batch, num_heads, seq_len, head_dim)

假设:

复制代码
batch=2, seq=3, hidden=12, num_heads=4  -> head_dim=3

我们来一步步变换:


Step 1:线性层后维度不变

复制代码
x -> (2, 3, 12)

Step 2:reshape 把 hidden 拆成 (num_heads, head_dim)

复制代码
y = x.reshape(2, 3, 4, 3)
# (batch, seq_len, num_heads, head_dim)

Step 3:permute,让头排到 batch 后面

这是注意力的核心:

复制代码
y = y.permute(0, 2, 1, 3)
print(y.shape)

结果:

复制代码
(2, 4, 3, 3)

你已经能读懂这个维度了------

2 个 batch,每个 batch 有 4 个 head,每个 head 有 3 个 token,每个 token 的特征是 3 维。

Transformer 的"多头"的物理意义就在这一步显现。


Part 3-4:一个直觉测试(你试试看)

你一定已经对维度敏感多了,我来给你一个稍微有点花的练习,让你的直觉再上一层:

假设:

复制代码
x = torch.randn(2, 4, 3, 3)  # (batch, head, seq, dim)

我想要把它变回:

复制代码
(2, 3, 12)   # (batch, seq, hidden)

其中 hidden = 4(head) * 3(dim) = 12

想一想:

你需要先 permute 还是先 reshape?

顺序弄错,结果就完全不对。

答案是:步骤:

  1. permute(0,2,1,3)

  2. reshape 为 (2,3,-1)

这是 多头注意力把 4D 张量再折回 3D 时的典型套路

Part 3-5:mT:矩阵转置(专门给二维矩阵用的)

你已经知道 permute 是"任意维度排列",transpose 是"交换两个维度"。

现在,我们扩展这两类操作的"矩阵专属版本":

  • mT

  • mH

这些是在矩阵或注意力结构中极常出现的。

当你有二维 Tensor:

python 复制代码
x.shape = (M, N)

mT 相当于:

python 复制代码
x.mT == x.transpose(-1, -2)

举例:

python 复制代码
x: (4, 8)
x.mT → (8, 4)

重点:

  • 只能用于二维(matrix)

  • 是最常用最直观的"转置"

在注意力中,K 的转置经常写成:

python 复制代码
K.mT

比 transpose(-1,-2) 更短更安全。


Part 3-6:mH:矩阵共轭转置(矩阵 Hermitian transpose)

多用于复数张量。

和 mT 很像,但是带复数共轭:

python 复制代码
x.mH == x.transpose(-1, -2).conj()

如果张量是实数,那么 mH == mT。

注意力里一般不用它(一般不涉及复数)。


🔥 小练习(你来做)

保持我们的学习节奏,我给你一个小 tensor,你来写代码表达"我要做什么变换",我检查你是否真正理解。

假设:

python 复制代码
x.shape = (batch=2, heads=4, seq=3, dim=5)

现在我想把 x 变成:

python 复制代码
(batch=2, seq=3, heads=4, dim=5)

只问一句:

你会用 permute 怎么写?

答案是permute(0,2,1,3)

你会用 transpose怎么写?

答案是transpose(1,2)

好,那我们就按既定路线进入 第四部分:拼接与切片(cat、stack、repeat、tile)

保持学习节奏:我讲一点,你验证一点,不会一股脑塞给你。


🌟 第四部分:拼接与复制(cat、stack、repeat、tile)

这一部分的关键是:
它们都在"增加维度"或"扩展维度",但方式不一样。

为了不堆砌概念,我们一个个来。


🔹 1. cat:按某个维度拼接(不增加新维度)

cat 的本质是:
把两个张量沿着某个已有维度接在一起。

例子:

python 复制代码
x: (2, 3)
y: (2, 3)
torch.cat([x, y], dim=0) → (4, 3)
torch.cat([x, y], dim=1) → (2, 6)

要点:

  • 维度数不变

  • 被拼接的维度必须对齐


🔹 2. stack:把多个张量叠成一个"新维度"

stack 的本质是:
一定会增加一个新维度。

例子:

python 复制代码
x: (3,)
y: (3,)
torch.stack([x, y], dim=0) → (2, 3)
torch.stack([x, y], dim=1) → (3, 2)

和 cat 的区别:

  • cat:不加维度

  • stack:加一个新维度


🔹 3. repeat:把张量按次数重复(复制数据)

repeat 会直接复制数据:

复制代码
x: (2, 3)
x.repeat(2, 1) → (4, 3)
x.repeat(1, 2) → (2, 6)

写 repeat(n1, n2, ...) 的意思是:
每个维度复制几次。

注意:

  • repeat 会真实复制数据(成本更高)

  • expand 只是广播,不复制


🔹 4. tile:与 repeat 类似,但更 numpy 风格

PyTorch 也有 tile,它等价于 repeat,但写法偏向 numpy:

复制代码
x: (2, 3)
x.tile(2, 1) → (4, 3)

tile 是 repeat 的更灵活版本,本质一样。


🎯 小练习(由你来做)

给定:

复制代码
a = torch.randn(2, 3)
b = torch.randn(2, 3)

我给你 3 个目标形状,你来选"应该用 cat / stack / repeat / tile 中哪个操作"。

只需要写方法名,不用写代码。

目标 1:

复制代码
(4, 3)

答案是:torch.cat([a,b],dim=0)或者a.repeat(2,1) b.repeat(2,1)

目标 2:

复制代码
(2, 2, 3)

答案是:torch.stack([a,b],dim=0)

目标 3:

复制代码
(2, 3, 3)

答案是:torch.unsqueeze(-1).repeat(1,1, 3)

第五部分:高级张量操作(不常用但很有用)


1、flatten ------ 把多维压成一维

你可以把它理解成把一块多层千层饼压成一条长面条。

简单记法:
flatten = 展平

示例:

python 复制代码
import torch
x = torch.rand(2, 3, 4)
y = torch.flatten(x)
print(y.shape)   # torch.Size([24])

你也可以只 flatten 部分维度,比如从第 1 维 flatten 到最后一维:

python 复制代码
y = torch.flatten(x, start_dim=1)
print(y.shape)   # torch.Size([2, 12])

使用场景:
做全连接层前的数据准备------几乎所有 CNN 分类器都会在最后 flatten。


2、unflatten ------ 和 flatten 反向,给一维数据重新分块

你可以把一根长面条重新折叠成带维度的形状。

示例:

python 复制代码
x = torch.rand(10)
y = torch.unflatten(x, dim=0, sizes=(2, 5))
print(y.shape)   # torch.Size([2, 5])

更自由的:

python 复制代码
x = torch.rand(24)
y = torch.unflatten(x, dim=0, sizes=(2, 3, 4))
# torch.Size([2, 3, 4])

使用场景:

模型输出维度重建、Transformer 中按 head 分组、图像 patch 重建。


3、swapdims / swapaxes ------ 交换两个维度

像交换 Rubik's cube 的两条轴,把数据沿不同方向"旋转"。

示例:

python 复制代码
x = torch.rand(2, 3, 4)
y = x.swapdims(1, 2)   # 或 swapaxes
print(y.shape)   # torch.Size([2, 4, 3])

使用场景:

序列维度、batch 维度、channel 维度之间互换。例如:

RNN 想要 (sequence, batch, features),CNN 想要 (batch, channel, height, width)。


4、moveaxis ------ 一次把一个维度移动到目标位置

swap 是"换位置",moveaxis 是"搬家"。

示例:

python 复制代码
x = torch.rand(2, 3, 4)
y = torch.moveaxis(x, 0, -1)
print(y.shape)   # torch.Size([3, 4, 2])

如果想把某几个维度移动也行:

python 复制代码
x = torch.rand(2, 3, 4, 5)
y = torch.moveaxis(x, (0, 2), (2, 0))
print(y.shape)   # torch.Size([4, 3, 2, 5])

使用场景:

图像数据格式转换(NHWC ↔ NCHW)、头维度移动、融合不同网络结构时调维度。


小总结(用一句话记住它们)

flatten:把饼压扁

unflatten:把面条折回来

swapdims:两轴互换

moveaxis:一轴搬家

相关推荐
量子位2 小时前
Nano Banana新玩法无限套娃!“GPT-5都不会处理这种级别的递归”
人工智能·gpt
m0_650108242 小时前
PaLM:Pathways 驱动的大规模语言模型 scaling 实践
论文阅读·人工智能·palm·谷歌大模型·大规模语言模型·全面评估与行为分析·scaling效应
Ma0407132 小时前
【论文阅读19】-用于PHM的大型语言模型:优化技术与应用综述
人工智能·语言模型·自然语言处理
熊猫钓鱼>_>2 小时前
从零开始构建RPG游戏战斗系统:实战心得与技术要点
开发语言·人工智能·经验分享·python·游戏·ai·qoder
CSDN官方博客2 小时前
CSDN AI社区镜像创作者征集计划正式启动,参与即可获得奖励哦~
人工智能
iMG3 小时前
当自动驾驶技术遭遇【电车难题】,专利制度如何处理?
人工智能·科技·机器学习·自动驾驶·创业创新
BoBoZz193 小时前
TriangleStrip连续三角带
python·vtk·图形渲染·图形处理
生信大表哥3 小时前
Python单细胞分析-基于leiden算法的降维聚类
linux·python·算法·生信·数信院生信服务器·生信云服务器
swanwei3 小时前
2025年11月22-23日互联网技术热点TOP3及影响分析(AI增量训练框架开源)
网络·人工智能·程序人生·安全·百度