YOLOv5怎么做改进?大佬手把手带你在YOLOv5中添加4种注意力机制,训练自己的数据集!

在YOLOv5中添加注意力机制可显著提升模型性能,以下是4种主流注意力机制的改进方法及训练数据集的详细步骤:

一、添加4种注意力机制

1. CBAM(卷积块注意力模块)

原理 :结合通道注意力和空间注意力,增强特征提取能力。
实现步骤

  1. 修改common.py:添加CBAM模块代码。

    python 复制代码
    class CBAM(nn.Module):
        def __init__(self, channels, reduction=16):
            super(CBAM, self).__init__()
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
            self.max_pool = nn.AdaptiveMaxPool2d(1)
            self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False)
            self.relu = nn.ReLU(inplace=True)
            self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False)
            self.sigmoid_channel = nn.Sigmoid()
            self.conv1 = nn.Conv2d(2, 1, kernel_size=7, padding=3)
            self.sigmoid_spatial = nn.Sigmoid()
    
        def forward(self, x):
            # 通道注意力
            avg_out = self.fc2(self.relu(self.fc1(self.avg_pool(x))))
            max_out = self.fc2(self.relu(self.fc1(self.max_pool(x))))
            channel_att = self.sigmoid_channel(avg_out + max_out)
            x = x * channel_att
    
            # 空间注意力
            avg_pool = torch.mean(x, dim=1, keepdim=True)
            max_pool = torch.max(x, dim=1, keepdim=True)[0]
            concat = torch.cat([avg_pool, max_pool], dim=1)
            spatial_att = self.sigmoid_spatial(self.conv1(concat))
            x = x * spatial_att
            return x
  2. 修改yolo.py :在parse_model函数中注册CBAM模块。

  3. 修改配置文件 :在yolov5s.yaml中指定CBAM的插入位置(如Backbone的C3模块后)。

  4. 调整超参数:根据实验调整CBAM的缩放因子和卷积核大小。

2. SE(Squeeze-and-Excitation)

