【Block总结】CSAM,包含分割、关键点、切分等均适用!|即插即用

论文信息

创新点

CSAM(跨切片注意力模块)旨在解决传统3D和2D医学图像分割方法在处理各向异性体积数据时的不足。其主要创新包括:

  • 跨切片注意力机制: 通过在不同尺度的深度特征图上应用语义、位置和切片注意力,CSAM能够有效捕捉体积数据中不同切片之间的关系。

  • 参数优化: CSAM设计了最小可训练参数的结构,减少了模型的复杂性,同时保持了良好的性能。

  • 2.5D方法的应用: 该模块结合了2D卷积与体积信息,填补了3D和2D方法之间的空白,特别适用于MRI等各向异性数据。

方法

CSAM的实现方法包括以下几个步骤:

  1. 特征提取: 使用卷积神经网络(CNN)提取输入的体积数据特征。

  2. 注意力机制: 在提取的特征图上应用跨切片注意力机制,分别关注语义信息、位置关系和切片信息,以增强特征的表达能力。

  3. 模型训练 : 通过最小化损失函数来训练模型,确保模型能够有效学习到各向异性体积数据的特征。

效果

实验结果表明,CSAM在多个医学图像分割任务中表现出色,尤其是在处理各向异性数据时,其性能优于传统的3D和2D方法。具体效果包括:

  • 分割精度: CSAM在分割精度上达到了新的状态,能够更好地识别和分割复杂的医学图像结构。

  • 训练效率: 由于参数较少,CSAM的训练时间显著低于其他复杂模型。

实验结果

研究者进行了广泛的实验,以验证CSAM的有效性和通用性。实验包括:

  • 数据集: 使用多个公开的医学图像数据集进行测试,涵盖不同的医学成像技术(如MRI)。

  • 对比实验: 将CSAM与现有的3D和2D分割模型进行比较,结果显示CSAM在多个指标上均优于对比模型。

  • 泛化能力: CSAM在不同任务和数据集上的表现一致,证明了其良好的泛化能力。

总结

CSAM作为一种新颖的2.5D跨切片注意力模块,为各向异性体积医学图像分割提供了有效的解决方案。通过引入跨切片注意力机制,CSAM不仅提高了分割精度,还减少了模型的复杂性和训练时间。实验结果验证了其在医学图像处理中的广泛适用性和优越性能,为未来的研究提供了新的思路和方法。

代码

python 复制代码
import torch
import torch.nn.functional
from torch import nn
import torch.distributions as td
def custom_max(x,dim,keepdim=True):
    temp_x=x
    for i in dim:
        temp_x=torch.max(temp_x,dim=i,keepdim=True)[0]
    if not keepdim:
        temp_x=temp_x.squeeze()
    return temp_x

class PositionalAttentionModule(nn.Module):
    def __init__(self):
        super(PositionalAttentionModule,self).__init__()
        self.conv=nn.Conv2d(in_channels=2,out_channels=1,kernel_size=(7,7),padding=3)
    def forward(self,x):
        max_x=custom_max(x,dim=(0,1),keepdim=True)
        avg_x=torch.mean(x,dim=(0,1),keepdim=True)
        att=torch.cat((max_x,avg_x),dim=1)
        att=self.conv(att)
        att=torch.sigmoid(att)
        return x*att

