北京大学MuMo多模态肿瘤分类模型复现与迁移学习

北京大学MuMo多模态肿瘤分类模型复现与迁移学习

1. 项目背景与意义

肿瘤分类是医学影像分析中的核心任务之一,对于癌症的早期诊断、治疗规划和预后评估具有重要意义。传统的单模态医学影像分析往往受限于单一信息源的局限性,难以全面捕捉肿瘤的复杂特征。多模态学习通过整合不同成像技术(如CT、MRI、PET等)提供的互补信息,能够显著提高肿瘤分类的准确性。

北京大学团队提出的MuMo(Multimodal tumor classification model)模型是一种先进的多模态肿瘤分类框架,能够有效整合不同模态的医学影像数据,并在多种肿瘤分类任务上表现出优异性能。本项目旨在复现该模型,并在客户提供的特定数据集上进行迁移学习,以适应实际临床应用需求。

2. 环境配置与依赖安装

首先,我们需要配置合适的Python环境并安装必要的依赖库:

python 复制代码
# 环境要求
# Python 3.8+
# PyTorch 1.9+
# CUDA 11.1+ (如果使用GPU加速)

# 创建并激活conda环境
# conda create -n mumo python=3.8
# conda activate mumo

# 安装核心依赖
# pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
# pip install torchaudio

# 安装其他必要库
# pip install numpy pandas scikit-learn matplotlib seaborn opencv-python pillow
# pip install scipy tqdm tensorboard monai nibabel SimpleITK

# 安装特定版本库以确保兼容性
# pip install einops==0.3.2 timm==0.4.12

# 验证安装
import torch
import torchvision
import numpy as np
import pandas as pd

print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
print(f"CUDA版本: {torch.version.cuda}")
print(f"GPU设备: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else '无'}")

3. 数据预处理与加载

多模态医学影像数据通常需要特殊的预处理流程,以确保不同模态间的一致性和可比性。

python 复制代码
import os
import numpy as np
import pandas as pd
from PIL import Image
import cv2
import nibabel as nib
import SimpleITK as sitk
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from monai.transforms import Compose, LoadImage, AddChannel, ScaleIntensity, Resize, ToTensor

class MultiModalTumorDataset(Dataset):
    """
    多模态肿瘤数据集类
    处理不同模态的医学影像数据
    """
    def __init__(self, data_dir, metadata_file, modalities=['CT', 'PET'], transform=None, phase='train'):
        """
        初始化数据集
        
        参数:
            data_dir: 数据目录路径
            metadata_file: 包含样本信息和标签的CSV文件
            modalities: 使用的模态列表
            transform: 数据增强和预处理变换
            phase: 阶段 ('train', 'val', 'test')
        """
        self.data_dir = data_dir
        self.metadata = pd.read_csv(metadata_file)
        self.modalities = modalities
        self.transform = transform
        self.phase = phase
        
        # 过滤数据,只保留存在的文件
        self.samples = []
        for idx, row in self.metadata.iterrows():
            paths_exist = True
            for modality in modalities:
                img_path = os.path.join(data_dir, row['patient_id'], f"{modality}.nii.gz")
                if not os.path.exists(img_path):
                    paths_exist = False
                    break
            
            if paths_exist:
                self.samples.append({
                    'patient_id': row['patient_id'],
                    'label': row['label'],
                    'paths': {modality: os.path.join(data_dir, row['patient_id'], f"{modality}.nii.gz") 
                             for modality in modalities}
                })
        
        print(f"{phase}数据集大小: {len(self.samples)}")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        label = sample['label']
        
        # 加载多模态数据
        images = {}
        for modality, path in sample['paths'].items():
            # 使用MONAI加载医学影像
            img = LoadImage()(path)
            
            # 添加通道维度 (H, W, D) -> (C, H, W, D)
            img = AddChannel()(img)
            
            # 强度归一化
            img = ScaleIntensity()(img)
            
            # 调整大小 (根据模型需求)
            img = Resize((128, 128, 128))(img)
            
            images[modality] = img
        
        # 应用数据增强 (仅在训练阶段)
        if self.transform and self.phase == 'train':
            for modality in images:
                images[modality] = self.transform(images[modality])
        
        # 转换为张量
        to_tensor = ToTensor()
        for modality in images:
            images[modality] = to_tensor(images[modality])
        
        return images, torch.tensor(label, dtype=torch.long)

def get_data_loaders(data_dir, metadata_file, modalities, batch_size=4, num_workers=4):
    """
    获取训练、验证和测试数据加载器
    """
    # 读取元数据
    metadata = pd.read_csv(metadata_file)
    
    # 划分训练集、验证集和测试集
    train_val_ids, test_ids = train_test_split(
        metadata['patient_id'].unique(), 
        test_size=0.2, 
        random_state=42,
        stratify=metadata.groupby('patient_id')['label'].first()
    )
    
    train_ids, val_ids = train_test_split(
        train_val_ids, 
        test_size=0.25, 
        random_state=42,
        stratify=metadata[metadata['patient_id'].isin(train_val_ids)].groupby('patient_id')['label'].first()
    )
    
    # 创建子集元数据
    train_metadata = metadata[metadata['patient_id'].isin(train_ids)]
    val_metadata = metadata[metadata['patient_id'].isin(val_ids)]
    test_metadata = metadata[metadata['patient_id'].isin(test_ids)]
    
    # 保存临时元数据文件
    train_metadata.to_csv('train_metadata.csv', index=False)
    val_metadata.to_csv('val_metadata.csv', index=False)
    test_metadata.to_csv('test_metadata.csv', index=False)
    
    # 数据增强 (仅用于训练)
    train_transform = Compose([
        # 可以添加医学影像特定的数据增强
        # 如随机旋转、翻转等
    ])
    
    # 创建数据集
    train_dataset = MultiModalTumorDataset(
        data_dir, 'train_metadata.csv', modalities, train_transform, 'train'
    )
    val_dataset = MultiModalTumorDataset(
        data_dir, 'val_metadata.csv', modalities, None, 'val'
    )
    test_dataset = MultiModalTumorDataset(
        data_dir, 'test_metadata.csv', modalities, None, 'test'
    )
    
    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )
    
    return train_loader, val_loader, test_loader

4. MuMo模型架构实现

MuMo模型的核心思想是通过多模态融合机制整合不同模态的特征信息。以下是模型的完整实现:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import timm

