RTDETR融合[WACV 2025]SEM-Net中的模块


RT-DETR使用教程: RT-DETR使用教程

RT-DETR改进汇总贴:RT-DETR更新汇总贴


《SEM-Net: Efficient Pixel Modelling for image inpainting with Spatially Enhanced SSM》

一、 模块介绍

论文链接: https://arxiv.org/pdf/2411.06318

代码链接: https://github.com/ChrisChen1023/SEM-Net

论文速览:

图像修复旨在根据图像中已知区域的信息修复部分受损的图像。实现语义上合理的修复结果尤其具有挑战性,因为这要求修复区域展现出与语义一致区域相似的模式。这需要一个具有强大能力来捕捉长程依赖关系的模型。现有的模型在这方面存在困难,因为基于卷积神经网络(CNN)的方法感受野增长缓慢,而基于Transformer的方法仅在补丁级别进行交互,这些都不利于捕捉长程依赖关系。受此启发,我们提出了SEM-Net,这是一种新颖的视觉状态空间模型(SSM)视觉网络,在像素级别对受损图像进行建模,同时在状态空间中捕捉长程依赖关系(LRDs),实现了线性计算复杂度。为了解决SSM固有的空间感知不足问题,我们引入了蛇曼巴块(SMB)和空间增强前馈网络。这些创新使SEM-Net在两个不同的数据集上超越了最先进的修复方法,显示出在捕捉长程依赖关系方面的显著改进。

总结:本文更新其中的FeedForward结构。


⭐⭐本文二创模块仅更新于付费群中,往期免费教程可看下方链接⭐⭐

RT-DETR更新汇总贴(含免费教程)文章浏览阅读264次。RT-DETR使用教程:缝合教程: RT-DETR中的yaml文件详解:labelimg使用教程:_rt-deterhttps://xy2668825911.blog.csdn.net/article/details/143696113

二、二创融合模块

2.1 相关代码

复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
from mamba_ssm import Mamba

import math



from pdb import set_trace as stx
import numbers
from einops import rearrange

class BaseNetwork(nn.Module):
    def __init__(self):
        super(BaseNetwork, self).__init__()

    def init_weights(self, init_type='normal', gain=0.02):
        '''
        initialize network's weights
        init_type: normal | xavier | kaiming | orthogonal
        https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
        '''

        def init_func(m):
            classname = m.__class__.__name__
            if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
                if init_type == 'normal':
                    nn.init.normal_(m.weight.data, 0.0, gain)
                elif init_type == 'xavier':
                    nn.init.xavier_normal_(m.weight.data, gain=gain)
                elif init_type == 'kaiming':
                    nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
                elif init_type == 'orthogonal':
                    nn.init.orthogonal_(m.weight.data, gain=gain)

                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias.data, 0.0)

            elif classname.find('BatchNorm2d') != -1:
                nn.init.normal_(m.weight.data, 1.0, gain)
                nn.init.constant_(m.bias.data, 0.0)

        self.apply(init_func)

def spectral_norm(module, mode=True):
    if mode:
        return nn.utils.spectral_norm(module)

    return module

class Discriminator(BaseNetwork):
    def __init__(self, in_channels, use_sigmoid=True, use_spectral_norm=True, init_weights=True):
        super(Discriminator, self).__init__()
        self.use_sigmoid = use_sigmoid

        self.conv1 = self.features = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.conv2 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.conv3 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.conv4 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.conv5 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
        )

        if init_weights:
            self.init_weights()

    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        conv5 = self.conv5(conv4)

        outputs = conv5
        if self.use_sigmoid:
            outputs = torch.sigmoid(conv5)

        return outputs, [conv1, conv2, conv3, conv4, conv5]

############# Restormer-inpainting
def to_3d(x):
    return rearrange(x, 'b c h w -> b (h w) c')


def to_4d(x, h, w):
    return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)


class BiasFree_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(BiasFree_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.normalized_shape = normalized_shape

    def forward(self, x):
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return x / torch.sqrt(sigma + 1e-5) * self.weight


class WithBias_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(WithBias_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.normalized_shape = normalized_shape

    def forward(self, x):
        mu = x.mean(-1, keepdim=True)
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias


class LayerNorm(nn.Module):
    def __init__(self, dim, LayerNorm_type):
        super(LayerNorm, self).__init__()
        if LayerNorm_type == 'BiasFree':
            self.body = BiasFree_LayerNorm(dim)
        else:
            self.body = WithBias_LayerNorm(dim)

    def forward(self, x):
        h, w = x.shape[-2:]
        return to_4d(self.body(to_3d(x)), h, w)


##########################################################################
## Spatially Enhanced Feed-Forward Network (SEFN)
class FeedForward(nn.Module):
    def __init__(self, dim, ffn_expansion_factor, bias):
        super(FeedForward, self).__init__()

        hidden_features = int(dim * ffn_expansion_factor)

        self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
        
        self.fusion = nn.Conv2d(hidden_features + dim, hidden_features, kernel_size=1, bias=bias)
        self.dwconv_afterfusion = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1,
                                groups=hidden_features, bias=bias)

        self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1,
                                groups=hidden_features * 2, bias=bias)    


        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=3,stride=1,padding=1,bias=True),
            LayerNorm(dim, 'WithBias'),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim, dim, kernel_size=3,stride=1,padding=1,bias=True),
            LayerNorm(dim, 'WithBias'),
            nn.ReLU(inplace=True)
        )
        self.upsample = nn.Upsample(scale_factor=2)
        

    def forward(self, x, spatial):
        
        
        x = self.project_in(x)
        
        
        #### Spatial branch
        y = self.avg_pool(spatial)
        y = self.conv(y)
        y = self.upsample(y)  
        ####
        

        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x1 = self.fusion(torch.cat((x1, y),dim=1))
        x1 = self.dwconv_afterfusion(x1)
        
        
        
        x = F.gelu(x1) * x2
        x = self.project_out(x)
        return x

