(超详细)7-YOLOV5改进-添加 CoTAttention注意力机制

1、在yolov5/models下面新建一个CoTAttention.py文件,在里面放入下面的代码

代码如下:

bash 复制代码
import numpy as np
import torch
from torch import flatten, nn
from torch.nn import init
from torch.nn.modules.activation import ReLU
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn import functional as F


class CoTAttention(nn.Module):

    def __init__(self, dim=512, kernel_size=3):
        super().__init__()
        self.dim = dim
        self.kernel_size = kernel_size

        self.key_embed = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=4, bias=False),
            nn.BatchNorm2d(dim),
            nn.ReLU()
        )
        self.value_embed = nn.Sequential(
            nn.Conv2d(dim, dim, 1, bias=False),
            nn.BatchNorm2d(dim)
        )

        factor = 4
        self.attention_embed = nn.Sequential(
            nn.Conv2d(2 * dim, 2 * dim // factor, 1, bias=False),
            nn.BatchNorm2d(2 * dim // factor),
            nn.ReLU(),
            nn.Conv2d(2 * dim // factor, kernel_size * kernel_size * dim, 1)
        )

    def forward(self, x):
        bs, c, h, w = x.shape
        k1 = self.key_embed(x)  # bs,c,h,w
        v = self.value_embed(x).view(bs, c, -1)  # bs,c,h,w

        y = torch.cat([k1, x], dim=1)  # bs,2c,h,w
        att = self.attention_embed(y)  # bs,c*k*k,h,w
        att = att.reshape(bs, c, self.kernel_size * self.kernel_size, h, w)
        att = att.mean(2, keepdim=False).view(bs, c, -1)  # bs,c,h*w
        k2 = F.softmax(att, dim=-1) * v
        k2 = k2.view(bs, c, h, w)

        return k1 + k2

2、找到yolo.py文件,进行更改内容

在29行加一个from models.CoTAttention import CoTAttention, 保存即可

3、找到自己想要更改的yaml文件,我选择的yolov5s.yaml文件(你可以根据自己需求进行选择),将刚刚写好的模块CoTAttention加入到yolov5s.yaml里面,并更改一些内容。更改如下

4、在yolo.py里面加入两行代码(335-337)

保存即可!

运行

相关推荐
ZCXZ12385296a1 小时前
YOLOv26在水果图像识别与分类中的应用:苹果、猕猴桃、橙子和红毛丹的检测研究
yolo·分类·数据挖掘
林深现海2 小时前
【刘二大人】PyTorch深度学习实践笔记 —— 第一集:深度学习全景概述(超详细版)
pytorch·笔记·深度学习
szxinmai主板定制专家3 小时前
基于 PC 的控制技术+ethercat+linux实时系统,助力追踪标签规模化生产,支持国产化
arm开发·人工智能·嵌入式硬件·yolo·fpga开发
阿狸OKay3 小时前
einops 库和 PyTorch 的 einsum 的语法
人工智能·pytorch·python
莱茶荼菜4 小时前
yolo26 阅读笔记
人工智能·笔记·深度学习·ai·yolo26
Dingdangcat864 小时前
【YOLOv8改进实战】使用Ghost模块优化P2结构提升涂胶缺陷检测精度_1
人工智能·yolo·目标跟踪
阿正的梦工坊7 小时前
Megatron中--train-iters和--max_epochs两个参数介绍
人工智能·深度学习·自然语言处理
哥布林学者7 小时前
吴恩达深度学习课程五:自然语言处理 第三周:序列模型与注意力机制(四)语音识别和触发字检测
深度学习·ai
青瓷程序设计9 小时前
【交通标志识别系统】python+深度学习+算法模型+Resnet算法+人工智能+2026计算机毕设项目
人工智能·python·深度学习