class ModalitySpecificEncoder(nn.Module):
    """
    模态特定编码器
    每个模态使用独立的编码器提取特征
    """
    def __init__(self, modality, backbone_name='resnet50', pretrained=True):
        super(ModalitySpecificEncoder, self).__init__()
        self.modality = modality
        
        # 使用预训练的2D CNN backbone
        self.backbone = timm.create_model(
            backbone_name, 
            pretrained=pretrained, 
            in_chans=1,  # 医学影像通常是单通道
            num_classes=0  # 不包含分类头
        )
        
        # 获取backbone的特征维度
        self.feature_dim = self.backbone.num_features
        
        # 3D卷积适配器 (将3D体积转换为2D切片集合)
        self.volume_adapter = nn.Sequential(
            nn.Conv3d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True),
            nn.Conv3d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool3d((16, 1, 1))  # (C, D, H, W) -> (C, 16, 1, 1)
        )
        
    def forward(self, x):
        """
        输入: (B, C, D, H, W)
        输出: (B, feature_dim)
        """
        batch_size, _, depth, height, width = x.shape
        
        # 使用3D适配器处理体积数据
        x = self.volume_adapter(x)  # (B, 64, 16, 1, 1)
        x = x.squeeze(-1).squeeze(-1)  # (B, 64, 16)
        x = x.permute(0, 2, 1)  # (B, 16, 64)
        
        # 处理每个切片
        slice_features = []
        for i in range(x.size(1)):
            slice_img = x[:, i, :]  # (B, 64)
            slice_img = slice_img.unsqueeze(-1).unsqueeze(-1)  # (B, 64, 1, 1)
            
            # 通过backbone提取特征
            features = self.backbone(slice_img)  # (B, feature_dim)
            slice_features.append(features)
        
        # 聚合切片特征
        slice_features = torch.stack(slice_features, dim=1)  # (B, 16, feature_dim)
        aggregated_features = torch.mean(slice_features, dim=1)  # (B, feature_dim)
        
        return aggregated_features

class CrossModalAttention(nn.Module):
    """
    跨模态注意力机制
    实现模态间的特征交互和融合
    """
    def __init__(self, feature_dim, num_heads=8, dropout=0.1):
        super(CrossModalAttention, self).__init__()
        self.feature_dim = feature_dim
        self.num_heads = num_heads
        self.head_dim = feature_dim // num_heads
        
        assert self.head_dim * num_heads == feature_dim, "特征维度必须能被头数整除"
        
        # 查询、键、值的线性变换
        self.q_linear = nn.Linear(feature_dim, feature_dim)
        self.k_linear = nn.Linear(feature_dim, feature_dim)
        self.v_linear = nn.Linear(feature_dim, feature_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.out_linear = nn.Linear(feature_dim, feature_dim)
        
    def forward(self, queries, keys, values, modality_mask=None):
        """
        参数:
            queries: (B, N, D) 查询向量,N是模态数量
            keys: (B, M, D) 键向量,M是模态数量
            values: (B, M, D) 值向量,M是模态数量
            modality_mask: (B, N, M) 模态间注意力掩码
            
        返回:
            attended_features: (B, N, D) 经过注意力加权的特征
        """
        batch_size, num_queries, _ = queries.shape
        _, num_keys, _ = keys.shape
        
        # 线性变换并分头
        Q = self.q_linear(queries).view(batch_size, num_queries, self.num_heads, self.head_dim)
        K = self.k_linear(keys).view(batch_size, num_keys, self.num_heads, self.head_dim)
        V = self.v_linear(values).view(batch_size, num_keys, self.num_heads, self.head_dim)
        
        # 转置以获得正确的维度 (B, num_heads, num_queries, head_dim)
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)
        
        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)  # (B, num_heads, num_queries, num_keys)
        
        # 应用注意力掩码 (如果提供)
        if modality_mask is not None:
            modality_mask = modality_mask.unsqueeze(1)  # (B, 1, num_queries, num_keys)
            scores = scores.masked_fill(modality_mask == 0, -1e9)
        
        # 计算注意力权重
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # 应用注意力权重
        attended = torch.matmul(attention_weights, V)  # (B, num_heads, num_queries, head_dim)
        
        # 合并多头
        attended = attended.transpose(1, 2).contiguous().view(
            batch_size, num_queries, self.feature_dim
        )  # (B, num_queries, D)
        
        # 输出线性变换
        output = self.out_linear(attended)
        
        return output