原理 :通过全局平均池化学习通道权重,增强重要通道特征。
实现步骤

  1. 修改common.py:添加SE模块代码。

    python 复制代码
    class SELayer(nn.Module):
        def __init__(self, channel, r=16):
            super(SELayer, self).__init__()
            self.avgpool = nn.AdaptiveAvgPool2d(1)
            self.l1 = nn.Linear(channel, channel // r, bias=False)
            self.relu = nn.ReLU(inplace=True)
            self.l2 = nn.Linear(channel // r, channel, bias=False)
            self.sig = nn.Sigmoid()
    
        def forward(self, x):
            b, c, _, _ = x.size()
            y = self.avgpool(x).view(b, c)
            y = self.l2(self.relu(self.l1(y)))
            y = self.sig(y).view(b, c, 1, 1)
            return x * y.expand_as(x)
  2. 修改yolo.py:注册SE模块并调整模型结构。

  3. 修改配置文件 :在yolov5s.yaml中插入SE模块(如替换C3模块中的部分卷积层)。

3. ECA(高效通道注意力)

原理 :通过1D卷积实现跨通道交互,避免维度缩减。
实现步骤

  1. 修改common.py:添加ECA模块代码。

    python 复制代码
    class ECA(nn.Module):
        def __init__(self, channel, k_size=3):
            super(ECA, self).__init__()
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
            self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
            self.sigmoid = nn.Sigmoid()
    
        def forward(self, x):
            y = self.avg_pool(x)
            y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
            y = self.sigmoid(y)
            return x * y.expand_as(x)
  2. 修改yolo.py:注册ECA模块并调整模型结构。

  3. 修改配置文件 :在yolov5s.yaml中插入ECA模块(如Backbone的SPPF层前)。

4. CA(坐标注意力)

原理 :通过全局平均池化捕捉宽度和高度方向的特征,增强位置感知能力。
实现步骤

  1. 修改common.py:添加CA模块代码。

    python 复制代码
    class CoordAtt(nn.Module):
        def __init__(self, inp, oup, reduction=32):
            super(CoordAtt, self).__init__()
            self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
            self.pool_w = nn.AdaptiveAvgPool2d((1, None))
            mip = max(8, inp // reduction)
            self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
            self.bn1 = nn.BatchNorm2d(mip)
            self.act = h_swish()
            self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
            self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
    
        def forward(self, x):
            identity = x
            n, c, h, w = x.size()
            x_h = self.pool_h(x)
            x_w = self.pool_w(x).permute(0, 1, 3, 2)
            y = torch.cat([x_h, x_w], dim=2)
            y = self.conv1(y)
            y = self.bn1(y)
            y = self.act(y)
            x_h, x_w = torch.split(y, [h, w], dim=2)
            x_w = x_w.permute(0, 1, 3, 2)
            a_h = self.conv_h(x_h).sigmoid()
            a_w = self.conv_w(x_w).sigmoid()
            out = identity * a_w * a_h
            return out
  2. 修改yolo.py:注册CA模块并调整模型结构。

  3. 修改配置文件 :在yolov5s.yaml中插入CA模块(如Backbone的C3模块后)。

二、训练自己的数据集

1. 准备数据集

  • 标注工具:使用LabelImg或精灵标注助手标注数据,生成VOC格式的XML文件。

  • 数据集结构

    bash 复制代码
    data/
    ├── images/       # 存放图片
    ├── labels/       # 存放标注文件(.txt格式)
    ├── train.txt     # 训练集路径
    ├── val.txt       # 验证集路径
    └── data.yaml     # 数据集配置文件
  • 生成标注文件 :使用voc_label.py将XML转换为YOLO格式的TXT文件。

  • 配置data.yaml

    yaml 复制代码
    train: ./train.txt
    val: ./val.txt
    nc: 3  # 类别数
    names: ['class1', 'class2', 'class3']  # 类别名称

2. 训练模型

  • 下载预训练模型 :从YOLOv5官方GitHub下载对应版本的预训练权重(如yolov5s.pt)。

  • 修改训练命令

    bash 复制代码
    python train.py --img 640 --batch 32 --epochs 300 --data data/data.yaml --cfg models/yolov5s.yaml --weights weights/yolov5s.pt --device 0
    • --img:输入图片尺寸。
    • --batch:批量大小。
    • --epochs:训练轮数。
    • --data:数据集配置文件路径。
    • --cfg:模型配置文件路径。
    • --weights:预训练权重路径。
    • --device:指定GPU设备。

3. 模型测试与推理

  • 测试模型

    bash 复制代码
    python test.py --data data/data.yaml --weights runs/train/exp/weights/best.pt --augment
    • --weights:使用训练好的最佳模型权重。
    • --augment:启用数据增强。
  • 模型推理

    bash 复制代码
    python detect.py --weights runs/train/exp/weights/best.pt --source inference/images/ --device 0
    • --source:指定测试图片文件夹路径。
    • 推理结果保存在inference/output文件夹中。
相关推荐
程序员蜗牛3 小时前
微信登录之OpenID与UnionID获取全流程解析
后端
SimonKing3 小时前
SpringBoot多模板引擎整合难题?一篇搞定JSP、Freemarker与Thymeleaf!
java·后端·程序员
rannn_1113 小时前
【LeetCode hot100|Week4】链表
后端·算法·leetcode·链表
SYC_MORE4 小时前
多线程环境下处理Flask上下文问题的文档
后端·python·flask
Craaaayon4 小时前
【数据结构】二叉树-图解深度优先搜索(递归法、迭代法)
java·数据结构·后端·算法·leetcode·深度优先
ChinaRainbowSea4 小时前
5. Prompt 提示词
java·人工智能·后端·spring·prompt·ai编程
IT_陈寒4 小时前
Vue3性能优化实战:这5个技巧让我的应用加载速度提升70%
前端·人工智能·后端
Apifox4 小时前
Apifox 9 月更新| AI 生成接口测试用例、在线文档调试能力全面升级、内置更多 HTTP 状态码、支持将目录转换为模块
前端·后端·测试
databook5 小时前
Manim实现闪电特效
后端·python·动效