YOLOv8修改特征金字塔(替换SPPF模块)

1.引言

1.1 引言

修改特征金字塔模块,即SPPF模块是YOLOv8改进中非常常见的一个改进点。

以下将介绍如何在yolov8中修改SPPF模型。

2.2 常见特征金字塔模块

常见特征金字塔可以看此贴:常见特征金字塔模块代码实现

1.3 本文示例

本文使用SimSPPF模块作为示例,SimSPPF模块是美团YOLOv6提出的模块,与SPPF只相差了一个激活函数,将Silu激活函数改为了Relu激活函数,相比于SPPF模块速度更快,可以尝试一下。

2. 实验

2.1 block.py修改

以下是SimSPPF模块代码

python 复制代码
class SimConv(nn.Module):
    '''Normal Conv with ReLU activation'''
    def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, bias=False):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=groups,
            bias=bias,
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        return self.act(self.conv(x))

class SimSPPF(nn.Module):
    '''Simplified SPPF with ReLU activation'''
    def __init__(self, in_channels, out_channels, kernel_size=5):
        super().__init__()
        c_ = in_channels // 2  # hidden channels
        self.cv1 = SimConv(in_channels, c_, 1, 1)
        self.cv2 = SimConv(c_ * 4, out_channels, 1, 1)
        self.m = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=kernel_size // 2)

    def forward(self, x):
        x = self.cv1(x)
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            y1 = self.m(x)
            y2 = self.m(y1)
            return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))

放置到ultralytics/nn/modules/block.py代码中的最后,如下图。

然后再次文件开头__all__中添加SimSPPF

2.2 __ init__.py

修改此路径下的ultralytics/nn/modules/__ init__.py文件

如下图所示,添加相应的代码:

2.3 tasks.py

修改此路径下的ultralytics/nn/tasks.py文件

因为SimSPPF和SPPF属于同一种结构,因此,我们写到SPPF后面即可。

另外需要导包,快捷键alt+回车键即可。

2.4 模型更改

复制基础模型即可,将SPPF改为SimSPPF

以yolov8n。yaml为例,如下:

yaml 复制代码
 # Ultralytics YOLO 🚀, GPL-3.0 license

# Parameters
nc: 1  # number of classes
depth_multiple: 0.33  # scales module repeats
width_multiple: 0.25  # scales convolution channels

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SimSPPF, [1024, 5]]  # 9


# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[15, 18, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)

模型运行时只需要将模型修改为次路径即可。

相关推荐
爱技术的小伙子9 分钟前
【30天玩转python】函数式编程
开发语言·python
zhaoyushi0016 分钟前
python自学笔记
windows·笔记·python
小鹿( ﹡ˆoˆ﹡ )22 分钟前
Python中的“打开与关闭文件”:从入门到精通
linux·前端·python
西猫雷婶23 分钟前
python画图|多个填充区域
开发语言·python
阿利同学37 分钟前
基于opencv的车牌检测和识别系统(代码+教程)
人工智能·python·opencv·计算机视觉·车牌识别·pyqt5·联系 qq1309399183
洛阳泰山38 分钟前
Chainlit集成LlamaIndex实现知识库高级检索(BM25全文检索器)
python·django·全文检索·bm25·llamaindex·pythonchainlit
科研小白 新人上路1 小时前
基于python深度学习遥感影像地物分类与目标识别、分割实践技术
python·tensorflow·目标识别·遥感影像·地物分类·城市规划·林业测量
阿华的代码王国1 小时前
【JavaEE】——内存可见性问题
开发语言·python
醒了就刷牙1 小时前
55 循环神经网络RNN的实现_by《李沐:动手学深度学习v2》pytorch版
pytorch·rnn·深度学习
码猿小菜鸡1 小时前
【小六壬占卜代码】
开发语言·人工智能·python·占卜