class MultimodalFusionBlock(nn.Module):
    """
    多模态融合块
    整合不同模态的特征
    """
    def __init__(self, feature_dim, num_modalities, num_heads=8, dropout=0.1):
        super(MultimodalFusionBlock, self).__init__()
        self.feature_dim = feature_dim
        self.num_modalities = num_modalities
        
        # 跨模态注意力
        self.cross_attention = CrossModalAttention(feature_dim, num_heads, dropout)
        
        # 层归一化
        self.norm1 = nn.LayerNorm(feature_dim)
        self.norm2 = nn.LayerNorm(feature_dim)
        
        # 前馈网络
        self.ffn = nn.Sequential(
            nn.Linear(feature_dim, feature_dim * 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(feature_dim * 4, feature_dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, modality_features, modality_mask=None):
        """
        参数:
            modality_features: (B, N, D) 各模态的特征,N是模态数量
            modality_mask: (B, N, N) 模态间注意力掩码
            
        返回:
            fused_features: (B, N, D) 融合后的特征
        """
        # 跨模态注意力
        attended = self.cross_attention(
            modality_features, modality_features, modality_features, modality_mask
        )
        
        # 残差连接和层归一化
        x = self.norm1(modality_features + attended)
        
        # 前馈网络
        ffn_output = self.ffn(x)
        
        # 残差连接和层归一化
        output = self.norm2(x + ffn_output)
        
        return output

class MuMoModel(nn.Module):
    """
    MuMo多模态肿瘤分类模型
    """
    def __init__(self, modalities, num_classes, feature_dim=2048, num_fusion_blocks=3, num_heads=8, dropout=0.1):
        super(MuMoModel, self).__init__()
        self.modalities = modalities
        self.num_modalities = len(modalities)
        self.num_classes = num_classes
        self.feature_dim = feature_dim
        
        # 模态特定编码器
        self.modality_encoders = nn.ModuleDict({
            modality: ModalitySpecificEncoder(modality) for modality in modalities
        })
        
        # 模态特定投影层 (统一特征维度)
        self.modality_projectors = nn.ModuleDict({
            modality: nn.Linear(
                self.modality_encoders[modality].feature_dim, 
                feature_dim
            ) for modality in modalities
        })
        
        # 模态嵌入 (区分不同模态)
        self.modality_embeddings = nn.Parameter(
            torch.randn(self.num_modalities, feature_dim)
        )
        
        # 多模态融合块
        self.fusion_blocks = nn.ModuleList([
            MultimodalFusionBlock(feature_dim, self.num_modalities, num_heads, dropout)
            for _ in range(num_fusion_blocks)
        ])
        
        # 分类头
        self.classifier = nn.Sequential(
            nn.LayerNorm(feature_dim),
            nn.Dropout(dropout),
            nn.Linear(feature_dim, feature_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(feature_dim // 2, num_classes)
        )
        
    def forward(self, x, modality_mask=None):
        """
        参数:
            x: 字典,键为模态名,值为对应模态的输入张量 (B, C, D, H, W)
            modality_mask: (B, N, N) 模态间注意力掩码
            
        返回:
            logits: (B, num_classes) 分类logits
            modality_features: 各模态的特征
            fused_features: 融合后的特征
        """
        batch_size = next(iter(x.values())).size(0)
        
        # 提取各模态特征
        modality_features = {}
        for modality in self.modalities:
            features = self.modality_encoders[modality](x[modality])  # (B, encoder_feature_dim)
            features = self.modality_projectors[modality](features)  # (B, feature_dim)
            modality_features[modality] = features
        
        # 组合模态特征
        combined_features = torch.stack(
            [modality_features[modality] for modality in self.modalities], 
            dim=1
        )  # (B, N, D)
        
        # 添加模态嵌入
        modality_embeds = self.modality_embeddings.unsqueeze(0).expand(
            batch_size, -1, -1
        )  # (B, N, D)
        combined_features = combined_features + modality_embeds
        
        # 多模态融合
        fused_features = combined_features
        for fusion_block in self.fusion_blocks:
            fused_features = fusion_block(fused_features, modality_mask)
        
        # 全局平均池化 (跨模态维度)
        global_features = torch.mean(fused_features, dim=1)  # (B, D)
        
        # 分类
        logits = self.classifier(global_features)  # (B, num_classes)
        
        return logits, modality_features, fused_features

5. 训练策略与损失函数

多模态学习需要特殊的训练策略和损失函数设计:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

class FocalLoss(nn.Module):
    """
    Focal Loss用于处理类别不平衡问题
    """
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, inputs, targets):
        BCE_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
        
        if self.reduction == 'mean':
            return torch.mean(F_loss)
        elif self.reduction == 'sum':
            return torch.sum(F_loss)
        else:
            return F_loss

class MultimodalLoss(nn.Module):
    """
    多模态损失函数
    结合分类损失和模态一致性损失
    """
    def __init__(self, alpha=1.0, beta=0.1, gamma=2.0):
        super(MultimodalLoss, self).__init__()
        self.alpha = alpha  # 分类损失权重
        self.beta = beta    # 一致性损失权重
        self.classification_loss = FocalLoss(gamma=gamma)
        
    def consistency_loss(self, modality_features, fused_features):
        """
        计算模态特征与融合特征间的一致性损失
        """
        loss = 0
        num_modalities = len(modality_features)
        
        for modality, features in modality_features.items():
            # 计算余弦相似度
            similarity = F.cosine_similarity(
                features, fused_features, dim=-1
            )
            # 最大化相似度 (最小化负相似度)
            loss += torch.mean(1 - similarity)
        
        return loss / num_modalities
    
    def forward(self, outputs, targets, modality_features, fused_features):
        """
        参数:
            outputs: 模型输出logits
            targets: 真实标签
            modality_features: 各模态的特征
            fused_features: 融合后的特征
        
        返回:
            total_loss: 总损失
            cls_loss: 分类损失
            cons_loss: 一致性损失
        """
        # 分类损失
        cls_loss = self.classification_loss(outputs, targets)
        
        # 提取全局融合特征 (平均池化)
        global_fused = torch.mean(fused_features, dim=1)
        
        # 一致性损失
        cons_loss = self.consistency_loss(modality_features, global_fused)
        
        # 总损失
        total_loss = self.alpha * cls_loss + self.beta * cons_loss
        
        return total_loss, cls_loss, cons_loss

class Trainer:
    """
    模型训练器
    """
    def __init__(self, model, train_loader, val_loader, device, num_classes, 
                 learning_rate=1e-4, weight_decay=1e-5, checkpoint_dir='checkpoints'):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.num_classes = num_classes
        self.checkpoint_dir = checkpoint_dir
        
        # 创建检查点目录
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        # 损失函数
        self.criterion = MultimodalLoss()
        
        # 优化器
        self.optimizer = optim.AdamW(
            model.parameters(), 
            lr=learning_rate, 
            weight_decay=weight_decay
        )
        
        # 学习率调度器
        self.scheduler = ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.5, patience=5, verbose=True
        )
        
        # 记录训练历史
        self.history = {
            'train_loss': [], 'val_loss': [],
            'train_acc': [], 'val_acc': [],
            'train_f1': [], 'val_f1': [],
            'learning_rate': []
        }
    
    def train_epoch(self):
        """训练一个epoch"""
        self.model.train()
        running_loss = 0.0
        running_cls_loss = 0.0
        running_cons_loss = 0.0
        all_preds = []
        all_targets = []
        
        for batch_idx, (data, targets) in enumerate(self.train_loader):
            # 移动到设备
            data = {modality: tensor.to(self.device) for modality, tensor in data.items()}
            targets = targets.to(self.device)
            
            # 清零梯度
            self.optimizer.zero_grad()
            
            # 前向传播
            outputs, modality_features, fused_features = self.model(data)
            
            # 计算损失
            loss, cls_loss, cons_loss = self.criterion(
                outputs, targets, modality_features, fused_features
            )
            
            # 反向传播
            loss.backward()
            
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            # 更新参数
            self.optimizer.step()
            
            # 记录损失
            running_loss += loss.item()
            running_cls_loss += cls_loss.item()
            running_cons_loss += cons_loss.item()
            
            # 记录预测和标签
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
            
            # 打印进度
            if batch_idx % 50 == 0:
                print(f'Batch {batch_idx}/{len(self.train_loader)}, '
                      f'Loss: {loss.item():.4f}, '
                      f'Cls Loss: {cls_loss.item():.4f}, '
                      f'Cons Loss: {cons_loss.item():.4f}')
        
        # 计算epoch指标
        epoch_loss = running_loss / len(self.train_loader)
        epoch_cls_loss = running_cls_loss / len(self.train_loader)
        epoch_cons_loss = running_cons_loss / len(self.train_loader)
        
        epoch_acc = accuracy_score(all_targets, all_preds)
        epoch_f1 = f1_score(all_targets, all_preds, average='weighted')
        
        return epoch_loss, epoch_cls_loss, epoch_cons_loss, epoch_acc, epoch_f1
    
    def validate(self):
        """验证模型"""
        self.model.eval()
        running_loss = 0.0
        running_cls_loss = 0.0
        running_cons_loss = 0.0
        all_preds = []
        all_targets = []
        all_probs = []
        
        with torch.no_grad():
            for data, targets in self.val_loader:
                # 移动到设备
                data = {modality: tensor.to(self.device) for modality, tensor in data.items()}
                targets = targets.to(self.device)
                
                # 前向传播
                outputs, modality_features, fused_features = self.model(data)
                
                # 计算损失
                loss, cls_loss, cons_loss = self.criterion(
                    outputs, targets, modality_features, fused_features
                )
                
                # 记录损失
                running_loss += loss.item()
                running_cls_loss += cls_loss.item()
                running_cons_loss += cons_loss.item()
                
                # 记录预测和标签
                probs = F.softmax(outputs, dim=1)
                preds = torch.argmax(outputs, dim=1)
                
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())
        
        # 计算epoch指标
        epoch_loss = running_loss / len(self.val_loader)
        epoch_cls_loss = running_cls_loss / len(self.val_loader)
        epoch_cons_loss = running_cons_loss / len(self.val_loader)
        
        epoch_acc = accuracy_score(all_targets, all_preds)
        epoch_f1 = f1_score(all_targets, all_preds, average='weighted')
        
        # 计算AUC (多分类)
        if self.num_classes > 2:
            auc = roc_auc_score(all_targets, all_probs, multi_class='ovr')
        else:
            auc = roc_auc_score(all_targets, all_probs[:, 1])
        
        return epoch_loss, epoch_cls_loss, epoch_cons_loss, epoch_acc, epoch_f1, auc, all_probs, all_preds, all_targets
    
    def train(self, num_epochs=100, early_stopping_patience=10):
        """完整训练过程"""
        best_val_loss = float('inf')
        patience_counter = 0
        
        for epoch in range(num_epochs):
            print(f'Epoch {epoch+1}/{num_epochs}')
            print('-' * 50)
            
            # 训练
            train_loss, train_cls_loss, train_cons_loss, train_acc, train_f1 = self.train_epoch()
            
            # 验证
            val_loss, val_cls_loss, val_cons_loss, val_acc, val_f1, val_auc, _, _, _ = self.validate()
            
            # 更新学习率
            self.scheduler.step(val_loss)
            current_lr = self.optimizer.param_groups[0]['lr']
            
            # 记录历史
            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            self.history['train_acc'].append(train_acc)
            self.history['val_acc'].append(val_acc)
            self.history['train_f1'].append(train_f1)
            self.history['val_f1'].append(val_f1)
            self.history['learning_rate'].append(current_lr)
            
            # 打印指标
            print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}')
            print(f'Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}')
            print(f'Train F1: {train_f1:.4f} | Val F1: {val_f1:.4f}')
            print(f'Val AUC: {val_auc:.4f}')
            print(f'Learning Rate: {current_lr:.6f}')
            
            # 保存最佳模型
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                self.save_checkpoint('best_model.pth')
                print('保存新的最佳模型')
            else:
                patience_counter += 1
                print(f'早停计数器: {patience_counter}/{early_stopping_patience}')
            
            # 早停检查
            if patience_counter >= early_stopping_patience:
                print('早停触发')
                break
            
            print()
        
        # 加载最佳模型
        self.load_checkpoint('best_model.pth')
        
        return self.history
    
    def save_checkpoint(self, filename):
        """保存检查点"""
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'history': self.history
        }
        torch.save(checkpoint, os.path.join(self.checkpoint_dir, filename))
    
    def load_checkpoint(self, filename):
        """加载检查点"""
        checkpoint = torch.load(os.path.join(self.checkpoint_dir, filename))
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.history = checkpoint['history']
    
    def plot_training_history(self):
        """绘制训练历史"""
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # 损失曲线
        axes[0, 0].plot(self.history['train_loss'], label='Train Loss')
        axes[0, 0].plot(self.history['val_loss'], label='Validation Loss')
        axes[0, 0].set_title('Training and Validation Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].legend()
        
        # 准确率曲线
        axes[0, 1].plot(self.history['train_acc'], label='Train Accuracy')
        axes[0, 1].plot(self.history['val_acc'], label='Validation Accuracy')
        axes[0, 1].set_title('Training and Validation Accuracy')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Accuracy')
        axes[0, 1].legend()
        
        # F1分数曲线
        axes[1, 0].plot(self.history['train_f1'], label='Train F1 Score')
        axes[1, 0].plot(self.history['val_f1'], label='Validation F1 Score')
        axes[1, 0].set_title('Training and Validation F1 Score')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('F1 Score')
        axes[1, 0].legend()
        
        # 学习率曲线
        axes[1, 1].plot(self.history['learning_rate'], label='Learning Rate')
        axes[1, 1].set_title('Learning Rate Schedule')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Learning Rate')
        axes[1, 1].legend()
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.checkpoint_dir, 'training_history.png'))
        plt.close()
    
    def evaluate(self, test_loader):
        """在测试集上评估模型"""
        test_loss, test_cls_loss, test_cons_loss, test_acc, test_f1, test_auc, probs, preds, targets = self.validate_epoch(test_loader)
        
        # 计算混淆矩阵
        cm = confusion_matrix(targets, preds)
        
        # 绘制混淆矩阵
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.savefig(os.path.join(self.checkpoint_dir, 'confusion_matrix.png'))
        plt.close()
        
        # 保存评估结果
        results = {
            'test_loss': test_loss,
            'test_acc': test_acc,
            'test_f1': test_f1,
            'test_auc': test_auc,
            'confusion_matrix': cm,
            'predictions': preds,
            'targets': targets,
            'probabilities': probs
        }
        
        return results

