【Block总结】CSAM,包含分割、关键点、切分等均适用!|即插即用

论文信息

创新点

CSAM(跨切片注意力模块)旨在解决传统3D和2D医学图像分割方法在处理各向异性体积数据时的不足。其主要创新包括:

  • 跨切片注意力机制: 通过在不同尺度的深度特征图上应用语义、位置和切片注意力,CSAM能够有效捕捉体积数据中不同切片之间的关系。

  • 参数优化: CSAM设计了最小可训练参数的结构,减少了模型的复杂性,同时保持了良好的性能。

  • 2.5D方法的应用: 该模块结合了2D卷积与体积信息,填补了3D和2D方法之间的空白,特别适用于MRI等各向异性数据。

方法

CSAM的实现方法包括以下几个步骤:

  1. 特征提取: 使用卷积神经网络(CNN)提取输入的体积数据特征。

  2. 注意力机制: 在提取的特征图上应用跨切片注意力机制,分别关注语义信息、位置关系和切片信息,以增强特征的表达能力。

  3. 模型训练 : 通过最小化损失函数来训练模型,确保模型能够有效学习到各向异性体积数据的特征。

效果

实验结果表明,CSAM在多个医学图像分割任务中表现出色,尤其是在处理各向异性数据时,其性能优于传统的3D和2D方法。具体效果包括:

  • 分割精度: CSAM在分割精度上达到了新的状态,能够更好地识别和分割复杂的医学图像结构。

  • 训练效率: 由于参数较少,CSAM的训练时间显著低于其他复杂模型。

实验结果

研究者进行了广泛的实验,以验证CSAM的有效性和通用性。实验包括:

  • 数据集: 使用多个公开的医学图像数据集进行测试,涵盖不同的医学成像技术(如MRI)。

  • 对比实验: 将CSAM与现有的3D和2D分割模型进行比较,结果显示CSAM在多个指标上均优于对比模型。

  • 泛化能力: CSAM在不同任务和数据集上的表现一致,证明了其良好的泛化能力。

总结

CSAM作为一种新颖的2.5D跨切片注意力模块,为各向异性体积医学图像分割提供了有效的解决方案。通过引入跨切片注意力机制,CSAM不仅提高了分割精度,还减少了模型的复杂性和训练时间。实验结果验证了其在医学图像处理中的广泛适用性和优越性能,为未来的研究提供了新的思路和方法。

代码

python 复制代码
import torch
import torch.nn.functional
from torch import nn
import torch.distributions as td
def custom_max(x,dim,keepdim=True):
    temp_x=x
    for i in dim:
        temp_x=torch.max(temp_x,dim=i,keepdim=True)[0]
    if not keepdim:
        temp_x=temp_x.squeeze()
    return temp_x

class PositionalAttentionModule(nn.Module):
    def __init__(self):
        super(PositionalAttentionModule,self).__init__()
        self.conv=nn.Conv2d(in_channels=2,out_channels=1,kernel_size=(7,7),padding=3)
    def forward(self,x):
        max_x=custom_max(x,dim=(0,1),keepdim=True)
        avg_x=torch.mean(x,dim=(0,1),keepdim=True)
        att=torch.cat((max_x,avg_x),dim=1)
        att=self.conv(att)
        att=torch.sigmoid(att)
        return x*att

