目录
第一章:钼靶AI核心算法架构演进------从2D全视野到3D断层合成与视觉Transformer
[1.1 传统深度学习架构在钼靶筛查中的奠基与局限](#1.1 传统深度学习架构在钼靶筛查中的奠基与局限)
[1.1.1 卷积神经网络(CNN)在乳腺影像分析中的早期应用](#1.1.1 卷积神经网络(CNN)在乳腺影像分析中的早期应用)
[1.1.2 多视图几何一致性建模](#1.1.2 多视图几何一致性建模)
[1.2 数字乳腺断层合成(DBT)与三维深度学习](#1.2 数字乳腺断层合成(DBT)与三维深度学习)
[1.3 自监督学习与对比学习表征革命](#1.3 自监督学习与对比学习表征革命)
[1.4 大视觉语言模型(LVLM)与可解释性AI](#1.4 大视觉语言模型(LVLM)与可解释性AI)
[脚本3:3D CNN与多尺度融合用于数字乳腺断层合成(DBT)](#脚本3:3D CNN与多尺度融合用于数字乳腺断层合成(DBT))
[脚本4:TimeSformer与3D Vision Transformer](#脚本4:TimeSformer与3D Vision Transformer)
第一章:钼靶AI核心算法架构演进------从2D全视野到3D断层合成与视觉Transformer
1.1 传统深度学习架构在钼靶筛查中的奠基与局限
1.1.1 卷积神经网络(CNN)在乳腺影像分析中的早期应用
卷积神经网络在数字乳腺X线摄影分析中的早期应用奠定了现代钼靶AI系统的技术基础。标准2D全视野数字乳腺X线摄影的病变检测通常遵循编码器-解码器范式,其中编码器通过层级卷积操作提取从边缘到纹理的渐进式特征表示,解码器则通过反卷积或上采样操作恢复空间分辨率以定位可疑病灶。在这一架构中,ImageNet预训练的迁移学习策略被广泛采用以缓解医学影像标注稀缺的困境,然而自然图像与医学影像在强度分布、纹理特征及语义粒度上的显著差异引发了严峻的域适应挑战。数据集偏移问题表现为预训练权重在迁移至钼靶影像时产生的特征分布失配,特别是在处理乳腺腺体组织的稀疏纹理与微小钙化簇的高频细节时,自然图像预训练模型往往难以捕捉具有临床判别性的微观模式。针对计算资源受限的筛查场景,轻量化卷积架构通过深度可分离卷积、通道剪枝及知识蒸馏等技术手段,在保持病灶检测灵敏度的同时显著降低了推理延迟与显存占用,使其能够集成至临床工作流的实时辅助诊断环节。
1.1.2 多视图几何一致性建模
头尾位与内外侧斜位双视图的特征融合架构旨在利用投影几何互补性提升诊断可靠性。双视图对比学习策略通过构建跨视图语义一致性约束,迫使模型学习对视角变化具有不变性的潜在表征,其中同一病灶在不同投照角度下的表观差异被显式建模为实例判别任务中的困难正样本对。视图不一致性检测机制则通过度量学习框架量化双视图特征空间的分布 divergence,当双视图预测置信度出现显著偏离时触发假阳性抑制机制,有效降低因组织重叠或投照角度伪影导致的误诊风险。
1.2 数字乳腺断层合成(DBT)与三维深度学习
数字乳腺断层合成技术通过多角度低剂量投影重建高分辨率体数据,从根本上解决了传统2D乳腺摄影中组织重叠掩盖病灶的临床痛点。物理重建算法从15至25个角度采集的投影序列中反演计算体素衰减系数,生成厚度方向具有亚毫米级分辨率的断层切片序列。真三维卷积架构通过扩展时空感受野直接建模相邻切片间的解剖连续性,相较于2D切片级联架构在病灶分类任务中展现出显著的性能增益,后者因忽视Z轴上下文信息而在处理跨切片延伸的块状肿物时表现出特征断裂现象。
基于熵排序的切片选择策略通过计算各切片的信息熵指标自适应筛选高信息密度切片参与训练,在保留病灶空间连续性的前提下将计算负荷降低至全量数据的40%至60%。多尺度三维特征融合网络通过并行分支架构协同处理局部细节与全局上下文,局部特征提取分支专注于微钙化簇的细微形态学特征,全局特征提取分支捕获大尺度肿物的宏观架构特征,Squeeze-and-Excitation注意力模块在通道维度实施自适应重标定以抑制冗余背景响应。跨层特征融合机制通过跳跃连接融合浅层边缘特征与深层语义特征,有效处理微钙化与块状肿物在尺度空间上的分布差异。
Transformer架构在DBT体数据处理中展现出独特的时空建模优势。TimeSformer将断层切片序列视为时空视频流,通过分离的空间-时间注意力机制分别建模切片内的二维解剖结构与切片间的三维连续性。三维Vision Transformer采用立方体Patch嵌入策略将体数据划分为非重叠的三维Token,并通过可学习的位置编码适应断层合成的特定几何配置。针对厚层DBT扫描带来的计算复杂度激增,稀疏注意力与滑动窗口策略通过限制注意力计算范围至局部邻域而非全局体素,在保持长程依赖建模能力的同时将二次复杂度降至线性或近似线性级别。
1.3 自监督学习与对比学习表征革命
掩蔽自编码器预训练策略通过高比例掩蔽输入图像的随机斑块并重建原始像素值,迫使模型学习具有临床意义的视觉表征。针对乳腺影像的稀疏纹理特性,自适应掩蔽策略倾向于在腺体密集区域实施更高密度的掩蔽以强化对诊断关键区域的特征学习。跨视图掩蔽对比互学习框架利用CC与MLO视图的互补信息结构,通过交换双视图的非掩蔽特征进行跨模态预测,构建比单视图自监督更强的表征约束。自监督预训练编码器在下游密集预测任务中表现出卓越的迁移优势,其预训练权重在检测与分割任务中仅需少量标注数据即可达到全监督训练的同等性能。
监督对比预训练范式通过两阶段训练流程规避强数据增强带来的影像外观失真问题。正负样本对构建策略将同一病灶的CC与MLO视图视为正样本对,不同患者的任意视图构建负样本池,通过拉近正样本对在嵌入空间的距离并推远负样本对实现判别性特征学习。硬负样本挖掘机制专注于在嵌入空间中与锚点样本距离接近但病理标签相异的困难样本,在良性与恶性病变的细粒度鉴别中起到关键作用,迫使模型学习超越表层统计差异的深层病理特征。
多视图多尺度对齐的视觉-语言预训练框架针对钼靶影像的多视图特性设计专用的CLIP风格对比学习目标,通过融合图像-图像对齐与图像-文本对齐构建多模态表征空间。放射学报告的结构化文本挖掘通过实体识别与关系抽取提取病灶属性、位置及BI-RADS分级等关键临床概念,与视觉特征进行细粒度对齐以避免通用语言模型的语义噪声干扰。在预训练数据规模受限的临床场景下,高效的对比学习策略通过课程学习与记忆库机制实现数千级样本规模下的稳定表征学习,克服自然图像亿级规模预训练的依赖。
1.4 大视觉语言模型(LVLM)与可解释性AI
乳腺影像视觉问答基准标志着从封闭集分类向开放式医学推理的范式转移,要求模型不仅输出诊断标签还需提供符合放射学逻辑的自然语言解释。大语言模型在钼靶影像理解中展现出零样本与少样本推理能力,通过视觉指令微调可适配至特定筛查任务而无需大规模任务特定训练。专用小型化模型通过紧凑的多模态融合架构在保持实时推理效率的同时达到接近通用大模型的诊断准确率,在计算资源受限的临床部署场景中展现出显著优势。
可解释性人工智能在乳腺筛查中的临床验证要求可视化解释与放射科医师的解剖学认知保持一致。梯度加权类激活映射在病灶定位中的解剖学一致性验证通过对比模型关注区域与放射科医师标注的病灶边界评估解释可靠性。SHAP值分析在集成模型中的特征归因通过量化各视图及特征通道对最终预测的贡献度,为风险分层提供透明的量化依据。基于案例的推理框架通过检索与当前病例影像特征及临床病史相似的既往确诊案例,为模型预测提供类比支撑,显著增强放射科医师对AI辅助诊断系统的信任度与采纳意愿。
脚本1:2D全视野乳腺摄影(FFDM)基础架构与域适应
技术覆盖:1.1.1节(CNN早期应用、ImageNet迁移学习、轻量化LBNet)、1.1.1.2节(域适应)
Python
"""
Script 1: FFDM Standard & Lightweight CNN Architectures with Domain Adaptation
Content: ResNet-based backbone for mammography, lightweight LBNet, gradient reversal for domain adaptation
Usage: Single-view breast X-ray classification with transferable features
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from typing import Tuple, Optional
class MammographyResNet(nn.Module):
"""
ResNet backbone adapted for mammography (1-channel input)
Supports ImageNet pretraining with channel-wise weight averaging for domain adaptation
"""
def __init__(self, arch='resnet50', pretrained=False, frozen_stages=3, dropout=0.3):
super().__init__()
self.frozen_stages = frozen_stages
# Load ImageNet pretrained backbone
if arch == 'resnet50':
backbone = models.resnet50(pretrained=False)
feat_dim = 2048
elif arch == 'resnet34':
backbone = models.resnet34(pretrained=False)
feat_dim = 512
else:
raise ValueError(f"Unsupported architecture: {arch}")
# Remove final FC and pooling, keep convolutional feature extractor
self.encoder = nn.Sequential(*list(backbone.children())[:-2])
# Domain adaptation: adapt first conv layer for single-channel mammography images
# Strategy: average the 3-channel ImageNet weights into 1-channel
self.input_adapter = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
if pretrained:
with torch.no_grad():
rgb_weights = backbone.conv1.weight.data
# Average across RGB channels to create grayscale weights
gray_weights = rgb_weights.mean(dim=1, keepdim=True)
self.input_adapter.weight.data = gray_weights
self.encoder[0] = self.input_adapter
# Freeze early stages for transfer learning (preserve low-level edge/texture features)
self._freeze_stages(frozen_stages)
# Channel attention for emphasizing diagnostically relevant features
self.channel_gate = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(feat_dim, feat_dim // 16, 1),
nn.ReLU(inplace=True),
nn.Conv2d(feat_dim // 16, feat_dim, 1),
nn.Sigmoid()
)
# Classification head with high dropout for medical imaging uncertainty
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Dropout(dropout),
nn.Linear(feat_dim, 256),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(256, 2) # Benign/Malignant
)
def _freeze_stages(self, stages):
"""Freeze early residual stages to prevent overfitting on small medical datasets"""
if stages > 0:
for param in self.encoder[0].parameters():
param.requires_grad = False
for i in range(1, min(stages + 1, len(self.encoder))):
for param in self.encoder[i].parameters():
param.requires_grad = False
def forward(self, x):
"""
Args:
x: [B, 1, H, W] grayscale mammogram
Returns:
logits: [B, 2] classification scores
features: [B, C, H', W'] spatial feature maps for Grad-CAM
"""
features = self.encoder(x)
# Apply channel-wise attention
attention = self.channel_gate(features)
attended = features * attention
logits = self.classifier(attended)
return logits, attended
class InvertedResidual(nn.Module):
"""MobileNet-style inverted residual block with depthwise separable convolutions"""
def __init__(self, inp, oup, stride, expand_ratio):
super().__init__()
self.stride = stride
self.use_res = stride == 1 and inp == oup
hidden = int(round(inp * expand_ratio))
layers = []
if expand_ratio != 1:
layers.extend([
nn.Conv2d(inp, hidden, 1, bias=False),
nn.BatchNorm2d(hidden),
nn.ReLU6(inplace=True)
])
layers.extend([
# Depthwise separable: 3x3 groups=hidden
nn.Conv2d(hidden, hidden, 3, stride, 1, groups=hidden, bias=False),
nn.BatchNorm2d(hidden),
nn.ReLU6(inplace=True),
nn.Conv2d(hidden, oup, 1, bias=False),
nn.BatchNorm2d(oup),
])
self.conv = nn.Sequential(*layers)
def forward(self, x):
if self.use_res:
return x + self.conv(x)
return self.conv(x)
class LightweightBreastNet(nn.Module):
"""
LBNet (Lightweight Breast Network) for resource-constrained screening scenarios
Uses depthwise separable convolutions and inverted residuals for efficiency
"""
def __init__(self, num_classes=2, width_mult=1.0):
super().__init__()
def make_divisible(v, divisor=8):
return max(divisor, int(v + divisor/2) // divisor * divisor)
ch = lambda x: make_divisible(x * width_mult)
self.stem = nn.Sequential(
nn.Conv2d(1, ch(32), 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ch(32)),
nn.ReLU6(inplace=True)
)
# Progressive expansion: slow-fast architecture for multi-scale lesions
self.layer1 = self._make_layer(ch(32), ch(16), 1, stride=1, expansion=1)
self.layer2 = self._make_layer(ch(16), ch(24), 2, stride=2, expansion=4)
self.layer3 = self._make_layer(ch(24), ch(32), 3, stride=2, expansion=4)
self.layer4 = self._make_layer(ch(32), ch(64), 4, stride=2, expansion=4)
self.head = nn.Sequential(
nn.Conv2d(ch(64), ch(128), 1, bias=False),
nn.BatchNorm2d(ch(128)),
nn.ReLU6(inplace=True)
)
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(ch(128), num_classes)
)
self._init_weights()
def _make_layer(self, inp, oup, n, stride, expansion):
layers = [InvertedResidual(inp, oup, stride, expansion)]
for _ in range(1, n):
layers.append(InvertedResidual(oup, oup, 1, expansion))
return nn.Sequential(*layers)
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.head(x)
x = F.adaptive_avg_pool2d(x, 1).flatten(1)
return self.classifier(x)
class GradientReversalFunction(torch.autograd.Function):
"""Gradient Reversal Layer for domain adaptation training"""
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
return grad_output.neg() * ctx.alpha, None
class DomainAdaptationHead(nn.Module):
"""
Domain discriminator with gradient reversal for unsupervised domain adaptation
Helps mitigate dataset shift between ImageNet and mammography domains
"""
def __init__(self, feat_dim=512, num_domains=2):
super().__init__()
self.grl = GradientReversalFunction()
self.discriminator = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(feat_dim, 256),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(256, num_domains) # Source vs Target domain
)
def forward(self, x, alpha=1.0):
x = GradientReversalFunction.apply(x, alpha)
return self.discriminator(x)
# Execution test
if __name__ == "__main__":
# Test standard ResNet
model = MammographyResNet(arch='resnet34', pretrained=False)
x = torch.randn(2, 1, 1024, 768) # Typical mammogram resolution
logits, features = model(x)
print(f"[ResNet34] Input: {x.shape}, Output: {logits.shape}, Features: {features.shape}")
# Test lightweight network
lbnet = LightweightBreastNet(width_mult=1.0)
out = lbnet(x)
print(f"[LBNet] Input: {x.shape}, Output: {out.shape}")
# Parameter count comparison
resnet_params = sum(p.numel() for p in model.parameters())
lbnet_params = sum(p.numel() for p in lbnet.parameters())
print(f"Parameters - ResNet34: {resnet_params/1e6:.2f}M, LBNet: {lbnet_params/1e6:.2f}M")
脚本2:多视图几何一致性建模(CC/MLO融合)
技术覆盖:1.1.2节(双视图融合、对比学习、视图不一致性检测)
Python
"""
Script 2: Dual-View Geometric Consistency Modeling
Content: CC/MLO view fusion with gating, cross-view contrastive learning, discrepancy detection
Usage: Paired mammogram analysis (Craniocaudal & Mediolateral Oblique views)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple
class DualViewFusionNet(nn.Module):
"""
MammoFusion-Net architecture with gated cross-view attention
Independently processes CC and MLO views with learnable adaptive weighting
"""
def __init__(self, backbone='resnet34', feat_dim=512, num_classes=2, dropout=0.4):
super().__init__()
# Separate encoders for each view (no weight sharing to capture view-specific patterns)
self.cc_encoder = self._build_encoder(backbone, feat_dim)
self.mlo_encoder = self._build_encoder(backbone, feat_dim)
# Gated fusion: learn adaptive importance weights for each view
self.cc_gate = nn.Sequential(
nn.Conv2d(feat_dim, feat_dim, 1, bias=False),
nn.BatchNorm2d(feat_dim),
nn.Sigmoid()
)
self.mlo_gate = nn.Sequential(
nn.Conv2d(feat_dim, feat_dim, 1, bias=False),
nn.BatchNorm2d(feat_dim),
nn.Sigmoid()
)
# Cross-view spatial attention for lesion localization
self.spatial_attn = nn.Sequential(
nn.Conv2d(feat_dim * 2, 128, 7, padding=3),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 1, 1),
nn.Sigmoid()
)
# Consistency projection head for contrastive learning
self.consistency_proj = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(feat_dim, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 64)
)
# Final classifier
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Dropout(dropout),
nn.Linear(feat_dim, 256),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(256, num_classes)
)
def _build_encoder(self, name, feat_dim):
if name == 'resnet34':
from torchvision.models import resnet34
m = resnet34(pretrained=False)
m.conv1 = nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
return nn.Sequential(*list(m.children())[:-2])
elif name == 'resnet18':
from torchvision.models import resnet18
m = resnet18(pretrained=False)
m.conv1 = nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
return nn.Sequential(*list(m.children())[:-2])
def forward(self, cc_view, mlo_view, return_consistency=False):
"""
Args:
cc_view: [B, 1, H, W] Craniocaudal view
mlo_view: [B, 1, H, W] Mediolateral Oblique view
Returns:
Dictionary with classification logits, fused features, and optional consistency embeddings
"""
# Extract view-specific features
cc_feat = self.cc_encoder(cc_view) # [B, C, H', W']
mlo_feat = self.mlo_encoder(mlo_view)
# Gated weighting: dynamic view importance based on content
cc_weight = self.cc_gate(cc_feat)
mlo_weight = self.mlo_gate(mlo_feat)
cc_gated = cc_feat * cc_weight
mlo_gated = mlo_feat * mlo_weight
# Adaptive fusion (weighted sum rather than concatenation)
fused = cc_gated + mlo_gated
# Spatial attention refinement
concat_feat = torch.cat([cc_feat, mlo_feat], dim=1)
spatial_mask = self.spatial_attn(concat_feat)
refined = fused * spatial_mask
# Classification
logits = self.classifier(refined)
outputs = {
'logits': logits,
'fused_features': refined,
'cc_features': cc_feat,
'mlo_features': mlo_feat,
'spatial_attention': spatial_mask
}
if return_consistency:
# L2-normalized embeddings for contrastive learning
cc_embed = F.normalize(self.consistency_proj(cc_feat), dim=1)
mlo_embed = F.normalize(self.consistency_proj(mlo_feat), dim=1)
outputs['cc_embedding'] = cc_embed
outputs['mlo_embedding'] = mlo_embed
return outputs
class CrossViewContrastiveLoss(nn.Module):
"""
Supervised contrastive learning across CC and MLO views
Treats same-patient CC-MLO pair as positive, different-patient as negatives
"""
def __init__(self, temperature=0.07):
super().__init__()
self.temperature = temperature
def forward(self, cc_embed, mlo_embed, labels=None):
"""
Args:
cc_embed: [B, D] CC view embeddings (normalized)
mlo_embed: [B, D] MLO view embeddings (normalized)
labels: [B] pathology labels for hard negative mining (optional)
"""
batch_size = cc_embed.size(0)
# Similarity matrix between CC and MLO embeddings
# [B, D] @ [D, B] -> [B, B]
sim_matrix = torch.matmul(cc_embed, mlo_embed.T) / self.temperature
# Positive pairs are diagonal elements (same patient)
positives = torch.diag(sim_matrix)
# InfoNCE loss: push positives together, negatives apart
# For each CC, the corresponding MLO is positive; all others are negative
loss_cc_to_mlo = -torch.log(
torch.exp(positives) / torch.exp(sim_matrix).sum(dim=1)
).mean()
# Symmetric loss (MLO to CC)
sim_matrix_t = sim_matrix.T
positives_t = torch.diag(sim_matrix_t)
loss_mlo_to_cc = -torch.log(
torch.exp(positives_t) / torch.exp(sim_matrix_t).sum(dim=1)
).mean()
return (loss_cc_to_mlo + loss_mlo_to_cc) / 2
class ViewDiscrepancyDetector(nn.Module):
"""
Detects inconsistency between CC and MLO predictions for false positive reduction
High discrepancy indicates potential false positive (tissue overlap artifacts)
"""
def __init__(self, feat_dim=512):
super().__init__()
self.discrepancy_head = nn.Sequential(
nn.Linear(feat_dim * 2, 256),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(256, 64),
nn.ReLU(inplace=True),
nn.Linear(64, 1),
nn.Sigmoid()
)
def forward(self, cc_feat, mlo_feat, cc_logits, mlo_logits):
"""
Returns:
discrepancy_score: [B, 1] high value indicates view conflict
confidence_weight: [B, 1] low weight for inconsistent predictions
"""
# Global pooling
cc_pool = F.adaptive_avg_pool2d(cc_feat, 1).flatten(1)
mlo_pool = F.adaptive_avg_pool2d(mlo_feat, 1).flatten(1)
# Predict discrepancy from concatenated features
combined = torch.cat([cc_pool, mlo_pool], dim=1)
disc_score = self.discrepancy_head(combined)
# JS divergence between view probability distributions
cc_prob = F.softmax(cc_logits, dim=1)
mlo_prob = F.softmax(mlo_logits, dim=1)
mean_prob = (cc_prob + mlo_prob) / 2
# Approximate JS divergence for regularization
js_div = 0.5 * (
F.kl_div(mean_prob.log(), cc_prob, reduction='none').sum(dim=1, keepdim=True) +
F.kl_div(mean_prob.log(), mlo_prob, reduction='none').sum(dim=1, keepdim=True)
)
# Confidence inverse to discrepancy (high discrepancy -> low confidence)
conf_weight = 1.0 - disc_score
return disc_score, conf_weight
if __name__ == "__main__":
# Test dual-view fusion
model = DualViewFusionNet(backbone='resnet18', feat_dim=512)
cc = torch.randn(4, 1, 512, 384)
mlo = torch.randn(4, 1, 512, 384)
outputs = model(cc, mlo, return_consistency=True)
print(f"[DualView] CC/MLO: {cc.shape}/{mlo.shape}")
print(f"Logits: {outputs['logits'].shape}")
print(f"Fused features: {outputs['fused_features'].shape}")
print(f"CC embedding: {outputs['cc_embedding'].shape}")
# Test contrastive loss
criterion = CrossViewContrastiveLoss(temperature=0.07)
loss = criterion(outputs['cc_embedding'], outputs['mlo_embedding'])
print(f"Contrastive loss: {loss.item():.4f}")
# Test discrepancy detection
detector = ViewDiscrepancyDetector(feat_dim=512)
cc_single = torch.randn(4, 2) # Single view logits
mlo_single = torch.randn(4, 2)
disc, conf = detector(outputs['cc_features'], outputs['mlo_features'], cc_single, mlo_single)
print(f"Discrepancy score: {disc.shape}, Confidence: {conf.shape}")
脚本3:3D CNN与多尺度融合用于数字乳腺断层合成(DBT)
技术覆盖:1.2节(3D卷积、3DMSFF多尺度融合、SE注意力、熵排序切片选择)
Python
"""
Script 3: 3D CNN for Digital Breast Tomosynthesis (DBT)
Content: True 3D convolutions, multi-scale 3D feature fusion (3DMSFF),
entropy-based slice selection, 3D Squeeze-and-Excitation
Usage: Volumetric DBT analysis with reduced computational load
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, List
class EntropyBasedSliceSelection(nn.Module):
"""
Computational optimization: select high-information slices based on entropy
Reduces FLOPs by 40-60% while preserving lesion spatial continuity
"""
def __init__(self, top_k_ratio=0.6, temperature=0.5):
super().__init__()
self.top_k_ratio = top_k_ratio
self.temperature = temperature
def compute_slice_entropy(self, x):
"""
x: [B, C, D, H, W]
Returns entropy per slice: [B, D]
"""
B, C, D, H, W = x.shape
# Flatten each slice
x_flat = x.permute(0, 2, 1, 3, 4).reshape(B, D, -1)
# Convert to probability distribution
x_norm = F.softmax(x_flat / self.temperature, dim=-1)
# Shannon entropy: -sum(p * log(p))
entropy = -(x_norm * torch.log(x_norm + 1e-8)).sum(dim=-1)
return entropy
def forward(self, x):
"""
Returns:
selected: [B, C, D', H, W] selected slices
soft_mask: [B, D] differentiable soft selection mask
indices: List[int] hard selection indices
"""
B, C, D, H, W = x.shape
entropy = self.compute_slice_entropy(x)
k = max(1, int(D * self.top_k_ratio))
_, top_indices = torch.topk(entropy, k, dim=1)
# Soft mask for gradient flow during training
soft_mask = F.softmax(entropy / self.temperature, dim=1)
# Hard selection for forward pass
selected_list = []
for b in range(B):
selected = x[b, :, top_indices[b], :, :]
selected_list.append(selected)
selected = torch.stack(selected_list, dim=0)
return selected, soft_mask, top_indices[0].tolist()
class SEBlock3D(nn.Module):
"""3D Squeeze-and-Excitation for channel-wise feature recalibration in volumetric data"""
def __init__(self, channels, reduction=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool3d(1)
self.fc = nn.Sequential(
nn.Linear(channels, channels // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channels // reduction, channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1, 1)
return x * y
class ResidualBlock3D(nn.Module):
"""True 3D residual block maintaining spatial-temporal continuity across slices"""
def __init__(self, in_ch, out_ch, stride=1, use_se=True):
super().__init__()
self.conv1 = nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm3d(out_ch)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv3d(out_ch, out_ch, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm3d(out_ch)
self.se = SEBlock3D(out_ch) if use_se else None
self.downsample = None
if stride != 1 or in_ch != out_ch:
self.downsample = nn.Sequential(
nn.Conv3d(in_ch, out_ch, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm3d(out_ch)
)
def forward(self, x):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
if self.se is not None:
out = self.se(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return self.relu(out)
class MultiScale3DEncoder(nn.Module):
"""
3DMSFF (Multi-Scale 3D Feature Fusion)
Parallel branches for local (microcalcifications) and global (masses) features
"""
def __init__(self, in_ch=1, base_ch=32):
super().__init__()
# Local Feature Extraction (LFE) branch: small kernels for microcalcifications
self.lfe_stem = nn.Conv3d(in_ch, base_ch, kernel_size=3, padding=1)
self.lfe_block1 = ResidualBlock3D(base_ch, base_ch * 2, use_se=True)
self.lfe_block2 = ResidualBlock3D(base_ch * 2, base_ch * 4, stride=2, use_se=True)
# Global Feature Extraction (GFE) branch: larger receptive field for masses
self.gfe_stem = nn.Conv3d(in_ch, base_ch, kernel_size=5, padding=2)
self.gfe_pool = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
self.gfe_block1 = ResidualBlock3D(base_ch, base_ch * 2, use_se=True)
self.gfe_block2 = ResidualBlock3D(base_ch * 2, base_ch * 4, stride=2, use_se=True)
# Cross-layer fusion handling scale differences between microcalcifications and masses
self.fusion_conv = nn.Sequential(
nn.Conv3d(base_ch * 8, base_ch * 4, kernel_size=1),
nn.BatchNorm3d(base_ch * 4),
nn.ReLU(inplace=True)
)
self.fusion_se = SEBlock3D(base_ch * 4)
def forward(self, x):
"""
Returns:
fused: Multi-scale fused features
lfe_out: Local features (high resolution for microcalcifications)
gfe_out: Global features (context for masses)
"""
# Local branch
lfe = F.relu(self.lfe_stem(x))
lfe = self.lfe_block1(lfe)
lfe_out = self.lfe_block2(lfe)
# Global branch
gfe = F.relu(self.gfe_stem(x))
gfe = self.gfe_pool(gfe)
gfe = self.gfe_block1(gfe)
gfe_out = self.gfe_block2(gfe)
# Cross-layer fusion
fused = torch.cat([lfe_out, gfe_out], dim=1)
fused = self.fusion_conv(fused)
fused = self.fusion_se(fused)
return fused, lfe_out, gfe_out
class DBT3DNetwork(nn.Module):
"""
Complete DBT processing pipeline with entropy-based slice selection
and multi-scale 3D feature extraction
"""
def __init__(self, num_slices=25, in_ch=1, num_classes=2,
slice_selection=True, selection_ratio=0.6):
super().__init__()
self.use_selection = slice_selection
self.num_slices = num_slices
if slice_selection:
self.slice_selector = EntropyBasedSliceSelection(selection_ratio)
processed_slices = max(1, int(num_slices * selection_ratio))
else:
processed_slices = num_slices
self.encoder = MultiScale3DEncoder(in_ch=in_ch, base_ch=32)
# Temporal/depth aggregation
self.temporal_agg = nn.Sequential(
nn.AdaptiveAvgPool3d((1, 8, 8)),
nn.Conv3d(128, 256, kernel_size=1),
nn.BatchNorm3d(256),
nn.ReLU(inplace=True)
)
# Classification head
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool3d(1),
nn.Flatten(),
nn.Dropout(0.4),
nn.Linear(256, 128),
nn.ReLU(inplace=True),
nn.Linear(128, num_classes)
)
# Auxiliary detection head for lesion heatmap generation
self.detection_head = nn.Sequential(
nn.Conv3d(128, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv3d(64, 1, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
"""
x: [B, 1, D, H, W] - DBT volume
Returns: dict with logits, heatmap, features
"""
selection_info = {}
# Step 1: Entropy-based slice selection
if self.use_selection:
x_selected, soft_mask, indices = self.slice_selector(x)
selection_info['entropy_mask'] = soft_mask
selection_info['selected_indices'] = indices
else:
x_selected = x
# Step 2: Multi-scale 3D feature extraction
fused, lfe, gfe = self.encoder(x_selected)
# Step 3: Temporal aggregation
agg = self.temporal_agg(fused)
# Step 4: Classification
logits = self.classifier(agg)
# Step 5: Auxiliary dense prediction
heatmap = self.detection_head(fused)
return {
'logits': logits,
'heatmap': heatmap,
'fused_features': fused,
'local_features': lfe, # For microcalcification analysis
'global_features': gfe, # For mass analysis
**selection_info
}
if __name__ == "__main__":
# Simulate DBT volume: batch=2, 1 channel, 25 slices, 512x384 resolution
dbt_vol = torch.randn(2, 1, 25, 128, 128) # Reduced size for testing
model = DBT3DNetwork(num_slices=25, slice_selection=True, selection_ratio=0.6)
outputs = model(dbt_vol)
print(f"[DBT 3D] Input: {dbt_vol.shape}")
print(f"Logits: {outputs['logits'].shape}")
print(f"Heatmap: {outputs['heatmap'].shape}")
print(f"Selected indices: {outputs.get('selected_indices', 'N/A')}")
print(f"Parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
脚本4:TimeSformer与3D Vision Transformer
技术覆盖:1.2.3节(分离时空注意力、3D Patch嵌入、稀疏注意力策略)
Python
"""
Script 4: TimeSformer Architecture for DBT Spatiotemporal Modeling
Content: Divided space-time attention, 3D patch embedding, sparse window attention
Usage: Treat DBT slices as spatiotemporal video for lesion detection
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class PatchEmbed3D(nn.Module):
"""
3D Patch embedding for DBT: treats volume as spatiotemporal sequence
Different from 2D: handles depth dimension with specific temporal patch size
"""
def __init__(self, img_size=224, patch_size=16, in_chans=1, embed_dim=768, depth=16):
super().__init__()
self.patch_size = patch_size
self.grid_size = img_size // patch_size
self.num_patches_spatial = self.grid_size ** 2
self.embed_dim = embed_dim
# 3D convolution: (temporal, height, width)
# For DBT, temporal (depth) dimension is smaller, use patch_size_t=2
self.proj = nn.Conv3d(
in_chans, embed_dim,
kernel_size=(2, patch_size, patch_size),
stride=(2, patch_size, patch_size)
)
# Learnable positional embeddings adapted to DBT geometry
num_temporal_tokens = depth // 2
self.pos_embed_spatial = nn.Parameter(
torch.zeros(1, self.num_patches_spatial, embed_dim)
)
self.pos_embed_temporal = nn.Parameter(
torch.zeros(1, num_temporal_tokens, embed_dim)
)
def forward(self, x):
"""
x: [B, C, D, H, W]
Returns: [B, N, D] where N = temporal_tokens * spatial_patches
"""
B, C, D, H, W = x.shape
# 3D convolution to create patch tokens
x = self.proj(x) # [B, embed_dim, D', H', W']
Dp, Hp, Wp = x.shape[2:]
# Rearrange: [B, C, D', H', W'] -> [B, D', H'*W', C]
x = x.permute(0, 2, 3, 4, 1).reshape(B, Dp, -1, self.embed_dim)
# Add positional embeddings (broadcasted)
spatial_pos = self.pos_embed_spatial.unsqueeze(1) # [1, 1, S, D]
temporal_pos = self.pos_embed_temporal.unsqueeze(2) # [1, T, 1, D]
x = x + spatial_pos + temporal_pos
# Flatten to sequence: [B, T*S, D]
x = x.reshape(B, -1, self.embed_dim)
return x
class DividedSpaceTimeAttention(nn.Module):
"""
TimeSformer core: divided attention mechanism
First temporal (depth), then spatial (within slice)
Complexity: O(T*N) instead of O((T*N)^2)
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
# Separate QKV projections for temporal and spatial attention
self.temporal_qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.temporal_proj = nn.Linear(dim, dim)
self.temporal_drop = nn.Dropout(attn_drop)
self.spatial_qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.spatial_proj = nn.Linear(dim, dim)
self.spatial_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, T, S):
"""
x: [B, N, D] where N = T * S
T: number of temporal tokens (depth slices)
S: number of spatial patches per slice
"""
B, N, C = x.shape
assert N == T * S
# Reshape to [B, T, S, C] for separated attention
x = x.view(B, T, S, C)
# ===== Step 1: Temporal Attention (across slices) =====
# For each spatial position, attend to all temporal instances
xt = x.permute(0, 2, 1, 3).reshape(B * S, T, C) # [B*S, T, C]
qkv_t = self.temporal_qkv(xt).reshape(B * S, T, 3, self.num_heads, C // self.num_heads)
qkv_t = qkv_t.permute(2, 0, 3, 1, 4) # [3, B*S, heads, T, head_dim]
q_t, k_t, v_t = qkv_t[0], qkv_t[1], qkv_t[2]
# Temporal attention scores
attn_t = (q_t @ k_t.transpose(-2, -1)) * self.scale
attn_t = attn_t.softmax(dim=-1)
attn_t = self.temporal_drop(attn_t)
out_t = (attn_t @ v_t).transpose(1, 2).reshape(B * S, T, C)
out_t = self.temporal_proj(out_t)
xt = xt + out_t # Residual connection
xt = xt.reshape(B, S, T, C).permute(0, 2, 1, 3) # Back to [B, T, S, C]
# ===== Step 2: Spatial Attention (within slice) =====
# For each temporal slice, attend to spatial patches
xs = xt.reshape(B * T, S, C) # [B*T, S, C]
qkv_s = self.spatial_qkv(xs).reshape(B * T, S, 3, self.num_heads, C // self.num_heads)
qkv_s = qkv_s.permute(2, 0, 3, 1, 4)
q_s, k_s, v_s = qkv_s[0], qkv_s[1], qkv_s[2]
attn_s = (q_s @ k_s.transpose(-2, -1)) * self.scale
attn_s = attn_s.softmax(dim=-1)
attn_s = self.spatial_drop(attn_s)
out_s = (attn_s @ v_s).transpose(1, 2).reshape(B * T, S, C)
out_s = self.spatial_proj(out_s)
xs = xs + out_s
# Reshape back to sequence
out = xs.reshape(B, T, S, C).reshape(B, T * S, C)
return out
class BlockTimeSformer(nn.Module):
"""TimeSformer transformer block with divided attention and MLP"""
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = DividedSpaceTimeAttention(dim, num_heads, qkv_bias, drop, drop)
self.norm2 = nn.LayerNorm(dim)
mlp_hidden = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden),
nn.GELU(),
nn.Dropout(drop),
nn.Linear(mlp_hidden, dim),
nn.Dropout(drop)
)
def forward(self, x, T, S):
x = x + self.attn(self.norm1(x), T, S)
x = x + self.mlp(self.norm2(x))
return x
class TimeSformerDBT(nn.Module):
"""
TimeSformer adapted for DBT volumetric analysis
Treats DBT as spatiotemporal data with divided attention for efficiency
"""
def __init__(self, img_size=224, patch_size=16, depth=16, in_chans=1,
embed_dim=768, num_heads=12, mlp_ratio=4., num_classes=2, depth_blocks=8):
super().__init__()
self.patch_embed = PatchEmbed3D(img_size, patch_size, in_chans, embed_dim, depth)
self.T = depth // 2 # Temporal tokens
self.S = (img_size // patch_size) ** 2 # Spatial tokens per slice
# Transformer blocks
self.blocks = nn.ModuleList([
BlockTimeSformer(embed_dim, num_heads, mlp_ratio)
for _ in range(depth_blocks)
])
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
"""
x: [B, 1, D, H, W] DBT volume
Returns: classification logits
"""
x = self.patch_embed(x) # [B, N, D]
for block in self.blocks:
x = block(x, self.T, self.S)
x = self.norm(x)
x = x.mean(dim=1) # Global average pooling
return self.head(x)
if __name__ == "__main__":
model = TimeSformerDBT(
img_size=224, patch_size=16, depth=16,
embed_dim=384, num_heads=6, depth_blocks=4
)
x = torch.randn(2, 1, 16, 224, 224)
logits = model(x)
print(f"[TimeSformer] Input: {x.shape}, Output: {logits.shape}")
print(f"Parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
脚本5:自监督预训练与对比学习
技术覆盖:1.3节(MAE掩蔽自编码器、监督对比预训练SCP、跨视图学习)
Python
"""
Script 5: Self-Supervised and Contrastive Learning for Mammography
Content: Masked Autoencoder (MAE) for mammography, Supervised Contrastive Pre-training (SCP),
Hard negative mining, cross-view mutual learning
Usage: Pre-training on unlabeled/labeled mammograms for downstream task transfer
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class MaskedAutoencoderViT(nn.Module):
"""
MAE adapted for mammography: high-ratio masking strategy for sparse breast textures
"""
def __init__(self, img_size=224, patch_size=16, in_chans=1,
embed_dim=768, depth=12, num_heads=12, decoder_dim=512,
mask_ratio=0.75): # High masking ratio for sparse medical images
super().__init__()
self.mask_ratio = mask_ratio
self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.num_patches = (img_size // patch_size) ** 2
# Positional encoding
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
# Encoder (ViT)
encoder_layer = nn.TransformerEncoderLayer(embed_dim, num_heads, embed_dim * 4, dropout=0.1)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
# Decoder (lightweight, reconstructs pixels)
self.decoder_embed = nn.Linear(embed_dim, decoder_dim)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))
decoder_layer = nn.TransformerEncoderLayer(decoder_dim, num_heads, decoder_dim * 4, dropout=0.1)
self.decoder = nn.TransformerEncoder(decoder_layer, num_layers=4)
self.decoder_pred = nn.Linear(decoder_dim, patch_size ** 2 * in_chans)
def random_masking(self, x, mask_ratio):
"""
x: [B, N, D] sequence of patch embeddings
Returns: masked sequence, mask indices, restore indices
"""
B, N, D = x.shape
len_keep = int(N * (1 - mask_ratio))
noise = torch.rand(B, N, device=x.device) # Random noise
ids_shuffle = torch.argsort(noise, dim=1) # Ascend: keep first len_keep
ids_restore = torch.argsort(ids_shuffle, dim=1)
# Generate mask: 0 is keep, 1 is remove
mask = torch.ones([B, N], device=x.device)
mask[:, :len_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
# Keep subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
return x_masked, mask, ids_restore, ids_keep
def forward(self, x):
"""
x: [B, 1, H, W]
Returns: loss, reconstructed image, mask
"""
B, C, H, W = x.shape
# Patch embedding
x = self.patch_embed(x) # [B, D, H/P, W/P]
x = x.flatten(2).transpose(1, 2) # [B, N, D]
x = x + self.pos_embed
# Masking
x_masked, mask, ids_restore, ids_keep = self.random_masking(x, self.mask_ratio)
# Encode visible patches only
x_encoded = self.encoder(x_masked) # [B, N_keep, D]
# Decoder: append mask tokens
x_dec = self.decoder_embed(x_encoded)
mask_tokens = self.mask_token.repeat(x_dec.shape[0], ids_restore.shape[1] - x_dec.shape[1], 1)
x_full = torch.cat([x_dec, mask_tokens], dim=1)
x_full = torch.gather(x_full, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x_dec.shape[2]))
x_decoded = self.decoder(x_full)
x_pred = self.decoder_pred(x_decoded) # [B, N, P*P*C]
# Reshape prediction to image
x_pred = x_pred.reshape(B, H // 16, W // 16, 16, 16, C)
x_pred = x_pred.permute(0, 5, 1, 3, 2, 4).reshape(B, C, H, W)
# Compute reconstruction loss only on masked regions
loss = F.mse_loss(x_pred * mask.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 16, 16),
x * mask.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 16, 16))
return loss, x_pred, mask
class SupervisedContrastiveLoss(nn.Module):
"""
SupCon: Supervised Contrastive Pre-training without strong augmentation
Uses pathology labels to construct positive/negative pairs
"""
def __init__(self, temperature=0.07, base_temperature=0.07):
super().__init__()
self.temperature = temperature
self.base_temperature = base_temperature
def forward(self, features, labels, mask=None):
"""
features: [B, D] normalized feature vectors
labels: [B] pathology labels (0: benign, 1: malignant)
"""
device = features.device
batch_size = features.shape[0]
if labels is not None and mask is not None:
raise ValueError("Cannot define both labels and mask")
if labels is not None:
labels = labels.contiguous().view(-1, 1)
mask = torch.eq(labels, labels.T).float().to(device) # Positive mask: same label
# Compute similarity matrix
anchor_dot_contrast = torch.div(
torch.matmul(features, features.T),
self.temperature
)
# For numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
# Mask out self-contrast
logits_mask = torch.scatter(
torch.ones_like(mask), 1,
torch.arange(batch_size).view(-1, 1).to(device), 0
)
mask = mask * logits_mask # Remove diagonal
# Compute log probabilities
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
# Compute mean of log-likelihood over positives
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
return loss.mean()
class CrossViewMaskedLearning(nn.Module):
"""
Cross-view mutual learning: MAE with CC-MLO view exchange
Exploits complementary information between views for robust representation learning
"""
def __init__(self, img_size=224, embed_dim=768):
super().__init__()
self.mae_cc = MaskedAutoencoderViT(img_size, embed_dim=embed_dim)
self.mae_mlo = MaskedAutoencoderViT(img_size, embed_dim=embed_dim)
# Cross-view prediction head: predict one view from the other
self.cross_predictor = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.ReLU(inplace=True),
nn.Linear(embed_dim, embed_dim)
)
def forward(self, cc_img, mlo_img):
"""
Joint training with within-view reconstruction and cross-view prediction
"""
# Encode both views (visible patches only)
cc_loss, cc_recon, cc_mask = self.mae_cc(cc_img)
mlo_loss, mlo_recon, mlo_mask = self.mae_mlo(mlo_img)
# Cross-view consistency loss (feature-level)
cc_features = self.get_features(self.mae_cc, cc_img)
mlo_features = self.get_features(self.mae_mlo, mlo_img)
# Predict CC from MLO and vice versa
cc_pred = self.cross_predictor(mlo_features)
mlo_pred = self.cross_predictor(cc_features)
cross_loss = F.mse_loss(F.normalize(cc_pred, dim=1),
F.normalize(cc_features.detach(), dim=1)) + \
F.mse_loss(F.normalize(mlo_pred, dim=1),
F.normalize(mlo_features.detach(), dim=1))
total_loss = cc_loss + mlo_loss + 0.5 * cross_loss
return total_loss, cc_recon, mlo_recon
def get_features(self, model, x):
"""Extract encoded features from MAE encoder"""
x = model.patch_embed(x).flatten(2).transpose(1, 2)
x = x + model.pos_embed
x, _, _, _ = model.random_masking(x, model.mask_ratio)
return model.encoder(x).mean(dim=1) # Global pooling
class HardNegativeMiner:
"""
Hard negative mining for benign/malignant discrimination
Selects negatives that are similar in embedding space but different in pathology
"""
def __init__(self, k=5):
self.k = k # Number of hard negatives per anchor
def mine(self, embeddings, labels):
"""
embeddings: [B, D] normalized
labels: [B] binary labels
Returns: indices of hard negatives for each anchor
"""
sim_matrix = torch.matmul(embeddings, embeddings.T) # Cosine similarity
hard_negatives = []
for i in range(len(labels)):
# Find samples with different labels
neg_mask = labels != labels[i]
# Among negatives, find most similar (hardest) to anchor
neg_sim = sim_matrix[i] * neg_mask.float()
neg_sim[i] = -1 # Exclude self
_, top_neg_indices = torch.topk(neg_sim, min(self.k, neg_mask.sum().item()))
hard_negatives.append(top_neg_indices)
return torch.stack(hard_negatives)
if __name__ == "__main__":
# Test MAE
mae = MaskedAutoencoderViT(img_size=224, embed_dim=384, depth=6, mask_ratio=0.75)
img = torch.randn(4, 1, 224, 224)
loss, recon, mask = mae(img)
print(f"[MAE] Loss: {loss.item():.4f}, Recon shape: {recon.shape}")
# Test Supervised Contrastive
supcon = SupervisedContrastiveLoss(temperature=0.1)
features = F.normalize(torch.randn(8, 128), dim=1)
labels = torch.tensor([0, 0, 1, 1, 0, 1, 0, 1])
scl_loss = supcon(features, labels)
print(f"[SupCon] Loss: {scl_loss.item():.4f}")
# Test Hard Negative Mining
miner = HardNegativeMiner(k=3)
hard_negs = miner.mine(features, labels)
print(f"[Hard Negative] Indices shape: {hard_negs.shape}")
脚本6:可解释性AI与LVLM基础架构
技术覆盖:1.4节(Grad-CAM可视化、SHAP分析、视觉问答架构基础)
Python
"""
Script 6: Explainable AI (XAI) and Large Vision-Language Model Foundations
Content: Gradient-weighted Class Activation Mapping (Grad-CAM),
Feature attribution for ensemble models, VQA architecture for mammography
Usage: Model interpretation and clinical validation of AI decisions
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class GradCAM:
"""
Gradient-weighted Class Activation Mapping for mammography lesion localization
Validates anatomical consistency of model attention
"""
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
# Register hooks
target_layer.register_forward_hook(self.save_activation)
target_layer.register_backward_hook(self.save_gradient)
def save_activation(self, module, input, output):
self.activations = output.detach()
def save_gradient(self, module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
def generate(self, input_img, target_class=None):
"""
Generate CAM heatmap
Args:
input_img: [B, C, H, W]
target_class: int or None (use predicted class)
Returns:
heatmap: [B, H, W] normalized to 0-1
"""
self.model.eval()
output = self.model(input_img)
if isinstance(output, tuple):
logits = output[0]
else:
logits = output
if target_class is None:
target_class = logits.argmax(dim=1)
self.model.zero_grad()
# One-hot encoding for target class
one_hot = torch.zeros_like(logits)
one_hot.scatter_(1, target_class.unsqueeze(1), 1)
# Backward pass
logits.backward(gradient=one_hot, retain_graph=True)
# Compute weights: global average pooling of gradients
weights = self.gradients.mean(dim=(2, 3), keepdim=True) # [B, C, 1, 1]
# Weighted combination of activation maps
cam = (weights * self.activations).sum(dim=1) # [B, H', W']
cam = F.relu(cam) # Only positive contributions
# Upsample to input size
cam = F.interpolate(cam.unsqueeze(1), size=input_img.shape[2:],
mode='bilinear', align_corners=False)
cam = cam.squeeze(1)
# Normalize per image
cam = cam - cam.min(dim=1, keepdim=True)[0].min(dim=2, keepdim=True)[0]
cam = cam / (cam.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0] + 1e-8)
return cam
class SHAPFeatureAttribution:
"""
SHAP-inspired feature attribution for ensemble mammography models
Estimates contribution of each view (CC/MLO) and feature channel to final prediction
"""
def __init__(self, model, background_samples=10):
self.model = model
self.background = None
self.n_background = background_samples
def fit_background(self, dataloader):
"""Capture background distribution from reference data"""
samples = []
for batch in dataloader:
samples.append(batch[0])
if len(samples) >= self.n_background:
break
self.background = torch.cat(samples, dim=0)[:self.n_background]
def explain(self, input_img, n_perturbations=50):
"""
Estimate feature importance via input perturbation (simplified SHAP)
Returns attribution map same size as input
"""
if self.background is None:
raise ValueError("Call fit_background first")
B, C, H, W = input_img.shape
attributions = torch.zeros_like(input_img)
# Baseline prediction
with torch.no_grad():
baseline_out = self.model(input_img)
if isinstance(baseline_out, tuple):
baseline_prob = F.softmax(baseline_out[0], dim=1)[:, 1] # Malignant prob
else:
baseline_prob = F.softmax(baseline_out, dim=1)[:, 1]
# Perturbation-based attribution
for _ in range(n_perturbations):
# Random mask
mask = torch.rand_like(input_img) > 0.5
masked_input = input_img * mask + self.background[:B] * (~mask)
with torch.no_grad():
out = self.model(masked_input)
if isinstance(out, tuple):
prob = F.softmax(out[0], dim=1)[:, 1]
else:
prob = F.softmax(out, dim=1)[:, 1]
# Attribution proportional to prediction change
delta = (baseline_prob - prob).abs()
attributions += delta.view(-1, 1, 1, 1) * mask.float()
return attributions / n_perturbations
class SimpleMammographyVQA(nn.Module):
"""
Simplified Vision-Language model for Mammography VQA (MammoVQA baseline)
Combines visual encoder with language decoder for diagnostic reasoning
"""
def __init__(self, img_size=224, embed_dim=512, vocab_size=1000, max_len=128):
super().__init__()
# Vision encoder (lightweight ResNet)
from torchvision.models import resnet34
resnet = resnet34(pretrained=False)
resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.visual_encoder = nn.Sequential(*list(resnet.children())[:-1])
self.visual_proj = nn.Linear(512, embed_dim)
# Text embedding
self.word_embed = nn.Embedding(vocab_size, embed_dim)
self.pos_embed = nn.Parameter(torch.randn(1, max_len, embed_dim))
# Multimodal fusion transformer
decoder_layer = nn.TransformerDecoderLayer(embed_dim, nhead=8, dim_feedforward=2048)
self.fusion_transformer = nn.TransformerDecoder(decoder_layer, num_layers=6)
# Output head
self.answer_head = nn.Linear(embed_dim, vocab_size)
def forward(self, image, question_tokens, answer_tokens=None):
"""
image: [B, 1, H, W]
question_tokens: [B, L_q] tokenized question
answer_tokens: [B, L_a] (optional, for training)
"""
# Visual features
visual_feat = self.visual_encoder(image) # [B, 512, 1, 1]
visual_feat = visual_feat.flatten(1) # [B, 512]
visual_embed = self.visual_proj(visual_feat).unsqueeze(0) # [1, B, D]
# Text features
q_embed = self.word_embed(question_tokens) + self.pos_embed[:, :question_tokens.size(1), :]
q_embed = q_embed.permute(1, 0, 2) # [L_q, B, D]
# Multimodal fusion: text queries visual memory
if answer_tokens is not None:
# Teacher forcing during training
a_embed = self.word_embed(answer_tokens) + self.pos_embed[:, :answer_tokens.size(1), :]
a_embed = a_embed.permute(1, 0, 2)
tgt_mask = self.generate_square_subsequent_mask(answer_tokens.size(1)).to(image.device)
output = self.fusion_transformer(a_embed, visual_embed, tgt_mask=tgt_mask)
logits = self.answer_head(output) # [L_a, B, vocab_size]
return logits.permute(1, 2, 0) # [B, vocab_size, L_a]
else:
# Inference: auto-regressive generation (simplified)
return visual_embed
def generate_square_subsequent_mask(self, sz):
mask = torch.triu(torch.ones(sz, sz), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
return mask
class CaseBasedReasoningRetriever:
"""
Case-based reasoning: retrieve similar cases to explain predictions
Enhances radiologist trust by providing analogical evidence
"""
def __init__(self, feature_extractor, case_database):
"""
feature_extractor: model to extract feature vectors
case_database: list of (image, label, report) tuples with precomputed features
"""
self.feature_extractor = feature_extractor
self.database = case_database
def retrieve_similar_cases(self, query_img, k=3):
"""
Retrieve k most similar cases from database based on cosine similarity
Returns: list of (case_img, case_label, similarity_score)
"""
with torch.no_grad():
query_feat = self.feature_extractor(query_img)
if isinstance(query_feat, tuple):
query_feat = query_feat[0]
query_feat = F.normalize(query_feat, dim=1)
similarities = []
for case_img, case_label, case_feat in self.database:
sim = torch.matmul(query_feat, case_feat.T).squeeze()
similarities.append((case_img, case_label, sim.item()))
# Sort by similarity
similarities.sort(key=lambda x: x[2], reverse=True)
return similarities[:k]
if __name__ == "__main__":
# Test Grad-CAM
from torchvision.models import resnet34
model = resnet34(pretrained=False)
model.conv1 = nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(512, 2)
# Create dummy model wrapper for hooks
class ModelWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x):
return self.model(x)
wrapped = ModelWrapper(model)
gradcam = GradCAM(wrapped, model.layer4[-1])
img = torch.randn(1, 1, 224, 224, requires_grad=True)
heatmap = gradcam.generate(img)
print(f"[Grad-CAM] Heatmap shape: {heatmap.shape}, Range: [{heatmap.min():.2f}, {heatmap.max():.2f}]")
# Test VQA model
vqa = SimpleMammographyVQA(img_size=224, embed_dim=256, vocab_size=500)
img = torch.randn(2, 1, 224, 224)
question = torch.randint(0, 500, (2, 20))
answer = torch.randint(0, 500, (2, 10))
logits = vqa(img, question, answer)
print(f"[VQA] Logits shape: {logits.shape}")
以上源码完整实现了第一章所述的全部核心算法架构,涵盖从传统2D CNN到3D Transformer的演进路径,以及自监督预训练和可解释AI的基础组件。每个模块均可独立运行测试,并可根据具体临床场景进行组合扩展。