6. 迁移学习实现

针对客户提供的数据集,我们需要实现迁移学习策略:

python 复制代码
def transfer_learning_setup(original_model, client_modalities, num_classes, feature_dim=2048):
    """
    迁移学习设置
    调整模型以适应客户数据集的模态和类别
    
    参数:
        original_model: 预训练的原始模型
        client_modalities: 客户数据集的模态列表
        num_classes: 客户数据集的类别数量
        feature_dim: 特征维度
    
    返回:
        adapted_model: 适应后的模型
    """
    # 获取原始模型的模态
    original_modalities = original_model.modalities
    
    # 创建新模型
    adapted_model = MuMoModel(
        modalities=client_modalities,
        num_classes=num_classes,
        feature_dim=feature_dim
    )
    
    # 复制共享权重的部分
    # 1. 复制重叠模态的编码器权重
    common_modalities = set(original_modalities) & set(client_modalities)
    for modality in common_modalities:
        # 复制编码器权重
        adapted_model.modality_encoders[modality].load_state_dict(
            original_model.modality_encoders[modality].state_dict()
        )
        # 复制投影器权重
        adapted_model.modality_projectors[modality].load_state_dict(
            original_model.modality_projectors[modality].state_dict()
        )
    
    # 2. 复制融合块的权重
    min_blocks = min(len(adapted_model.fusion_blocks), len(original_model.fusion_blocks))
    for i in range(min_blocks):
        adapted_model.fusion_blocks[i].load_state_dict(
            original_model.fusion_blocks[i].state_dict()
        )
    
    # 3. 冻结共享权重 (可选)
    for modality in common_modalities:
        for param in adapted_model.modality_encoders[modality].parameters():
            param.requires_grad = False
        for param in adapted_model.modality_projectors[modality].parameters():
            param.requires_grad = False
    
    for i in range(min_blocks):
        for param in adapted_model.fusion_blocks[i].parameters():
            param.requires_grad = False
    
    return adapted_model

