(超详细)7-YOLOV5改进-添加 CoTAttention注意力机制

1、在yolov5/models下面新建一个CoTAttention.py文件,在里面放入下面的代码

代码如下:

bash 复制代码
import numpy as np
import torch
from torch import flatten, nn
from torch.nn import init
from torch.nn.modules.activation import ReLU
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn import functional as F


class CoTAttention(nn.Module):

    def __init__(self, dim=512, kernel_size=3):
        super().__init__()
        self.dim = dim
        self.kernel_size = kernel_size

        self.key_embed = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=4, bias=False),
            nn.BatchNorm2d(dim),
            nn.ReLU()
        )
        self.value_embed = nn.Sequential(
            nn.Conv2d(dim, dim, 1, bias=False),
            nn.BatchNorm2d(dim)
        )

        factor = 4
        self.attention_embed = nn.Sequential(
            nn.Conv2d(2 * dim, 2 * dim // factor, 1, bias=False),
            nn.BatchNorm2d(2 * dim // factor),
            nn.ReLU(),
            nn.Conv2d(2 * dim // factor, kernel_size * kernel_size * dim, 1)
        )

    def forward(self, x):
        bs, c, h, w = x.shape
        k1 = self.key_embed(x)  # bs,c,h,w
        v = self.value_embed(x).view(bs, c, -1)  # bs,c,h,w

        y = torch.cat([k1, x], dim=1)  # bs,2c,h,w
        att = self.attention_embed(y)  # bs,c*k*k,h,w
        att = att.reshape(bs, c, self.kernel_size * self.kernel_size, h, w)
        att = att.mean(2, keepdim=False).view(bs, c, -1)  # bs,c,h*w
        k2 = F.softmax(att, dim=-1) * v
        k2 = k2.view(bs, c, h, w)

        return k1 + k2

2、找到yolo.py文件,进行更改内容

在29行加一个from models.CoTAttention import CoTAttention, 保存即可

3、找到自己想要更改的yaml文件,我选择的yolov5s.yaml文件(你可以根据自己需求进行选择),将刚刚写好的模块CoTAttention加入到yolov5s.yaml里面,并更改一些内容。更改如下

4、在yolo.py里面加入两行代码(335-337)

保存即可!

运行

相关推荐
羊小猪~~22 分钟前
神经网络基础--什么是正向传播??什么是方向传播??
人工智能·pytorch·python·深度学习·神经网络·算法·机器学习
软工菜鸡1 小时前
预训练语言模型BERT——PaddleNLP中的预训练模型
大数据·人工智能·深度学习·算法·语言模型·自然语言处理·bert
deephub2 小时前
Tokenformer:基于参数标记化的高效可扩展Transformer架构
人工智能·python·深度学习·架构·transformer
___Dream2 小时前
【CTFN】基于耦合翻译融合网络的多模态情感分析的层次学习
人工智能·深度学习·机器学习·transformer·人机交互
极客代码2 小时前
【Python TensorFlow】入门到精通
开发语言·人工智能·python·深度学习·tensorflow
王哈哈^_^3 小时前
【数据集】【YOLO】【VOC】目标检测数据集,查找数据集,yolo目标检测算法详细实战训练步骤!
人工智能·深度学习·算法·yolo·目标检测·计算机视觉·pyqt
写代码的小阿帆4 小时前
pytorch实现深度神经网络DNN与卷积神经网络CNN
pytorch·cnn·dnn
是瑶瑶子啦4 小时前
【深度学习】论文笔记:空间变换网络(Spatial Transformer Networks)
论文阅读·人工智能·深度学习·视觉检测·空间变换
wangyue45 小时前
c# 深度模型入门
深度学习
川石课堂软件测试5 小时前
性能测试|docker容器下搭建JMeter+Grafana+Influxdb监控可视化平台
运维·javascript·深度学习·jmeter·docker·容器·grafana