【YOLOv8】YOLOv8改进系列(9)----替换主干网络之RepViT

主页:HABUO🍁主页:HABUO

🍁YOLOv8入门+改进专栏🍁

🍁如果再也不能见到你,祝你早安,午安,晚安🍁


【YOLOv8改进系列】:

【YOLOv8】YOLOv8结构解读

YOLOv8改进系列(1)----替换主干网络之EfficientViT

YOLOv8改进系列(2)----替换主干网络之FasterNet

YOLOv8改进系列(3)----替换主干网络之ConvNeXt V2

YOLOv8改进系列(4)----替换C2f之FasterNet中的FasterBlock替换C2f中的Bottleneck

YOLOv8改进系列(5)----替换主干网络之EfficientFormerV2

YOLOv8改进系列(6)----替换主干网络之VanillaNet

YOLOv8改进系列(7)----替换主干网络之LSKNet

YOLOv8改进系列(8)----替换主干网络之Swin Transformer


目录

💯一、RepViT介绍

[1. 简介](#1. 简介)

[2. RepViT架构设计](#2. RepViT架构设计)

背景知识

研究方法

[3. 实验与结果](#3. 实验与结果)

[4. 关键结论](#4. 关键结论)

💯二、具体添加方法

第①步:创建repvit.py

第②步:修改task.py

(1)引入创建的repvit文件

(2)修改_predict_once函数

(3)修改parse_model函数

第③步:yolov8.yaml文件修改

第④步:验证是否加入成功


💯一、RepViT介绍

1. 简介

论文介绍了一种新型的轻量级卷积神经网络(CNN)架构------RepViT,它通过从Vision Transformers(ViTs)的角度重新审视和改进传统的轻量级CNN设计,实现了在移动设备上的高性能和低延迟表现。RepViT在多个视觉任务中展现出优越的性能,特别是在ImageNet分类、目标检测、实例分割和语义分割任务中,均优于现有的轻量级ViTs和CNNs。

2. RepViT架构设计

背景知识

  • 轻量级CNN与ViT的发展:过去十年中,轻量级CNN(如MobileNets、ShuffleNets等)在移动设备上取得了广泛应用。然而,随着ViT在大规模图像识别任务中展现出超越CNN的性能,研究者开始探索轻量级ViT的设计,以期在移动设备上实现更好的性能与效率平衡。

  • 轻量级ViT的优势与挑战:轻量级ViT通过引入多头自注意力机制(MHSA)等设计,在性能上超越了轻量级CNN,但它们在移动设备上的实际部署仍面临硬件支持不足和对高分辨率输入敏感等挑战。

研究方法

研究者从轻量级CNN(MobileNetV3)出发,逐步引入轻量级ViT的高效架构设计,最终形成了RepViT架构。改进过程分为以下几个关键步骤:

  1. 块设计(Block Design)

    • 分离令牌混合器和通道混合器:借鉴ViT的MetaFormer架构,将MobileNetV3中的深度可分离卷积(DW Convolution)上移,分离出令牌混合器和通道混合器,并引入结构重参数化技术,优化了模型的训练和推理效率。

    • 调整扩展比例和网络宽度:将通道混合器的扩展比例从6降低到2,并相应增加网络宽度,以减少计算量并提高性能。

  2. 宏观设计(Macro Design)

    • 早期卷积作为干细胞:用两个3×3卷积层替换MobileNetV3复杂的干细胞,降低延迟并提高性能。

    • 加深下采样层:采用类似ViT的独立下采样层设计,通过增加深度和使用RepViT块来优化下采样过程。

    • 简化分类器:将MobileNetV3的复杂分类器替换为简单的全局平均池化层和线性层,以减少延迟。

    • 调整阶段比例:优化网络中各阶段的块数量比例,以实现更好的性能与延迟平衡。

  3. 微观设计(Micro Design)

    • 选择合适的卷积核大小:优先使用3×3卷积核,以确保在移动设备上的推理效率。

    • 优化SE层的放置:采用跨块方式放置SE层,以在不影响延迟的情况下最大化性能提升。


3. 实验与结果

  • 图像分类:在ImageNet-1K数据集上进行实验,RepViT在不同模型大小下均展现出优越的性能与延迟平衡。例如,RepViT-M1.0在iPhone 12上实现了超过80%的top-1准确率,且延迟仅为1.0ms,这是轻量级模型首次达到这样的性能水平。

  • RepViT与SAM结合:将RepViT作为Segment Anything Model(SAM)的图像编码器,形成RepViT-SAM模型。在iPhone 12上,RepViT-SAM能够顺利进行模型推理,而其他竞争对手(如MobileSAM和ViT-B-SAM)则无法运行。在Macbook M1 Pro上,RepViT-SAM的推理速度比MobileSAM快近10倍,同时在零样本边缘检测、零样本实例分割和分割野外基准测试中,RepViT-SAM的性能优于MobileSAM和ViT-B-SAM。

  • 下游任务:RepViT在目标检测和实例分割(使用Mask-RCNN框架)以及语义分割(使用Semantic FPN框架)任务中,均展现出优越的性能与延迟平衡。例如,RepViT-M1.1在MS COCO 2017数据集上的APbox和APmask指标均优于EfficientFormer-L1,且延迟更小;在ADE20K数据集上,RepViT-M1.1的mIoU指标比EfficientFormer-L1高出1.7,同时速度更快。


4. 关键结论

RepViT通过整合轻量级ViT的高效架构设计,成功地将轻量级CNN的性能提升到了一个新的水平,使其在移动设备上展现出优越的性能与延迟平衡。RepViT不仅在图像分类任务中优于现有的轻量级ViTs和CNNs,还在目标检测、实例分割和语义分割等下游任务中展现出强大的迁移能力。此外,RepViT-SAM模型在SAM框架下展现了出色的效率和零样本迁移性能,为SAM在边缘设备上的部署提供了一个强大的基线。研究者希望RepViT能够激发更多关于轻量级模型在边缘部署方面的研究。


💯二、具体添加方法

第①步:创建repvit.py

创建完成后,将下面代码直接复制粘贴进去:

python 复制代码
import torch.nn as nn
import numpy as np
from timm.models.layers import SqueezeExcite
import torch

__all__ = ['repvit_m0_9', 'repvit_m1_0', 'repvit_m1_1', 'repvit_m1_5', 'repvit_m2_3']

def replace_batchnorm(net):
    for child_name, child in net.named_children():
        if hasattr(child, 'fuse_self'):
            fused = child.fuse_self()
            setattr(net, child_name, fused)
            replace_batchnorm(fused)
        elif isinstance(child, torch.nn.BatchNorm2d):
            setattr(net, child_name, torch.nn.Identity())
        else:
            replace_batchnorm(child)

def _make_divisible(v, divisor, min_value=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

class Conv2d_BN(torch.nn.Sequential):
    def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
                 groups=1, bn_weight_init=1, resolution=-10000):
        super().__init__()
        self.add_module('c', torch.nn.Conv2d(
            a, b, ks, stride, pad, dilation, groups, bias=False))
        self.add_module('bn', torch.nn.BatchNorm2d(b))
        torch.nn.init.constant_(self.bn.weight, bn_weight_init)
        torch.nn.init.constant_(self.bn.bias, 0)

    @torch.no_grad()
    def fuse_self(self):
        c, bn = self._modules.values()
        w = bn.weight / (bn.running_var + bn.eps)**0.5
        w = c.weight * w[:, None, None, None]
        b = bn.bias - bn.running_mean * bn.weight / \
            (bn.running_var + bn.eps)**0.5
        m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
            0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups,
            device=c.weight.device)
        m.weight.data.copy_(w)
        m.bias.data.copy_(b)
        return m

class Residual(torch.nn.Module):
    def __init__(self, m, drop=0.):
        super().__init__()
        self.m = m
        self.drop = drop

    def forward(self, x):
        if self.training and self.drop > 0:
            return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
                                              device=x.device).ge_(self.drop).div(1 - self.drop).detach()
        else:
            return x + self.m(x)
    
    @torch.no_grad()
    def fuse_self(self):
        if isinstance(self.m, Conv2d_BN):
            m = self.m.fuse_self()
            assert(m.groups == m.in_channels)
            identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
            identity = torch.nn.functional.pad(identity, [1,1,1,1])
            m.weight += identity.to(m.weight.device)
            return m
        elif isinstance(self.m, torch.nn.Conv2d):
            m = self.m
            assert(m.groups != m.in_channels)
            identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
            identity = torch.nn.functional.pad(identity, [1,1,1,1])
            m.weight += identity.to(m.weight.device)
            return m
        else:
            return self

class RepVGGDW(torch.nn.Module):
    def __init__(self, ed) -> None:
        super().__init__()
        self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed)
        self.conv1 = torch.nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
        self.dim = ed
        self.bn = torch.nn.BatchNorm2d(ed)
    
    def forward(self, x):
        return self.bn((self.conv(x) + self.conv1(x)) + x)
    
    @torch.no_grad()
    def fuse_self(self):
        conv = self.conv.fuse_self()
        conv1 = self.conv1
        
        conv_w = conv.weight
        conv_b = conv.bias
        conv1_w = conv1.weight
        conv1_b = conv1.bias
        
        conv1_w = torch.nn.functional.pad(conv1_w, [1,1,1,1])

        identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1,1,1,1])

        final_conv_w = conv_w + conv1_w + identity
        final_conv_b = conv_b + conv1_b

        conv.weight.data.copy_(final_conv_w)
        conv.bias.data.copy_(final_conv_b)

        bn = self.bn
        w = bn.weight / (bn.running_var + bn.eps)**0.5
        w = conv.weight * w[:, None, None, None]
        b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / \
            (bn.running_var + bn.eps)**0.5
        conv.weight.data.copy_(w)
        conv.bias.data.copy_(b)
        return conv