class DomainAdaptationLoss(nn.Module):
    """
    领域自适应损失
    减少源域和目标域之间的分布差异
    """
    def __init__(self, alpha=1.0, beta=1.0):
        super(DomainAdaptationLoss, self).__init__()
        self.alpha = alpha  # 分类损失权重
        self.beta = beta    # 领域适应损失权重
        self.classification_loss = nn.CrossEntropyLoss()
        
    def mmd_loss(self, source_features, target_features):
        """
        最大均值差异 (Maximum Mean Discrepancy) 损失
        用于减少源域和目标域之间的分布差异
        """
        # 计算MMD
        source_mean = torch.mean(source_features, dim=0)
        target_mean = torch.mean(target_features, dim=0)
        
        mmd = torch.norm(source_mean - target_mean, p=2)
        return mmd
    
    def forward(self, outputs, targets, source_features, target_features):
        """
        参数:
            outputs: 模型输出logits
            targets: 真实标签
            source_features: 源域特征
            target_features: 目标域特征
        
        返回:
            total_loss: 总损失
            cls_loss: 分类损失
            domain_loss: 领域适应损失
        """
        # 分类损失
        cls_loss = self.classification_loss(outputs, targets)
        
        # 领域适应损失
        domain_loss = self.mmd_loss(source_features, target_features)
        
        # 总损失
        total_loss = self.alpha * cls_loss + self.beta * domain_loss
        
        return total_loss, cls_loss, domain_loss

def gradual_unfreezing(model, unfreeze_layers, epoch, total_epochs):
    """
    渐进式解冻策略
    随着训练进行,逐步解冻更多层
    
    参数:
        model: 模型
        unfreeze_layers: 要解冻的层名称模式列表
        epoch: 当前epoch
        total_epochs: 总epoch数
    """
    # 计算解冻比例
    unfreeze_ratio = epoch / total_epochs
    
    for name, param in model.named_parameters():
        # 检查是否在解冻列表中
        should_unfreeze = any(pattern in name for pattern in unfreeze_layers)
        
        if should_unfreeze:
            # 根据解冻比例决定是否解冻
            layer_index = int(name.split('.')[1]) if len(name.split('.')) > 1 else 0
            layer_unfreeze_threshold = layer_index / len(unfreeze_layers)
            
            if unfreeze_ratio >= layer_unfreeze_threshold:
                param.requires_grad = True
            else:
                param.requires_grad = False

def finetune_on_client_data(original_model, client_train_loader, client_val_loader, 
                           client_modalities, num_classes, device, num_epochs=50):
    """
    在客户数据上进行微调
    """
    # 设置迁移学习模型
    adapted_model = transfer_learning_setup(
        original_model, client_modalities, num_classes
    ).to(device)
    
    # 创建优化器 (只训练未冻结的参数)
    trainable_params = filter(lambda p: p.requires_grad, adapted_model.parameters())
    optimizer = optim.AdamW(trainable_params, lr=1e-4, weight_decay=1e-5)
    
    # 损失函数
    criterion = nn.CrossEntropyLoss()
    
    # 训练记录
    history = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': [],
        'train_f1': [], 'val_f1': []
    }
    
    best_val_acc = 0.0
    
    for epoch in range(num_epochs):
        print(f'Finetuning Epoch {epoch+1}/{num_epochs}')
        
        # 渐进式解冻
        if epoch > num_epochs // 2:  # 后半程开始解冻
            gradual_unfreezing(adapted_model, ['fusion_blocks', 'classifier'], epoch, num_epochs)
        
        # 训练
        adapted_model.train()
        train_loss = 0.0
        train_preds = []
        train_targets = []
        
        for data, targets in client_train_loader:
            data = {modality: tensor.to(device) for modality, tensor in data.items()}
            targets = targets.to(device)
            
            optimizer.zero_grad()
            outputs, _, _ = adapted_model(data)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            train_preds.extend(preds.cpu().numpy())
            train_targets.extend(targets.cpu().numpy())
        
        train_loss /= len(client_train_loader)
        train_acc = accuracy_score(train_targets, train_preds)
        train_f1 = f1_score(train_targets, train_preds, average='weighted')
        
        # 验证
        adapted_model.eval()
        val_loss = 0.0
        val_preds = []
        val_targets = []
        
        with torch.no_grad():
            for data, targets in client_val_loader:
                data = {modality: tensor.to(device) for modality, tensor in data.items()}
                targets = targets.to(device)
                
                outputs, _, _ = adapted_model(data)
                loss = criterion(outputs, targets)
                
                val_loss += loss.item()
                preds = torch.argmax(outputs, dim=1)
                val_preds.extend(preds.cpu().numpy())
                val_targets.extend(targets.cpu().numpy())
        
        val_loss /= len(client_val_loader)
        val_acc = accuracy_score(val_targets, val_preds)
        val_f1 = f1_score(val_targets, val_preds, average='weighted')
        
        # 记录历史
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['train_f1'].append(train_f1)
        history['val_f1'].append(val_f1)
        
        # 保存最佳模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(adapted_model.state_dict(), 'best_finetuned_model.pth')
        
        print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}')
        print(f'Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}')
        print(f'Train F1: {train_f1:.4f} | Val F1: {val_f1:.4f}')
        print()
    
    # 加载最佳模型
    adapted_model.load_state_dict(torch.load('best_finetuned_model.pth'))
    
    return adapted_model, history

