承接上册的Unet的脑肿瘤分割检测,我们采用Transformer+Unet检测更加的准确

在这里我把Transformer模块加入到U-Net里面,相比于单单的U-Net模块我这里
UNet-Transformer/net.py 在经典 UNet 结构基础上,引入了 PatchEmbedding、MultiHeadSelfAttention、TransformerBlock、TransformerModule 等一整套 Transformer 组件,能够把高阶 2D 特征转换成 patch 序列做全局建模,再映射回卷积特征图;而 UNet-Medical-master/net.py 只有卷积 encoder--decoder,不含注意力模块。
我们的UNet+Transformer在这里面Transformer 被灵活插入到 down4、down5 的 skip 分支以及瓶颈层,具体由 use_transformer_at 参数控制,允许只在 deep 层、瓶颈层或全部位置启用;纯 U-Net 没有这类可选分支。
其中为了兼容 Transformer 变换后的空间尺寸,改进版 UpBlock 在拼接 skip 之前会根据需要做 F.interpolate 对齐;纯 U-Net 直接 torch.cat,默认假设尺寸天然匹配。(这一步也很重要)
可调的 transformer_embed_dim/num_heads/num_layers 赋予模型更高表达力,能在全局范围捕获肿瘤的长距离依赖;纯 U-Net 仅依赖局部卷积与 skip 连接,感受野受限。
也可以看到我们的这样的一个UNet+Transformer检测的肿瘤的准确度和完整度是大于纯UNet的
具体的代码实现
python
# Transformer相关模块
class PatchEmbedding(nn.Module):
"""将2D特征图转换为patch序列"""
def __init__(self, img_size, patch_size, in_channels, embed_dim):
super(PatchEmbedding, self).__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
# 使用卷积实现patch embedding
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
# x: [B, C, H, W]
B, C, H, W = x.shape
# 如果输入尺寸不匹配,进行自适应调整
if H != self.img_size or W != self.img_size:
x = F.interpolate(x, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)
x = self.proj(x) # [B, embed_dim, H/patch_size, W/patch_size]
B, C, H, W = x.shape
x = x.flatten(2).transpose(1, 2) # [B, n_patches, embed_dim]
return x
class MultiHeadSelfAttention(nn.Module):
"""多头自注意力机制"""
def __init__(self, embed_dim, num_heads, dropout=0.1):
super(MultiHeadSelfAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim, "embed_dim必须能被num_heads整除"
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
self.proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, N, C = x.shape
# 生成Q, K, V
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# 计算注意力
attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
attn = attn.softmax(dim=-1)
attn = self.dropout(attn)
# 应用注意力到值
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.dropout(x)
return x
class TransformerBlock(nn.Module):
"""Transformer编码块"""
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
super(TransformerBlock, self).__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
self.norm2 = nn.LayerNorm(embed_dim)
mlp_hidden_dim = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden_dim, embed_dim),
nn.Dropout(dropout)
)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class TransformerModule(nn.Module):
"""Transformer模块:将2D特征转换为序列,应用Transformer,再转换回2D"""
def __init__(self, in_channels, embed_dim, img_size, patch_size, num_heads=8, num_layers=2, dropout=0.1):
super(TransformerModule, self).__init__()
self.embed_dim = embed_dim
self.img_size = img_size
self.patch_size = patch_size
# Patch embedding
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
# Position embedding
self.n_patches = (img_size // patch_size) ** 2
self.pos_embed = nn.Parameter(torch.zeros(1, self.n_patches, embed_dim))
# Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, dropout=dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(embed_dim)
# 将embed_dim转换回in_channels
self.proj_back = nn.Linear(embed_dim, in_channels)
# 初始化位置编码
nn.init.trunc_normal_(self.pos_embed, std=0.02)
def forward(self, x):
B, C, H, W = x.shape
# 保存原始尺寸
orig_size = (H, W)
# 如果输入尺寸与预期不匹配,先调整到预期尺寸
if H != self.img_size or W != self.img_size:
x = F.interpolate(x, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)
# Patch embedding
x = self.patch_embed(x) # [B, n_patches, embed_dim]
# 添加位置编码
x = x + self.pos_embed
# 通过Transformer blocks
for block in self.blocks:
x = block(x)
x = self.norm(x)
# 转换回通道维度
x = self.proj_back(x) # [B, n_patches, in_channels]
# 重塑回2D特征图
patch_h = patch_w = self.img_size // self.patch_size
x = x.transpose(1, 2).reshape(B, C, patch_h, patch_w)
# 上采样回原始尺寸(因为patch_size可能使尺寸变小)
if patch_h * self.patch_size != orig_size[0] or patch_w * self.patch_size != orig_size[1]:
x = F.interpolate(x, size=orig_size, mode='bilinear', align_corners=False)
return x
在上面这四个模块中
1. PatchEmbedding - 图像分块嵌入
功能:将2D图像/特征图分割成小块,并转换为序列形式,供Transformer处理
2. MultiHeadSelfAttention - 多头自注意力
核心思想:让模型同时关注输入的不同表示子空间
3. TransformerBlock - Transformer编码块
这个块实现了Transformer编码器的核心操作:自注意力机制和前馈神经网络,并辅以残差连接和层归一化
4. TransformerModule - 完整的Transformer模块
之后在我们的UNetTransformer里面进行实际的模型呈现就可以了
python
# 定义UNet+Transformer混合模型
class UNetTransformer(nn.Module):
def __init__(self, n_channels=3, n_classes=1, n_filters=32,
transformer_embed_dim=256, transformer_num_heads=8,
transformer_num_layers=2, use_transformer_at=['bottleneck', 'deep']):
"""
UNet+Transformer混合模型
Args:
n_channels: 输入通道数
n_classes: 输出类别数
n_filters: 基础滤波器数量
transformer_embed_dim: Transformer的嵌入维度
transformer_num_heads: 多头注意力的头数
transformer_num_layers: Transformer层数
use_transformer_at: 在哪些位置使用Transformer ['bottleneck', 'deep', 'all']
"""
super(UNetTransformer, self).__init__()
self.use_transformer_at = use_transformer_at
# 编码器路径
self.down1 = DownBlock(n_channels, n_filters)
self.down2 = DownBlock(n_filters, n_filters * 2)
self.down3 = DownBlock(n_filters * 2, n_filters * 4)
self.down4 = DownBlock(n_filters * 4, n_filters * 8)
self.down5 = DownBlock(n_filters * 8, n_filters * 16)
# 在深层特征处添加Transformer(down4和down5之后)
if 'deep' in use_transformer_at or 'all' in use_transformer_at:
# skip4是down4的skip连接,在maxpool之前,尺寸为32x32
self.transformer_down4 = TransformerModule(
in_channels=n_filters * 8,
embed_dim=transformer_embed_dim,
img_size=32,
patch_size=2,
num_heads=transformer_num_heads,
num_layers=transformer_num_layers
)
# skip5是down5的skip连接,在maxpool之前,尺寸为16x16
self.transformer_down5 = TransformerModule(
in_channels=n_filters * 16,
embed_dim=transformer_embed_dim,
img_size=16,
patch_size=2,
num_heads=transformer_num_heads,
num_layers=transformer_num_layers
)
# 瓶颈层
self.bottleneck = DownBlock(n_filters * 16, n_filters * 32, dropout_prob=0.4, max_pooling=False)
# 在瓶颈层添加Transformer
if 'bottleneck' in use_transformer_at or 'all' in use_transformer_at:
self.transformer_bottleneck = TransformerModule(
in_channels=n_filters * 32,
embed_dim=transformer_embed_dim,
img_size=8,
patch_size=2,
num_heads=transformer_num_heads,
num_layers=transformer_num_layers * 2 # 瓶颈层使用更多层
)
# 解码器路径
self.up1 = UpBlock(n_filters * 32, n_filters * 16)
self.up2 = UpBlock(n_filters * 16, n_filters * 8)
self.up3 = UpBlock(n_filters * 8, n_filters * 4)
self.up4 = UpBlock(n_filters * 4, n_filters * 2)
self.up5 = UpBlock(n_filters * 2, n_filters)
# 输出层
self.outc = nn.Conv2d(n_filters, n_classes, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# 编码器路径
x1, skip1 = self.down1(x) # 256 -> 128, skip1: 256x256
x2, skip2 = self.down2(x1) # 128 -> 64, skip2: 128x128
x3, skip3 = self.down3(x2) # 64 -> 32, skip3: 64x64
x4, skip4 = self.down4(x3) # 32 -> 16, skip4: 32x32
# 在deep层应用Transformer
if 'deep' in self.use_transformer_at or 'all' in self.use_transformer_at:
skip4 = self.transformer_down4(skip4)
# 确保skip4尺寸正确 (32x32)
if skip4.shape[2:] != (32, 32):
skip4 = F.interpolate(skip4, size=(32, 32), mode='bilinear', align_corners=False)
x5, skip5 = self.down5(x4) # 16 -> 8, skip5: 16x16
# 在deep层应用Transformer
if 'deep' in self.use_transformer_at or 'all' in self.use_transformer_at:
skip5 = self.transformer_down5(skip5)
# 确保skip5尺寸正确 (16x16)
if skip5.shape[2:] != (16, 16):
skip5 = F.interpolate(skip5, size=(16, 16), mode='bilinear', align_corners=False)
# 瓶颈层
x6, skip6 = self.bottleneck(x5) # 8 (无下采样), x6: 8x8
# 在瓶颈层应用Transformer
if 'bottleneck' in self.use_transformer_at or 'all' in self.use_transformer_at:
x6 = self.transformer_bottleneck(x6)
# 确保x6尺寸正确 (8x8)
if x6.shape[2:] != (8, 8):
x6 = F.interpolate(x6, size=(8, 8), mode='bilinear', align_corners=False)
# 解码器路径
x = self.up1(x6, skip5) # 8 -> 16
x = self.up2(x, skip4) # 16 -> 32
x = self.up3(x, skip3) # 32 -> 64
x = self.up4(x, skip2) # 64 -> 128
x = self.up5(x, skip1) # 128 -> 256
x = self.outc(x)
x = self.sigmoid(x)
return x
融合位置策略
在Deep层融合(skip4, skip5)
down4之后:32x32特征图
self.transformer_down4 = TransformerModule(
in_channels=n_filters * 8, # 256通道
embed_dim=transformer_embed_dim, # 256
img_size=32,
patch_size=2, # 将32x32分成16x16个patch
num_heads=transformer_num_heads,
num_layers=transformer_num_layers
)
在Bottleneck融合
瓶颈层:8x8特征图
self.transformer_bottleneck = TransformerModule(
in_channels=n_filters * 32, # 1024通道
embed_dim=transformer_embed_dim,
img_size=8,
patch_size=2, # 将8x8分成4x4个patch
num_layers=transformer_num_layers * 2 # 这里使用更多层
)
具体的融合流程
输入 → CNN下采样 → skip4(32x32) → Transformer → 增强的skip4
→ skip5(16x16) → Transformer → 增强的skip5
→ 瓶颈层(8x8) → Transformer → 增强的bottleneck
效果
-
CNN:擅长提取局部特征、纹理信息
-
Transformer:擅长建模长距离依赖、全局上下文
-
结合:局部细节 + 全局理解