环境配置参考:
https://blog.csdn.net/llz19670/article/details/140359923?utm_source%20=%20uc_fansmsg
https://blog.csdn.net/llz19670/article/details/140359923?utm_source%20=%20uc_fansmsg主要的结构性创新是将 SS2D 的长程依赖捕捉能力与显式的通道混合局部卷积 结合,并加入了残差连接 ,形成一个更鲁棒的 CNN-SSM 混合块。
创新描述:
| 方面 | 原始 SS2D 模块 | 改进后的 H-SS2D 模块 | 提升点 |
|---|---|---|---|
| 局部特征 | 仅使用深度卷积 (groups=d_inner),不混合通道。 |
引入额外的 1*1 标准卷积 (self.channel_mix) 在 SSM 运算之后。 |
增强了模块在捕捉长程依赖后,跨通道混合信息的能力,提供更丰富的局部特征表示。 |
| 模块连接 | 无残差连接(或未显式展示)。 | 在模块输入和最终输出之间添加残差连接。 | 提高训练稳定性,缓解深层网络中的梯度消失问题,加速收敛。 |
| 信息流 | SSM 扫描结果 y 经过 Norm,Gate,Proj。 | SSM 扫描结果y经过Norm,Conv,Gate,Proj。 | 1*1卷积为 SSM 结果添加了一个额外的非线性通道交互层,提高了特征的表示力。 |
完整的改进代码 (Hybrid SS2D)
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import repeat
# 尝试导入 Mamba SSM 核心操作
try:
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
except ImportError:
print("Warning: 'mamba_ssm' not found. Using alternative/placeholder if available.")
selective_scan_fn = None
pass
try:
from selective_scan import selective_scan_fn as selective_scan_fn_v1
except ImportError:
selective_scan_fn_v1 = None
pass
class HybridSS2D(nn.Module):
"""
Hybrid SS2D Block (H-SS2D):
An improved Mamba-based 2D block incorporating a Channel Mixing 1x1 Conv
and a Residual Connection for enhanced performance in vision tasks.
NOTE: For residual connection to work simply, ensure d_model == d_inner (i.e., expand=1).
"""
def __init__(
self,
d_model,
d_state=16,
d_conv=3,
expand=1, # Setting expand=1 for simple residual connection
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
dropout=0.,
conv_bias=True,
bias=False,
device=None,
dtype=None,
**kwargs,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
# Ensure d_inner matches d_model for the residual connection
if expand != 1:
raise ValueError("HybridSS2D requires expand=1 for residual connection matching.")
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model) # d_inner = d_model
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
# Input Projection: Maps d_model -> 2*d_inner (x and z)
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
# Depthwise Conv: Local feature extraction
self.conv2d = nn.Conv2d(
in_channels=self.d_inner,
out_channels=self.d_inner,
groups=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
padding=(d_conv - 1) // 2,
**factory_kwargs,
)
self.act = nn.SiLU()
# x_proj and dt_projs for SSM parameters (A, B, C, dt)
proj_dim = (self.dt_rank + self.d_state * 2)
self.x_proj_weight = nn.Parameter(
torch.stack(
[nn.Linear(self.d_inner, proj_dim, bias=False, **factory_kwargs).weight for _ in range(4)],
dim=0
)
) # (K=4, proj_dim, d_inner)
dt_projs = [
self._dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs)
for _ in range(4)
]
self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in dt_projs], dim=0)) # (K=4, d_inner, rank)
self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in dt_projs], dim=0)) # (K=4, d_inner)
# A, D parameters
self.A_logs = self._A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K*D, N)
self.Ds = self._D_init(self.d_inner, copies=4, merge=True) # (K*D)
# Selective Scan core function setting
self.forward_core = self._forward_corev0
# Output/Gating Structure
self.out_norm = nn.LayerNorm(self.d_inner)
# === NEW FEATURE: 1x1 Conv for Channel Mixing after SSM Core ===
self.channel_mix = nn.Conv2d(
in_channels=self.d_inner,
out_channels=self.d_inner,
kernel_size=1,
bias=conv_bias,
**factory_kwargs,
)
# =============================================================
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
self.dropout = nn.Dropout(dropout) if dropout > 0. else None
# --- Helper Methods from Original SS2D (Renamed with underscore) ---
@staticmethod
def _dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs):
dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
# Initialization logic (unchanged)
dt_init_std = dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
dt = torch.exp(
torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
dt_proj.bias.copy_(inv_dt)
dt_proj.bias._no_reinit = True
return dt_proj
@staticmethod
def _A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
A = repeat(
torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=d_inner,
).contiguous()
A_log = torch.log(A)
if copies > 1:
A_log = repeat(A_log, "d n -> r d n", r=copies)
if merge:
A_log = A_log.flatten(0, 1)
A_log = nn.Parameter(A_log)
A_log._no_weight_decay = True
return A_log
@staticmethod
def _D_init(d_inner, copies=1, device=None, merge=True):
D = torch.ones(d_inner, device=device)
if copies > 1:
D = repeat(D, "n1 -> r n1", r=copies)
if merge:
D = D.flatten(0, 1)
D = nn.Parameter(D)
D._no_weight_decay = True
return D
# --- Core SSM Forward (Unchanged, using _forward_corev0) ---
def _forward_corev0(self, x: torch.Tensor):
# Dynamically select the best available selective scan implementation
if selective_scan_fn is not None:
selective_scan = selective_scan_fn
elif selective_scan_fn_v1 is not None:
selective_scan = selective_scan_fn_v1
else:
raise NotImplementedError("No selective_scan_fn implementation available.")
B, C, H, W = x.shape
L = H * W
K = 4
# 1. Prepare input sequences (4 directions + flip)
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
# 2. Project x to get delta, B, C
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
# Reshape for selective_scan
xs = xs.float().view(B, -1, L) # (b, k * d, l)
dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
Ds = self.Ds.float().view(-1) # (k * d)
As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
# 3. Selective Scan
out_y = selective_scan(
xs, dts,
As, Bs, Cs, Ds, z=None,
delta_bias=dt_projs_bias,
delta_softplus=True,
return_last_state=False,
).view(B, K, -1, L)
# 4. Invert/Transpose the results back to spatial domain
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
# Summation: out_y[:, 0] (HW) + inv_y[:, 0] (InvHW) + wh_y (WH) + invwh_y (InvWH)
return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y
# --- Improved Forward Method with 1x1 Conv and Residual ---
def forward(self, x: torch.Tensor, **kwargs):
# x: (B, C, H, W)
# 1. Residual Connection: Store input for final addition
x_res = x.to(torch.float32) # Ensure residual is in high precision if needed
# Convert to (B, H, W, C) for Linear/LayerNorm operations
x = x.permute(0, 2, 3, 1).contiguous()
B, H, W, C = x.shape
# Input Projection and Split (x and z for gating)
xz = self.in_proj(x)
x, z = xz.chunk(2, dim=-1) # (B, H, W, d_inner)
# Convert to (B, D, H, W) for Conv2d
x = x.permute(0, 3, 1, 2).contiguous()
x = self.act(self.conv2d(x)) # (B, d_inner, H, W) - Depthwise Conv
# 2. Core SSM Operation
y1, y2, y3, y4 = self._forward_corev0(x)
y = y1 + y2 + y3 + y4 # (B, d_inner, L=H*W)
# Convert y back to (B, H, W, d_inner)
y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
# 3. Normalization
y = self.out_norm(y)
# --- NEW STEP: Channel Mixing 1x1 Conv ---
# Convert to (B, D, H, W) for 1x1 Conv
y = y.permute(0, 3, 1, 2).contiguous()
y = self.channel_mix(y)
# Convert back to (B, H, W, D) for Gating
y = y.permute(0, 2, 3, 1).contiguous()
# ------------------------------------------
# 4. Gating and Final Projection
y = y * F.silu(z)
out = self.out_proj(y) # (B, H, W, d_model)
if self.dropout is not None:
out = self.dropout(out)
# Convert output to (B, C, H, W)
out = out.permute(0, 3, 1, 2).contiguous()
# 5. Residual Connection (Input + Output)
# Assuming input x_res.dtype and out.dtype can be added (e.g., both float32 or compatible)
out = out + x_res.to(out.dtype)
return out
H-SS2D 模块通过将 SSM 的长程依赖性 (全局上下文)与 1*1 标准卷积 的通道交互能力(局部特征混合)进行结构上的深度融合,并辅以残差连接来保障性能,使其成为一个更强大、更鲁棒的 2D 视觉特征提取单元,尤其适用于需要精确边界框定位和特征识别的目标检测任务。