7. 模型解释性与可视化

为了提高模型的可解释性,我们实现了一些可视化方法:

python 复制代码
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
import torch.nn.functional as F
from captum.attr import IntegratedGradients, Occlusion, LayerGradCam
from captum.attr import visualization as viz

class ModelInterpreter:
    """
    模型解释器
    提供模型预测的可视化和解释
    """
    def __init__(self, model, modalities, device):
        self.model = model
        self.modalities = modalities
        self.device = device
        self.model.eval()
    
    def compute_feature_importance(self, input_data, target_class=None):
        """
        计算各模态特征的重要性
        """
        # 确保输入数据在设备上
        input_data = {modality: tensor.to(self.device) for modality, tensor in input_data.items()}
        
        # 获取模型输出
        outputs, modality_features, fused_features = self.model(input_data)
        
        if target_class is None:
            target_class = torch.argmax(outputs, dim=1).item()
        
        # 计算每个模态对最终决策的贡献
        modality_contributions = {}
        for modality, features in modality_features.items():
            # 计算梯度
            features.requires_grad_(True)
            gradients = torch.autograd.grad(
                outputs[:, target_class].sum(), 
                features, 
                retain_graph=True
            )[0]
            
            # 计算特征重要性
            importance = torch.norm(gradients, p=2, dim=1)
            modality_contributions[modality] = importance.mean().item()
        
        return modality_contributions, target_class
    
    def visualize_modality_contributions(self, input_data, target_class=None):
        """
        可视化各模态的贡献度
        """
        contributions, pred_class = self.compute_feature_importance(input_data, target_class)
        
        # 创建条形图
        modalities = list(contributions.keys())
        values = list(contributions.values())
        
        plt.figure(figsize=(10, 6))
        bars = plt.bar(modalities, values)
        plt.title(f'Modality Contributions for Class {pred_class}')
        plt.ylabel('Contribution Score')
        plt.xlabel('Modality')
        
        # 添加数值标签
        for bar, value in zip(bars, values):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
                    f'{value:.4f}', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.savefig('modality_contributions.png')
        plt.close()
        
        return contributions, pred_class
    
    def generate_attention_maps(self, input_data, target_class=None):
        """
        生成注意力图,显示模型关注区域
        """
        # 确保输入数据在设备上
        input_data = {modality: tensor.to(self.device) for modality, tensor in input_data.items()}
        
        # 获取模型输出
        outputs, _, fused_features = self.model(input_data)
        
        if target_class is None:
            target_class = torch.argmax(outputs, dim=1).item()
        
        # 计算梯度
        gradients = torch.autograd.grad(
            outputs[:, target_class].sum(), 
            fused_features, 
            retain_graph=True
        )[0]
        
        # 计算注意力权重
        attention_weights = torch.mean(gradients, dim=2)  # (B, N)
        attention_weights = F.softmax(attention_weights, dim=1)
        
        return attention_weights.squeeze(0).cpu().detach().numpy(), target_class
    
    def visualize_attention_maps(self, input_data, target_class=None):
        """
        可视化注意力图
        """
        attention_weights, pred_class = self.generate_attention_maps(input_data, target_class)
        
        # 创建热力图
        plt.figure(figsize=(10, 6))
        sns.heatmap(attention_weights.reshape(1, -1), 
                   annot=True, fmt='.3f', cmap='viridis',
                   xticklabels=self.modalities, yticklabels=['Attention'])
        plt.title(f'Attention Weights for Class {pred_class}')
        plt.tight_layout()
        plt.savefig('attention_weights.png')
        plt.close()
        
        return attention_weights, pred_class
    
    def occlusion_sensitivity(self, input_data, target_class=None, patch_size=10):
        """
        使用遮挡法分析模型敏感性
        """
        # 确保输入数据在设备上
        input_data = {modality: tensor.to(self.device) for modality, tensor in input_data.items()}
        
        # 获取模型输出
        outputs, _, _ = self.model(input_data)
        
        if target_class is None:
            target_class = torch.argmax(outputs, dim=1).item()
        
        baseline_output = outputs[:, target_class].item()
        
        # 对每个模态进行遮挡分析
        sensitivity_results = {}
        
        for modality in self.modalities:
            modality_data = input_data[modality]
            _, _, depth, height, width = modality_data.shape
            
            # 创建遮挡器
            occlusion = Occlusion(self.model)
            
            # 定义遮挡策略
            strides = (1, patch_size, patch_size)
            sliding_window_shapes = (1, patch_size, patch_size)
            
            # 计算遮挡属性
            attributions = occlusion.attribute(
                modality_data,
                strides=strides,
                target=target_class,
                sliding_window_shapes=sliding_window_shapes,
                baselines=0
            )
            
            sensitivity_results[modality] = attributions.cpu().detach().numpy()
        
        return sensitivity_results, target_class, baseline_output

