1. unfold 是什么
在 PyTorch 中,unfold 的作用是:沿着指定维度,用滑动窗口的方式,把数据一段一段地取出来。
它不会对数据做卷积、求和、池化等计算,只是负责把局部区域提取出来。
基本格式:
x.unfold(dimension, size, step)
参数含义:
dimension:在哪个维度上切
size:每个窗口的大小
step:每次滑动的步长
可以理解为:
沿着某个维度,
每次取 size 个元素,
然后移动 step 个位置,
继续取下一段。
2. 一维例子理解 unfold
假设有一个一维张量:
x = torch.tensor([1, 2, 3, 4, 5, 6])
它的形状是
y = x.unfold(0, 3, 1)
含义是:
在第 0 个维度上,
每次取 3 个数,
每次移动 1 个位置。
结果是:
tensor([
[1, 2, 3],
[2, 3, 4],
[3, 4, 5],
[4, 5, 6]
])
输出形状是:
[4, 3]
其中:
4:表示一共切出了 4 个窗口
3:表示每个窗口内部有 3 个元素
所以,原来的长度 6 被拆成了:
4 个窗口 × 每个窗口长度 3
3. size 和 step 的关系
size 表示窗口大小,step 表示滑动步长。
它们的关系决定了切出来的窗口是否重叠。
3.1 step < size:窗口之间有重叠
例如:
x.unfold(0, 3, 1)
表示窗口大小是 3,步长是 1。
窗口形式是:
[1, 2, 3]
[2, 3, 4]
[3, 4, 5]
[4, 5, 6]
这种方式常用于滑动窗口特征提取。
3.2 step == size:窗口之间不重叠
例如:
x.unfold(0, 3, 3)
结果是:
tensor([
[1, 2, 3],
[4, 5, 6]
])
窗口形式是:
[1, 2, 3] [4, 5, 6]
这种方式就是无重叠切分。
在图像切 patch 时经常这样用:
x.unfold(1, patch_size, patch_size)
因为窗口大小和步长相等,所以每个 patch 之间不重叠。
3.3 step > size:窗口之间有间隔
例如:
x.unfold(0, 2, 3)
结果是:
tensor([
[1, 2],
[4, 5]
])
窗口形式是:
[1, 2] [4, 5]
中间的元素 3 被跳过了。
这种情况说明步长太大,会导致部分数据没有被使用。
4. unfold 能切出多少个窗口
假设原始长度是:
L
窗口大小是:
size
步长是:
step
那么切出来的窗口数量是:
floor((L - size) / step) + 1
例如:
L = 6
size = 3
step = 1
窗口数量是:
floor((6 - 3) / 1) + 1 = 4
所以可以切出 4 个窗口。
如果:
L = 6
size = 3
step = 2
窗口数量是:
floor((6 - 3) / 2) + 1 = 2
所以只能切出 2 个窗口。
注意:unfold 不会自动补边。如果剩余的数据不够一个完整窗口,就会被舍弃。
5. unfold 为什么会多出一维
unfold 的本质是把原来的某个维度拆成两个维度:
原来的维度
↓
窗口数量维度 + 窗口内部大小维度
例如原来有一个长度为 6 的维度
执行:
x.unfold(0, 3, 1)
之后变成:
[4, 3]
其中:
4:窗口数量
3:窗口内部大小
也就是说:
长度 6
被拆成:
4 个窗口,每个窗口长度 3
6. 为什么多出来的一维在最后
PyTorch 的 Tensor.unfold() 设计规则是:
窗口数量维度留在原来的位置;
窗口内部大小维度追加到最后。
也就是说,对某个维度做 unfold 后:
原来的这个维度位置,用来表示切出了多少个窗口;
新增的最后一维,用来表示每个窗口里面有多少个元素。
可以简单记成:
原位置放"窗口个数",最后一维放"窗口大小"。
7. 图像中的 unfold
图像张量通常是:
[B, C, H, W]
含义是:
B:batch size,图片数量
C:通道数
H:图像高度
W:图像宽度
如果要把图像切成 patch,就需要沿着高度和宽度两个方向分别切分。
假设输入图像是:
x.shape = [B, C, H, W]
先执行:
x = x.permute(0, 2, 3, 1)
形状从:
[B, C, H, W]
变成:
[B, H, W, C]
这样做是为了把高度 H 和宽度 W 放到中间,方便使用 unfold 沿高度和宽度方向切块。
8. 第一次 unfold:沿高度方向切
代码:
x = x.unfold(1, patch_size, patch_size)
此时张量形状是:
[B, H, W, C]
第 1 个维度是高度 H,所以这一步是在高度方向切块。
如果:
H = 32
patch_size = 16
那么高度方向可以切出:
32 / 16 = 2
个窗口。
形状变化为:
[B, H, W, C]
↓
[B, H_num, W, C, patch_H]
例如:
[B, 32, 32, C]
↓
[B, 2, 32, C, 16]
其中:
H_num = 2:高度方向切出了 2 个窗口
patch_H = 16:每个窗口内部高度为 16
9. 第二次 unfold:沿宽度方向切
继续执行:
x = x.unfold(2, patch_size, patch_size)
此时张量形状大致是:
[B, H_num, W, C, patch_H]
第 2 个维度是宽度 W,所以这一步是在宽度方向切块。
如果:
W = 32
patch_size = 16
宽度方向也可以切出:
32 / 16 = 2
个窗口。
形状变化为:
[B, H_num, W, C, patch_H]
↓
[B, H_num, W_num, C, patch_H, patch_W]
例如:
[B, 2, 32, C, 16]
↓
[B, 2, 2, C, 16, 16]
其中:
H_num:高度方向 patch 数量
W_num:宽度方向 patch 数量
C:通道数
patch_H:每个 patch 的高度
patch_W:每个 patch 的宽度
10. 完整图像切 patch 的维度变化
以这段代码为例:
x = x.permute(0, 2, 3, 1).unfold(1, self.patch_size, self.patch_size)\
.unfold(2, self.patch_size, self.patch_size).contiguous()\
.view(B, -1, C, self.patch_size, self.patch_size)
可以拆开理解:
x = x.permute(0, 2, 3, 1)
形状变化:
[B, C, H, W] → [B, H, W, C]
然后:
x = x.unfold(1, patch_size, patch_size)
形状变化:
[B, H, W, C] → [B, H_num, W, C, patch_H]
再然后:
x = x.unfold(2, patch_size, patch_size)
形状变化:
[B, H_num, W, C, patch_H] → [B, H_num, W_num, C, patch_H, patch_W]
最后:
x = x.contiguous().view(B, -1, C, patch_size, patch_size)
形状变化:
[B, H_num, W_num, C, patch_H, patch_W]
↓
[B, N, C, patch_size, patch_size]
其中:
N = H_num × W_num
也就是每张图片被切出来的 patch 数量。
11. 具体例子:32×32 图片切成 16×16 patch
假设输入:
x.shape = [1, 3, 32, 32]
表示:
1 张图片
3 个通道
图片大小为 32 × 32
设置:
patch_size = 16
第一步:
x = x.permute(0, 2, 3, 1)
形状变为:
[1, 32, 32, 3]
第二步:
x = x.unfold(1, 16, 16)
沿高度方向切。
高度 32,每块 16,步长 16,所以高度方向切出 2 个窗口。
形状变为:
[1, 2, 32, 3, 16]
第三步:
x = x.unfold(2, 16, 16)
沿宽度方向切。
宽度 32,每块 16,步长 16,所以宽度方向切出 2 个窗口。
形状变为:
[1, 2, 2, 3, 16, 16]
第四步:
x = x.contiguous().view(1, -1, 3, 16, 16)
形状变为:
[1, 4, 3, 16, 16]
因为:
4 = 2 × 2
表示这张图片被切成了 4 个 patch。
12. 具体例子:224×224 图片切成 16×16 patch
假设输入:
x.shape = [B, 3, 224, 224]
设置:
patch_size = 16
高度方向:
224 / 16 = 14
宽度方向:
224 / 16 = 14
所以每张图片可以切出:
14 × 14 = 196
个 patch。
最终输出形状是:
[B, 196, 3, 16, 16]
含义是:
B:图片数量
196:每张图片的 patch 数量
3:每个 patch 的通道数
16:每个 patch 的高度
16:每个 patch 的宽度