改进的mamba核心块—Hybrid SS2D Block(适用于视觉)

环境配置参考:

https://blog.csdn.net/llz19670/article/details/140359923?utm_source%20=%20uc_fansmsghttps://blog.csdn.net/llz19670/article/details/140359923?utm_source%20=%20uc_fansmsg主要的结构性创新是将 SS2D 的长程依赖捕捉能力与显式的通道混合局部卷积 结合,并加入了残差连接 ,形成一个更鲁棒的 CNN-SSM 混合块

创新描述:

方面 原始 SS2D 模块 改进后的 H-SS2D 模块 提升点
局部特征 仅使用深度卷积 (groups=d_inner),不混合通道。 引入额外的 1*1 标准卷积 (self.channel_mix) 在 SSM 运算之后。 增强了模块在捕捉长程依赖后,跨通道混合信息的能力,提供更丰富的局部特征表示。
模块连接 无残差连接(或未显式展示)。 在模块输入和最终输出之间添加残差连接 提高训练稳定性,缓解深层网络中的梯度消失问题,加速收敛。
信息流 SSM 扫描结果 y 经过 Norm,Gate,Proj。 SSM 扫描结果y经过Norm,Conv,Gate,Proj。 1*1卷积为 SSM 结果添加了一个额外的非线性通道交互层,提高了特征的表示力。

完整的改进代码 (Hybrid SS2D)

复制代码
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import repeat

# 尝试导入 Mamba SSM 核心操作
try:
    from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
except ImportError:
    print("Warning: 'mamba_ssm' not found. Using alternative/placeholder if available.")
    selective_scan_fn = None
    pass

try:
    from selective_scan import selective_scan_fn as selective_scan_fn_v1
except ImportError:
    selective_scan_fn_v1 = None
    pass


