(超详细)7-YOLOV5改进-添加 CoTAttention注意力机制

1、在yolov5/models下面新建一个CoTAttention.py文件,在里面放入下面的代码

代码如下:

bash 复制代码
import numpy as np
import torch
from torch import flatten, nn
from torch.nn import init
from torch.nn.modules.activation import ReLU
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn import functional as F


class CoTAttention(nn.Module):

    def __init__(self, dim=512, kernel_size=3):
        super().__init__()
        self.dim = dim
        self.kernel_size = kernel_size

        self.key_embed = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=4, bias=False),
            nn.BatchNorm2d(dim),
            nn.ReLU()
        )
        self.value_embed = nn.Sequential(
            nn.Conv2d(dim, dim, 1, bias=False),
            nn.BatchNorm2d(dim)
        )

        factor = 4
        self.attention_embed = nn.Sequential(
            nn.Conv2d(2 * dim, 2 * dim // factor, 1, bias=False),
            nn.BatchNorm2d(2 * dim // factor),
            nn.ReLU(),
            nn.Conv2d(2 * dim // factor, kernel_size * kernel_size * dim, 1)
        )

    def forward(self, x):
        bs, c, h, w = x.shape
        k1 = self.key_embed(x)  # bs,c,h,w
        v = self.value_embed(x).view(bs, c, -1)  # bs,c,h,w

        y = torch.cat([k1, x], dim=1)  # bs,2c,h,w
        att = self.attention_embed(y)  # bs,c*k*k,h,w
        att = att.reshape(bs, c, self.kernel_size * self.kernel_size, h, w)
        att = att.mean(2, keepdim=False).view(bs, c, -1)  # bs,c,h*w
        k2 = F.softmax(att, dim=-1) * v
        k2 = k2.view(bs, c, h, w)

        return k1 + k2

2、找到yolo.py文件,进行更改内容

在29行加一个from models.CoTAttention import CoTAttention, 保存即可

3、找到自己想要更改的yaml文件,我选择的yolov5s.yaml文件(你可以根据自己需求进行选择),将刚刚写好的模块CoTAttention加入到yolov5s.yaml里面,并更改一些内容。更改如下

4、在yolo.py里面加入两行代码(335-337)

保存即可!

运行

相关推荐
FL16238631292 小时前
芸豆叶子病害检测数据集VOC+YOLO格式1762张3类别
yolo
王哈哈^_^3 小时前
【完整源码+数据集】课堂行为数据集,yolo课堂行为检测数据集 2090 张,学生课堂行为识别数据集,目标检测课堂行为识别系统实战教程
人工智能·算法·yolo·目标检测·计算机视觉·视觉检测·毕业设计
hongjianMa5 小时前
【论文阅读】Hypercomplex Prompt-aware Multimodal Recommendation
论文阅读·python·深度学习·机器学习·prompt·推荐系统
Sunhen_Qiletian5 小时前
YOLOv2算法详解(下篇):细节打磨与性能突破的终极密码
算法·yolo
现在,此刻6 小时前
李沐深度学习笔记D3-线性回归
笔记·深度学习·线性回归
能来帮帮蒟蒻吗6 小时前
深度学习(2)—— 神经网络与训练
人工智能·深度学习·神经网络
知行力8 小时前
【GitHub每日速递 20251111】PyTorch:GPU加速、动态网络,深度学习平台的不二之选!
pytorch·深度学习·github
ifeng09189 小时前
HarmonyOS资源加载进阶:惰性加载、预加载与缓存机制
深度学习·缓存·harmonyos
Danceful_YJ10 小时前
34.来自Transformers的双向编码器表示(BERT)
人工智能·深度学习·bert
love530love10 小时前
【笔记】xFormers版本与PyTorch、CUDA对应关系及正确安装方法详解
人工智能·pytorch·windows·笔记·python·深度学习·xformers