PyTorch nn.Conv2d 空洞卷积

torch.nn.Conv2d() 中 dilation 参数控制卷积核的间隔

dilation controls the spacing between the kernel points

  • 当 dilation=1 时, 表示卷积核没有额外的空白间距, 也就是标准卷积
  • 当 dilation>1 时, 表示空洞卷积(dilated convolution)

动画演示:

手动计算

以 2*2 的卷积核和 dilation=2 为例, 等效卷积核的大小为:

左上角区域卷积: 1 * 2 + 3 * 0 + 3 * 1 + 1 * 3 = 8, 卷积核中的空白间隔不参与运算, 当然也可以将其置为 0, 等效为 3 * 3 的卷积运算

结果:

使用 PyTorch 计算

python 复制代码
import torch
from torch import nn

data = [
    [1, 2, 3, 0],
    [0, 1, 2, 3],
    [3, 0, 1, 2],
    [2, 3, 0, 1]
]
# 单通道 4*4 图片
# minibatch=1
inp = torch.tensor(data).reshape(1, 1, 4, 4).to(torch.float32)

conv = nn.Conv2d(1, 1, kernel_size=2, dilation=2, bias=False)
conv.weight.data = torch.tensor(
    [[2, 0], [1, 3]]
).reshape(1, 1, 2, 2).to(torch.float32)

oup = conv(inp)
print(oup)

输出

python 复制代码
tensor([[[[ 8., 10.],
          [ 2.,  8.]]]], grad_fn=<ConvolutionBackward0>)

空洞卷积可以扩大感受野, 2*2 的卷积核, dilation 参数设为 2, 可以提取特征图中 3*3 的内容, 却只有 2*2 的卷积运算量

空洞卷积会丢失局部信息

相关推荐
拓朗工控6 小时前
工业视觉AI边缘计算解决方案
人工智能·深度学习·边缘计算·工控机·工业电脑·拓朗工控
YOLO数据集集合7 小时前
无人机航拍巡检数据集|城市乡镇港口工业区|高分辨率旋转目标检测|深度学习训练基准
深度学习·目标检测·无人机
【建模先锋】7 小时前
强噪声故障诊断新思路!从频域降噪到故障分类:FusADFaultClassifier 自适应谱降噪分类模型详解
人工智能·深度学习·分类·数据挖掘·信号处理·故障诊断·降噪算法
承渊政道7 小时前
【从零开始大模型开发与微调:基于PyTorch与ChatGLM】(新时代的曙光之大模型与人工智能)
人工智能·pytorch·python·深度学习·机器学习·语言模型·自然语言处理
llfjfz7 小时前
TensorFlow花卉图片分类器模型训练
tensorflow·卷积神经网络
钓了猫的鱼儿7 小时前
基于深度学习+AI的无人机麦苗目标检测与预警系统(Python源码+数据集+UI可视化界面+YOLOv11训练结果)
人工智能·深度学习·无人机
神州数码云基地8 小时前
DSPy + Parlant:从手动调优到自动编译的效率加速器
人工智能·深度学习·机器学习
武子康20 小时前
调查研究-151 Slack vs Jira:区别、使用指南与团队选择方法
人工智能·科技·深度学习·ai·职场和发展·jira·slack
z小猫不吃鱼1 天前
05 Transformer Encoder 详解:BERT 为什么使用 Encoder?
深度学习·bert·transformer