class RepViTBlock(nn.Module):
    def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
        super(RepViTBlock, self).__init__()
        assert stride in [1, 2]

        self.identity = stride == 1 and inp == oup
        assert(hidden_dim == 2 * inp)

        if stride == 2:
            self.token_mixer = nn.Sequential(
                Conv2d_BN(inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp),
                SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
                Conv2d_BN(inp, oup, ks=1, stride=1, pad=0)
            )
            self.channel_mixer = Residual(nn.Sequential(
                    # pw
                    Conv2d_BN(oup, 2 * oup, 1, 1, 0),
                    nn.GELU() if use_hs else nn.GELU(),
                    # pw-linear
                    Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
                ))
        else:
            assert(self.identity)
            self.token_mixer = nn.Sequential(
                RepVGGDW(inp),
                SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
            )
            self.channel_mixer = Residual(nn.Sequential(
                    # pw
                    Conv2d_BN(inp, hidden_dim, 1, 1, 0),
                    nn.GELU() if use_hs else nn.GELU(),
                    # pw-linear
                    Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
                ))

    def forward(self, x):
        return self.channel_mixer(self.token_mixer(x))

class RepViT(nn.Module):
    def __init__(self, cfgs):
        super(RepViT, self).__init__()
        # setting of inverted residual blocks
        self.cfgs = cfgs

        # building first layer
        input_channel = self.cfgs[0][2]
        patch_embed = torch.nn.Sequential(Conv2d_BN(3, input_channel // 2, 3, 2, 1), torch.nn.GELU(),
                           Conv2d_BN(input_channel // 2, input_channel, 3, 2, 1))
        layers = [patch_embed]
        # building inverted residual blocks
        block = RepViTBlock
        for k, t, c, use_se, use_hs, s in self.cfgs:
            output_channel = _make_divisible(c, 8)
            exp_size = _make_divisible(input_channel * t, 8)
            layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs))
            input_channel = output_channel
        self.features = nn.ModuleList(layers)
        self.channel = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))]
        
    def forward(self, x):
        input_size = x.size(2)
        scale = [4, 8, 16, 32]
        features = [None, None, None, None]
        for f in self.features:
            x = f(x)
            if input_size // x.size(2) in scale:
                features[scale.index(input_size // x.size(2))] = x
        return features
    
    def switch_to_deploy(self):
        replace_batchnorm(self)

def update_weight(model_dict, weight_dict):
    idx, temp_dict = 0, {}
    for k, v in weight_dict.items():
        # k = k[9:]
        if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
            temp_dict[k] = v
            idx += 1
    model_dict.update(temp_dict)
    print(f'loading weights... {idx}/{len(model_dict)} items')
    return model_dict

def repvit_m0_9(weights=''):
    """
    Constructs a MobileNetV3-Large model
    """
    cfgs = [
        # k, t, c, SE, HS, s 
        [3,   2,  48, 1, 0, 1],
        [3,   2,  48, 0, 0, 1],
        [3,   2,  48, 0, 0, 1],
        [3,   2,  96, 0, 0, 2],
        [3,   2,  96, 1, 0, 1],
        [3,   2,  96, 0, 0, 1],
        [3,   2,  96, 0, 0, 1],
        [3,   2,  192, 0, 1, 2],
        [3,   2,  192, 1, 1, 1],
        [3,   2,  192, 0, 1, 1],
        [3,   2,  192, 1, 1, 1],
        [3,   2, 192, 0, 1, 1],
        [3,   2, 192, 1, 1, 1],
        [3,   2, 192, 0, 1, 1],
        [3,   2, 192, 1, 1, 1],
        [3,   2, 192, 0, 1, 1],
        [3,   2, 192, 1, 1, 1],
        [3,   2, 192, 0, 1, 1],
        [3,   2, 192, 1, 1, 1],
        [3,   2, 192, 0, 1, 1],
        [3,   2, 192, 1, 1, 1],
        [3,   2, 192, 0, 1, 1],
        [3,   2, 192, 0, 1, 1],
        [3,   2, 384, 0, 1, 2],
        [3,   2, 384, 1, 1, 1],
        [3,   2, 384, 0, 1, 1]
    ]
    model = RepViT(cfgs)
    if weights:
        model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)['model']))
    return model

