YOLOv5白皮书-第Y6周:模型改进

本文为365天深度学习训练营 中的学习记录博客

原作者:K同学啊|接辅导、项目定制

本次训练是在前文《YOLOv5白皮书-第Y2周:训练自己的数据集》的基础上进行的,并参考了《YOLOv5白皮书-第Y5周:yolo.py文件解读》。

任务:修改了YOLOv5s的网络结构图,请根据网络结构图以及第Y1~Y5周的内容修改对应代码,并跑通程序。

原YOLOv5s的网络结构图:

修改后的YOLOv5s的网络结构图:

通过比较上面两张图,可知,本次的任务:

把索引为4的层从C3 * 2修改为C2 * 2

把索引为6的层从C3 * 3修改为C3 * 1

去除索引为7、8的层

其实还有隐藏任务,就是去除了索引为7、8的层后,原来8层后面的索引也会改变的,这也是要注意的,特别是有concat的地方,更要注意了,这会在后面的代码修改中有解释。

一、代码修改

1、把索引为4的层从C32修改为C22

参考前文《YOLOv5白皮书-第Y5周:yolo.py文件解读》,在common.py中根据下图写出C2的代码。

就是在C3模块代码的基础上修改为C2模块,把C3模块的代码去掉concat后面的Conv,就变成了C2模块。

详见下面的代码,这里把C3模块的代码也写出来,方便和C2模块的代码比较:

python 复制代码
class C3(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
        """Initializes C3 module with options for channel count, bottleneck repetition, shortcut usage, group
        convolutions, and expansion.
        """
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))

    def forward(self, x):
        """Performs forward propagation using concatenated outputs from two convolutions and a Bottleneck sequence."""
        return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))


# 这是在C3模块的基础上修改的,就是把C3模块的self.cv3去掉,主要是把forward(self, x)中的self.cv3去掉
class C2(nn.Module):
    # CSP Bottleneck with 2 convolutions
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
        """Initializes C2 module with options for channel count, bottleneck repetition, shortcut usage, group
        convolutions, and expansion.
        """
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        # self.cv3 = Conv(2 * c_, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))

    def forward(self, x):
        """Performs forward propagation using concatenated outputs from two convolutions and a Bottleneck sequence."""
        return torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)

在yolo.py中,找到parse_model函数,并在这个函数里的两个地方添加C2,详见注释,如下所示:

python 复制代码
if m in {
            Conv,
            GhostConv,
            Bottleneck,
            GhostBottleneck,
            SPP,
            SPPF,
            DWConv,
            MixConv2d,
            Focus,
            CrossConv,
            BottleneckCSP,
            C3,
            C2,          #原代码中没有C2,现在添加了C2
            C3TR,
            C3SPP,
            C3Ghost,
            nn.ConvTranspose2d,
            DWConvTranspose2d,
            C3x,
        }:
            c1, c2 = ch[f], args[0]
            if c2 != no:  # if not output
                c2 = make_divisible(c2 * gw, ch_mul)

            # 下面的代码中原来是没有C2,现在添加了C2(不是指小写的c2)
            args = [c1, c2, *args[1:]]
            if m in {BottleneckCSP, C3, C2, C3TR, C3Ghost, C3x}:    #添加C2
                args.insert(2, n)  # number of repeats
                n = 1

在yolo.py的开头,还要把C2添加上去,详见注释,如下所示:

python 复制代码
from models.common import (
    C3,
    C2,     #原来没有C2,现在添加C2
    C3SPP,
    C3TR,
    SPP,
    SPPF,
    Bottleneck,
    BottleneckCSP,
    C3Ghost,
    C3x,
    Classify,
    Concat,
    Contract,
    Conv,
    CrossConv,
    DetectMultiBackend,
    DWConv,
    DWConvTranspose2d,
    Expand,
    Focus,
    GhostBottleneck,
    GhostConv,
    Proto,
)

