为什么学这个?
最近在做医疗影像的超分辨率(Super-Resolution, SR)任务。我们都知道,深度学习模型如果直接在一个垂直领域的小数据集上从头训练,往往很难收敛,效果也不尽如人意。标准的做法是:先在一个大规模通用数据集(如 DIV2K)上预训练,让模型学会提取直线、角点、色彩过渡等"基础高频特征",然后再迁移到我们特定的医疗数据集(我使用的是最新开源的原生 4K 内窥镜数据集 SurgiSR4K)上进行微调(Fine-tuning) 。
微调听起来很简单,无非就是加载一下 .pth 权重接着跑。但在实际落地中,为了保证实验的绝对可复现性 以及无损的特征迁移,我踩了不少坑。这篇文章就来复盘一下我是如何从零构建一个稳健的微调流水线的。
核心内容与步骤
我的整体思路是将"预训练"和"微调"的逻辑彻底解耦,单独编写了一个 finetune.py 脚本。
1. 严格锁定随机种子(保证可复现)
做算法实验,不可复现是大忌。在代码的开头,我写了一个死锁所有随机性的函数,确保只要传入相同的 seed,每次微调的 loss 曲线必须完全一致:
Python
ini
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
cudnn.deterministic = True
cudnn.benchmark = False
2. 动态数据集加载与预处理对齐
预训练时,为了追求极致的 I/O 速度,我提前将通用数据集切分成了静态的 .h5 文件。但在微调阶段,为了更好的数据扩增效果,我改为直接读取重新组织好(划分了 train/val/test)的图片文件夹,并在 DataLoader 中实时进行联合随机裁剪(Joint Random Crop) 。
关键点:通道严格对齐!
在仔细盘点预训练阶段的 prepare.py 脚本时,我惊出了一身冷汗,发现了一个极易致命的细节------预训练模型完全是仅在 Y 通道(亮度通道)上进行训练的!
很多初次接触超分算法(如 ESPCN、SRCNN)的开发者可能会疑惑:为什么放着好好的 RGB 三通道不用,非要大费周章地只提 Y 通道来训练?
这其实是基于**人类视觉系统(HVS)**的生理特性做出的极致优化:
-
人眼对"结构"比对"色彩"更敏感:人眼对图像的亮度(Y通道,几乎包含了所有的边缘、轮廓、纹理等高频结构信息)极其敏感,而对色度/饱和度(Cb、Cr通道)的微小模糊则相对迟钝。
-
算力与效果的完美平衡 :如果同时对 RGB 三通道做超分辨率重建,计算量和显存开销会直接翻三倍。因此,业内极其经典且高效的做法是:只把最难的高频亮度特征(Y通道)交给神经网络去精雕细琢,而对于 Cb、Cr 通道,只需使用极低成本的双三次插值(Bicubic)直接放大即可。 最后将它们合并转回 RGB,肉眼几乎看不出色彩瑕疵,但速度却快得多。
明白了这一层底层算法逻辑,再回到我们的微调代码上:预训练网络的第一层卷积是为 1通道 输入量身定制的,如果微调时我没注意,直接把原图的 RGB 三通道 喂进去,由于张量维度不匹配,模型会瞬间抛出 Shape Mismatch 错误并当场崩溃。
为了规避这个大坑,我在新的 Dataset 类中严谨地重写了转换逻辑。我强制将每张图片先转为 YCbCr 色彩空间,并剥离出 Y 通道,确保喂给网络的数据与预训练时的特征空间达到绝对的、像素级的对齐:
Python
python
def rgb_to_y(self, img):
"""提取 Y 通道,严格对齐预训练模型的输入特征空间"""
ycbcr = img.convert('YCbCr')
y, cb, cr = ycbcr.split()
return y
3. 制定微调超参数策略
基于预训练的参数(Scale=2, LR=1e-5, BatchSize=16, Epochs=200),我为微调阶段制定了如下策略:
- 放大倍率 (Scale) = 2:绝对不能变!网络末端的上采样层权重尺寸是和 Scale 绑定的,改了直接报 Shape Mismatch。
- 学习率 (LR) = 1e-6:将预训练的最终学习率砍半(甚至可以缩小到 1/10)。只做微小调整,适应医疗图像中组织黏膜、血管的纹理。
- 训练轮数 (Epochs) = 80 :微调起点极高,通常 50-100 轮即可在特定数据集上收敛,跑多了反而会在小数据集上过拟合。配合 Early Stopping 逻辑保存
best_finetuned.pth。
遇到的坑点与排雷指南
在调试过程中,我总结了微调阶段最容易踩的 4 个大坑:
坑点一:灾难性遗忘(Catastrophic Forgetting)
- 现象:加载权重后,头几个 Epoch 的 Loss 突然爆炸,PSNR 断崖式下跌。
- 原因:学习率设置过大,巨大的梯度更新瞬间摧毁了模型在预训练时好不容易学到的底层通用特征。
- 解法:微调学习率必须远小于预训练学习率。
坑点二:无脑加载 Optimizer 状态
- 现象:Loss 降不下去,收敛方向极其诡异。
- 原因:像 Adam 这样的优化器内部会保存历史动量(Momentum)和方差信息。如果换了全新的数据集还加载旧的优化器状态,历史动量会把模型往错误的方向"带偏"。
- 解法 :迁移数据集微调时,只加载
model.state_dict(),不要加载optimizer.state_dict(),让优化器以新的小学习率重新初始化。
坑点三:Patch Size(裁剪大小)缩水导致感受野丢失
- 疑惑:微调时,输入图片的裁剪大小可以随便改吗?
- 正解 :强烈建议保持一致或适度调大,绝不能变小。 模型在预训练时已经习惯了在固定大小的窗口(如 32x32)内寻找纹理关联。如果微调时切成了 16x16,视野变小,全局特征聚合能力就会失效。
坑点四:不敢调整 Batch Size
- 疑惑:换了 4K 数据集后显存吃紧,微调可以减小 Batch Size 吗?
- 正解 :完全可以。 传统的分类网络(如 ResNet)因为有 Batch Normalization (BN) 层,强行缩小 Batch 会导致统计量崩塌。但 SR 网络(如我用的 ESPCN)通常没有 BN 层,Batch Size 缩小到 8 甚至 4 并不影响内部特征分布,只需把学习率同步调小一点防止梯度震荡即可。
收获与总结
这次实战让我深刻体会到,深度学习的微调绝不是简单的 torch.load。它要求我们不仅要洞悉底层网络结构(如 BN 层的有无、感受野的大小),还要对数据流水线(如预处理通道、归一化方式)有近乎苛刻的像素级把控。
SurgiSR4K 这样的高质量医疗 4K 数据集非常难得,希望这套"无损微调"的方法论也能帮到正在做医疗影像算法的同行们。
以下是为你准备的文章附录部分。你可以直接将这部分追加到刚才那篇博客的末尾。
我已经将我们在讨论中得出的最佳实践(如 scale=2、学习率 5e-6、提取 Y 通道等)全部更新到了这份最终版本的代码中。
附录:finetune.py 完整源码
这里放出我用于微调的完整脚本,代码中包含了严格的随机种子控制、Y 通道提取、联合数据增强以及最佳模型保存逻辑。大家可以基于此代码直接跑在自己的 SurgiSR4K 数据集上。
Python
python
"""
SurgiSR4K 医疗内窥镜 4K 数据集微调脚本 (Fine-tuning Script)
Author: 你的名字/昵称
Description: 专为保留预训练高频特征、防止灾难性遗忘设计的超分辨率微调流水线。
"""
import argparse
import os
import random
import re # <--- 新增正则模块,用于解析文件名
import numpy as np
from pathlib import Path
from PIL import Image
import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
# 导入你原有的模型和工具函数
from models import ESPCN_RDB
from utils import AverageMeter, calc_psnr
# ==========================================
# 1. 保证可复现性的种子设置
# ==========================================
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
cudnn.deterministic = True
cudnn.benchmark = False
# ==========================================
# 2. 适配 organized 目录和 Y 通道的 Dataset
# ==========================================
class SurgiSR4KDataset(Dataset):
def __init__(self, data_root, split='train', lr_res="1920x1080p", hr_res="3840x2160p",
scale_factor=2, lr_patch_size=32, is_train=True):
self.data_root = Path(data_root)
self.split = split
self.lr_res = lr_res
self.hr_res = hr_res
self.scale_factor = scale_factor
self.lr_patch_size = lr_patch_size
self.is_train = is_train
self.lr_dir = self.data_root / self.split / self.lr_res
self.hr_dir = self.data_root / self.split / self.hr_res
if not self.lr_dir.exists() or not self.hr_dir.exists():
raise FileNotFoundError(f"找不到数据目录: {self.lr_dir} 或 {self.hr_dir}")
self.lr_image_paths = sorted(list(self.lr_dir.rglob("*.png")))
if len(self.lr_image_paths) == 0:
raise ValueError(f"在 {self.lr_dir} 中未找到任何图像!")
def __len__(self):
return len(self.lr_image_paths)
def rgb_to_y(self, img):
ycbcr = img.convert('YCbCr')
y, cb, cr = ycbcr.split()
return y
def __getitem__(self, idx):
lr_path = self.lr_image_paths[idx]
rel_path = lr_path.relative_to(self.lr_dir)
hr_rel_path_str = str(rel_path).replace(self.lr_res, self.hr_res)
hr_path = self.hr_dir / hr_rel_path_str
lr_img = Image.open(lr_path).convert("RGB")
hr_img = Image.open(hr_path).convert("RGB")
if self.is_train:
lr_w, lr_h = lr_img.size
lr_x = random.randint(0, lr_w - self.lr_patch_size)
lr_y = random.randint(0, lr_h - self.lr_patch_size)
hr_x = lr_x * self.scale_factor
hr_y = lr_y * self.scale_factor
hr_patch_size = self.lr_patch_size * self.scale_factor
lr_img = lr_img.crop((lr_x, lr_y, lr_x + self.lr_patch_size, lr_y + self.lr_patch_size))
hr_img = hr_img.crop((hr_x, hr_y, hr_x + hr_patch_size, hr_y + hr_patch_size))
if random.random() < 0.5:
lr_img = TF.hflip(lr_img)
hr_img = TF.hflip(hr_img)
if random.random() < 0.5:
lr_img = TF.vflip(lr_img)
hr_img = TF.vflip(hr_img)
angle = random.choice([0, 90, 180, 270])
if angle != 0:
lr_img = TF.rotate(lr_img, angle)
hr_img = TF.rotate(hr_img, angle)
lr_y = self.rgb_to_y(lr_img)
hr_y = self.rgb_to_y(hr_img)
lr_tensor = TF.to_tensor(lr_y)
hr_tensor = TF.to_tensor(hr_y)
return lr_tensor, hr_tensor
# ==========================================
# 3. 微调主循环
# ==========================================
def main():
parser = argparse.ArgumentParser(description="SurgiSR4K Fine-tuning Script")
parser.add_argument('--data_root', type=str, required=True, help='Organized 数据集的根目录路径')
parser.add_argument('--pretrained_weights', type=str, required=True, help='预训练模型的 .pth 文件路径')
parser.add_argument('--scale', type=int, default=2, help='超分辨率放大倍数 (SurgiSR4K 从 1920 到 3840 是 2 倍)')
parser.add_argument('--lr_patch_size', type=int, default=32, help='输入给网络的 LR 裁剪大小 (HR对应为 32*2=64)')
parser.add_argument('--batch_size', type=int, default=16, help='Batch Size')
parser.add_argument('--num_epochs', type=int, default=100, help='微调轮数')
parser.add_argument('--lr', type=float, default=1e-6, help='微调学习率')
parser.add_argument('--num_workers', type=int, default=4, help='DataLoader 线程数 (Windows 建议设为 0)')
parser.add_argument('--seed', type=int, default=42, help='随机种子')
args = parser.parse_args()
pretrained_abspath = os.path.abspath(args.pretrained_weights)
outputs_dir = os.path.dirname(pretrained_abspath)
pretrained_basename = os.path.basename(args.pretrained_weights)
pretrained_name, ext = os.path.splitext(pretrained_basename)
best_save_filename = f"{pretrained_name}_finetuned{ext}"
latest_save_filename = f"{pretrained_name}_latest{ext}"
# ==========================================
# === 新增模块:利用正则表达式解析模型参数 ===
# ==========================================
parsed_kwargs = {}
# 解析 growth_channels (例如: growth_channels_16)
match_gc = re.search(r'growth_channels_(\d+)', pretrained_name)
if match_gc:
parsed_kwargs['growth_channels'] = int(match_gc.group(1))
# 解析 rdb_layers (例如: RDB_3)
match_rdb = re.search(r'RDB_(\d+)', pretrained_name)
if match_rdb:
parsed_kwargs['rdb_layers'] = int(match_rdb.group(1))
# 解析 attention_type (例如: Attn_pixel, Attn_weakened_pixel, Attn_none)
match_attn = re.search(r'Attn_(pixel|weakened_pixel|none)', pretrained_name)
if match_attn:
parsed_kwargs['attention_type'] = match_attn.group(1)
# 解析 activation (通过枚举常见的激活函数名)
activation_types = ['ReLU', 'LeakyReLU', 'PReLU', 'Tanh', 'Sigmoid', 'GELU']
for act in activation_types:
# 匹配下划线包裹的激活函数,例如 _LeakyReLU_
if f"_{act}_" in pretrained_name or pretrained_name.endswith(f"_{act}"):
parsed_kwargs['activation'] = act
break
print(f"==> 从文件名 [{pretrained_basename}] 中解析到以下超参数:")
for k, v in parsed_kwargs.items():
print(f" - {k}: {v}")
# ==========================================
set_random_seed(args.seed)
os.makedirs(outputs_dir, exist_ok=True)
writer = SummaryWriter(log_dir=os.path.join(outputs_dir, 'logs_finetune'))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# === 初始化模型 (利用字典解包传入解析出的参数) ===
print(f"==> 初始化模型 (Scale: x{args.scale}) ...")
model = ESPCN_RDB(scale_factor=args.scale, **parsed_kwargs).to(device)
print(f"==> 正在加载预训练权重: {args.pretrained_weights}")
checkpoint = torch.load(args.pretrained_weights, map_location=device, weights_only=False)
# 提取 state_dict
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
else:
state_dict = checkpoint
# === 关键步骤:过滤 thop 生成的脏数据 ===
clean_state_dict = {}
for k, v in state_dict.items():
if "total_ops" in k or "total_params" in k:
continue
clean_state_dict[k] = v
# 加载干净的权重
model.load_state_dict(clean_state_dict, strict=True)
print(f"==> 预训练权重加载成功!微调后的模型将保存在: {outputs_dir}")
print("==> 构建数据集...")
train_dataset = SurgiSR4KDataset(
data_root=args.data_root, split='train',
lr_res="1920x1080p", hr_res="3840x2160p",
scale_factor=args.scale, lr_patch_size=args.lr_patch_size, is_train=True
)
val_dataset = SurgiSR4KDataset(
data_root=args.data_root, split='val',
lr_res="1920x1080p", hr_res="3840x2160p",
scale_factor=args.scale, is_train=False
)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False,
num_workers=args.num_workers, pin_memory=True)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
criterion_mse = nn.MSELoss().to(device)
best_psnr = 0.0
best_epoch = 0
print("==> 开始微调...")
for epoch in range(args.num_epochs):
model.train()
epoch_losses = AverageMeter()
with tqdm(total=len(train_loader.dataset)) as t:
t.set_description(f'Epoch: {epoch}/{args.num_epochs - 1}')
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
preds = model(inputs)
loss = criterion_mse(preds, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_losses.update(loss.item(), len(inputs))
t.set_postfix(loss=f'{epoch_losses.avg:.6f}')
t.update(len(inputs))
model.eval()
eval_psnr = AverageMeter()
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
preds = model(inputs).clamp(0.0, 1.0)
psnr = calc_psnr(preds, labels)
eval_psnr.update(psnr.item(), len(inputs))
print(f'Eval PSNR: {eval_psnr.avg:.2f}dB')
writer.add_scalar('Loss/train', epoch_losses.avg, epoch)
writer.add_scalar('PSNR/val', eval_psnr.avg, epoch)
save_state = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'best_psnr': max(best_psnr, eval_psnr.avg)
}
torch.save(save_state, os.path.join(outputs_dir, latest_save_filename))
if eval_psnr.avg > best_psnr:
best_psnr = eval_psnr.avg
best_epoch = epoch
torch.save(save_state, os.path.join(outputs_dir, best_save_filename))
print(f"!!! 找到更好的模型,已保存 {best_save_filename} (PSNR: {best_psnr:.2f}dB)")
print(f'==> 微调完成!Best Epoch: {best_epoch}, 最佳 PSNR: {best_psnr:.2f}dB')
if __name__ == '__main__':
main()