def repvit_m1_0(weights=''):
    """
    Constructs a MobileNetV3-Large model
    """
    cfgs = [
        # k, t, c, SE, HS, s 
        [3,   2,  56, 1, 0, 1],
        [3,   2,  56, 0, 0, 1],
        [3,   2,  56, 0, 0, 1],
        [3,   2,  112, 0, 0, 2],
        [3,   2,  112, 1, 0, 1],
        [3,   2,  112, 0, 0, 1],
        [3,   2,  112, 0, 0, 1],
        [3,   2,  224, 0, 1, 2],
        [3,   2,  224, 1, 1, 1],
        [3,   2,  224, 0, 1, 1],
        [3,   2,  224, 1, 1, 1],
        [3,   2, 224, 0, 1, 1],
        [3,   2, 224, 1, 1, 1],
        [3,   2, 224, 0, 1, 1],
        [3,   2, 224, 1, 1, 1],
        [3,   2, 224, 0, 1, 1],
        [3,   2, 224, 1, 1, 1],
        [3,   2, 224, 0, 1, 1],
        [3,   2, 224, 1, 1, 1],
        [3,   2, 224, 0, 1, 1],
        [3,   2, 224, 1, 1, 1],
        [3,   2, 224, 0, 1, 1],
        [3,   2, 224, 0, 1, 1],
        [3,   2, 448, 0, 1, 2],
        [3,   2, 448, 1, 1, 1],
        [3,   2, 448, 0, 1, 1]
    ]
    model = RepViT(cfgs)
    if weights:
        model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)['model']))
    return model

def repvit_m1_1(weights=''):
    """
    Constructs a MobileNetV3-Large model
    """
    cfgs = [
        # k, t, c, SE, HS, s 
        [3,   2,  64, 1, 0, 1],
        [3,   2,  64, 0, 0, 1],
        [3,   2,  64, 0, 0, 1],
        [3,   2,  128, 0, 0, 2],
        [3,   2,  128, 1, 0, 1],
        [3,   2,  128, 0, 0, 1],
        [3,   2,  128, 0, 0, 1],
        [3,   2,  256, 0, 1, 2],
        [3,   2,  256, 1, 1, 1],
        [3,   2,  256, 0, 1, 1],
        [3,   2,  256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 512, 0, 1, 2],
        [3,   2, 512, 1, 1, 1],
        [3,   2, 512, 0, 1, 1]
    ]
    model = RepViT(cfgs)
    if weights:
        model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)['model']))
    return model

def repvit_m1_5(weights=''):
    """
    Constructs a MobileNetV3-Large model
    """
    cfgs = [
        # k, t, c, SE, HS, s 
        [3,   2,  64, 1, 0, 1],
        [3,   2,  64, 0, 0, 1],
        [3,   2,  64, 1, 0, 1],
        [3,   2,  64, 0, 0, 1],
        [3,   2,  64, 0, 0, 1],
        [3,   2,  128, 0, 0, 2],
        [3,   2,  128, 1, 0, 1],
        [3,   2,  128, 0, 0, 1],
        [3,   2,  128, 1, 0, 1],
        [3,   2,  128, 0, 0, 1],
        [3,   2,  128, 0, 0, 1],
        [3,   2,  256, 0, 1, 2],
        [3,   2,  256, 1, 1, 1],
        [3,   2,  256, 0, 1, 1],
        [3,   2,  256, 1, 1, 1],
        [3,   2,  256, 0, 1, 1],
        [3,   2,  256, 1, 1, 1],
        [3,   2,  256, 0, 1, 1],
        [3,   2,  256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 512, 0, 1, 2],
        [3,   2, 512, 1, 1, 1],
        [3,   2, 512, 0, 1, 1],
        [3,   2, 512, 1, 1, 1],
        [3,   2, 512, 0, 1, 1]
    ]
    model = RepViT(cfgs)
    if weights:
        model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)['model']))
    return model

def repvit_m2_3(weights=''):
    """
    Constructs a MobileNetV3-Large model
    """
    cfgs = [
        # k, t, c, SE, HS, s 
        [3,   2,  80, 1, 0, 1],
        [3,   2,  80, 0, 0, 1],
        [3,   2,  80, 1, 0, 1],
        [3,   2,  80, 0, 0, 1],
        [3,   2,  80, 1, 0, 1],
        [3,   2,  80, 0, 0, 1],
        [3,   2,  80, 0, 0, 1],
        [3,   2,  160, 0, 0, 2],
        [3,   2,  160, 1, 0, 1],
        [3,   2,  160, 0, 0, 1],
        [3,   2,  160, 1, 0, 1],
        [3,   2,  160, 0, 0, 1],
        [3,   2,  160, 1, 0, 1],
        [3,   2,  160, 0, 0, 1],
        [3,   2,  160, 0, 0, 1],
        [3,   2,  320, 0, 1, 2],
        [3,   2,  320, 1, 1, 1],
        [3,   2,  320, 0, 1, 1],
        [3,   2,  320, 1, 1, 1],
        [3,   2,  320, 0, 1, 1],
        [3,   2,  320, 1, 1, 1],
        [3,   2,  320, 0, 1, 1],
        [3,   2,  320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        # [3,   2, 320, 1, 1, 1],
        # [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 640, 0, 1, 2],
        [3,   2, 640, 1, 1, 1],
        [3,   2, 640, 0, 1, 1],
        # [3,   2, 640, 1, 1, 1],
        # [3,   2, 640, 0, 1, 1]
    ]
    model = RepViT(cfgs)
    if weights:
        model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)['model']))
    return model

