Unet+Transformer脑肿瘤分割检测

承接上册的Unet的脑肿瘤分割检测,我们采用Transformer+Unet检测更加的准确

在这里我把Transformer模块加入到U-Net里面,相比于单单的U-Net模块我这里

UNet-Transformer/net.py 在经典 UNet 结构基础上,引入了 PatchEmbedding、MultiHeadSelfAttention、TransformerBlock、TransformerModule 等一整套 Transformer 组件,能够把高阶 2D 特征转换成 patch 序列做全局建模,再映射回卷积特征图;而 UNet-Medical-master/net.py 只有卷积 encoder--decoder,不含注意力模块。

我们的UNet+Transformer在这里面Transformer 被灵活插入到 down4、down5 的 skip 分支以及瓶颈层,具体由 use_transformer_at 参数控制,允许只在 deep 层、瓶颈层或全部位置启用;纯 U-Net 没有这类可选分支。

其中为了兼容 Transformer 变换后的空间尺寸,改进版 UpBlock 在拼接 skip 之前会根据需要做 F.interpolate 对齐;纯 U-Net 直接 torch.cat,默认假设尺寸天然匹配。(这一步也很重要)

可调的 transformer_embed_dim/num_heads/num_layers 赋予模型更高表达力,能在全局范围捕获肿瘤的长距离依赖;纯 U-Net 仅依赖局部卷积与 skip 连接,感受野受限。

也可以看到我们的这样的一个UNet+Transformer检测的肿瘤的准确度和完整度是大于纯UNet的

具体的代码实现

python 复制代码
# Transformer相关模块
class PatchEmbedding(nn.Module):
    """将2D特征图转换为patch序列"""
    def __init__(self, img_size, patch_size, in_channels, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        # 使用卷积实现patch embedding
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x):
        # x: [B, C, H, W]
        B, C, H, W = x.shape
        # 如果输入尺寸不匹配,进行自适应调整
        if H != self.img_size or W != self.img_size:
            x = F.interpolate(x, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)
        
        x = self.proj(x)  # [B, embed_dim, H/patch_size, W/patch_size]
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # [B, n_patches, embed_dim]
        return x

class MultiHeadSelfAttention(nn.Module):
    """多头自注意力机制"""
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        assert self.head_dim * num_heads == embed_dim, "embed_dim必须能被num_heads整除"
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        B, N, C = x.shape
        
        # 生成Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # 计算注意力
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)
        
        # 应用注意力到值
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.dropout(x)
        
        return x

class TransformerBlock(nn.Module):
    """Transformer编码块"""
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class TransformerModule(nn.Module):
    """Transformer模块:将2D特征转换为序列,应用Transformer,再转换回2D"""
    def __init__(self, in_channels, embed_dim, img_size, patch_size, num_heads=8, num_layers=2, dropout=0.1):
        super(TransformerModule, self).__init__()
        self.embed_dim = embed_dim
        self.img_size = img_size
        self.patch_size = patch_size
        
        # Patch embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        
        # Position embedding
        self.n_patches = (img_size // patch_size) ** 2
        self.pos_embed = nn.Parameter(torch.zeros(1, self.n_patches, embed_dim))
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, dropout=dropout)
            for _ in range(num_layers)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        
        # 将embed_dim转换回in_channels
        self.proj_back = nn.Linear(embed_dim, in_channels)
        
        # 初始化位置编码
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        
    def forward(self, x):
        B, C, H, W = x.shape
        
        # 保存原始尺寸
        orig_size = (H, W)
        
        # 如果输入尺寸与预期不匹配,先调整到预期尺寸
        if H != self.img_size or W != self.img_size:
            x = F.interpolate(x, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)
        
        # Patch embedding
        x = self.patch_embed(x)  # [B, n_patches, embed_dim]
        
        # 添加位置编码
        x = x + self.pos_embed
        
        # 通过Transformer blocks
        for block in self.blocks:
            x = block(x)
        
        x = self.norm(x)
        
        # 转换回通道维度
        x = self.proj_back(x)  # [B, n_patches, in_channels]
        
        # 重塑回2D特征图
        patch_h = patch_w = self.img_size // self.patch_size
        x = x.transpose(1, 2).reshape(B, C, patch_h, patch_w)
        
        # 上采样回原始尺寸(因为patch_size可能使尺寸变小)
        if patch_h * self.patch_size != orig_size[0] or patch_w * self.patch_size != orig_size[1]:
            x = F.interpolate(x, size=orig_size, mode='bilinear', align_corners=False)
        
        return x

