在 PyTorch 中,unfold
是张量的一种操作方法,用于将张量的一个维度展开为多个重叠的窗口。这在处理滑动窗口操作(如卷积、时间序列分片等)时非常有用。
方法签名
Tensor.unfold(dimension, size, step)
参数说明
dimension
: 要展开的维度索引(从 0 开始)。size
: 每个窗口的大小。step
: 每个窗口之间的步幅。
返回值
返回一个新的张量,其中指定的维度被替换为一个新的维度,该维度包含了从原始维度中抽取的滑动窗口。
用法示例
示例 1: 基本用法
import torch
# 创建一个一维张量
x = torch.arange(10)
print("原始张量:", x)
# 使用 unfold 创建滑动窗口
y = x.unfold(dimension=0, size=3, step=2)
print("展开后的张量:\n", y)
输出:
原始张量: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
展开后的张量:
tensor([[0, 1, 2],
[2, 3, 4],
[4, 5, 6],
[6, 7, 8]])
解释:
- 在第 0 维(行)上展开,窗口大小为 3,步幅为 2。
- 展开的每个窗口是连续的 3 个元素,相邻窗口之间跳过 2 个元素。
示例 2: 二维张量的展开
# 创建一个二维张量
x = torch.arange(1, 17).view(4, 4)
print("原始张量:\n", x)
# 在维度 1 上展开
y = x.unfold(dimension=1, size=2, step=1)
print("展开后的张量:\n", y)
输出:
原始张量:
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[13, 14, 15, 16]])
展开后的张量:
tensor([[[ 1, 2],
[ 2, 3],
[ 3, 4]],
[[ 5, 6],
[ 6, 7],
[ 7, 8]],
[[ 9, 10],
[10, 11],
[11, 12]],
[[13, 14],
[14, 15],
[15, 16]]])
解释:
- 在第 1 维(列)上展开,窗口大小为 2,步幅为 1。
- 每一行的每个窗口是连续的 2 个元素,相邻窗口之间跳过 1 个元素。
示例 3: 高维张量的展开
对于高维张量,unfold
操作也类似,只需指定正确的维度即可。例如:
x = torch.randn(2, 3, 4) # 一个 3 维张量
y = x.unfold(dimension=2, size=2, step=1)
print("高维张量展开后形状:", y.shape)
输出:
高维张量展开后形状: torch.Size([2, 3, 3, 2])
注意事项
- 边界问题:如果剩余的元素不足以填满窗口,则这些元素会被忽略。
- 效率 :
unfold
是一个视图操作,不会拷贝数据,但其返回值是一个新的张量,需要额外注意内存管理。
常见应用
- 图像处理中的滑动窗口操作。
- 自定义卷积操作或池化操作。
- 时间序列数据的分片处理。