if __name__ == '__main__':
    model = repvit_m2_3('repvit_m2_3_distill_450e.pth')
    inputs = torch.randn((1, 3, 640, 640))
    res = model(inputs)
    for i in res:
        print(i.size())

第②步:修改task.py

(1)引入创建的repvit文件

python 复制代码
from ultralytics.nn.backbone.repvit import *

(2)修改_predict_once函数

python 复制代码
 def _predict_once(self, x, profile=False, visualize=False, embed=None):
        """
        Perform a forward pass through the network.
        Args:
            x (torch.Tensor): The input tensor to the model.
            profile (bool):  Print the computation time of each layer if True, defaults to False.
            visualize (bool): Save the feature maps of the model if True, defaults to False.
            embed (list, optional): A list of feature vectors/embeddings to return.
        Returns:
            (torch.Tensor): The last output of the model.
        """
        y, dt, embeddings = [], [], []  # outputs
        for idx, m in enumerate(self.model):
            if m.f != -1:  # if not from previous layer
                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
            if profile:
                self._profile_one_layer(m, x, dt)
            if hasattr(m, 'backbone'):
                x = m(x)
                for _ in range(5 - len(x)):
                    x.insert(0, None)
                for i_idx, i in enumerate(x):
                    if i_idx in self.save:
                        y.append(i)
                    else:
                        y.append(None)
                # print(f'layer id:{idx:>2} {m.type:>50} output shape:{", ".join([str(x_.size()) for x_ in x if x_ is not None])}')
                x = x[-1]
            else:
                x = m(x)  # run
                y.append(x if m.i in self.save else None)  # save output
            
            # if type(x) in {list, tuple}:
            #     if idx == (len(self.model) - 1):
            #         if type(x[1]) is dict:
            #             print(f'layer id:{idx:>2} {m.type:>50} output shape:{", ".join([str(x_.size()) for x_ in x[1]["one2one"]])}')
            #         else:
            #             print(f'layer id:{idx:>2} {m.type:>50} output shape:{", ".join([str(x_.size()) for x_ in x[1]])}')
            #     else:
            #         print(f'layer id:{idx:>2} {m.type:>50} output shape:{", ".join([str(x_.size()) for x_ in x if x_ is not None])}')
            # elif type(x) is dict:
            #     print(f'layer id:{idx:>2} {m.type:>50} output shape:{", ".join([str(x_.size()) for x_ in x["one2one"]])}')
            # else:
            #     if not hasattr(m, 'backbone'):
            #         print(f'layer id:{idx:>2} {m.type:>50} output shape:{x.size()}')
            
            if visualize:
                feature_visualization(x, m.type, m.i, save_dir=visualize)
            if embed and m.i in embed:
                embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1))  # flatten
                if m.i == max(embed):
                    return torch.unbind(torch.cat(embeddings, 1), dim=0)
        return x

(3)修改parse_model函数

可以直接把下面的代码粘贴到对应的位置中,后续的改进中,对应的模块就不需要做出改变,有改变处,后续会另有说明

