(超详细)7-YOLOV5改进-添加 CoTAttention注意力机制

1、在yolov5/models下面新建一个CoTAttention.py文件,在里面放入下面的代码

代码如下:

bash 复制代码
import numpy as np
import torch
from torch import flatten, nn
from torch.nn import init
from torch.nn.modules.activation import ReLU
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn import functional as F


class CoTAttention(nn.Module):

    def __init__(self, dim=512, kernel_size=3):
        super().__init__()
        self.dim = dim
        self.kernel_size = kernel_size

        self.key_embed = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=4, bias=False),
            nn.BatchNorm2d(dim),
            nn.ReLU()
        )
        self.value_embed = nn.Sequential(
            nn.Conv2d(dim, dim, 1, bias=False),
            nn.BatchNorm2d(dim)
        )

        factor = 4
        self.attention_embed = nn.Sequential(
            nn.Conv2d(2 * dim, 2 * dim // factor, 1, bias=False),
            nn.BatchNorm2d(2 * dim // factor),
            nn.ReLU(),
            nn.Conv2d(2 * dim // factor, kernel_size * kernel_size * dim, 1)
        )

    def forward(self, x):
        bs, c, h, w = x.shape
        k1 = self.key_embed(x)  # bs,c,h,w
        v = self.value_embed(x).view(bs, c, -1)  # bs,c,h,w

        y = torch.cat([k1, x], dim=1)  # bs,2c,h,w
        att = self.attention_embed(y)  # bs,c*k*k,h,w
        att = att.reshape(bs, c, self.kernel_size * self.kernel_size, h, w)
        att = att.mean(2, keepdim=False).view(bs, c, -1)  # bs,c,h*w
        k2 = F.softmax(att, dim=-1) * v
        k2 = k2.view(bs, c, h, w)

        return k1 + k2

2、找到yolo.py文件,进行更改内容

在29行加一个from models.CoTAttention import CoTAttention, 保存即可

3、找到自己想要更改的yaml文件,我选择的yolov5s.yaml文件(你可以根据自己需求进行选择),将刚刚写好的模块CoTAttention加入到yolov5s.yaml里面,并更改一些内容。更改如下

4、在yolo.py里面加入两行代码(335-337)

保存即可!

运行

相关推荐
玄同7656 分钟前
Python 自动发送邮件实战:用 QQ/163 邮箱发送大模型生成的内容
开发语言·人工智能·python·深度学习·机器学习·邮件·邮箱
逸俊晨晖6 分钟前
NVIDIA 4090的8路1080p实时YOLOv8目标检测
人工智能·yolo·目标检测·nvidia
玄同76513 分钟前
机器学习中的三大距离度量:欧式距离、曼哈顿距离、切比雪夫距离详解
人工智能·深度学习·神经网络·目标检测·机器学习·自然语言处理·数据挖掘
听麟21 分钟前
HarmonyOS 6.0+ APP AR文旅导览系统开发实战:空间定位与文物交互落地
人工智能·深度学习·华为·ar·wpf·harmonyos
盼小辉丶41 分钟前
Transformer实战——微调多语言Transformer模型
深度学习·语言模型·transformer
Tadas-Gao42 分钟前
深度学习与机器学习的知识路径:从必要基石到独立范式
人工智能·深度学习·机器学习·架构·大模型·llm
不懒不懒1 小时前
【从零开始:PyTorch实现MNIST手写数字识别全流程解析】
人工智能·pytorch·python
机器学习之心1 小时前
基于GRU门控循环单元的轴承剩余寿命预测MATLAB实现
深度学习·matlab·gru·轴承剩余寿命预测
算法狗21 小时前
大模型面试题:1B的模型和1T的数据大概要训练多久
人工智能·深度学习·机器学习·语言模型
工程师老罗1 小时前
YOLOv1数据增强
人工智能·yolo