在上面这四个模块中

1. PatchEmbedding - 图像分块嵌入

功能:将2D图像/特征图分割成小块,并转换为序列形式,供Transformer处理

2. MultiHeadSelfAttention - 多头自注意力

核心思想:让模型同时关注输入的不同表示子空间

3. TransformerBlock - Transformer编码块

这个块实现了Transformer编码器的核心操作:自注意力机制和前馈神经网络,并辅以残差连接和层归一化

4. TransformerModule - 完整的Transformer模块

之后在我们的UNetTransformer里面进行实际的模型呈现就可以了

python 复制代码
# 定义UNet+Transformer混合模型
class UNetTransformer(nn.Module):
    def __init__(self, n_channels=3, n_classes=1, n_filters=32, 
                 transformer_embed_dim=256, transformer_num_heads=8, 
                 transformer_num_layers=2, use_transformer_at=['bottleneck', 'deep']):
        """
        UNet+Transformer混合模型
        
        Args:
            n_channels: 输入通道数
            n_classes: 输出类别数
            n_filters: 基础滤波器数量
            transformer_embed_dim: Transformer的嵌入维度
            transformer_num_heads: 多头注意力的头数
            transformer_num_layers: Transformer层数
            use_transformer_at: 在哪些位置使用Transformer ['bottleneck', 'deep', 'all']
        """
        super(UNetTransformer, self).__init__()
        
        self.use_transformer_at = use_transformer_at
        
        # 编码器路径
        self.down1 = DownBlock(n_channels, n_filters)
        self.down2 = DownBlock(n_filters, n_filters * 2)
        self.down3 = DownBlock(n_filters * 2, n_filters * 4)
        self.down4 = DownBlock(n_filters * 4, n_filters * 8)
        self.down5 = DownBlock(n_filters * 8, n_filters * 16)
        
        # 在深层特征处添加Transformer(down4和down5之后)
        if 'deep' in use_transformer_at or 'all' in use_transformer_at:
            # skip4是down4的skip连接,在maxpool之前,尺寸为32x32
            self.transformer_down4 = TransformerModule(
                in_channels=n_filters * 8,
                embed_dim=transformer_embed_dim,
                img_size=32,
                patch_size=2,
                num_heads=transformer_num_heads,
                num_layers=transformer_num_layers
            )
            # skip5是down5的skip连接,在maxpool之前,尺寸为16x16
            self.transformer_down5 = TransformerModule(
                in_channels=n_filters * 16,
                embed_dim=transformer_embed_dim,
                img_size=16,
                patch_size=2,
                num_heads=transformer_num_heads,
                num_layers=transformer_num_layers
            )
        
        # 瓶颈层
        self.bottleneck = DownBlock(n_filters * 16, n_filters * 32, dropout_prob=0.4, max_pooling=False)
        
        # 在瓶颈层添加Transformer
        if 'bottleneck' in use_transformer_at or 'all' in use_transformer_at:
            self.transformer_bottleneck = TransformerModule(
                in_channels=n_filters * 32,
                embed_dim=transformer_embed_dim,
                img_size=8,
                patch_size=2,
                num_heads=transformer_num_heads,
                num_layers=transformer_num_layers * 2  # 瓶颈层使用更多层
            )
        
        # 解码器路径
        self.up1 = UpBlock(n_filters * 32, n_filters * 16)
        self.up2 = UpBlock(n_filters * 16, n_filters * 8)
        self.up3 = UpBlock(n_filters * 8, n_filters * 4)
        self.up4 = UpBlock(n_filters * 4, n_filters * 2)
        self.up5 = UpBlock(n_filters * 2, n_filters)
        
        # 输出层
        self.outc = nn.Conv2d(n_filters, n_classes, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # 编码器路径
        x1, skip1 = self.down1(x)      # 256 -> 128, skip1: 256x256
        x2, skip2 = self.down2(x1)     # 128 -> 64, skip2: 128x128
        x3, skip3 = self.down3(x2)     # 64 -> 32, skip3: 64x64
        x4, skip4 = self.down4(x3)     # 32 -> 16, skip4: 32x32
        
        # 在deep层应用Transformer
        if 'deep' in self.use_transformer_at or 'all' in self.use_transformer_at:
            skip4 = self.transformer_down4(skip4)
            # 确保skip4尺寸正确 (32x32)
            if skip4.shape[2:] != (32, 32):
                skip4 = F.interpolate(skip4, size=(32, 32), mode='bilinear', align_corners=False)
        
        x5, skip5 = self.down5(x4)     # 16 -> 8, skip5: 16x16
        
        # 在deep层应用Transformer
        if 'deep' in self.use_transformer_at or 'all' in self.use_transformer_at:
            skip5 = self.transformer_down5(skip5)
            # 确保skip5尺寸正确 (16x16)
            if skip5.shape[2:] != (16, 16):
                skip5 = F.interpolate(skip5, size=(16, 16), mode='bilinear', align_corners=False)
        
        # 瓶颈层
        x6, skip6 = self.bottleneck(x5)  # 8 (无下采样), x6: 8x8
        
        # 在瓶颈层应用Transformer
        if 'bottleneck' in self.use_transformer_at or 'all' in self.use_transformer_at:
            x6 = self.transformer_bottleneck(x6)
            # 确保x6尺寸正确 (8x8)
            if x6.shape[2:] != (8, 8):
                x6 = F.interpolate(x6, size=(8, 8), mode='bilinear', align_corners=False)
        
        # 解码器路径
        x = self.up1(x6, skip5)    # 8 -> 16
        x = self.up2(x, skip4)     # 16 -> 32
        x = self.up3(x, skip3)     # 32 -> 64
        x = self.up4(x, skip2)     # 64 -> 128
        x = self.up5(x, skip1)     # 128 -> 256
        
        x = self.outc(x)
        x = self.sigmoid(x)
        return x

融合位置策略

在Deep层融合(skip4, skip5)

down4之后:32x32特征图

self.transformer_down4 = TransformerModule(

in_channels=n_filters * 8, # 256通道

embed_dim=transformer_embed_dim, # 256

img_size=32,

patch_size=2, # 将32x32分成16x16个patch

num_heads=transformer_num_heads,

num_layers=transformer_num_layers

)

在Bottleneck融合

瓶颈层:8x8特征图

self.transformer_bottleneck = TransformerModule(

in_channels=n_filters * 32, # 1024通道

embed_dim=transformer_embed_dim,

img_size=8,

patch_size=2, # 将8x8分成4x4个patch

num_layers=transformer_num_layers * 2 # 这里使用更多层

)

具体的融合流程

输入 → CNN下采样 → skip4(32x32) → Transformer → 增强的skip4

→ skip5(16x16) → Transformer → 增强的skip5

→ 瓶颈层(8x8) → Transformer → 增强的bottleneck

效果

  • CNN:擅长提取局部特征、纹理信息

  • Transformer:擅长建模长距离依赖、全局上下文

  • 结合:局部细节 + 全局理解

相关推荐
AI即插即用2 小时前
即插即用涨点系列(十四)2025 SOTA | Efficient ViM:基于“隐状态混合SSD”与“多阶段融合”的轻量级视觉 Mamba 新标杆
人工智能·pytorch·深度学习·计算机视觉·视觉检测·transformer
1***81532 小时前
免费的自然语言处理教程,NLP入门
人工智能·自然语言处理
算家计算2 小时前
Gemini 3.0重磅发布!技术全面突破:百万上下文、全模态推理与开发者生态重构
人工智能·资讯·gemini
说私域3 小时前
“开源链动2+1模式AI智能名片S2B2C商城小程序”赋能同城自媒体商家营销创新研究
人工智能·小程序·开源
m0_635129263 小时前
内外具身智能VLA模型深度解析
人工智能·机器学习
zhougoo3 小时前
AI驱动代码开之Vs Code Cline插件集成
人工智能
minhuan3 小时前
构建AI智能体:九十五、YOLO视觉大模型入门指南:从零开始掌握目标检测
人工智能·yolo·目标检测·计算机视觉·视觉大模型
双翌视觉3 小时前
机器视觉的车载显示器玻璃覆膜应用
人工智能·机器学习·计算机外设
哥布林学者3 小时前
吴恩达深度学习课程二: 改善深层神经网络 第三周:超参数调整,批量标准化和编程框架(四)编程框架
深度学习·ai