python 复制代码
def parse_model(d, ch, verbose=True, warehouse_manager=None):  # model_dict, input_channels(3)
    """Parse a YOLO model.yaml dictionary into a PyTorch model."""
    import ast
 
    # Args
    max_channels = float("inf")
    nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
    depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
    if scales:
        scale = d.get("scale")
        if not scale:
            scale = tuple(scales.keys())[0]
            LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
        if len(scales[scale]) == 3:
            depth, width, max_channels = scales[scale]
        elif len(scales[scale]) == 4:
            depth, width, max_channels, threshold = scales[scale]
 
    if act:
        Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = nn.SiLU()
        if verbose:
            LOGGER.info(f"{colorstr('activation:')} {act}")  # print
 
    if verbose:
        LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10}  {'module':<60}{'arguments':<50}")
    ch = [ch]
    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
    is_backbone = False
    for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]):  # from, number, module, args
        try:
            if m == 'node_mode':
                m = d[m]
                if len(args) > 0:
                    if args[0] == 'head_channel':
                        args[0] = int(d[args[0]])
            t = m
            m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m]  # get module
        except:
            pass
        for j, a in enumerate(args):
            if isinstance(a, str):
                with contextlib.suppress(ValueError):
                    try:
                        args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
                    except:
                        args[j] = a
        n = n_ = max(round(n * depth), 1) if n > 1 else n  # depth gain
        if m in {
            Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus, BottleneckCSP, C1, C2, C2f, ELAN1, AConv, SPPELAN, C2fAttn, C3, C3TR, 
            C3Ghost, nn.Conv2d, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3, PSA, SCDown, C2fCIB, C2f_Faster, C2f_ODConv,
            C2f_Faster_EMA, C2f_DBB, GSConv, GSConvns, VoVGSCSP, VoVGSCSPns, VoVGSCSPC, C2f_CloAtt, C3_CloAtt, SCConv, C2f_SCConv, C3_SCConv, C2f_ScConv, C3_ScConv,
            C3_EMSC, C3_EMSCP, C2f_EMSC, C2f_EMSCP, RCSOSA, KWConv, C2f_KW, C3_KW, DySnakeConv, C2f_DySnakeConv, C3_DySnakeConv,
            DCNv2, C3_DCNv2, C2f_DCNv2, DCNV3_YOLO, C3_DCNv3, C2f_DCNv3, C3_Faster, C3_Faster_EMA, C3_ODConv,
            OREPA, OREPA_LargeConv, RepVGGBlock_OREPA, C3_OREPA, C2f_OREPA, C3_DBB, C3_REPVGGOREPA, C2f_REPVGGOREPA,
            C3_DCNv2_Dynamic, C2f_DCNv2_Dynamic, C3_ContextGuided, C2f_ContextGuided, C3_MSBlock, C2f_MSBlock,
            C3_DLKA, C2f_DLKA, CSPStage, SPDConv, RepBlock, C3_EMBC, C2f_EMBC, SPPF_LSKA, C3_DAttention, C2f_DAttention,
            C3_Parc, C2f_Parc, C3_DWR, C2f_DWR, RFAConv, RFCAConv, RFCBAMConv, C3_RFAConv, C2f_RFAConv,
            C3_RFCBAMConv, C2f_RFCBAMConv, C3_RFCAConv, C2f_RFCAConv, C3_FocusedLinearAttention, C2f_FocusedLinearAttention,
            C3_AKConv, C2f_AKConv, AKConv, C3_MLCA, C2f_MLCA,
            C3_UniRepLKNetBlock, C2f_UniRepLKNetBlock, C3_DRB, C2f_DRB, C3_DWR_DRB, C2f_DWR_DRB, CSP_EDLAN,
            C3_AggregatedAtt, C2f_AggregatedAtt, DCNV4_YOLO, C3_DCNv4, C2f_DCNv4, HWD, SEAM,
            C3_SWC, C2f_SWC, C3_iRMB, C2f_iRMB, C3_iRMB_Cascaded, C2f_iRMB_Cascaded, C3_iRMB_DRB, C2f_iRMB_DRB, C3_iRMB_SWC, C2f_iRMB_SWC,
            C3_VSS, C2f_VSS, C3_LVMB, C2f_LVMB, RepNCSPELAN4, DBBNCSPELAN4, OREPANCSPELAN4, DRBNCSPELAN4, ADown, V7DownSampling,
            C3_DynamicConv, C2f_DynamicConv, C3_GhostDynamicConv, C2f_GhostDynamicConv, C3_RVB, C2f_RVB, C3_RVB_SE, C2f_RVB_SE, C3_RVB_EMA, C2f_RVB_EMA, DGCST,
            C3_RetBlock, C2f_RetBlock, C3_PKIModule, C2f_PKIModule, RepNCSPELAN4_CAA, C3_FADC, C2f_FADC, C3_PPA, C2f_PPA, SRFD, DRFD, RGCSPELAN,
            C3_Faster_CGLU, C2f_Faster_CGLU, C3_Star, C2f_Star, C3_Star_CAA, C2f_Star_CAA, C3_KAN, C2f_KAN, C3_EIEM, C2f_EIEM, C3_DEConv, C2f_DEConv,
            C3_SMPCGLU, C2f_SMPCGLU, C3_Heat, C2f_Heat, CSP_PTB, SimpleStem, VisionClueMerge, VSSBlock_YOLO, XSSBlock, GLSA, C2f_WTConv, WTConv2d, FeaturePyramidSharedConv,
            C2f_FMB, LDConv, C2f_gConv, C2f_WDBB, C2f_DeepDBB, C2f_AdditiveBlock, C2f_AdditiveBlock_CGLU, CSP_MSCB, C2f_MSMHSA_CGLU, CSP_PMSFA, C2f_MogaBlock,
            C2f_SHSA, C2f_SHSA_CGLU, C2f_SMAFB, C2f_SMAFB_CGLU, C2f_IdentityFormer, C2f_RandomMixing, C2f_PoolingFormer, C2f_ConvFormer, C2f_CaFormer,
            C2f_IdentityFormerCGLU, C2f_RandomMixingCGLU, C2f_PoolingFormerCGLU, C2f_ConvFormerCGLU, C2f_CaFormerCGLU, CSP_MutilScaleEdgeInformationEnhance, C2f_FFCM,
            C2f_SFHF, CSP_FreqSpatial, C2f_MSM, C2f_RAB, C2f_HDRAB, C2f_LFE, CSP_MutilScaleEdgeInformationSelect, C2f_SFA, C2f_CTA, C2f_CAMixer, MANet,
            MANet_FasterBlock, MANet_FasterCGLU, MANet_Star, C2f_HFERB, C2f_DTAB, C2f_ETB, C2f_JDPM, C2f_AP, PSConv, C2f_Kat, C2f_Faster_KAN, C2f_Strip, C2f_StripCGLU
        }:
            if args[0] == 'head_channel':
                args[0] = d[args[0]]
            c1, c2 = ch[f], args[0]
            if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            if m is C2fAttn:
                args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)  # embed channels
                args[2] = int(
                    max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2]
                )  # num heads
 
            args = [c1, c2, *args[1:]]
            if m in (KWConv, C2f_KW, C3_KW):
                args.insert(2, f'layer{i}')
                args.insert(2, warehouse_manager)
            if m in (DySnakeConv,):
                c2 = c2 * 3
            if m in (RepNCSPELAN4, DBBNCSPELAN4, OREPANCSPELAN4, DRBNCSPELAN4, RepNCSPELAN4_CAA):
                args[2] = make_divisible(min(args[2], max_channels) * width, 8)
                args[3] = make_divisible(min(args[3], max_channels) * width, 8)
            if m in {
                     BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3, C2fCIB, C2f_Faster, C2f_ODConv, C2f_Faster_EMA, C2f_DBB,
                     VoVGSCSP, VoVGSCSPns, VoVGSCSPC, C2f_CloAtt, C3_CloAtt, C2f_SCConv, C3_SCConv, C2f_ScConv, C3_ScConv,
                     C3_EMSC, C3_EMSCP, C2f_EMSC, C2f_EMSCP, RCSOSA, C2f_KW, C3_KW, C2f_DySnakeConv, C3_DySnakeConv,
                     C3_DCNv2, C2f_DCNv2, C3_DCNv3, C2f_DCNv3, C3_Faster, C3_Faster_EMA, C3_ODConv, C3_OREPA, C2f_OREPA, C3_DBB,
                     C3_REPVGGOREPA, C2f_REPVGGOREPA, C3_DCNv2_Dynamic, C2f_DCNv2_Dynamic, C3_ContextGuided, C2f_ContextGuided, 
                     C3_MSBlock, C2f_MSBlock, C3_DLKA, C2f_DLKA, CSPStage, RepBlock, C3_EMBC, C2f_EMBC, C3_DAttention, C2f_DAttention,
                     C3_Parc, C2f_Parc, C3_DWR, C2f_DWR, C3_RFAConv, C2f_RFAConv, C3_RFCBAMConv, C2f_RFCBAMConv, C3_RFCAConv, C2f_RFCAConv,
                     C3_FocusedLinearAttention, C2f_FocusedLinearAttention, C3_AKConv, C2f_AKConv, C3_MLCA, C2f_MLCA,
                     C3_UniRepLKNetBlock, C2f_UniRepLKNetBlock, C3_DRB, C2f_DRB, C3_DWR_DRB, C2f_DWR_DRB, CSP_EDLAN,
                     C3_AggregatedAtt, C2f_AggregatedAtt, C3_DCNv4, C2f_DCNv4, C3_SWC, C2f_SWC,
                     C3_iRMB, C2f_iRMB, C3_iRMB_Cascaded, C2f_iRMB_Cascaded, C3_iRMB_DRB, C2f_iRMB_DRB, C3_iRMB_SWC, C2f_iRMB_SWC,
                     C3_VSS, C2f_VSS, C3_LVMB, C2f_LVMB, C3_DynamicConv, C2f_DynamicConv, C3_GhostDynamicConv, C2f_GhostDynamicConv,
                     C3_RVB, C2f_RVB, C3_RVB_SE, C2f_RVB_SE, C3_RVB_EMA, C2f_RVB_EMA, C3_RetBlock, C2f_RetBlock, C3_PKIModule, C2f_PKIModule,
                     C3_FADC, C2f_FADC, C3_PPA, C2f_PPA, RGCSPELAN, C3_Faster_CGLU, C2f_Faster_CGLU, C3_Star, C2f_Star, C3_Star_CAA, C2f_Star_CAA,
                     C3_KAN, C2f_KAN, C3_EIEM, C2f_EIEM, C3_DEConv, C2f_DEConv, C3_SMPCGLU, C2f_SMPCGLU, C3_Heat, C2f_Heat, CSP_PTB, XSSBlock, C2f_WTConv,
                     C2f_FMB, C2f_gConv, C2f_WDBB, C2f_DeepDBB, C2f_AdditiveBlock, C2f_AdditiveBlock_CGLU, CSP_MSCB, C2f_MSMHSA_CGLU, CSP_PMSFA, C2f_MogaBlock,
                     C2f_SHSA, C2f_SHSA_CGLU, C2f_SMAFB, C2f_SMAFB_CGLU, C2f_IdentityFormer, C2f_RandomMixing, C2f_PoolingFormer, C2f_ConvFormer, C2f_CaFormer,
                     C2f_IdentityFormerCGLU, C2f_RandomMixingCGLU, C2f_PoolingFormerCGLU, C2f_ConvFormerCGLU, C2f_CaFormerCGLU, CSP_MutilScaleEdgeInformationEnhance, C2f_FFCM,
                     C2f_SFHF, CSP_FreqSpatial, C2f_MSM, C2f_RAB, C2f_HDRAB, C2f_LFE, CSP_MutilScaleEdgeInformationSelect, C2f_SFA, C2f_CTA, C2f_CAMixer, MANet,
                     MANet_FasterBlock, MANet_FasterCGLU, MANet_Star, C2f_HFERB, C2f_DTAB, C2f_ETB, C2f_JDPM, C2f_AP, C2f_Kat, C2f_Faster_KAN, C2f_Strip, C2f_StripCGLU
                     }:
                args.insert(2, n)  # number of repeats
                n = 1
        elif m in {AIFI, AIFI_RepBN}:
            args = [ch[f], *args]
            c2 = args[0]
        elif m in (HGStem, HGBlock, Ghost_HGBlock, Rep_HGBlock, Dynamic_HGBlock, EIEStem):
            c1, cm, c2 = ch[f], args[0], args[1]
            if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)
                c2 = make_divisible(min(c2, max_channels) * width, 8)
                cm = make_divisible(min(cm, max_channels) * width, 8)
            args = [c1, cm, c2, *args[2:]]
            if m in (HGBlock, Ghost_HGBlock, Rep_HGBlock, Dynamic_HGBlock):
                args.insert(4, n)  # number of repeats
                n = 1
        elif m is ResNetLayer:
            c2 = args[1] if args[3] else args[1] * 4
        elif m is nn.BatchNorm2d:
            args = [ch[f]]
        elif m is Concat:
            c2 = sum(ch[x] for x in f)
        elif m in ((WorldDetect, ImagePoolingAttn) + DETECT_CLASS + V10_DETECT_CLASS + SEGMENT_CLASS + POSE_CLASS + OBB_CLASS):
            args.append([ch[x] for x in f])
            if m in SEGMENT_CLASS:
                args[2] = make_divisible(min(args[2], max_channels) * width, 8)
                if m in (Segment_LSCD, Segment_TADDH, Segment_LSCSBD, Segment_LSDECD, Segment_RSCD):
                    args[3] = make_divisible(min(args[3], max_channels) * width, 8)
            if m in (Detect_LSCD, Detect_TADDH, Detect_LSCSBD, Detect_LSDECD, Detect_RSCD, v10Detect_LSCD, v10Detect_TADDH, v10Detect_RSCD, v10Detect_LSDECD):
                args[1] = make_divisible(min(args[1], max_channels) * width, 8)
            if m in (Pose_LSCD, Pose_TADDH, Pose_LSCSBD, Pose_LSDECD, Pose_RSCD, OBB_LSCD, OBB_TADDH, OBB_LSCSBD, OBB_LSDECD, OBB_RSCD):
                args[2] = make_divisible(min(args[2], max_channels) * width, 8)
        elif m is RTDETRDecoder:  # special case, channels arg must be passed in index 1
            args.insert(1, [ch[x] for x in f])
        elif m is Fusion:
            args[0] = d[args[0]]
            c1, c2 = [ch[x] for x in f], (sum([ch[x] for x in f]) if args[0] == 'concat' else ch[f[0]])
            args = [c1, args[0]]
        elif m is CBLinear:
            c2 = make_divisible(min(args[0][-1], max_channels) * width, 8)
            c1 = ch[f]
            args = [c1, [make_divisible(min(c2_, max_channels) * width, 8) for c2_ in args[0]], *args[1:]]
        elif m is CBFuse:
            c2 = ch[f[-1]]
        elif isinstance(m, str):
            t = m
            if len(args) == 2:        
                m = timm.create_model(m, pretrained=args[0], pretrained_cfg_overlay={'file':args[1]}, features_only=True)
            elif len(args) == 1:
                m = timm.create_model(m, pretrained=args[0], features_only=True)
            c2 = m.feature_info.channels()
        elif m in {convnextv2_atto, convnextv2_femto, convnextv2_pico, convnextv2_nano, convnextv2_tiny, convnextv2_base, convnextv2_large, convnextv2_huge,
                   fasternet_t0, fasternet_t1, fasternet_t2, fasternet_s, fasternet_m, fasternet_l,
                   EfficientViT_M0, EfficientViT_M1, EfficientViT_M2, EfficientViT_M3, EfficientViT_M4, EfficientViT_M5,
                   efficientformerv2_s0, efficientformerv2_s1, efficientformerv2_s2, efficientformerv2_l,
                   vanillanet_5, vanillanet_6, vanillanet_7, vanillanet_8, vanillanet_9, vanillanet_10, vanillanet_11, vanillanet_12, vanillanet_13, vanillanet_13_x1_5, vanillanet_13_x1_5_ada_pool,
                   RevCol,
                   lsknet_t, lsknet_s,
                   SwinTransformer_Tiny,
                   repvit_m0_9, repvit_m1_0, repvit_m1_1, repvit_m1_5, repvit_m2_3,
                   CSWin_tiny, CSWin_small, CSWin_base, CSWin_large,
                   unireplknet_a, unireplknet_f, unireplknet_p, unireplknet_n, unireplknet_t, unireplknet_s, unireplknet_b, unireplknet_l, unireplknet_xl,
                   transnext_micro, transnext_tiny, transnext_small, transnext_base,
                   RMT_T, RMT_S, RMT_B, RMT_L,
                   PKINET_T, PKINET_S, PKINET_B,
                   MobileNetV4ConvSmall, MobileNetV4ConvMedium, MobileNetV4ConvLarge, MobileNetV4HybridMedium, MobileNetV4HybridLarge,
                   starnet_s050, starnet_s100, starnet_s150, starnet_s1, starnet_s2, starnet_s3, starnet_s4
                   }:
            if m is RevCol:
                args[1] = [make_divisible(min(k, max_channels) * width, 8) for k in args[1]]
                args[2] = [max(round(k * depth), 1) for k in args[2]]
            m = m(*args)
            c2 = m.channel
        elif m in {EMA, SpatialAttention, BiLevelRoutingAttention, BiLevelRoutingAttention_nchw,
                   TripletAttention, CoordAtt, CBAM, BAMBlock, LSKBlock, ScConv, LAWDS, EMSConv, EMSConvP,
                   SEAttention, CPCA, Partial_conv3, FocalModulation, EfficientAttention, MPCA, deformable_LKA,
                   EffectiveSEModule, LSKA, SegNext_Attention, DAttention, MLCA, TransNeXt_AggregatedAttention,
                   FocusedLinearAttention, LocalWindowAttention, ChannelAttention_HSFPN, ELA_HSFPN, CA_HSFPN, CAA_HSFPN, 
                   DySample, CARAFE, CAA, ELA, CAFM, AFGCAttention, EUCB, ContrastDrivenFeatureAggregation, FSA}:
            c2 = ch[f]
            args = [c2, *args]
            # print(args)
        elif m in {SimAM, SpatialGroupEnhance}:
            c2 = ch[f]
        elif m is ContextGuidedBlock_Down:
            c2 = ch[f] * 2
            args = [ch[f], c2, *args]
        elif m is BiFusion:
            c1 = [ch[x] for x in f]
            c2 = make_divisible(min(args[0], max_channels) * width, 8)
            args = [c1, c2]
        # --------------GOLD-YOLO--------------
        elif m in {SimFusion_4in, AdvPoolFusion}:
            c2 = sum(ch[x] for x in f)
        elif m is SimFusion_3in:
            c2 = args[0]
            if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [[ch[f_] for f_ in f], c2]
        elif m is IFM:
            c1 = ch[f]
            c2 = sum(args[0])
            args = [c1, *args]
        elif m is InjectionMultiSum_Auto_pool:
            c1 = ch[f[0]]
            c2 = args[0]
            args = [c1, *args]
        elif m is PyramidPoolAgg:
            c2 = args[0]
            args = [sum([ch[f_] for f_ in f]), *args]
        elif m is TopBasicLayer:
            c2 = sum(args[1])
        # --------------GOLD-YOLO--------------
        # --------------ASF--------------
        elif m is Zoom_cat:
            c2 = sum(ch[x] for x in f)
        elif m is Add:
            c2 = ch[f[-1]]
        elif m in {ScalSeq, DynamicScalSeq}:
            c1 = [ch[x] for x in f]
            c2 = make_divisible(args[0] * width, 8)
            args = [c1, c2]
        elif m is asf_attention_model:
            args = [ch[f[-1]]]
        # --------------ASF--------------
        elif m is SDI:
            args = [[ch[x] for x in f]]
        elif m is Multiply:
            c2 = ch[f[0]]
        elif m is FocusFeature:
            c1 = [ch[x] for x in f]
            c2 = int(c1[1] * 0.5 * 3)
            args = [c1, *args]
        elif m is DASI:
            c1 = [ch[x] for x in f]
            args = [c1, c2]
        elif m is CSMHSA:
            c1 = [ch[x] for x in f]
            c2 = ch[f[-1]]
            args = [c1, c2]
        elif m is CFC_CRB:
            c1 = ch[f]
            c2 = c1 // 2
            args = [c1, *args]
        elif m is SFC_G2:
            c1 = [ch[x] for x in f]
            c2 = c1[0]
            args = [c1]
        elif m in {CGAFusion, CAFMFusion, SDFM, PSFM}:
            c2 = ch[f[1]]
            args = [c2, *args]
        elif m in {ContextGuideFusionModule}:
            c1 = [ch[x] for x in f]
            c2 = 2 * c1[1]
            args = [c1]
        # elif m in {PSA}:
        #     c2 = ch[f]
        #     args = [c2, *args]
        elif m in {SBA}:
            c1 = [ch[x] for x in f]
            c2 = c1[-1]
            args = [c1, c2]
        elif m in {WaveletPool}:
            c2 = ch[f] * 4
        elif m in {WaveletUnPool}:
            c2 = ch[f] // 4
        elif m in {CSPOmniKernel}:
            c2 = ch[f]
            args = [c2]
        elif m in {ChannelTransformer, PyramidContextExtraction}:
            c1 = [ch[x] for x in f]
            c2 = c1
            args = [c1]
        elif m in {RCM}:
            c2 = ch[f]
            args = [c2, *args]
        elif m in {DynamicInterpolationFusion}:
            c2 = ch[f[0]]
            args = [[ch[x] for x in f]]
        elif m in {FuseBlockMulti}:
            c2 = ch[f[0]]
            args = [c2]
        elif m in {CrossLayerChannelAttention, CrossLayerSpatialAttention}:
            c2 = [ch[x] for x in f]
            args = [c2[0], *args]
        elif m in {FreqFusion}:
            c2 = ch[f[0]]
            args = [[ch[x] for x in f], *args]
        elif m in {DynamicAlignFusion}:
            c2 = args[0]
            args = [[ch[x] for x in f], c2]
        elif m in {ConvEdgeFusion}:
            c2 = make_divisible(min(args[0], max_channels) * width, 8)
            args = [[ch[x] for x in f], c2]
        elif m in {MutilScaleEdgeInfoGenetator}:
            c1 = ch[f]
            c2 = [make_divisible(min(i, max_channels) * width, 8) for i in args[0]]
            args = [c1, c2]
        elif m in {MultiScaleGatedAttn}:
            c1 = [ch[x] for x in f]
            c2 = min(c1)
            args = [c1]
        elif m in {WFU, MultiScalePCA, MultiScalePCA_Down}:
            c1 = [ch[x] for x in f]
            c2 = c1[0]
            args = [c1]
        elif m in {GetIndexOutput}:
            c2 = ch[f][args[0]]
        elif m is HyperComputeModule:
            c1, c2 = ch[f], args[0]
            c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, c2, threshold]
        else:
            c2 = ch[f]
 
        if isinstance(c2, list) and m not in {ChannelTransformer, PyramidContextExtraction, CrossLayerChannelAttention, CrossLayerSpatialAttention, MutilScaleEdgeInfoGenetator}:
            is_backbone = True
            m_ = m
            m_.backbone = True
        else:
            m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
            t = str(m)[8:-2].replace('__main__.', '')  # module type
        m.np = sum(x.numel() for x in m_.parameters())  # number params
        m_.i, m_.f, m_.type = i + 4 if is_backbone else i, f, t  # attach index, 'from' index, type
        if verbose:
            LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f}  {t:<60}{str(args):<50}")  # print
        save.extend(x % (i + 4 if is_backbone else i) for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
        layers.append(m_)
        if i == 0:
            ch = []
        if isinstance(c2, list) and m not in {ChannelTransformer, PyramidContextExtraction, CrossLayerChannelAttention, CrossLayerSpatialAttention, MutilScaleEdgeInfoGenetator}:
            ch.extend(c2)
            for _ in range(5 - len(ch)):
                ch.insert(0, 0)
        else:
            ch.append(c2)
    return nn.Sequential(*layers), sorted(save)

第③步:yolov8.yaml文件修改

在下述文件夹中创立yolov8-repvit.yaml

python 复制代码
# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# 0-P1/2
# 1-P2/4
# 2-P3/8
# 3-P4/16
# 4-P5/32

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, repvit_m0_9, []]  # 4
  - [-1, 1, SPPF, [1024, 5]]  # 5

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 6
  - [[-1, 3], 1, Concat, [1]]  # 7 cat backbone P4
  - [-1, 3, C2f, [512]]  # 8

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 9
  - [[-1, 2], 1, Concat, [1]]  # 10 cat backbone P3
  - [-1, 3, C2f, [256]]  # 11 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]] # 12
  - [[-1, 8], 1, Concat, [1]]  # 13 cat head P4
  - [-1, 3, C2f, [512]]  # 14 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]] # 15
  - [[-1, 5], 1, Concat, [1]]  # 16 cat head P5
  - [-1, 3, C2f, [1024]]  # 17 (P5/32-large)

  - [[11, 14, 17], 1, Detect, [nc]]  # Detect(P3, P4, P5)

