一、项目背景与设计思路
1. 为什么"端到端 CNN"在医疗中经常失败?
很多教程喜欢这样做:
CT 图像 → CNN → 预测是否患病
但在真实医疗场景中,问题很快会暴露:
-
数据量不够(几百 ~ 几千)
-
批次差异大(不同医院 / 设备)
-
医生需要解释模型结果
-
模型上线后性能漂移严重
👉 这不是 CNN 不强,而是医疗场景不适合"一把梭"。
2. 更成熟的工程方案:CNN + XGBoost
医学影像 → CNN → 高阶影像特征 ↓ XGBoost / RF / LR ↓ 疾病风险预测
这个结构的优势是:
-
CNN 专注于特征表达
-
XGBoost 专注于稳定决策
-
小样本也能工作
-
方便做可解释性
二、项目整体结构设计
medical_prediction/ ├── data/ │ ├── images/ │ ├── clinical.csv │ └── labels.csv ├── cnn/ │ ├── dataset.py │ ├── model.py │ └── train_cnn.py ├── feature/ │ └── extract_features.py ├── ml/ │ ├── train_xgb.py │ └── evaluate.py └── main_pipeline.py
这是一个"真实可维护"的结构,不是 Notebook 玩具
三、Step 1:医学影像数据准备与 Dataset 构建
1️⃣ 自定义 Dataset(PyTorch)
python
# cnn/dataset.py
import torch
from torch.utils.data import Dataset
import numpy as np
class MedicalImageDataset(Dataset):
def __init__(self, images, labels):
self.images = images
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
x = self.images[idx]
y = self.labels[idx]
return torch.tensor(x, dtype=torch.float32), torch.tensor(y)
2️⃣ 医疗影像预处理经验(非常关键)
真实项目中通常需要:
-
归一化(HU 值 / 强度)
-
Resize
-
中心裁剪
-
简单增强(翻转、噪声)
不要一上来就疯狂数据增强,医疗里很容易引入伪特征。
四、Step 2:CNN 模型设计
1️⃣ CNN 设计原则
-
不追求太深
-
不追求 ImageNet 那套
-
目标是"稳定特征"而不是极致精度
2️⃣ CNN 模型代码
python
# cnn/model.py
import torch
import torch.nn as nn
class MedicalCNN(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 16, 3, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.classifier = nn.Linear(32 * 7 * 7, 2)
def forward(self, x, return_feature=False):
x = self.features(x)
x = x.view(x.size(0), -1)
if return_feature:
return x
return self.classifier(x)
五、Step 3:CNN 训练
1️⃣ 训练代码
python
# cnn/train_cnn.py
import torch
import torch.nn as nn
import torch.optim as optim
from cnn.model import MedicalCNN
model = MedicalCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(15):
model.train()
images = torch.randn(64, 1, 28, 28)
labels = torch.randint(0, 2, (64,))
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}, Loss={loss.item():.4f}")
👉 工程经验
-
CNN 不必训到极致
-
过拟合反而会让特征"失真"
-
我通常在 loss 稳定后就停
六、Step 4:CNN 特征提取
python
# feature/extract_features.py
import torch
import numpy as np
from cnn.model import MedicalCNN
model = MedicalCNN()
model.eval()
def extract_features(images):
with torch.no_grad():
feats = model(images, return_feature=True)
return feats.cpu().numpy()
python
images = torch.randn(300, 1, 28, 28)
cnn_features = extract_features(images)
print(cnn_features.shape)
七、Step 5:融合临床特征
python
clinical_features = np.random.randn(300, 6)
X = np.concatenate(
[cnn_features, clinical_features],
axis=1
)
y = np.random.randint(0, 2, 300)
👉 影像 + 临床 = 医疗 AI 的基本盘
八、Step 6:XGBoost 训练
python
from xgboost import XGBClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
model = XGBClassifier(
n_estimators=400,
max_depth=5,
learning_rate=0.03,
subsample=0.8,
colsample_bytree=0.8,
eval_metric="logloss"
)
model.fit(X_train, y_train)
y_prob = model.predict_proba(X_test)[:, 1]
print("AUC:", roc_auc_score(y_test, y_prob))
九、Step 7:可解释性
1️⃣ SHAP 示例
python
import shap
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
shap.summary_plot(shap_values, X_test)
👉 你可以清楚看到:
-
哪些影像特征重要
-
哪些临床指标起决定作用
十、真实医疗项目的 5 条血泪经验
1️⃣ 不要迷信大模型
2️⃣ 稳定性 > 精度
3️⃣ 特征质量 > 网络深度
4️⃣ 医生信任比 AUC 更重要
5️⃣ CNN + XGBoost 是成熟方案,不是退而求其次
十一、总结
CNN 解决"看不懂影像"的问题
XGBoost 解决"怎么做决定"的问题
这不是妥协,而是工程智慧。