北京大学MuMo多模态肿瘤分类模型复现与迁移学习
1. 项目背景与意义
肿瘤分类是医学影像分析中的核心任务之一,对于癌症的早期诊断、治疗规划和预后评估具有重要意义。传统的单模态医学影像分析往往受限于单一信息源的局限性,难以全面捕捉肿瘤的复杂特征。多模态学习通过整合不同成像技术(如CT、MRI、PET等)提供的互补信息,能够显著提高肿瘤分类的准确性。
北京大学团队提出的MuMo(Multimodal tumor classification model)模型是一种先进的多模态肿瘤分类框架,能够有效整合不同模态的医学影像数据,并在多种肿瘤分类任务上表现出优异性能。本项目旨在复现该模型,并在客户提供的特定数据集上进行迁移学习,以适应实际临床应用需求。
2. 环境配置与依赖安装
首先,我们需要配置合适的Python环境并安装必要的依赖库:
python
# 环境要求
# Python 3.8+
# PyTorch 1.9+
# CUDA 11.1+ (如果使用GPU加速)
# 创建并激活conda环境
# conda create -n mumo python=3.8
# conda activate mumo
# 安装核心依赖
# pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
# pip install torchaudio
# 安装其他必要库
# pip install numpy pandas scikit-learn matplotlib seaborn opencv-python pillow
# pip install scipy tqdm tensorboard monai nibabel SimpleITK
# 安装特定版本库以确保兼容性
# pip install einops==0.3.2 timm==0.4.12
# 验证安装
import torch
import torchvision
import numpy as np
import pandas as pd
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
print(f"CUDA版本: {torch.version.cuda}")
print(f"GPU设备: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else '无'}")
3. 数据预处理与加载
多模态医学影像数据通常需要特殊的预处理流程,以确保不同模态间的一致性和可比性。
python
import os
import numpy as np
import pandas as pd
from PIL import Image
import cv2
import nibabel as nib
import SimpleITK as sitk
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from monai.transforms import Compose, LoadImage, AddChannel, ScaleIntensity, Resize, ToTensor
class MultiModalTumorDataset(Dataset):
"""
多模态肿瘤数据集类
处理不同模态的医学影像数据
"""
def __init__(self, data_dir, metadata_file, modalities=['CT', 'PET'], transform=None, phase='train'):
"""
初始化数据集
参数:
data_dir: 数据目录路径
metadata_file: 包含样本信息和标签的CSV文件
modalities: 使用的模态列表
transform: 数据增强和预处理变换
phase: 阶段 ('train', 'val', 'test')
"""
self.data_dir = data_dir
self.metadata = pd.read_csv(metadata_file)
self.modalities = modalities
self.transform = transform
self.phase = phase
# 过滤数据,只保留存在的文件
self.samples = []
for idx, row in self.metadata.iterrows():
paths_exist = True
for modality in modalities:
img_path = os.path.join(data_dir, row['patient_id'], f"{modality}.nii.gz")
if not os.path.exists(img_path):
paths_exist = False
break
if paths_exist:
self.samples.append({
'patient_id': row['patient_id'],
'label': row['label'],
'paths': {modality: os.path.join(data_dir, row['patient_id'], f"{modality}.nii.gz")
for modality in modalities}
})
print(f"{phase}数据集大小: {len(self.samples)}")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
label = sample['label']
# 加载多模态数据
images = {}
for modality, path in sample['paths'].items():
# 使用MONAI加载医学影像
img = LoadImage()(path)
# 添加通道维度 (H, W, D) -> (C, H, W, D)
img = AddChannel()(img)
# 强度归一化
img = ScaleIntensity()(img)
# 调整大小 (根据模型需求)
img = Resize((128, 128, 128))(img)
images[modality] = img
# 应用数据增强 (仅在训练阶段)
if self.transform and self.phase == 'train':
for modality in images:
images[modality] = self.transform(images[modality])
# 转换为张量
to_tensor = ToTensor()
for modality in images:
images[modality] = to_tensor(images[modality])
return images, torch.tensor(label, dtype=torch.long)
def get_data_loaders(data_dir, metadata_file, modalities, batch_size=4, num_workers=4):
"""
获取训练、验证和测试数据加载器
"""
# 读取元数据
metadata = pd.read_csv(metadata_file)
# 划分训练集、验证集和测试集
train_val_ids, test_ids = train_test_split(
metadata['patient_id'].unique(),
test_size=0.2,
random_state=42,
stratify=metadata.groupby('patient_id')['label'].first()
)
train_ids, val_ids = train_test_split(
train_val_ids,
test_size=0.25,
random_state=42,
stratify=metadata[metadata['patient_id'].isin(train_val_ids)].groupby('patient_id')['label'].first()
)
# 创建子集元数据
train_metadata = metadata[metadata['patient_id'].isin(train_ids)]
val_metadata = metadata[metadata['patient_id'].isin(val_ids)]
test_metadata = metadata[metadata['patient_id'].isin(test_ids)]
# 保存临时元数据文件
train_metadata.to_csv('train_metadata.csv', index=False)
val_metadata.to_csv('val_metadata.csv', index=False)
test_metadata.to_csv('test_metadata.csv', index=False)
# 数据增强 (仅用于训练)
train_transform = Compose([
# 可以添加医学影像特定的数据增强
# 如随机旋转、翻转等
])
# 创建数据集
train_dataset = MultiModalTumorDataset(
data_dir, 'train_metadata.csv', modalities, train_transform, 'train'
)
val_dataset = MultiModalTumorDataset(
data_dir, 'val_metadata.csv', modalities, None, 'val'
)
test_dataset = MultiModalTumorDataset(
data_dir, 'test_metadata.csv', modalities, None, 'test'
)
# 创建数据加载器
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
val_loader = DataLoader(
val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
test_loader = DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
return train_loader, val_loader, test_loader
4. MuMo模型架构实现
MuMo模型的核心思想是通过多模态融合机制整合不同模态的特征信息。以下是模型的完整实现:
python
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import timm
class ModalitySpecificEncoder(nn.Module):
"""
模态特定编码器
每个模态使用独立的编码器提取特征
"""
def __init__(self, modality, backbone_name='resnet50', pretrained=True):
super(ModalitySpecificEncoder, self).__init__()
self.modality = modality
# 使用预训练的2D CNN backbone
self.backbone = timm.create_model(
backbone_name,
pretrained=pretrained,
in_chans=1, # 医学影像通常是单通道
num_classes=0 # 不包含分类头
)
# 获取backbone的特征维度
self.feature_dim = self.backbone.num_features
# 3D卷积适配器 (将3D体积转换为2D切片集合)
self.volume_adapter = nn.Sequential(
nn.Conv3d(1, 32, kernel_size=3, padding=1),
nn.BatchNorm3d(32),
nn.ReLU(inplace=True),
nn.Conv3d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm3d(64),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool3d((16, 1, 1)) # (C, D, H, W) -> (C, 16, 1, 1)
)
def forward(self, x):
"""
输入: (B, C, D, H, W)
输出: (B, feature_dim)
"""
batch_size, _, depth, height, width = x.shape
# 使用3D适配器处理体积数据
x = self.volume_adapter(x) # (B, 64, 16, 1, 1)
x = x.squeeze(-1).squeeze(-1) # (B, 64, 16)
x = x.permute(0, 2, 1) # (B, 16, 64)
# 处理每个切片
slice_features = []
for i in range(x.size(1)):
slice_img = x[:, i, :] # (B, 64)
slice_img = slice_img.unsqueeze(-1).unsqueeze(-1) # (B, 64, 1, 1)
# 通过backbone提取特征
features = self.backbone(slice_img) # (B, feature_dim)
slice_features.append(features)
# 聚合切片特征
slice_features = torch.stack(slice_features, dim=1) # (B, 16, feature_dim)
aggregated_features = torch.mean(slice_features, dim=1) # (B, feature_dim)
return aggregated_features
class CrossModalAttention(nn.Module):
"""
跨模态注意力机制
实现模态间的特征交互和融合
"""
def __init__(self, feature_dim, num_heads=8, dropout=0.1):
super(CrossModalAttention, self).__init__()
self.feature_dim = feature_dim
self.num_heads = num_heads
self.head_dim = feature_dim // num_heads
assert self.head_dim * num_heads == feature_dim, "特征维度必须能被头数整除"
# 查询、键、值的线性变换
self.q_linear = nn.Linear(feature_dim, feature_dim)
self.k_linear = nn.Linear(feature_dim, feature_dim)
self.v_linear = nn.Linear(feature_dim, feature_dim)
self.dropout = nn.Dropout(dropout)
self.out_linear = nn.Linear(feature_dim, feature_dim)
def forward(self, queries, keys, values, modality_mask=None):
"""
参数:
queries: (B, N, D) 查询向量,N是模态数量
keys: (B, M, D) 键向量,M是模态数量
values: (B, M, D) 值向量,M是模态数量
modality_mask: (B, N, M) 模态间注意力掩码
返回:
attended_features: (B, N, D) 经过注意力加权的特征
"""
batch_size, num_queries, _ = queries.shape
_, num_keys, _ = keys.shape
# 线性变换并分头
Q = self.q_linear(queries).view(batch_size, num_queries, self.num_heads, self.head_dim)
K = self.k_linear(keys).view(batch_size, num_keys, self.num_heads, self.head_dim)
V = self.v_linear(values).view(batch_size, num_keys, self.num_heads, self.head_dim)
# 转置以获得正确的维度 (B, num_heads, num_queries, head_dim)
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5) # (B, num_heads, num_queries, num_keys)
# 应用注意力掩码 (如果提供)
if modality_mask is not None:
modality_mask = modality_mask.unsqueeze(1) # (B, 1, num_queries, num_keys)
scores = scores.masked_fill(modality_mask == 0, -1e9)
# 计算注意力权重
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# 应用注意力权重
attended = torch.matmul(attention_weights, V) # (B, num_heads, num_queries, head_dim)
# 合并多头
attended = attended.transpose(1, 2).contiguous().view(
batch_size, num_queries, self.feature_dim
) # (B, num_queries, D)
# 输出线性变换
output = self.out_linear(attended)
return output
class MultimodalFusionBlock(nn.Module):
"""
多模态融合块
整合不同模态的特征
"""
def __init__(self, feature_dim, num_modalities, num_heads=8, dropout=0.1):
super(MultimodalFusionBlock, self).__init__()
self.feature_dim = feature_dim
self.num_modalities = num_modalities
# 跨模态注意力
self.cross_attention = CrossModalAttention(feature_dim, num_heads, dropout)
# 层归一化
self.norm1 = nn.LayerNorm(feature_dim)
self.norm2 = nn.LayerNorm(feature_dim)
# 前馈网络
self.ffn = nn.Sequential(
nn.Linear(feature_dim, feature_dim * 4),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(feature_dim * 4, feature_dim),
nn.Dropout(dropout)
)
def forward(self, modality_features, modality_mask=None):
"""
参数:
modality_features: (B, N, D) 各模态的特征,N是模态数量
modality_mask: (B, N, N) 模态间注意力掩码
返回:
fused_features: (B, N, D) 融合后的特征
"""
# 跨模态注意力
attended = self.cross_attention(
modality_features, modality_features, modality_features, modality_mask
)
# 残差连接和层归一化
x = self.norm1(modality_features + attended)
# 前馈网络
ffn_output = self.ffn(x)
# 残差连接和层归一化
output = self.norm2(x + ffn_output)
return output
class MuMoModel(nn.Module):
"""
MuMo多模态肿瘤分类模型
"""
def __init__(self, modalities, num_classes, feature_dim=2048, num_fusion_blocks=3, num_heads=8, dropout=0.1):
super(MuMoModel, self).__init__()
self.modalities = modalities
self.num_modalities = len(modalities)
self.num_classes = num_classes
self.feature_dim = feature_dim
# 模态特定编码器
self.modality_encoders = nn.ModuleDict({
modality: ModalitySpecificEncoder(modality) for modality in modalities
})
# 模态特定投影层 (统一特征维度)
self.modality_projectors = nn.ModuleDict({
modality: nn.Linear(
self.modality_encoders[modality].feature_dim,
feature_dim
) for modality in modalities
})
# 模态嵌入 (区分不同模态)
self.modality_embeddings = nn.Parameter(
torch.randn(self.num_modalities, feature_dim)
)
# 多模态融合块
self.fusion_blocks = nn.ModuleList([
MultimodalFusionBlock(feature_dim, self.num_modalities, num_heads, dropout)
for _ in range(num_fusion_blocks)
])
# 分类头
self.classifier = nn.Sequential(
nn.LayerNorm(feature_dim),
nn.Dropout(dropout),
nn.Linear(feature_dim, feature_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(feature_dim // 2, num_classes)
)
def forward(self, x, modality_mask=None):
"""
参数:
x: 字典,键为模态名,值为对应模态的输入张量 (B, C, D, H, W)
modality_mask: (B, N, N) 模态间注意力掩码
返回:
logits: (B, num_classes) 分类logits
modality_features: 各模态的特征
fused_features: 融合后的特征
"""
batch_size = next(iter(x.values())).size(0)
# 提取各模态特征
modality_features = {}
for modality in self.modalities:
features = self.modality_encoders[modality](x[modality]) # (B, encoder_feature_dim)
features = self.modality_projectors[modality](features) # (B, feature_dim)
modality_features[modality] = features
# 组合模态特征
combined_features = torch.stack(
[modality_features[modality] for modality in self.modalities],
dim=1
) # (B, N, D)
# 添加模态嵌入
modality_embeds = self.modality_embeddings.unsqueeze(0).expand(
batch_size, -1, -1
) # (B, N, D)
combined_features = combined_features + modality_embeds
# 多模态融合
fused_features = combined_features
for fusion_block in self.fusion_blocks:
fused_features = fusion_block(fused_features, modality_mask)
# 全局平均池化 (跨模态维度)
global_features = torch.mean(fused_features, dim=1) # (B, D)
# 分类
logits = self.classifier(global_features) # (B, num_classes)
return logits, modality_features, fused_features
5. 训练策略与损失函数
多模态学习需要特殊的训练策略和损失函数设计:
python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
class FocalLoss(nn.Module):
"""
Focal Loss用于处理类别不平衡问题
"""
def __init__(self, alpha=1, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
BCE_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
if self.reduction == 'mean':
return torch.mean(F_loss)
elif self.reduction == 'sum':
return torch.sum(F_loss)
else:
return F_loss
class MultimodalLoss(nn.Module):
"""
多模态损失函数
结合分类损失和模态一致性损失
"""
def __init__(self, alpha=1.0, beta=0.1, gamma=2.0):
super(MultimodalLoss, self).__init__()
self.alpha = alpha # 分类损失权重
self.beta = beta # 一致性损失权重
self.classification_loss = FocalLoss(gamma=gamma)
def consistency_loss(self, modality_features, fused_features):
"""
计算模态特征与融合特征间的一致性损失
"""
loss = 0
num_modalities = len(modality_features)
for modality, features in modality_features.items():
# 计算余弦相似度
similarity = F.cosine_similarity(
features, fused_features, dim=-1
)
# 最大化相似度 (最小化负相似度)
loss += torch.mean(1 - similarity)
return loss / num_modalities
def forward(self, outputs, targets, modality_features, fused_features):
"""
参数:
outputs: 模型输出logits
targets: 真实标签
modality_features: 各模态的特征
fused_features: 融合后的特征
返回:
total_loss: 总损失
cls_loss: 分类损失
cons_loss: 一致性损失
"""
# 分类损失
cls_loss = self.classification_loss(outputs, targets)
# 提取全局融合特征 (平均池化)
global_fused = torch.mean(fused_features, dim=1)
# 一致性损失
cons_loss = self.consistency_loss(modality_features, global_fused)
# 总损失
total_loss = self.alpha * cls_loss + self.beta * cons_loss
return total_loss, cls_loss, cons_loss
class Trainer:
"""
模型训练器
"""
def __init__(self, model, train_loader, val_loader, device, num_classes,
learning_rate=1e-4, weight_decay=1e-5, checkpoint_dir='checkpoints'):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.num_classes = num_classes
self.checkpoint_dir = checkpoint_dir
# 创建检查点目录
os.makedirs(checkpoint_dir, exist_ok=True)
# 损失函数
self.criterion = MultimodalLoss()
# 优化器
self.optimizer = optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=weight_decay
)
# 学习率调度器
self.scheduler = ReduceLROnPlateau(
self.optimizer, mode='min', factor=0.5, patience=5, verbose=True
)
# 记录训练历史
self.history = {
'train_loss': [], 'val_loss': [],
'train_acc': [], 'val_acc': [],
'train_f1': [], 'val_f1': [],
'learning_rate': []
}
def train_epoch(self):
"""训练一个epoch"""
self.model.train()
running_loss = 0.0
running_cls_loss = 0.0
running_cons_loss = 0.0
all_preds = []
all_targets = []
for batch_idx, (data, targets) in enumerate(self.train_loader):
# 移动到设备
data = {modality: tensor.to(self.device) for modality, tensor in data.items()}
targets = targets.to(self.device)
# 清零梯度
self.optimizer.zero_grad()
# 前向传播
outputs, modality_features, fused_features = self.model(data)
# 计算损失
loss, cls_loss, cons_loss = self.criterion(
outputs, targets, modality_features, fused_features
)
# 反向传播
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# 更新参数
self.optimizer.step()
# 记录损失
running_loss += loss.item()
running_cls_loss += cls_loss.item()
running_cons_loss += cons_loss.item()
# 记录预测和标签
preds = torch.argmax(outputs, dim=1)
all_preds.extend(preds.cpu().numpy())
all_targets.extend(targets.cpu().numpy())
# 打印进度
if batch_idx % 50 == 0:
print(f'Batch {batch_idx}/{len(self.train_loader)}, '
f'Loss: {loss.item():.4f}, '
f'Cls Loss: {cls_loss.item():.4f}, '
f'Cons Loss: {cons_loss.item():.4f}')
# 计算epoch指标
epoch_loss = running_loss / len(self.train_loader)
epoch_cls_loss = running_cls_loss / len(self.train_loader)
epoch_cons_loss = running_cons_loss / len(self.train_loader)
epoch_acc = accuracy_score(all_targets, all_preds)
epoch_f1 = f1_score(all_targets, all_preds, average='weighted')
return epoch_loss, epoch_cls_loss, epoch_cons_loss, epoch_acc, epoch_f1
def validate(self):
"""验证模型"""
self.model.eval()
running_loss = 0.0
running_cls_loss = 0.0
running_cons_loss = 0.0
all_preds = []
all_targets = []
all_probs = []
with torch.no_grad():
for data, targets in self.val_loader:
# 移动到设备
data = {modality: tensor.to(self.device) for modality, tensor in data.items()}
targets = targets.to(self.device)
# 前向传播
outputs, modality_features, fused_features = self.model(data)
# 计算损失
loss, cls_loss, cons_loss = self.criterion(
outputs, targets, modality_features, fused_features
)
# 记录损失
running_loss += loss.item()
running_cls_loss += cls_loss.item()
running_cons_loss += cons_loss.item()
# 记录预测和标签
probs = F.softmax(outputs, dim=1)
preds = torch.argmax(outputs, dim=1)
all_preds.extend(preds.cpu().numpy())
all_targets.extend(targets.cpu().numpy())
all_probs.extend(probs.cpu().numpy())
# 计算epoch指标
epoch_loss = running_loss / len(self.val_loader)
epoch_cls_loss = running_cls_loss / len(self.val_loader)
epoch_cons_loss = running_cons_loss / len(self.val_loader)
epoch_acc = accuracy_score(all_targets, all_preds)
epoch_f1 = f1_score(all_targets, all_preds, average='weighted')
# 计算AUC (多分类)
if self.num_classes > 2:
auc = roc_auc_score(all_targets, all_probs, multi_class='ovr')
else:
auc = roc_auc_score(all_targets, all_probs[:, 1])
return epoch_loss, epoch_cls_loss, epoch_cons_loss, epoch_acc, epoch_f1, auc, all_probs, all_preds, all_targets
def train(self, num_epochs=100, early_stopping_patience=10):
"""完整训练过程"""
best_val_loss = float('inf')
patience_counter = 0
for epoch in range(num_epochs):
print(f'Epoch {epoch+1}/{num_epochs}')
print('-' * 50)
# 训练
train_loss, train_cls_loss, train_cons_loss, train_acc, train_f1 = self.train_epoch()
# 验证
val_loss, val_cls_loss, val_cons_loss, val_acc, val_f1, val_auc, _, _, _ = self.validate()
# 更新学习率
self.scheduler.step(val_loss)
current_lr = self.optimizer.param_groups[0]['lr']
# 记录历史
self.history['train_loss'].append(train_loss)
self.history['val_loss'].append(val_loss)
self.history['train_acc'].append(train_acc)
self.history['val_acc'].append(val_acc)
self.history['train_f1'].append(train_f1)
self.history['val_f1'].append(val_f1)
self.history['learning_rate'].append(current_lr)
# 打印指标
print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}')
print(f'Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}')
print(f'Train F1: {train_f1:.4f} | Val F1: {val_f1:.4f}')
print(f'Val AUC: {val_auc:.4f}')
print(f'Learning Rate: {current_lr:.6f}')
# 保存最佳模型
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
self.save_checkpoint('best_model.pth')
print('保存新的最佳模型')
else:
patience_counter += 1
print(f'早停计数器: {patience_counter}/{early_stopping_patience}')
# 早停检查
if patience_counter >= early_stopping_patience:
print('早停触发')
break
print()
# 加载最佳模型
self.load_checkpoint('best_model.pth')
return self.history
def save_checkpoint(self, filename):
"""保存检查点"""
checkpoint = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'history': self.history
}
torch.save(checkpoint, os.path.join(self.checkpoint_dir, filename))
def load_checkpoint(self, filename):
"""加载检查点"""
checkpoint = torch.load(os.path.join(self.checkpoint_dir, filename))
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.history = checkpoint['history']
def plot_training_history(self):
"""绘制训练历史"""
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
# 损失曲线
axes[0, 0].plot(self.history['train_loss'], label='Train Loss')
axes[0, 0].plot(self.history['val_loss'], label='Validation Loss')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
# 准确率曲线
axes[0, 1].plot(self.history['train_acc'], label='Train Accuracy')
axes[0, 1].plot(self.history['val_acc'], label='Validation Accuracy')
axes[0, 1].set_title('Training and Validation Accuracy')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].legend()
# F1分数曲线
axes[1, 0].plot(self.history['train_f1'], label='Train F1 Score')
axes[1, 0].plot(self.history['val_f1'], label='Validation F1 Score')
axes[1, 0].set_title('Training and Validation F1 Score')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('F1 Score')
axes[1, 0].legend()
# 学习率曲线
axes[1, 1].plot(self.history['learning_rate'], label='Learning Rate')
axes[1, 1].set_title('Learning Rate Schedule')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Learning Rate')
axes[1, 1].legend()
plt.tight_layout()
plt.savefig(os.path.join(self.checkpoint_dir, 'training_history.png'))
plt.close()
def evaluate(self, test_loader):
"""在测试集上评估模型"""
test_loss, test_cls_loss, test_cons_loss, test_acc, test_f1, test_auc, probs, preds, targets = self.validate_epoch(test_loader)
# 计算混淆矩阵
cm = confusion_matrix(targets, preds)
# 绘制混淆矩阵
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.savefig(os.path.join(self.checkpoint_dir, 'confusion_matrix.png'))
plt.close()
# 保存评估结果
results = {
'test_loss': test_loss,
'test_acc': test_acc,
'test_f1': test_f1,
'test_auc': test_auc,
'confusion_matrix': cm,
'predictions': preds,
'targets': targets,
'probabilities': probs
}
return results
6. 迁移学习实现
针对客户提供的数据集,我们需要实现迁移学习策略:
python
def transfer_learning_setup(original_model, client_modalities, num_classes, feature_dim=2048):
"""
迁移学习设置
调整模型以适应客户数据集的模态和类别
参数:
original_model: 预训练的原始模型
client_modalities: 客户数据集的模态列表
num_classes: 客户数据集的类别数量
feature_dim: 特征维度
返回:
adapted_model: 适应后的模型
"""
# 获取原始模型的模态
original_modalities = original_model.modalities
# 创建新模型
adapted_model = MuMoModel(
modalities=client_modalities,
num_classes=num_classes,
feature_dim=feature_dim
)
# 复制共享权重的部分
# 1. 复制重叠模态的编码器权重
common_modalities = set(original_modalities) & set(client_modalities)
for modality in common_modalities:
# 复制编码器权重
adapted_model.modality_encoders[modality].load_state_dict(
original_model.modality_encoders[modality].state_dict()
)
# 复制投影器权重
adapted_model.modality_projectors[modality].load_state_dict(
original_model.modality_projectors[modality].state_dict()
)
# 2. 复制融合块的权重
min_blocks = min(len(adapted_model.fusion_blocks), len(original_model.fusion_blocks))
for i in range(min_blocks):
adapted_model.fusion_blocks[i].load_state_dict(
original_model.fusion_blocks[i].state_dict()
)
# 3. 冻结共享权重 (可选)
for modality in common_modalities:
for param in adapted_model.modality_encoders[modality].parameters():
param.requires_grad = False
for param in adapted_model.modality_projectors[modality].parameters():
param.requires_grad = False
for i in range(min_blocks):
for param in adapted_model.fusion_blocks[i].parameters():
param.requires_grad = False
return adapted_model
class DomainAdaptationLoss(nn.Module):
"""
领域自适应损失
减少源域和目标域之间的分布差异
"""
def __init__(self, alpha=1.0, beta=1.0):
super(DomainAdaptationLoss, self).__init__()
self.alpha = alpha # 分类损失权重
self.beta = beta # 领域适应损失权重
self.classification_loss = nn.CrossEntropyLoss()
def mmd_loss(self, source_features, target_features):
"""
最大均值差异 (Maximum Mean Discrepancy) 损失
用于减少源域和目标域之间的分布差异
"""
# 计算MMD
source_mean = torch.mean(source_features, dim=0)
target_mean = torch.mean(target_features, dim=0)
mmd = torch.norm(source_mean - target_mean, p=2)
return mmd
def forward(self, outputs, targets, source_features, target_features):
"""
参数:
outputs: 模型输出logits
targets: 真实标签
source_features: 源域特征
target_features: 目标域特征
返回:
total_loss: 总损失
cls_loss: 分类损失
domain_loss: 领域适应损失
"""
# 分类损失
cls_loss = self.classification_loss(outputs, targets)
# 领域适应损失
domain_loss = self.mmd_loss(source_features, target_features)
# 总损失
total_loss = self.alpha * cls_loss + self.beta * domain_loss
return total_loss, cls_loss, domain_loss
def gradual_unfreezing(model, unfreeze_layers, epoch, total_epochs):
"""
渐进式解冻策略
随着训练进行,逐步解冻更多层
参数:
model: 模型
unfreeze_layers: 要解冻的层名称模式列表
epoch: 当前epoch
total_epochs: 总epoch数
"""
# 计算解冻比例
unfreeze_ratio = epoch / total_epochs
for name, param in model.named_parameters():
# 检查是否在解冻列表中
should_unfreeze = any(pattern in name for pattern in unfreeze_layers)
if should_unfreeze:
# 根据解冻比例决定是否解冻
layer_index = int(name.split('.')[1]) if len(name.split('.')) > 1 else 0
layer_unfreeze_threshold = layer_index / len(unfreeze_layers)
if unfreeze_ratio >= layer_unfreeze_threshold:
param.requires_grad = True
else:
param.requires_grad = False
def finetune_on_client_data(original_model, client_train_loader, client_val_loader,
client_modalities, num_classes, device, num_epochs=50):
"""
在客户数据上进行微调
"""
# 设置迁移学习模型
adapted_model = transfer_learning_setup(
original_model, client_modalities, num_classes
).to(device)
# 创建优化器 (只训练未冻结的参数)
trainable_params = filter(lambda p: p.requires_grad, adapted_model.parameters())
optimizer = optim.AdamW(trainable_params, lr=1e-4, weight_decay=1e-5)
# 损失函数
criterion = nn.CrossEntropyLoss()
# 训练记录
history = {
'train_loss': [], 'val_loss': [],
'train_acc': [], 'val_acc': [],
'train_f1': [], 'val_f1': []
}
best_val_acc = 0.0
for epoch in range(num_epochs):
print(f'Finetuning Epoch {epoch+1}/{num_epochs}')
# 渐进式解冻
if epoch > num_epochs // 2: # 后半程开始解冻
gradual_unfreezing(adapted_model, ['fusion_blocks', 'classifier'], epoch, num_epochs)
# 训练
adapted_model.train()
train_loss = 0.0
train_preds = []
train_targets = []
for data, targets in client_train_loader:
data = {modality: tensor.to(device) for modality, tensor in data.items()}
targets = targets.to(device)
optimizer.zero_grad()
outputs, _, _ = adapted_model(data)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
preds = torch.argmax(outputs, dim=1)
train_preds.extend(preds.cpu().numpy())
train_targets.extend(targets.cpu().numpy())
train_loss /= len(client_train_loader)
train_acc = accuracy_score(train_targets, train_preds)
train_f1 = f1_score(train_targets, train_preds, average='weighted')
# 验证
adapted_model.eval()
val_loss = 0.0
val_preds = []
val_targets = []
with torch.no_grad():
for data, targets in client_val_loader:
data = {modality: tensor.to(device) for modality, tensor in data.items()}
targets = targets.to(device)
outputs, _, _ = adapted_model(data)
loss = criterion(outputs, targets)
val_loss += loss.item()
preds = torch.argmax(outputs, dim=1)
val_preds.extend(preds.cpu().numpy())
val_targets.extend(targets.cpu().numpy())
val_loss /= len(client_val_loader)
val_acc = accuracy_score(val_targets, val_preds)
val_f1 = f1_score(val_targets, val_preds, average='weighted')
# 记录历史
history['train_loss'].append(train_loss)
history['val_loss'].append(val_loss)
history['train_acc'].append(train_acc)
history['val_acc'].append(val_acc)
history['train_f1'].append(train_f1)
history['val_f1'].append(val_f1)
# 保存最佳模型
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(adapted_model.state_dict(), 'best_finetuned_model.pth')
print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}')
print(f'Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}')
print(f'Train F1: {train_f1:.4f} | Val F1: {val_f1:.4f}')
print()
# 加载最佳模型
adapted_model.load_state_dict(torch.load('best_finetuned_model.pth'))
return adapted_model, history
7. 模型解释性与可视化
为了提高模型的可解释性,我们实现了一些可视化方法:
python
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
import torch.nn.functional as F
from captum.attr import IntegratedGradients, Occlusion, LayerGradCam
from captum.attr import visualization as viz
class ModelInterpreter:
"""
模型解释器
提供模型预测的可视化和解释
"""
def __init__(self, model, modalities, device):
self.model = model
self.modalities = modalities
self.device = device
self.model.eval()
def compute_feature_importance(self, input_data, target_class=None):
"""
计算各模态特征的重要性
"""
# 确保输入数据在设备上
input_data = {modality: tensor.to(self.device) for modality, tensor in input_data.items()}
# 获取模型输出
outputs, modality_features, fused_features = self.model(input_data)
if target_class is None:
target_class = torch.argmax(outputs, dim=1).item()
# 计算每个模态对最终决策的贡献
modality_contributions = {}
for modality, features in modality_features.items():
# 计算梯度
features.requires_grad_(True)
gradients = torch.autograd.grad(
outputs[:, target_class].sum(),
features,
retain_graph=True
)[0]
# 计算特征重要性
importance = torch.norm(gradients, p=2, dim=1)
modality_contributions[modality] = importance.mean().item()
return modality_contributions, target_class
def visualize_modality_contributions(self, input_data, target_class=None):
"""
可视化各模态的贡献度
"""
contributions, pred_class = self.compute_feature_importance(input_data, target_class)
# 创建条形图
modalities = list(contributions.keys())
values = list(contributions.values())
plt.figure(figsize=(10, 6))
bars = plt.bar(modalities, values)
plt.title(f'Modality Contributions for Class {pred_class}')
plt.ylabel('Contribution Score')
plt.xlabel('Modality')
# 添加数值标签
for bar, value in zip(bars, values):
plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
f'{value:.4f}', ha='center', va='bottom')
plt.tight_layout()
plt.savefig('modality_contributions.png')
plt.close()
return contributions, pred_class
def generate_attention_maps(self, input_data, target_class=None):
"""
生成注意力图,显示模型关注区域
"""
# 确保输入数据在设备上
input_data = {modality: tensor.to(self.device) for modality, tensor in input_data.items()}
# 获取模型输出
outputs, _, fused_features = self.model(input_data)
if target_class is None:
target_class = torch.argmax(outputs, dim=1).item()
# 计算梯度
gradients = torch.autograd.grad(
outputs[:, target_class].sum(),
fused_features,
retain_graph=True
)[0]
# 计算注意力权重
attention_weights = torch.mean(gradients, dim=2) # (B, N)
attention_weights = F.softmax(attention_weights, dim=1)
return attention_weights.squeeze(0).cpu().detach().numpy(), target_class
def visualize_attention_maps(self, input_data, target_class=None):
"""
可视化注意力图
"""
attention_weights, pred_class = self.generate_attention_maps(input_data, target_class)
# 创建热力图
plt.figure(figsize=(10, 6))
sns.heatmap(attention_weights.reshape(1, -1),
annot=True, fmt='.3f', cmap='viridis',
xticklabels=self.modalities, yticklabels=['Attention'])
plt.title(f'Attention Weights for Class {pred_class}')
plt.tight_layout()
plt.savefig('attention_weights.png')
plt.close()
return attention_weights, pred_class
def occlusion_sensitivity(self, input_data, target_class=None, patch_size=10):
"""
使用遮挡法分析模型敏感性
"""
# 确保输入数据在设备上
input_data = {modality: tensor.to(self.device) for modality, tensor in input_data.items()}
# 获取模型输出
outputs, _, _ = self.model(input_data)
if target_class is None:
target_class = torch.argmax(outputs, dim=1).item()
baseline_output = outputs[:, target_class].item()
# 对每个模态进行遮挡分析
sensitivity_results = {}
for modality in self.modalities:
modality_data = input_data[modality]
_, _, depth, height, width = modality_data.shape
# 创建遮挡器
occlusion = Occlusion(self.model)
# 定义遮挡策略
strides = (1, patch_size, patch_size)
sliding_window_shapes = (1, patch_size, patch_size)
# 计算遮挡属性
attributions = occlusion.attribute(
modality_data,
strides=strides,
target=target_class,
sliding_window_shapes=sliding_window_shapes,
baselines=0
)
sensitivity_results[modality] = attributions.cpu().detach().numpy()
return sensitivity_results, target_class, baseline_output
def visualize_occlusion_sensitivity(sensitivity_results, modality, slice_idx=0):
"""
可视化遮挡敏感性分析结果
"""
if modality not in sensitivity_results:
print(f"Modality {modality} not found in results")
return
# 获取指定模态和切片的结果
data = sensitivity_results[modality][0, 0, slice_idx] # (H, W)
# 创建热力图
plt.figure(figsize=(10, 8))
plt.imshow(data, cmap='hot', interpolation='nearest')
plt.colorbar()
plt.title(f'Occlusion Sensitivity for {modality} (Slice {slice_idx})')
plt.tight_layout()
plt.savefig(f'occlusion_sensitivity_{modality}_slice{slice_idx}.png')
plt.close()
# 示例使用
def demonstrate_interpretability(model, test_loader, device, modalities):
"""
演示模型解释性功能
"""
# 获取一个测试样本
data_iter = iter(test_loader)
sample_data, sample_target = next(data_iter)
# 创建解释器
interpreter = ModelInterpreter(model, modalities, device)
# 可视化模态贡献
print("Visualizing modality contributions...")
contributions, pred_class = interpreter.visualize_modality_contributions(sample_data)
print(f"Predicted class: {pred_class}")
print("Modality contributions:", contributions)
# 可视化注意力权重
print("Visualizing attention weights...")
attention_weights, _ = interpreter.visualize_attention_maps(sample_data)
print("Attention weights:", attention_weights)
# 进行遮挡敏感性分析 (可选,计算量较大)
print("Performing occlusion sensitivity analysis...")
sensitivity_results, _, _ = interpreter.occlusion_sensitivity(sample_data)
# 可视化第一个模态的敏感性
if modalities[0] in sensitivity_results:
visualize_occlusion_sensitivity(sensitivity_results, modalities[0])
return contributions, attention_weights, sensitivity_results
8. 完整训练与评估流程
下面是完整的训练和评估流程:
python
def main():
"""
主函数:完整的训练和评估流程
"""
# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
# 设置超参数
config = {
'data_dir': '/path/to/client/data',
'metadata_file': '/path/to/metadata.csv',
'modalities': ['CT', 'PET'], # 客户数据的模态
'num_classes': 2, # 客户数据的类别数
'batch_size': 4,
'num_workers': 4,
'feature_dim': 2048,
'num_fusion_blocks': 3,
'num_heads': 8,
'dropout': 0.1,
'learning_rate': 1e-4,
'weight_decay': 1e-5,
'num_epochs': 100,
'checkpoint_dir': 'checkpoints'
}
# 创建数据加载器
print("准备数据加载器...")
train_loader, val_loader, test_loader = get_data_loaders(
config['data_dir'],
config['metadata_file'],
config['modalities'],
config['batch_size'],
config['num_workers']
)
# 创建模型
print("初始化模型...")
model = MuMoModel(
modalities=config['modalities'],
num_classes=config['num_classes'],
feature_dim=config['feature_dim'],
num_fusion_blocks=config['num_fusion_blocks'],
num_heads=config['num_heads'],
dropout=config['dropout']
).to(device)
# 如果有预训练权重,加载预训练权重
pretrained_path = '/path/to/pretrained/model.pth'
if os.path.exists(pretrained_path):
print("加载预训练权重...")
model.load_state_dict(torch.load(pretrained_path, map_location=device))
# 创建训练器
print("创建训练器...")
trainer = Trainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
num_classes=config['num_classes'],
learning_rate=config['learning_rate'],
weight_decay=config['weight_decay'],
checkpoint_dir=config['checkpoint_dir']
)
# 训练模型
print("开始训练...")
history = trainer.train(num_epochs=config['num_epochs'])
# 绘制训练历史
trainer.plot_training_history()
# 在测试集上评估模型
print("在测试集上评估模型...")
results = trainer.evaluate(test_loader)
print(f"测试准确率: {results['test_acc']:.4f}")
print(f"测试F1分数: {results['test_f1']:.4f}")
print(f"测试AUC: {results['test_auc']:.4f}")
# 保存最终模型
torch.save(model.state_dict(), os.path.join(config['checkpoint_dir'], 'final_model.pth'))
# 模型解释性分析
print("进行模型解释性分析...")
contributions, attention_weights, sensitivity_results = demonstrate_interpretability(
model, test_loader, device, config['modalities']
)
# 保存结果
results_df = pd.DataFrame({
'accuracy': [results['test_acc']],
'f1_score': [results['test_f1']],
'auc': [results['test_auc']]
})
results_df.to_csv(os.path.join(config['checkpoint_dir'], 'evaluation_results.csv'), index=False)
# 保存模态贡献度
contributions_df = pd.DataFrame(list(contributions.items()), columns=['modality', 'contribution'])
contributions_df.to_csv(os.path.join(config['checkpoint_dir'], 'modality_contributions.csv'), index=False)
print("训练和评估完成!")
if __name__ == "__main__":
main()
9. 部署与推理接口
为了方便实际应用,我们提供了一个简单的推理接口:
python
import torch
import numpy as np
from PIL import Image
import json
class TumorClassificationAPI:
"""
肿瘤分类API
提供简单的推理接口
"""
def __init__(self, model_path, config_path, device='cuda' if torch.cuda.is_available() else 'cpu'):
self.device = device
# 加载配置
with open(config_path, 'r') as f:
self.config = json.load(f)
# 加载模型
self.model = MuMoModel(
modalities=self.config['modalities'],
num_classes=self.config['num_classes'],
feature_dim=self.config.get('feature_dim', 2048),
num_fusion_blocks=self.config.get('num_fusion_blocks', 3),
num_heads=self.config.get('num_heads', 8),
dropout=self.config.get('dropout', 0.1)
).to(device)
self.model.load_state_dict(torch.load(model_path, map_location=device))
self.model.eval()
# 加载类别标签映射
self.class_names = self.config.get('class_names', [f'Class_{i}' for i in range(self.config['num_classes'])])
def preprocess_image(self, image_path, modality):
"""
预处理医学影像
"""
# 根据模态使用不同的预处理方法
if modality in ['CT', 'MRI']:
# 加载NIfTI文件
img = nib.load(image_path).get_fdata()
else:
# 加载普通图像
img = Image.open(image_path)
img = np.array(img)
# 标准化和调整大小
img = self._normalize_image(img, modality)
img = self._resize_image(img, self.config['image_size'])
# 添加批次和通道维度
img = np.expand_dims(img, axis=0) # 添加通道维度
img = np.expand_dims(img, axis=0) # 添加批次维度
return torch.from_numpy(img).float()
def _normalize_image(self, img, modality):
"""
标准化图像
"""
if modality == 'CT':
# CT图像通常使用Hounsfield单位,需要特殊处理
img = np.clip(img, -1000, 1000) # 裁剪到典型CT值范围
img = (img + 1000) / 2000 # 归一化到[0, 1]
elif modality == 'PET':
# PET图像标准化
img = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-8)
else:
# 通用标准化
img = (img - np.mean(img)) / (np.std(img) + 1e-8)
return img
def _resize_image(self, img, target_size):
"""
调整图像大小
"""
if len(img.shape) == 3: # 3D体积
# 使用插值调整大小
from scipy.ndimage import zoom
factors = [t/o for t, o in zip(target_size, img.shape)]
img = zoom(img, factors, order=1) # 线性插值
else: # 2D图像
from skimage.transform import resize
img = resize(img, target_size, preserve_range=True)
return img
def predict(self, image_paths):
"""
预测肿瘤类别
参数:
image_paths: 字典,键为模态,值为图像路径
返回:
predictions: 预测结果
probabilities: 各类别概率
"""
# 预处理输入图像
input_data = {}
for modality, path in image_paths.items():
if modality not in self.config['modalities']:
raise ValueError(f"不支持的模态: {modality}")
input_data[modality] = self.preprocess_image(path, modality).to(self.device)
# 进行预测
with torch.no_grad():
outputs, _, _ = self.model(input_data)
probabilities = torch.softmax(outputs, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
# 转换为numpy数组
probabilities = probabilities.cpu().numpy()[0]
# 创建结果字典
result = {
'predicted_class': self.class_names[predicted_class],
'class_probabilities': {
self.class_names[i]: float(probabilities[i])
for i in range(len(self.class_names))
},
'confidence': float(probabilities[predicted_class])
}
return result
# 示例使用
def example_usage():
# 初始化API
api = TumorClassificationAPI(
model_path='checkpoints/best_model.pth',
config_path='model_config.json'
)
# 准备输入数据
image_paths = {
'CT': '/path/to/ct/image.nii.gz',
'PET': '/path/to/pet/image.png'
}
# 进行预测
result = api.predict(image_paths)
print(f"预测类别: {result['predicted_class']}")
print(f"置信度: {result['confidence']:.4f}")
print("各类别概率:")
for class_name, prob in result['class_probabilities'].items():
print(f" {class_name}: {prob:.4f}")
if __name__ == "__main__":
example_usage()
10. 总结与展望
本项目成功复现了北京大学团队的MuMo多模态肿瘤分类模型,并在客户提供的数据集上实现了迁移学习。通过实现完整的数据预处理管道、模型架构、训练策略和评估方法,我们建立了一个强大的多模态肿瘤分类系统。
主要成果:
-
完整模型复现:实现了MuMo模型的核心组件,包括模态特定编码器、跨模态注意力机制和多模态融合块。
-
迁移学习实现:设计了针对客户数据的迁移学习策略,包括权重转移、渐进式解冻和领域自适应损失。
-
模型解释性:实现了多种可视化方法,帮助理解模型的决策过程和各模态的贡献度。
-
完整管道:从数据预处理到模型部署,提供了完整的端到端解决方案。
未来工作方向:
-
多模态数据融合策略优化:探索更先进的多模态融合方法,如基于Transformer的融合机制。
-
领域自适应改进:开发更强大的领域自适应技术,减少源域和目标域之间的分布差异。
-
模型轻量化:优化模型结构,减少计算资源需求,便于在临床环境中部署。
-
可解释性增强:开发更先进的可视化工具,提供更直观的模型决策解释。
-
多中心验证:在多个医疗中心的数据上验证模型性能,确保其泛化能力。
通过本项目的实施,我们不仅复现了先进的MuMo模型,还为其在实际医疗场景中的应用奠定了基础,为多模态医学影像分析领域的发展做出了贡献。