PyTorch 中 unfold 的理解笔记

1. unfold 是什么

在 PyTorch 中,unfold 的作用是:沿着指定维度,用滑动窗口的方式,把数据一段一段地取出来。

它不会对数据做卷积、求和、池化等计算,只是负责把局部区域提取出来。

基本格式:

复制代码
x.unfold(dimension, size, step)

参数含义:

复制代码
dimension:在哪个维度上切
size:每个窗口的大小
step:每次滑动的步长

可以理解为:

复制代码
沿着某个维度,
每次取 size 个元素,
然后移动 step 个位置,
继续取下一段。

2. 一维例子理解 unfold

假设有一个一维张量:

复制代码
x = torch.tensor([1, 2, 3, 4, 5, 6])

它的形状是

复制代码
y = x.unfold(0, 3, 1)

含义是:

复制代码
在第 0 个维度上,
每次取 3 个数,
每次移动 1 个位置。

结果是:

复制代码
tensor([
    [1, 2, 3],
    [2, 3, 4],
    [3, 4, 5],
    [4, 5, 6]
])

输出形状是:

复制代码
[4, 3]

其中:

复制代码
4:表示一共切出了 4 个窗口
3:表示每个窗口内部有 3 个元素

所以,原来的长度 6 被拆成了:

复制代码
4 个窗口 × 每个窗口长度 3

3. size 和 step 的关系

size 表示窗口大小,step 表示滑动步长。

它们的关系决定了切出来的窗口是否重叠。

3.1 step < size:窗口之间有重叠

例如:

复制代码
x.unfold(0, 3, 1)

表示窗口大小是 3,步长是 1。

窗口形式是:

复制代码
[1, 2, 3]
   [2, 3, 4]
      [3, 4, 5]
         [4, 5, 6]

这种方式常用于滑动窗口特征提取。


3.2 step == size:窗口之间不重叠

例如:

复制代码
x.unfold(0, 3, 3)

结果是:

复制代码
tensor([
    [1, 2, 3],
    [4, 5, 6]
])

窗口形式是:

复制代码
[1, 2, 3] [4, 5, 6]

这种方式就是无重叠切分。

在图像切 patch 时经常这样用:

复制代码
x.unfold(1, patch_size, patch_size)

因为窗口大小和步长相等,所以每个 patch 之间不重叠。


3.3 step > size:窗口之间有间隔

例如:

复制代码
x.unfold(0, 2, 3)

结果是:

复制代码
tensor([
    [1, 2],
    [4, 5]
])

窗口形式是:

复制代码
[1, 2]    [4, 5]

中间的元素 3 被跳过了。

这种情况说明步长太大,会导致部分数据没有被使用。


4. unfold 能切出多少个窗口

假设原始长度是:

复制代码
L

窗口大小是:

复制代码
size

步长是:

复制代码
step

那么切出来的窗口数量是:

复制代码
floor((L - size) / step) + 1

例如:

复制代码
L = 6
size = 3
step = 1

窗口数量是:

复制代码
floor((6 - 3) / 1) + 1 = 4

所以可以切出 4 个窗口。

如果:

复制代码
L = 6
size = 3
step = 2

窗口数量是:

复制代码
floor((6 - 3) / 2) + 1 = 2

所以只能切出 2 个窗口。

注意:unfold 不会自动补边。如果剩余的数据不够一个完整窗口,就会被舍弃。


5. unfold 为什么会多出一维

unfold 的本质是把原来的某个维度拆成两个维度:

复制代码
原来的维度
↓
窗口数量维度 + 窗口内部大小维度

例如原来有一个长度为 6 的维度

执行:

复制代码
x.unfold(0, 3, 1)

之后变成:

复制代码
[4, 3]

其中:

复制代码
4:窗口数量
3:窗口内部大小

也就是说:

复制代码
长度 6
被拆成:
4 个窗口,每个窗口长度 3

6. 为什么多出来的一维在最后

PyTorch 的 Tensor.unfold() 设计规则是:

复制代码
窗口数量维度留在原来的位置;
窗口内部大小维度追加到最后。

也就是说,对某个维度做 unfold 后:

复制代码
原来的这个维度位置,用来表示切出了多少个窗口;
新增的最后一维,用来表示每个窗口里面有多少个元素。

可以简单记成:

复制代码
原位置放"窗口个数",最后一维放"窗口大小"。

7. 图像中的 unfold

图像张量通常是:

复制代码
[B, C, H, W]

含义是:

复制代码
B:batch size,图片数量
C:通道数
H:图像高度
W:图像宽度

如果要把图像切成 patch,就需要沿着高度和宽度两个方向分别切分。

假设输入图像是:

复制代码
x.shape = [B, C, H, W]

先执行:

复制代码
x = x.permute(0, 2, 3, 1)

形状从:

复制代码
[B, C, H, W]

变成:

复制代码
[B, H, W, C]

这样做是为了把高度 H 和宽度 W 放到中间,方便使用 unfold 沿高度和宽度方向切块。


8. 第一次 unfold:沿高度方向切

代码:

复制代码
x = x.unfold(1, patch_size, patch_size)

此时张量形状是:

复制代码
[B, H, W, C]

第 1 个维度是高度 H,所以这一步是在高度方向切块。

如果:

复制代码
H = 32
patch_size = 16

那么高度方向可以切出:

复制代码
32 / 16 = 2

个窗口。

形状变化为:

复制代码
[B, H, W, C]
↓
[B, H_num, W, C, patch_H]

例如:

复制代码
[B, 32, 32, C]
↓
[B, 2, 32, C, 16]

其中:

复制代码
H_num = 2:高度方向切出了 2 个窗口
patch_H = 16:每个窗口内部高度为 16

9. 第二次 unfold:沿宽度方向切

继续执行:

复制代码
x = x.unfold(2, patch_size, patch_size)

此时张量形状大致是:

复制代码
[B, H_num, W, C, patch_H]

第 2 个维度是宽度 W,所以这一步是在宽度方向切块。

如果:

复制代码
W = 32
patch_size = 16

宽度方向也可以切出:

复制代码
32 / 16 = 2

个窗口。

形状变化为:

复制代码
[B, H_num, W, C, patch_H]
↓
[B, H_num, W_num, C, patch_H, patch_W]

例如:

复制代码
[B, 2, 32, C, 16]
↓
[B, 2, 2, C, 16, 16]

其中:

复制代码
H_num:高度方向 patch 数量
W_num:宽度方向 patch 数量
C:通道数
patch_H:每个 patch 的高度
patch_W:每个 patch 的宽度

10. 完整图像切 patch 的维度变化

以这段代码为例:

复制代码
x = x.permute(0, 2, 3, 1).unfold(1, self.patch_size, self.patch_size)\
    .unfold(2, self.patch_size, self.patch_size).contiguous()\
    .view(B, -1, C, self.patch_size, self.patch_size)

可以拆开理解:

复制代码
x = x.permute(0, 2, 3, 1)

形状变化:

复制代码
[B, C, H, W] → [B, H, W, C]

然后:

复制代码
x = x.unfold(1, patch_size, patch_size)

形状变化:

复制代码
[B, H, W, C] → [B, H_num, W, C, patch_H]

再然后:

复制代码
x = x.unfold(2, patch_size, patch_size)

形状变化:

复制代码
[B, H_num, W, C, patch_H] → [B, H_num, W_num, C, patch_H, patch_W]

最后:

复制代码
x = x.contiguous().view(B, -1, C, patch_size, patch_size)

形状变化:

复制代码
[B, H_num, W_num, C, patch_H, patch_W]
↓
[B, N, C, patch_size, patch_size]

其中:

复制代码
N = H_num × W_num

也就是每张图片被切出来的 patch 数量。


11. 具体例子:32×32 图片切成 16×16 patch

假设输入:

复制代码
x.shape = [1, 3, 32, 32]

表示:

复制代码
1 张图片
3 个通道
图片大小为 32 × 32

设置:

复制代码
patch_size = 16

第一步:

复制代码
x = x.permute(0, 2, 3, 1)

形状变为:

复制代码
[1, 32, 32, 3]

第二步:

复制代码
x = x.unfold(1, 16, 16)

沿高度方向切。

高度 32,每块 16,步长 16,所以高度方向切出 2 个窗口。

形状变为:

复制代码
[1, 2, 32, 3, 16]

第三步:

复制代码
x = x.unfold(2, 16, 16)

沿宽度方向切。

宽度 32,每块 16,步长 16,所以宽度方向切出 2 个窗口。

形状变为:

复制代码
[1, 2, 2, 3, 16, 16]

第四步:

复制代码
x = x.contiguous().view(1, -1, 3, 16, 16)

形状变为:

复制代码
[1, 4, 3, 16, 16]

因为:

复制代码
4 = 2 × 2

表示这张图片被切成了 4 个 patch。


12. 具体例子:224×224 图片切成 16×16 patch

假设输入:

复制代码
x.shape = [B, 3, 224, 224]

设置:

复制代码
patch_size = 16

高度方向:

复制代码
224 / 16 = 14

宽度方向:

复制代码
224 / 16 = 14

所以每张图片可以切出:

复制代码
14 × 14 = 196

个 patch。

最终输出形状是:

复制代码
[B, 196, 3, 16, 16]

含义是:

复制代码
B:图片数量
196:每张图片的 patch 数量
3:每个 patch 的通道数
16:每个 patch 的高度
16:每个 patch 的宽度

相关推荐
IT_陈寒1 小时前
Vue组件通信这个坑我跳了两次才知道怎么爬出来
前端·人工智能·后端
张哈大1 小时前
MCP:重塑AI工具调用的统一标准,告别重复造轮子的时代
人工智能·python·ai·prompt
K姐研究社1 小时前
美图设计室实测 – 输入1张商品图,AI批量生成带货视频
人工智能·aigc
HackTwoHub1 小时前
WEB扫描器Invicti-Professional-V26.50.0(自动化爬虫扫描)更新
前端·人工智能·chrome·爬虫·web安全·网络安全·自动化
李二。1 小时前
AI翻译通(鸿蒙原生)—— 鸿蒙Next声明式UI翻译工具实战
人工智能·ui·harmonyos
咖啡星人k1 小时前
用 MonkeyCode 构建全栈应用:从需求到部署的AI自动化实践
运维·人工智能·自动化
keykey6.1 小时前
PyTorch 入门实战:从张量到训练循环
开发语言·人工智能·深度学习·机器学习
智者知已应修善业1 小时前
【51单片机0.1秒计时到21.0时点亮LED】2024-1-5
c++·经验分享·笔记·算法·51单片机
AeeeSs1 小时前
web shell
笔记