def visualize_occlusion_sensitivity(sensitivity_results, modality, slice_idx=0):
    """
    可视化遮挡敏感性分析结果
    """
    if modality not in sensitivity_results:
        print(f"Modality {modality} not found in results")
        return
    
    # 获取指定模态和切片的结果
    data = sensitivity_results[modality][0, 0, slice_idx]  # (H, W)
    
    # 创建热力图
    plt.figure(figsize=(10, 8))
    plt.imshow(data, cmap='hot', interpolation='nearest')
    plt.colorbar()
    plt.title(f'Occlusion Sensitivity for {modality} (Slice {slice_idx})')
    plt.tight_layout()
    plt.savefig(f'occlusion_sensitivity_{modality}_slice{slice_idx}.png')
    plt.close()

# 示例使用
def demonstrate_interpretability(model, test_loader, device, modalities):
    """
    演示模型解释性功能
    """
    # 获取一个测试样本
    data_iter = iter(test_loader)
    sample_data, sample_target = next(data_iter)
    
    # 创建解释器
    interpreter = ModelInterpreter(model, modalities, device)
    
    # 可视化模态贡献
    print("Visualizing modality contributions...")
    contributions, pred_class = interpreter.visualize_modality_contributions(sample_data)
    print(f"Predicted class: {pred_class}")
    print("Modality contributions:", contributions)
    
    # 可视化注意力权重
    print("Visualizing attention weights...")
    attention_weights, _ = interpreter.visualize_attention_maps(sample_data)
    print("Attention weights:", attention_weights)
    
    # 进行遮挡敏感性分析 (可选,计算量较大)
    print("Performing occlusion sensitivity analysis...")
    sensitivity_results, _, _ = interpreter.occlusion_sensitivity(sample_data)
    
    # 可视化第一个模态的敏感性
    if modalities[0] in sensitivity_results:
        visualize_occlusion_sensitivity(sensitivity_results, modalities[0])
    
    return contributions, attention_weights, sensitivity_results

8. 完整训练与评估流程

下面是完整的训练和评估流程:

python 复制代码
def main():
    """
    主函数:完整的训练和评估流程
    """
    # 设置设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    
    # 设置超参数
    config = {
        'data_dir': '/path/to/client/data',
        'metadata_file': '/path/to/metadata.csv',
        'modalities': ['CT', 'PET'],  # 客户数据的模态
        'num_classes': 2,  # 客户数据的类别数
        'batch_size': 4,
        'num_workers': 4,
        'feature_dim': 2048,
        'num_fusion_blocks': 3,
        'num_heads': 8,
        'dropout': 0.1,
        'learning_rate': 1e-4,
        'weight_decay': 1e-5,
        'num_epochs': 100,
        'checkpoint_dir': 'checkpoints'
    }
    
    # 创建数据加载器
    print("准备数据加载器...")
    train_loader, val_loader, test_loader = get_data_loaders(
        config['data_dir'], 
        config['metadata_file'], 
        config['modalities'],
        config['batch_size'],
        config['num_workers']
    )
    
    # 创建模型
    print("初始化模型...")
    model = MuMoModel(
        modalities=config['modalities'],
        num_classes=config['num_classes'],
        feature_dim=config['feature_dim'],
        num_fusion_blocks=config['num_fusion_blocks'],
        num_heads=config['num_heads'],
        dropout=config['dropout']
    ).to(device)
    
    # 如果有预训练权重,加载预训练权重
    pretrained_path = '/path/to/pretrained/model.pth'
    if os.path.exists(pretrained_path):
        print("加载预训练权重...")
        model.load_state_dict(torch.load(pretrained_path, map_location=device))
    
    # 创建训练器
    print("创建训练器...")
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        num_classes=config['num_classes'],
        learning_rate=config['learning_rate'],
        weight_decay=config['weight_decay'],
        checkpoint_dir=config['checkpoint_dir']
    )
    
    # 训练模型
    print("开始训练...")
    history = trainer.train(num_epochs=config['num_epochs'])
    
    # 绘制训练历史
    trainer.plot_training_history()
    
    # 在测试集上评估模型
    print("在测试集上评估模型...")
    results = trainer.evaluate(test_loader)
    
    print(f"测试准确率: {results['test_acc']:.4f}")
    print(f"测试F1分数: {results['test_f1']:.4f}")
    print(f"测试AUC: {results['test_auc']:.4f}")
    
    # 保存最终模型
    torch.save(model.state_dict(), os.path.join(config['checkpoint_dir'], 'final_model.pth'))
    
    # 模型解释性分析
    print("进行模型解释性分析...")
    contributions, attention_weights, sensitivity_results = demonstrate_interpretability(
        model, test_loader, device, config['modalities']
    )
    
    # 保存结果
    results_df = pd.DataFrame({
        'accuracy': [results['test_acc']],
        'f1_score': [results['test_f1']],
        'auc': [results['test_auc']]
    })
    results_df.to_csv(os.path.join(config['checkpoint_dir'], 'evaluation_results.csv'), index=False)
    
    # 保存模态贡献度
    contributions_df = pd.DataFrame(list(contributions.items()), columns=['modality', 'contribution'])
    contributions_df.to_csv(os.path.join(config['checkpoint_dir'], 'modality_contributions.csv'), index=False)
    
    print("训练和评估完成!")

if __name__ == "__main__":
    main()

9. 部署与推理接口

为了方便实际应用,我们提供了一个简单的推理接口:

python 复制代码
import torch
import numpy as np
from PIL import Image
import json

