【pytorch】torch.nn.Unfold操作

说明

一个代码里涉及到了unfold的操作,看了半天官网都没整明白维度怎么变化的,参考这个链接搞明白了:

https://blog.csdn.net/ViatorSun/article/details/119940759

https://zhuanlan.zhihu.com/p/361140988

维度计算

输入( N, C, H, W)

输出(N,C×∏(kernel_size),L)

L 是通过卷积核 滑动裁剪 后,得到的区块的数量。

C×∏(kernel_size)是怎么来的?

从第一个参考链接的图可以看到,就是窗口滑动的时候,把窗口同一个位置的值放在了一个通道,所以窗口有多少个像素,就变成了多少个通道。

用参考博客的代码做一个示例:

python 复制代码
inputs = torch.randn(1, 2, 4, 4)
print(inputs.size())
print(inputs)
unfold  = torch.nn.Unfold(kernel_size=(2, 2), stride=2)
patches = unfold(inputs)
print(patches.size())
print(patches)

输出结果

python 复制代码
torch.Size([1, 2, 4, 4])
tensor([[[[ 0.4448, -1.8525, -1.8243, -1.0243],
          [ 0.0224, -1.2402, -0.7154, -1.2538],
          [-0.6515, -0.6022,  0.2263, -1.6286],
          [ 0.2067,  0.8257, -1.9318,  1.0372]],

         [[ 2.4799, -0.5248, -0.3170,  1.5934],
          [-0.3643,  1.1624, -1.5762, -0.1827],
          [-0.0553,  0.1629, -1.3280, -0.8468],
          [ 0.0671,  1.6328,  1.1706,  1.7891]]]])
torch.Size([1, 8, 4])
tensor([[[ 0.4448, -1.8243, -0.6515,  0.2263],
         [-1.8525, -1.0243, -0.6022, -1.6286],
         [ 0.0224, -0.7154,  0.2067, -1.9318],
         [-1.2402, -1.2538,  0.8257,  1.0372],
         [ 2.4799, -0.3170, -0.0553, -1.3280],
         [-0.5248,  1.5934,  0.1629, -0.8468],
         [-0.3643, -1.5762,  0.0671,  1.1706],
         [ 1.1624, -0.1827,  1.6328,  1.7891]]])

用两个窗口的情况来举例,每个位置对应的结果情况如下:

相关推荐
闲人编程5 分钟前
用Python分析你的Spotify/网易云音乐听歌数据
开发语言·python·ai·数据分析·spotify·网易云·codecapsule
“负拾捌”20 分钟前
LangChain 中 ChatPromptTemplate 的几种使用方式
python·langchain·prompt
thorn_r23 分钟前
MCP驱动的AI角色扮演游戏
人工智能·游戏·机器学习·ai·自然语言处理·agent·mcp
得贤招聘官24 分钟前
智能招聘革新:破解校招低效困局的核心方案
人工智能
乌恩大侠38 分钟前
【Spark】操作记录
人工智能·spark·usrp
一水鉴天42 分钟前
整体设计 全面梳理复盘 之27 九宫格框文法 Type 0~Ⅲ型文法和 bnf/abnf/ebnf 之1
人工智能·状态模式·公共逻辑
极客BIM工作室1 小时前
GAN vs. VAE:生成对抗网络 vs. 变分自编码机
人工智能·神经网络·生成对抗网络
咋吃都不胖lyh1 小时前
小白零基础教程:安装 Conda + VSCode 配置 Python 开发环境
人工智能·python·conda
minhuan1 小时前
构建AI智能体:八十九、Encoder-only与Decoder-only模型架构:基于ModelScope小模型的实践解析
人工智能·模型架构·encoder-only架构·decoder-only架构
rit84324991 小时前
基于MATLAB的PCA+SVM人脸识别系统实现
人工智能·算法