2.2 更改yaml文件 (以自研模型加入为例)

yam文件解读:YOLO系列 ".yaml"文件解读_yolo yaml文件-CSDN博客

打开更改ultralytics/cfg/models/rt-detr路径下的rtdetr-l.yaml文件,替换原有模块。

​​

复制代码
# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
# ⭐⭐Powered by https://blog.csdn.net/StopAndGoyyy,  技术指导QQ:2668825911⭐⭐

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
  # [depth, width, max_channels]
  l: [1.00, 1.00, 512]
#  n: [ 0.33, 0.25, 1024 ]
#  s: [ 0.33, 0.50, 1024 ]
#  m: [ 0.67, 0.75, 768 ]
#  l: [ 1.00, 1.00, 512 ]
#  x: [ 1.00, 1.25, 512 ]
# ⭐⭐Powered by https://blog.csdn.net/StopAndGoyyy,  技术指导QQ:2668825911⭐⭐

backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, CCRI, [128, 5, True, False]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 4, CCRI, [256, 3, True, True]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 4, CCRI, [512, 3, True, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, CCRI, [1024, 3, True, False]]

head:
  - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 9 input_proj.2
  - [-1, 1, AIFI, [1024, 8]]
  - [-1, 1, Conv, [32, 1, 1]] # 11, Y5, lateral_convs.0

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [6, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 13 input_proj.1
  - [[-2, -1], 1, Concat, [1]]
  - [-1, 2, RepC4, [256]] # 15, fpn_blocks.0
  - [-1, 1, Conv, [256, 1, 1]] # 16, Y4, lateral_convs.1

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [4, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 18 input_proj.0
  - [[-2, -1], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, RepC4, [128]] # X3 (20), fpn_blocks.1

  - [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0
  - [[-1, 16], 1, Concat, [1]] # cat Y4
  - [-1, 2, RepC4, [128]] # F4 (23), pan_blocks.0

  - [-1, 1, Conv, [32, 3, 2]] # 24, downsample_convs.1
  - [[-1, 11], 1, FeedForward_SEF, [1]] # cat Y5
  - [-1, 2, RepC4, [128]] # F5 (26), pan_blocks.1

  - [[20, 23, 26], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
# ⭐⭐Powered by https://blog.csdn.net/StopAndGoyyy,  技术指导QQ:2668825911⭐⭐

2.2 修改train.py文件

创建Train_RT脚本用于训练。

复制代码
from ultralytics.models import RTDETR
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

if __name__ == '__main__':
    model = RTDETR(model='ultralytics/cfg/models/rt-detr/rtdetr-l.yaml')
    # model.load('yolov8n.pt')
    model.train(data='./data.yaml', epochs=2, batch=1, device='0', imgsz=640, workers=2, cache=False,
                amp=True, mosaic=False, project='runs/train', name='exp')

​​

在train.py脚本中填入修改好的yaml路径,运行即可训。

​​


相关推荐
点云SLAM2 分钟前
PyTorch张量(Tensor)创建的方式汇总详解和代码示例
人工智能·pytorch·python·深度学习·机器学习·张量创建方式
九.九21 分钟前
【Python】基础语法
python
AndrewHZ28 分钟前
【图像处理基石】什么是色盲仿真技术?
图像处理·人工智能·pytorch·深度学习·计算机视觉·颜色科学·hvs
Eiceblue1 小时前
用Python向PDF添加文本:精确插入文本到PDF文档
开发语言·python·pdf
Python×CATIA工业智造1 小时前
Pycaita二次开发基础代码解析:特征识别、参数化建模与可视化控制
python·pycharm·pycatia
TY-20252 小时前
七、深度学习——RNN
人工智能·rnn·深度学习
nightunderblackcat2 小时前
进阶向:Python图像处理,使用PIL库实现圆形裁剪
开发语言·图像处理·python
站大爷IP2 小时前
动态HTTP隧道代理IP:从配置到实战的完整指南
python
补三补四2 小时前
RNN(循环神经网络)
人工智能·rnn·深度学习·神经网络·算法
婪苏2 小时前
Python 面向对象(二):继承与封装的深度探索
后端·python