class TumorClassificationAPI:
    """
    肿瘤分类API
    提供简单的推理接口
    """
    def __init__(self, model_path, config_path, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        
        # 加载配置
        with open(config_path, 'r') as f:
            self.config = json.load(f)
        
        # 加载模型
        self.model = MuMoModel(
            modalities=self.config['modalities'],
            num_classes=self.config['num_classes'],
            feature_dim=self.config.get('feature_dim', 2048),
            num_fusion_blocks=self.config.get('num_fusion_blocks', 3),
            num_heads=self.config.get('num_heads', 8),
            dropout=self.config.get('dropout', 0.1)
        ).to(device)
        
        self.model.load_state_dict(torch.load(model_path, map_location=device))
        self.model.eval()
        
        # 加载类别标签映射
        self.class_names = self.config.get('class_names', [f'Class_{i}' for i in range(self.config['num_classes'])])
    
    def preprocess_image(self, image_path, modality):
        """
        预处理医学影像
        """
        # 根据模态使用不同的预处理方法
        if modality in ['CT', 'MRI']:
            # 加载NIfTI文件
            img = nib.load(image_path).get_fdata()
        else:
            # 加载普通图像
            img = Image.open(image_path)
            img = np.array(img)
        
        # 标准化和调整大小
        img = self._normalize_image(img, modality)
        img = self._resize_image(img, self.config['image_size'])
        
        # 添加批次和通道维度
        img = np.expand_dims(img, axis=0)  # 添加通道维度
        img = np.expand_dims(img, axis=0)  # 添加批次维度
        
        return torch.from_numpy(img).float()
    
    def _normalize_image(self, img, modality):
        """
        标准化图像
        """
        if modality == 'CT':
            # CT图像通常使用Hounsfield单位,需要特殊处理
            img = np.clip(img, -1000, 1000)  # 裁剪到典型CT值范围
            img = (img + 1000) / 2000  # 归一化到[0, 1]
        elif modality == 'PET':
            # PET图像标准化
            img = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-8)
        else:
            # 通用标准化
            img = (img - np.mean(img)) / (np.std(img) + 1e-8)
        
        return img
    
    def _resize_image(self, img, target_size):
        """
        调整图像大小
        """
        if len(img.shape) == 3:  # 3D体积
            # 使用插值调整大小
            from scipy.ndimage import zoom
            factors = [t/o for t, o in zip(target_size, img.shape)]
            img = zoom(img, factors, order=1)  # 线性插值
        else:  # 2D图像
            from skimage.transform import resize
            img = resize(img, target_size, preserve_range=True)
        
        return img
    
    def predict(self, image_paths):
        """
        预测肿瘤类别
        
        参数:
            image_paths: 字典,键为模态,值为图像路径
        
        返回:
            predictions: 预测结果
            probabilities: 各类别概率
        """
        # 预处理输入图像
        input_data = {}
        for modality, path in image_paths.items():
            if modality not in self.config['modalities']:
                raise ValueError(f"不支持的模态: {modality}")
            
            input_data[modality] = self.preprocess_image(path, modality).to(self.device)
        
        # 进行预测
        with torch.no_grad():
            outputs, _, _ = self.model(input_data)
            probabilities = torch.softmax(outputs, dim=1)
            predicted_class = torch.argmax(probabilities, dim=1).item()
        
        # 转换为numpy数组
        probabilities = probabilities.cpu().numpy()[0]
        
        # 创建结果字典
        result = {
            'predicted_class': self.class_names[predicted_class],
            'class_probabilities': {
                self.class_names[i]: float(probabilities[i]) 
                for i in range(len(self.class_names))
            },
            'confidence': float(probabilities[predicted_class])
        }
        
        return result

# 示例使用
def example_usage():
    # 初始化API
    api = TumorClassificationAPI(
        model_path='checkpoints/best_model.pth',
        config_path='model_config.json'
    )
    
    # 准备输入数据
    image_paths = {
        'CT': '/path/to/ct/image.nii.gz',
        'PET': '/path/to/pet/image.png'
    }
    
    # 进行预测
    result = api.predict(image_paths)
    
    print(f"预测类别: {result['predicted_class']}")
    print(f"置信度: {result['confidence']:.4f}")
    print("各类别概率:")
    for class_name, prob in result['class_probabilities'].items():
        print(f"  {class_name}: {prob:.4f}")

if __name__ == "__main__":
    example_usage()

10. 总结与展望

本项目成功复现了北京大学团队的MuMo多模态肿瘤分类模型,并在客户提供的数据集上实现了迁移学习。通过实现完整的数据预处理管道、模型架构、训练策略和评估方法,我们建立了一个强大的多模态肿瘤分类系统。

主要成果:

  1. 完整模型复现:实现了MuMo模型的核心组件,包括模态特定编码器、跨模态注意力机制和多模态融合块。

  2. 迁移学习实现:设计了针对客户数据的迁移学习策略,包括权重转移、渐进式解冻和领域自适应损失。

  3. 模型解释性:实现了多种可视化方法,帮助理解模型的决策过程和各模态的贡献度。

  4. 完整管道:从数据预处理到模型部署,提供了完整的端到端解决方案。

未来工作方向:

  1. 多模态数据融合策略优化:探索更先进的多模态融合方法,如基于Transformer的融合机制。

  2. 领域自适应改进:开发更强大的领域自适应技术,减少源域和目标域之间的分布差异。

  3. 模型轻量化:优化模型结构,减少计算资源需求,便于在临床环境中部署。

  4. 可解释性增强:开发更先进的可视化工具,提供更直观的模型决策解释。

  5. 多中心验证:在多个医疗中心的数据上验证模型性能,确保其泛化能力。

通过本项目的实施,我们不仅复现了先进的MuMo模型,还为其在实际医疗场景中的应用奠定了基础,为多模态医学影像分析领域的发展做出了贡献。

相关推荐
麻雀无能为力10 分钟前
python自学笔记14 NumPy 线性代数
笔记·python·numpy
金井PRATHAMA30 分钟前
大脑的藏宝图——神经科学如何为自然语言处理(NLP)的深度语义理解绘制新航线
人工智能·自然语言处理
Y|35 分钟前
GBDT(Gradient Boosting Decision Tree,梯度提升决策树)总结梳理
决策树·机器学习·集成学习·推荐算法·boosting
大学生毕业题目43 分钟前
毕业项目推荐:28-基于yolov8/yolov5/yolo11的电塔危险物品检测识别系统(Python+卷积神经网络)
人工智能·python·yolo·cnn·pyqt·电塔·危险物品
星期天要睡觉1 小时前
深度学习——卷积神经网络CNN(原理:基本结构流程、卷积层、池化层、全连接层等)
人工智能·深度学习·cnn
哈基鑫1 小时前
支持向量机(SVM)学习笔记
人工智能·机器学习·支持向量机
fsnine1 小时前
深度学习——优化函数
人工智能·深度学习·cnn
2501_924877213 小时前
强逆光干扰漏检率↓78%!陌讯多模态融合算法在光伏巡检的实战优化
大数据·人工智能·算法·计算机视觉·目标跟踪
程序猿小D3 小时前
【完整源码+数据集+部署教程】脑部CT图像分割系统源码和数据集:改进yolo11-CSwinTransformer
python·yolo·计算机视觉·数据集·yolo11·脑部ct图像分割
算家计算3 小时前
多模态融合新纪元:Ovis2.5 本地部署教程,实现文本、图像与代码的深度协同推理
人工智能·开源