YOLOv8修改特征金字塔(替换SPPF模块)

1.引言

1.1 引言

修改特征金字塔模块,即SPPF模块是YOLOv8改进中非常常见的一个改进点。

以下将介绍如何在yolov8中修改SPPF模型。

2.2 常见特征金字塔模块

常见特征金字塔可以看此贴:常见特征金字塔模块代码实现

1.3 本文示例

本文使用SimSPPF模块作为示例,SimSPPF模块是美团YOLOv6提出的模块,与SPPF只相差了一个激活函数,将Silu激活函数改为了Relu激活函数,相比于SPPF模块速度更快,可以尝试一下。

2. 实验

2.1 block.py修改

以下是SimSPPF模块代码

python 复制代码
class SimConv(nn.Module):
    '''Normal Conv with ReLU activation'''
    def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, bias=False):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=groups,
            bias=bias,
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        return self.act(self.conv(x))

class SimSPPF(nn.Module):
    '''Simplified SPPF with ReLU activation'''
    def __init__(self, in_channels, out_channels, kernel_size=5):
        super().__init__()
        c_ = in_channels // 2  # hidden channels
        self.cv1 = SimConv(in_channels, c_, 1, 1)
        self.cv2 = SimConv(c_ * 4, out_channels, 1, 1)
        self.m = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=kernel_size // 2)

    def forward(self, x):
        x = self.cv1(x)
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            y1 = self.m(x)
            y2 = self.m(y1)
            return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))

放置到ultralytics/nn/modules/block.py代码中的最后,如下图。

然后再次文件开头__all__中添加SimSPPF

2.2 __ init__.py

修改此路径下的ultralytics/nn/modules/__ init__.py文件

如下图所示,添加相应的代码:

2.3 tasks.py

修改此路径下的ultralytics/nn/tasks.py文件

因为SimSPPF和SPPF属于同一种结构,因此,我们写到SPPF后面即可。

另外需要导包,快捷键alt+回车键即可。

2.4 模型更改

复制基础模型即可,将SPPF改为SimSPPF

以yolov8n。yaml为例,如下:

yaml 复制代码
 # Ultralytics YOLO 🚀, GPL-3.0 license

# Parameters
nc: 1  # number of classes
depth_multiple: 0.33  # scales module repeats
width_multiple: 0.25  # scales convolution channels

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SimSPPF, [1024, 5]]  # 9


# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[15, 18, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)

模型运行时只需要将模型修改为次路径即可。

相关推荐
databook1 小时前
Manim实现闪光轨迹特效
后端·python·动效
Juchecar3 小时前
解惑:NumPy 中 ndarray.ndim 到底是什么?
python
用户8356290780513 小时前
Python 删除 Excel 工作表中的空白行列
后端·python
Json_3 小时前
使用python-fastApi框架开发一个学校宿舍管理系统-前后端分离项目
后端·python·fastapi
CoovallyAIHub5 小时前
开源的消逝与新生:从 TensorFlow 的落幕到开源生态的蜕变
pytorch·深度学习·llm
数据智能老司机9 小时前
精通 Python 设计模式——分布式系统模式
python·设计模式·架构
数据智能老司机10 小时前
精通 Python 设计模式——并发与异步模式
python·设计模式·编程语言
数据智能老司机10 小时前
精通 Python 设计模式——测试模式
python·设计模式·架构
数据智能老司机10 小时前
精通 Python 设计模式——性能模式
python·设计模式·架构
c8i11 小时前
drf初步梳理
python·django