class HybridSS2D(nn.Module):
    """
    Hybrid SS2D Block (H-SS2D):
    An improved Mamba-based 2D block incorporating a Channel Mixing 1x1 Conv 
    and a Residual Connection for enhanced performance in vision tasks.
    
    NOTE: For residual connection to work simply, ensure d_model == d_inner (i.e., expand=1).
    """
    def __init__(
        self,
        d_model,
        d_state=16,
        d_conv=3,
        expand=1, # Setting expand=1 for simple residual connection
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        dropout=0.,
        conv_bias=True,
        bias=False,
        device=None,
        dtype=None,
        **kwargs,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = d_model
        # Ensure d_inner matches d_model for the residual connection
        if expand != 1:
             raise ValueError("HybridSS2D requires expand=1 for residual connection matching.")
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model) # d_inner = d_model
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank

        # Input Projection: Maps d_model -> 2*d_inner (x and z)
        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
        
        # Depthwise Conv: Local feature extraction
        self.conv2d = nn.Conv2d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            groups=self.d_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            padding=(d_conv - 1) // 2,
            **factory_kwargs,
        )
        self.act = nn.SiLU()

        # x_proj and dt_projs for SSM parameters (A, B, C, dt)
        proj_dim = (self.dt_rank + self.d_state * 2)
        self.x_proj_weight = nn.Parameter(
            torch.stack(
                [nn.Linear(self.d_inner, proj_dim, bias=False, **factory_kwargs).weight for _ in range(4)], 
                dim=0
            )
        ) # (K=4, proj_dim, d_inner)
        
        dt_projs = [
            self._dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs) 
            for _ in range(4)
        ]
        self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in dt_projs], dim=0)) # (K=4, d_inner, rank)
        self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in dt_projs], dim=0)) # (K=4, d_inner)
        
        # A, D parameters
        self.A_logs = self._A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K*D, N)
        self.Ds = self._D_init(self.d_inner, copies=4, merge=True) # (K*D)

        # Selective Scan core function setting
        self.forward_core = self._forward_corev0

        # Output/Gating Structure
        self.out_norm = nn.LayerNorm(self.d_inner)
        
        # === NEW FEATURE: 1x1 Conv for Channel Mixing after SSM Core ===
        self.channel_mix = nn.Conv2d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            kernel_size=1,
            bias=conv_bias,
            **factory_kwargs,
        )
        # =============================================================
        
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
        self.dropout = nn.Dropout(dropout) if dropout > 0. else None
        
        
    # --- Helper Methods from Original SS2D (Renamed with underscore) ---
    @staticmethod
    def _dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs):
        dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
        # Initialization logic (unchanged)
        dt_init_std = dt_rank**-0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        dt = torch.exp(
            torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            dt_proj.bias.copy_(inv_dt)
        dt_proj.bias._no_reinit = True
        return dt_proj

    @staticmethod
    def _A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
        A = repeat(
            torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=d_inner,
        ).contiguous()
        A_log = torch.log(A)
        if copies > 1:
            A_log = repeat(A_log, "d n -> r d n", r=copies)
            if merge:
                A_log = A_log.flatten(0, 1)
        A_log = nn.Parameter(A_log)
        A_log._no_weight_decay = True
        return A_log

    @staticmethod
    def _D_init(d_inner, copies=1, device=None, merge=True):
        D = torch.ones(d_inner, device=device)
        if copies > 1:
            D = repeat(D, "n1 -> r n1", r=copies)
            if merge:
                D = D.flatten(0, 1)
        D = nn.Parameter(D)
        D._no_weight_decay = True
        return D
    
    # --- Core SSM Forward (Unchanged, using _forward_corev0) ---
    def _forward_corev0(self, x: torch.Tensor):
        # Dynamically select the best available selective scan implementation
        if selective_scan_fn is not None:
            selective_scan = selective_scan_fn
        elif selective_scan_fn_v1 is not None:
            selective_scan = selective_scan_fn_v1
        else:
            raise NotImplementedError("No selective_scan_fn implementation available.")
            
        B, C, H, W = x.shape
        L = H * W
        K = 4

        # 1. Prepare input sequences (4 directions + flip)
        x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
        xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)

        # 2. Project x to get delta, B, C
        x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
        dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
        dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
        
        # Reshape for selective_scan
        xs = xs.float().view(B, -1, L) # (b, k * d, l)
        dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
        Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
        Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
        Ds = self.Ds.float().view(-1) # (k * d)
        As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
        dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)

        # 3. Selective Scan
        out_y = selective_scan(
            xs, dts, 
            As, Bs, Cs, Ds, z=None,
            delta_bias=dt_projs_bias,
            delta_softplus=True,
            return_last_state=False,
        ).view(B, K, -1, L)
        
        # 4. Invert/Transpose the results back to spatial domain
        inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
        wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
        invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)

        # Summation: out_y[:, 0] (HW) + inv_y[:, 0] (InvHW) + wh_y (WH) + invwh_y (InvWH)
        return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y


    # --- Improved Forward Method with 1x1 Conv and Residual ---
    def forward(self, x: torch.Tensor, **kwargs):
        # x: (B, C, H, W)
        
        # 1. Residual Connection: Store input for final addition
        x_res = x.to(torch.float32) # Ensure residual is in high precision if needed
        
        # Convert to (B, H, W, C) for Linear/LayerNorm operations
        x = x.permute(0, 2, 3, 1).contiguous()
        B, H, W, C = x.shape
        
        # Input Projection and Split (x and z for gating)
        xz = self.in_proj(x)
        x, z = xz.chunk(2, dim=-1) # (B, H, W, d_inner)
        
        # Convert to (B, D, H, W) for Conv2d
        x = x.permute(0, 3, 1, 2).contiguous() 
        x = self.act(self.conv2d(x)) # (B, d_inner, H, W) - Depthwise Conv

        # 2. Core SSM Operation
        y1, y2, y3, y4 = self._forward_corev0(x) 
        y = y1 + y2 + y3 + y4 # (B, d_inner, L=H*W)
        
        # Convert y back to (B, H, W, d_inner)
        y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) 

        # 3. Normalization
        y = self.out_norm(y)
        
        # --- NEW STEP: Channel Mixing 1x1 Conv ---
        # Convert to (B, D, H, W) for 1x1 Conv
        y = y.permute(0, 3, 1, 2).contiguous() 
        y = self.channel_mix(y)
        # Convert back to (B, H, W, D) for Gating
        y = y.permute(0, 2, 3, 1).contiguous() 
        # ------------------------------------------
        
        # 4. Gating and Final Projection
        y = y * F.silu(z)
        out = self.out_proj(y) # (B, H, W, d_model)
        
        if self.dropout is not None:
            out = self.dropout(out)
            
        # Convert output to (B, C, H, W)
        out = out.permute(0, 3, 1, 2).contiguous()
        
        # 5. Residual Connection (Input + Output)
        # Assuming input x_res.dtype and out.dtype can be added (e.g., both float32 or compatible)
        out = out + x_res.to(out.dtype)
        
        return out

H-SS2D 模块通过将 SSM长程依赖性 (全局上下文)与 1*1 标准卷积通道交互能力(局部特征混合)进行结构上的深度融合,并辅以残差连接来保障性能,使其成为一个更强大、更鲁棒的 2D 视觉特征提取单元,尤其适用于需要精确边界框定位和特征识别的目标检测任务。

相关推荐
serve the people2 小时前
如何区分什么场景下用机器学习,什么场景下用深度学习
人工智能·深度学习·机器学习
xjxijd2 小时前
Serverless 3.0 混合架构:容器 + 事件驱动,AI 服务弹性伸缩响应快 3 倍
人工智能·架构·serverless
csdn_aspnet2 小时前
如何用爬虫、机器学习识别方式屏蔽恶意广告
人工智能·爬虫·机器学习
weixin_457760002 小时前
RNN(循环神经网络)原理
人工智能·rnn·深度学习
代码AI弗森2 小时前
意图识别深度原理解析:从向量空间到语义流形
人工智能
姚华军2 小时前
RagFlow、Dify部署时,端口如何调整成指定端口
人工智能·dify·ragflow
老蒋新思维2 小时前
创客匠人峰会新视角:AI 时代知识变现的 “组织化转型”—— 从个人 IP 到 “AI+IP” 组织的增长革命
大数据·人工智能·网络协议·tcp/ip·创始人ip·创客匠人·知识变现
JoannaJuanCV2 小时前
自动驾驶—CARLA仿真(0)报错记录
人工智能·机器学习·自动驾驶
小白狮ww3 小时前
Matlab 教程:基于 RFUAV 系统使用 Matlab 处理无人机信号
开发语言·人工智能·深度学习·机器学习·matlab·无人机·rfuav