第④步:验证是否加入成功

将train.py中的配置文件进行修改,并运行


🏋不是每一粒种子都能开花,但播下种子就比荒芜的旷野强百倍🏋

🍁YOLOv8入门+改进专栏🍁


【YOLOv8改进系列】:

【YOLOv8】YOLOv8结构解读

YOLOv8改进系列(1)----替换主干网络之EfficientViT

YOLOv8改进系列(2)----替换主干网络之FasterNet

YOLOv8改进系列(3)----替换主干网络之ConvNeXt V2

YOLOv8改进系列(4)----替换C2f之FasterNet中的FasterBlock替换C2f中的Bottleneck

YOLOv8改进系列(5)----替换主干网络之EfficientFormerV2

YOLOv8改进系列(6)----替换主干网络之VanillaNet

YOLOv8改进系列(7)----替换主干网络之LSKNet

YOLOv8改进系列(8)----替换主干网络之Swin Transformer


相关推荐
文心快码BaiduComate6 小时前
百度云与光本位签署战略合作:用AI Agent 重构芯片研发流程
前端·人工智能·架构
风象南6 小时前
Claude Code这个隐藏技能,让我告别PPT焦虑
人工智能·后端
Mintopia7 小时前
OpenClaw 对软件行业产生的影响
人工智能
陈广亮8 小时前
构建具有长期记忆的 AI Agent:从设计模式到生产实践
人工智能
会写代码的柯基犬8 小时前
DeepSeek vs Kimi vs Qwen —— AI 生成俄罗斯方块代码效果横评
人工智能·llm
Mintopia8 小时前
OpenClaw 是什么?为什么节后热度如此之高?
人工智能
爱可生开源社区8 小时前
DBA 的未来?八位行业先锋的年度圆桌讨论
人工智能·dba
叁两11 小时前
用opencode打造全自动公众号写作流水线,AI 代笔太香了!
前端·人工智能·agent
前端付豪11 小时前
LangChain记忆:通过Memory记住上次的对话细节
人工智能·python·langchain