class SemanticAttentionModule(nn.Module):
    def __init__(self,in_features,reduction_rate=16):
        super(SemanticAttentionModule,self).__init__()
        self.linear=[]
        self.linear.append(nn.Linear(in_features=in_features,out_features=in_features//reduction_rate))
        self.linear.append(nn.ReLU())
        self.linear.append(nn.Linear(in_features=in_features//reduction_rate,out_features=in_features))
        self.linear=nn.Sequential(*self.linear)
    def forward(self,x):
        max_x=custom_max(x,dim=(0,2,3),keepdim=False).unsqueeze(0)
        avg_x=torch.mean(x,dim=(0,2,3),keepdim=False).unsqueeze(0)
        max_x=self.linear(max_x)
        avg_x=self.linear(avg_x)
        att=max_x+avg_x
        att=torch.sigmoid(att).unsqueeze(-1).unsqueeze(-1)
        return x*att

class SliceAttentionModule(nn.Module):
    def __init__(self,in_features,rate=4,uncertainty=True,rank=5):
        super(SliceAttentionModule,self).__init__()
        self.uncertainty=uncertainty
        self.rank=rank
        self.linear=[]
        self.linear.append(nn.Linear(in_features=in_features,out_features=int(in_features*rate)))
        self.linear.append(nn.ReLU())
        self.linear.append(nn.Linear(in_features=int(in_features*rate),out_features=in_features))
        self.linear=nn.Sequential(*self.linear)
        if uncertainty:
            self.non_linear=nn.ReLU()
            self.mean=nn.Linear(in_features=in_features,out_features=in_features)
            self.log_diag=nn.Linear(in_features=in_features,out_features=in_features)
            self.factor=nn.Linear(in_features=in_features,out_features=in_features*rank)
    def forward(self,x):
        max_x=custom_max(x,dim=(1,2,3),keepdim=False).unsqueeze(0)
        avg_x=torch.mean(x,dim=(1,2,3),keepdim=False).unsqueeze(0)
        max_x=self.linear(max_x)
        avg_x=self.linear(avg_x)
        att=max_x+avg_x
        if self.uncertainty:
            temp=self.non_linear(att)
            mean=self.mean(temp)
            diag=self.log_diag(temp).exp()
            factor=self.factor(temp)
            factor=factor.view(1,-1,self.rank)
            dist=td.LowRankMultivariateNormal(loc=mean,cov_factor=factor,cov_diag=diag)
            att=dist.sample()
        att=torch.sigmoid(att).squeeze().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        return x*att


class CSAM(nn.Module):
    def __init__(self,num_slices,num_channels,semantic=True,positional=True,slice=True,uncertainty=True,rank=5):
        super(CSAM,self).__init__()
        self.semantic=semantic
        self.positional=positional
        self.slice=slice
        if semantic:
            self.semantic_att=SemanticAttentionModule(num_channels)
        if positional:
            self.positional_att=PositionalAttentionModule()
        if slice:
            self.slice_att=SliceAttentionModule(num_slices,uncertainty=uncertainty,rank=rank)
    def forward(self,x):
        if self.semantic:
            x=self.semantic_att(x)
        if self.positional:
            x=self.positional_att(x)
        if self.slice:
            x=self.slice_att(x)
        return x


if __name__ == "__main__":
    dim=64
    # 如果GPU可用,将模块移动到 GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 输入张量 (batch_size, channels,height, width)
    x = torch.randn(2,dim,40,40).to(device)
    # 初始化 FullyAttentionalBlock 模块

    block = CSAM(2,dim,) # kernel_size为height或者width
    print(block)
    block = block.to(device)
    # 前向传播
    output = block(x)
    print("输入:", x.shape)
    print("输出:", output.shape)
相关推荐
佛州小李哥1 小时前
在亚马逊云科技上用Stable Diffusion 3.5 Large生成赛博朋克风图片(上)
人工智能·科技·ai·语言模型·stable diffusion·aws·亚马逊云科技
东锋1.31 小时前
深度解析近期爆火的 DeepSeek
人工智能·深度学习
爱研究的小牛2 小时前
讯飞智作 AI 配音技术浅析(二):深度学习与神经网络
人工智能·深度学习·神经网络·机器学习·aigc
Luzem03192 小时前
使用PyTorch实现逻辑回归:从训练到模型保存与性能评估
人工智能·pytorch·逻辑回归
灵封~2 小时前
自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数
人工智能·深度学习
辞落山2 小时前
使用 Numpy 自定义数据集,使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数
人工智能·pytorch·逻辑回归
nnerddboy2 小时前
深度学习查漏补缺:2. 三个指标和注意力机制
人工智能·神经网络·cnn
新加坡内哥谈技术2 小时前
Deepseek-R1 和 OpenAI o1 这样的推理模型普遍存在“思考不足”的问题
人工智能·科技·深度学习·语言模型·机器人
goomind2 小时前
深度卷积神经网络实战无人机视角目标识别
人工智能·神经网络·yolo·cnn·无人机·pyqt5·目标识别
終不似少年遊*2 小时前
国产之光DeepSeek架构理解与应用分析
人工智能·深度学习·神经网络·架构·deepseek·分析解读