然后在yolov5s.yaml中,在backbone把索引为4的层从C3 * 2修改为C2 * 2 ,因为yolov5s.yaml已经写明depth_multiple为0.33,0.33 * 6 约等于 2,所以第4层为:[-1, 6, C2, [256]]

如下所示:

python 复制代码
# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [
    [-1, 1, Conv, [64, 6, 2, 2]],     # 0-P1/2
    [-1, 1, Conv, [128, 3, 2]],       # 1-P2/4
    [-1, 3, C3, [128]],               #2
    [-1, 1, Conv, [256, 3, 2]],       # 3-P3/8
    [-1, 6, C2, [256]],               #4  原来是C3*2,现在修改为C2*2
    [-1, 1, Conv, [512, 3, 2]],       # 5-P4/16
    [-1, 3, C3, [512]],               #6   原来是C3*3,现在修改为C3*1
    #[ -1, 1, Conv, [ 1024, 3, 2 ] ],  # 7-P5/32,要注释掉或者删除
    #[ -1, 3, C3, [ 1024 ] ],          #8,要注释掉或者删除
    [-1, 1, SPPF, [1024, 5]],         # 9 现在是第7层 ,此层之后的层索引都发生了改变
  ]

2、把索引为6的层从C3 * 3修改为C3 * 1

因为yolov5s.yaml已经写明depth_multiple为0.33,0.33 * 3 约等于 1,所以第6层为: [-1, 3, C3, [512]],详见下面的代码:

python 复制代码
# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [
    [-1, 1, Conv, [64, 6, 2, 2]],     # 0-P1/2
    [-1, 1, Conv, [128, 3, 2]],       # 1-P2/4
    [-1, 3, C3, [128]],               #2
    [-1, 1, Conv, [256, 3, 2]],       # 3-P3/8
    [-1, 6, C2, [256]],               #4  原来是C3*2,现在修改为C2*2
    [-1, 1, Conv, [512, 3, 2]],       # 5-P4/16
    [-1, 3, C3, [512]],               #6   原来是C3*3,现在修改为C3*1
    #[ -1, 1, Conv, [ 1024, 3, 2 ] ],  # 7-P5/32,要注释掉或者删除
    #[ -1, 3, C3, [ 1024 ] ],          #8,要注释掉或者删除
    [-1, 1, SPPF, [1024, 5]],         # 9 现在是第7层 ,此层之后的层索引都发生了改变
  ]

3、去除索引为7、8的层

在yolov5s.yaml中,在backbone中把索引为7、8的层注释就可以了,详见下面的代码:

python 复制代码
# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [
    [-1, 1, Conv, [64, 6, 2, 2]],     # 0-P1/2
    [-1, 1, Conv, [128, 3, 2]],       # 1-P2/4
    [-1, 3, C3, [128]],               #2
    [-1, 1, Conv, [256, 3, 2]],       # 3-P3/8
    [-1, 6, C2, [256]],               #4  原来是C3*2,现在修改为C2*2
    [-1, 1, Conv, [512, 3, 2]],       # 5-P4/16
    [-1, 3, C3, [512]],               #6   原来是C3*3,现在修改为C3*1
    #[ -1, 1, Conv, [ 1024, 3, 2 ] ],  # 7-P5/32,要注释掉或者删除
    #[ -1, 3, C3, [ 1024 ] ],          #8,要注释掉或者删除
    [-1, 1, SPPF, [1024, 5]],         # 9 现在是第7层 ,此层之后的层索引都发生了改变
  ]

但把第7、8层注释掉后,后面的层的索引就变了,在yolov5s.yaml中,head的索引也受到了影响,特别是Concat中,这是用层的索引的,这就是隐藏的任务,所以head的代码修改如下,并参考注释内容:

python 复制代码
# YOLOv5 v6.0 head
head: [
    [-1, 1, Conv, [512, 3, 2]],                     #8,原10  现在是[512, 3, 2],原来是[512, 1, 1]
    [-1, 1, nn.Upsample, [None, 2, "nearest"]],     #9,原11
    [[-1, 6], 1, Concat, [1]], # cat backbone P4     10,原12,  原12层是cat第6层的,变为10层后无变化
    [-1, 3, C3, [512, False]],                      # 现在是11,原来是13

    [-1, 1, Conv, [256, 1, 1]],                     #12,原14
    [-1, 1, nn.Upsample, [None, 2, "nearest"]],     #13,原15
    [[-1, 4], 1, Concat, [1]], # cat backbone P3      14,原16  原16层是cat第4层的,变为14层后无变化
    [-1, 3, C3, [256, False]], # 17 (P3/8-small)      15,原17

    [-1, 1, Conv, [256, 3, 2]],                      #16,原18
    [[-1, 12], 1, Concat, [1]], # cat head P4         17,原19  原19层是cat第14层的,但原14层变为第12层,所以要把14修改为12
    [-1, 3, C3, [512, False]], # 20 (P4/16-medium)    18,原20

    [-1, 1, Conv, [512, 3, 2]],                      #19,原21
    [[-1, 8], 1, Concat, [1]], # cat head P5        #20,原22,原22层是cat原第10层,但原第10层已经变为第8层了,所以要把10修改为8
    [-1, 3, C3, [1024, False]], # 23 (P5/32-large)   #21,原23

    [[15, 18, 21], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)    22,原24,原来是[17, 20, 23],层索引已经有变化了,要改为[15, 18, 21]
  ]

4、补充

因为训练数据的类别是["banana", "snake fruit", "dragon fruit", "pineapple"],只有4种,所以在yolov5s.yaml中还要把nc修改为4,如下所示:

python 复制代码
# Parameters
nc: 4 # number of classes,原文是80,因为要训练的数据集只有4种类别,所以把80修改为4
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple

二、代码运行

开始用自己的数据集训练模型,在项目目录中打开cmd。

如果电脑有GPU,则在cmd中输入命令:

python train.py--img 900 --batch 2 --epoch 50 --data data/ab.yaml --cfg models/yolov5s.yaml --weights weights/yolov5s.pt --device '0'

如果电脑没有GPU,则在cmd中输入命令:

python train.py --img 900 --batch 2 --epoch 50 --data data/ab.yaml --cfg models/yolov5s.yaml --weights weights/yolov5s.pt --device cpu

就可以直接训练自己的数据集啦,训练结果如下所示:


三、总结

本次训练需要修改的内容比较多,看起来有三个点,其实是四个点,在yolov5s.yaml中把backbone的索引7、8的层删除了,还要修改head的Concat的索引。

相关推荐
YINWA AI2 分钟前
胤娲科技:谷歌DeepMind祭出蛋白质设计新AI——癌症治疗迎来曙光
人工智能·科技·ai
会飞的Anthony7 分钟前
基于Python的自然语言处理系列(14):TorchText + biGRU + Attention + Teacher Forcing
人工智能·自然语言处理
jun7788959 分钟前
机器学习-监督学习:朴素贝叶斯分类器
人工智能·学习·机器学习
FL162386312910 分钟前
基于yolov5的混凝土缺陷检测系统python源码+onnx模型+评估指标曲线+精美GUI界面
人工智能·python·yolo
Kenneth風车14 分钟前
【第十三章:Sentosa_DSML社区版-机器学习聚类】
人工智能·低代码·机器学习·数据分析·聚类
jndingxin21 分钟前
OpenCV运动分析和目标跟踪(4)创建汉宁窗函数createHanningWindow()的使用
人工智能·opencv·目标跟踪
机器之心23 分钟前
o1 带火的 CoT 到底行不行?新论文引发了论战
android·人工智能
机器之心29 分钟前
从架构、工艺到能效表现,全面了解 LLM 硬件加速,这篇综述就够了
android·人工智能
jndingxin1 小时前
OpenCV特征检测(1)检测图像中的线段的类LineSegmentDe()的使用
人工智能·opencv·计算机视觉
@月落1 小时前
alibaba获得店铺的所有商品 API接口
java·大数据·数据库·人工智能·学习