class SemanticAttentionModule(nn.Module):
    def __init__(self,in_features,reduction_rate=16):
        super(SemanticAttentionModule,self).__init__()
        self.linear=[]
        self.linear.append(nn.Linear(in_features=in_features,out_features=in_features//reduction_rate))
        self.linear.append(nn.ReLU())
        self.linear.append(nn.Linear(in_features=in_features//reduction_rate,out_features=in_features))
        self.linear=nn.Sequential(*self.linear)
    def forward(self,x):
        max_x=custom_max(x,dim=(0,2,3),keepdim=False).unsqueeze(0)
        avg_x=torch.mean(x,dim=(0,2,3),keepdim=False).unsqueeze(0)
        max_x=self.linear(max_x)
        avg_x=self.linear(avg_x)
        att=max_x+avg_x
        att=torch.sigmoid(att).unsqueeze(-1).unsqueeze(-1)
        return x*att

class SliceAttentionModule(nn.Module):
    def __init__(self,in_features,rate=4,uncertainty=True,rank=5):
        super(SliceAttentionModule,self).__init__()
        self.uncertainty=uncertainty
        self.rank=rank
        self.linear=[]
        self.linear.append(nn.Linear(in_features=in_features,out_features=int(in_features*rate)))
        self.linear.append(nn.ReLU())
        self.linear.append(nn.Linear(in_features=int(in_features*rate),out_features=in_features))
        self.linear=nn.Sequential(*self.linear)
        if uncertainty:
            self.non_linear=nn.ReLU()
            self.mean=nn.Linear(in_features=in_features,out_features=in_features)
            self.log_diag=nn.Linear(in_features=in_features,out_features=in_features)
            self.factor=nn.Linear(in_features=in_features,out_features=in_features*rank)
    def forward(self,x):
        max_x=custom_max(x,dim=(1,2,3),keepdim=False).unsqueeze(0)
        avg_x=torch.mean(x,dim=(1,2,3),keepdim=False).unsqueeze(0)
        max_x=self.linear(max_x)
        avg_x=self.linear(avg_x)
        att=max_x+avg_x
        if self.uncertainty:
            temp=self.non_linear(att)
            mean=self.mean(temp)
            diag=self.log_diag(temp).exp()
            factor=self.factor(temp)
            factor=factor.view(1,-1,self.rank)
            dist=td.LowRankMultivariateNormal(loc=mean,cov_factor=factor,cov_diag=diag)
            att=dist.sample()
        att=torch.sigmoid(att).squeeze().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        return x*att


class CSAM(nn.Module):
    def __init__(self,num_slices,num_channels,semantic=True,positional=True,slice=True,uncertainty=True,rank=5):
        super(CSAM,self).__init__()
        self.semantic=semantic
        self.positional=positional
        self.slice=slice
        if semantic:
            self.semantic_att=SemanticAttentionModule(num_channels)
        if positional:
            self.positional_att=PositionalAttentionModule()
        if slice:
            self.slice_att=SliceAttentionModule(num_slices,uncertainty=uncertainty,rank=rank)
    def forward(self,x):
        if self.semantic:
            x=self.semantic_att(x)
        if self.positional:
            x=self.positional_att(x)
        if self.slice:
            x=self.slice_att(x)
        return x


if __name__ == "__main__":
    dim=64
    # 如果GPU可用,将模块移动到 GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 输入张量 (batch_size, channels,height, width)
    x = torch.randn(2,dim,40,40).to(device)
    # 初始化 FullyAttentionalBlock 模块

    block = CSAM(2,dim,) # kernel_size为height或者width
    print(block)
    block = block.to(device)
    # 前向传播
    output = block(x)
    print("输入:", x.shape)
    print("输出:", output.shape)
相关推荐
风象南16 小时前
Claude Code这个隐藏技能,让我告别PPT焦虑
人工智能·后端
Mintopia17 小时前
OpenClaw 对软件行业产生的影响
人工智能
陈广亮17 小时前
构建具有长期记忆的 AI Agent:从设计模式到生产实践
人工智能
会写代码的柯基犬17 小时前
DeepSeek vs Kimi vs Qwen —— AI 生成俄罗斯方块代码效果横评
人工智能·llm
Mintopia18 小时前
OpenClaw 是什么?为什么节后热度如此之高?
人工智能
爱可生开源社区18 小时前
DBA 的未来?八位行业先锋的年度圆桌讨论
人工智能·dba
叁两21 小时前
用opencode打造全自动公众号写作流水线,AI 代笔太香了!
前端·人工智能·agent
前端付豪21 小时前
LangChain记忆:通过Memory记住上次的对话细节
人工智能·python·langchain
strayCat2325521 小时前
Clawdbot 源码解读 7: 扩展机制
人工智能·开源