代码来自:https://github.com/ChuHan89/WSSS-Tissue?tab=readme-ov-file
借助了一些人工智能
代码功能总览
该代码是弱监督语义分割(WSSS)流程的 Stage1 训练与测试脚本 ,核心任务是通过 多标签分类模型 生成图像级标签,为后续生成伪掩码(Pseudo-Masks)提供基础。代码分为 train_phase
和 test_phase
两个阶段,支持 渐进式Dropout注意力(PDA) 和 Visdom可视化监控。
1. 依赖库导入
import os
import numpy as np
import argparse
import importlib
from visdom import Visdom # 可视化工具
import torch
import torch.nn.functional as F
from torch.backends import cudnn # CUDA加速
from torch.utils.data import DataLoader
from torchvision import transforms # 数据预处理
from tool import pyutils, torchutils # 自定义工具包
from tool.GenDataset import Stage1_TrainDataset # 自定义数据集类
from tool.infer_fun import infer # 测试阶段推理函数
cudnn.enabled = True # 启用CUDA加速(自动优化卷积算法)
-
关键细节:
-
cudnn.enabled=True
:启用cuDNN加速,自动选择最优卷积实现。 -
pyutils
和torchutils
:项目自定义工具模块(包含优化器、计时器等)。 -
Visdom
:用于实时可视化训练过程中的损失和准确率曲线。
-
2. 辅助函数 compute_acc
def compute_acc(pred_labels, gt_labels):
pred_correct_count = 0
for pred_label in pred_labels: # 遍历预测标签
if pred_label in gt_labels: # 判断是否在真实标签中
pred_correct_count += 1
union = len(gt_labels) + len(pred_labels) - pred_correct_count # 并集大小
acc = round(pred_correct_count/union, 4) # 交并比(IoU)式准确率
return acc
-
功能 :计算预测标签与真实标签的 交并比准确率(IoU-like Accuracy)。
-
数学公式:
Acc=预测正确的标签数预测标签数+真实标签数−预测正确的标签数Acc=预测标签数+真实标签数−预测正确的标签数预测正确的标签数
-
示例:
-
预测标签:
[0, 2]
,真实标签:[2, 3]
-
正确数:1(标签2),并集:2 + 2 - 1 = 3 → Acc = 1/3 ≈ 0.333
-
3. 训练阶段 train_phase
3.1 初始化与模型加载
def train_phase(args):
viz = Visdom(env=args.env_name) # 创建Visdom环境(用于可视化)
model = getattr(importlib.import_module(args.network), 'Net')(args.init_gama, n_class=args.n_class)
print(vars(args)) # 打印所有输入参数
-
关键细节:
-
动态模型加载 :通过
importlib
从字符串args.network
(如"network.resnet38_cls"
)动态加载模型类Net
。 -
PDA参数 :
args.init_gama
控制渐进式Dropout注意力的初始强度(值越大,注意力区域越集中)。 -
Visdom环境 :通过
env=args.env_name
隔离不同实验的可视化结果。
-
3.2 数据增强与加载
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5), # 50%概率水平翻转
transforms.RandomVerticalFlip(p=0.5), # 50%概率垂直翻转
transforms.ToTensor() # 转为Tensor(范围[0,1])
])
train_dataset = Stage1_TrainDataset(
data_path=args.trainroot, # 训练集路径(如'datasets/BCSS-WSSS/train/')
transform=transform_train,
dataset=args.dataset # 数据集标识(如'bcss')
)
train_data_loader = DataLoader(
train_dataset,
batch_size=args.batch_size, # 批大小(默认20)
shuffle=True, # 打乱数据顺序
num_workers=args.num_workers, # 数据加载子进程数(默认10)
pin_memory=False, # 不锁页内存(适用于小批量数据)
drop_last=True # 丢弃最后不足一个batch的数据
)
-
关键细节:
-
数据增强策略:仅使用翻转操作,避免复杂变换干扰分类模型的学习。
-
自定义数据集类 :
Stage1_TrainDataset
需实现图像和标签的加载逻辑(如解析XML或CSV文件)。
-
3.3 优化器配置
max_step = (len(train_dataset) // args.batch_size) * args.max_epoches # 总迭代次数
param_groups = model.get_parameter_groups() # 获取模型参数分组(通常按网络层分组)
optimizer = torchutils.PolyOptimizer(
[
{'params': param_groups[0], 'lr': args.lr, 'weight_decay': args.wt_dec}, # 主干网络(低学习率)
{'params': param_groups[1], 'lr': 2*args.lr, 'weight_decay': 0}, # 中间层(较高学习率)
{'params': param_groups[2], 'lr': 10*args.lr, 'weight_decay': args.wt_dec}, # 分类头(高学习率)
{'params': param_groups[3], 'lr': 20*args.lr, 'weight_decay': 0} # 特殊模块(最高学习率)
],
lr=args.lr,
weight_decay=args.wt_dec,
max_step=max_step # 控制学习率衰减
)
-
关键细节:
-
参数分组:不同网络层(如ResNet38的卷积层、全连接层)使用不同的学习率,分类头通常需要更高学习率以快速适应新任务。
-
Poly学习率衰减 :学习率按公式 lr=base_lr×(1−stepmax_step)powerlr=base_lr×(1−max_stepstep)power 衰减,默认
power=0.9
。
-
3.4 加载预训练权重
if args.weights[-7:] == '.params': # MXNet格式权重(如'init_weights/ilsvrc-cls_rna-a1_cls1000_ep-0001.params')
import network.resnet38d
weights_dict = network.resnet38d.convert_mxnet_to_torch(args.weights) # 转换权重格式
model.load_state_dict(weights_dict, strict=False) # 非严格加载(允许部分参数不匹配)
elif args.weights[-4:] == '.pth': # PyTorch格式权重
weights_dict = torch.load(args.weights)
model.load_state_dict(weights_dict, strict=False)
else:
print('random init') # 随机初始化(无预训练)
-
关键细节:
-
MXNet转换:项目可能基于早期MXNet实现,需将预训练权重转换为PyTorch格式。
-
strict=False
:允许模型结构与权重文件部分不匹配(如分类头维度不同)。
-
3.5 训练循环
model = model.cuda() # 将模型移至GPU
avg_meter = pyutils.AverageMeter('loss', 'avg_ep_EM', 'avg_ep_acc') # 统计训练指标
timer = pyutils.Timer("Session started: ") # 计时器(计算剩余时间)
for ep in range(args.max_epoches): # 遍历每个epoch
model.train()
args.ep_index = ep # 当前epoch索引(可能用于回调)
ep_count = 0 # 当前epoch累计样本数
ep_EM = 0 # 完全匹配(Exact Match)次数
ep_acc = 0 # 累计准确率
for iter, (filename, data, label) in enumerate(train_data_loader): # 遍历每个batch
img = data # 图像数据(未使用filename)
label = label.cuda(non_blocking=True) # 标签移至GPU(异步传输)
# 控制PDA的启用(前3个epoch禁用)
enable_PDA = 1 if ep > 2 else 0
# 前向传播(返回分类输出、特征图、概率)
x, feature, y = model(img.cuda(), enable_PDA)
# 转换为CPU numpy数组以计算指标
prob = y.cpu().data.numpy() # 预测概率(shape=[batch_size, n_class])
gt = label.cpu().data.numpy() # 真实标签(shape=[batch_size, n_class])
# 遍历batch内每个样本计算指标
for num, one in enumerate(prob):
ep_count += 1
pass_cls = np.where(one > 0.5)[0] # 预测标签(概率>0.5的类别)
true_cls = np.where(gt[num] == 1)[0] # 真实标签(one-hot编码中为1的类别)
# 统计Exact Match(完全匹配)
if np.array_equal(pass_cls, true_cls):
ep_EM += 1
# 计算交并比式准确率
acc = compute_acc(pass_cls, true_cls)
ep_acc += acc
# 计算当前batch的平均指标
avg_ep_EM = round(ep_EM / ep_count, 4)
avg_ep_acc = round(ep_acc / ep_count, 4)
# 计算多标签分类损失
loss = F.multilabel_soft_margin_loss(x, label) # x为模型原始输出(未经过sigmoid)
# 更新统计指标
avg_meter.add({
'loss': loss.item(),
'avg_ep_EM': avg_ep_EM,
'avg_ep_acc': avg_ep_acc
})
# 反向传播与优化
optimizer.zero_grad() # 清空梯度
loss.backward() # 计算梯度
optimizer.step() # 更新参数
torch.cuda.empty_cache() # 清理GPU缓存(防止内存泄漏)
# 每100步打印日志并更新Visdom
if (optimizer.global_step) % 100 == 0 and (optimizer.global_step) != 0:
timer.update_progress(optimizer.global_step / max_step) # 更新剩余时间估计
print(
'Epoch:%2d' % (ep),
'Iter:%5d/%5d' % (optimizer.global_step, max_step),
'Loss:%.4f' % (avg_meter.get('loss')),
'avg_ep_EM:%.4f' % (avg_meter.get('avg_ep_EM')),
'avg_ep_acc:%.4f' % (avg_meter.get('avg_ep_acc')),
'lr: %.4f' % (optimizer.param_groups[0]['lr']),
'Fin:%s' % (timer.str_est_finish()),
flush=True
)
# 更新Visdom图表
viz.line(
[avg_meter.pop('loss')],
[optimizer.global_step],
win='loss',
update='append',
opts=dict(title='loss')
)
# 同理更新 'Acc_exact' 和 'Acc' 图表...
# 每epoch后调整PDA的gama参数
if model.gama > 0.65:
model.gama = model.gama * 0.98 # 逐步衰减注意力强度
print('Gama of progressive dropout attention is: ', model.gama)
# 保存最终模型
torch.save(
model.state_dict(),
os.path.join(args.save_folder, 'stage1_checkpoint_trained_on_'+args.dataset+'.pth')
)
-
关键细节:
-
渐进式Dropout注意力(PDA):
-
前3个epoch禁用(
enable_PDA=0
),让模型初步学习基础特征。 -
gama
初始值为1,逐渐衰减(gama *= 0.98
),控制注意力区域的聚焦程度。
-
-
损失函数 :
F.multilabel_soft_margin_loss
结合Sigmoid和交叉熵,适用于多标签分类。 -
指标计算:
-
Exact Match (EM):预测标签与真实标签完全一致的样本比例(严格指标)。
-
IoU式准确率:反映预测与真实标签的重合程度(宽松指标)。
-
-
Visdom集成:实时可视化损失和准确率曲线,便于监控训练状态。
-
4. 测试阶段 test_phase
def test_phase(args):
# 加载生成CAM的模型变体(Net_CAM)
model = getattr(importlib.import_module(args.network), 'Net_CAM')(n_class=args.n_class)
model = model.cuda()
# 加载训练阶段保存的权重
args.weights = os.path.join(args.save_folder, 'stage1_checkpoint_trained_on_'+args.dataset+'.pth')
weights_dict = torch.load(args.weights)
model.load_state_dict(weights_dict, strict=False)
model.eval() # 设置为评估模式(禁用Dropout和BatchNorm的随机性)
# 调用自定义推理函数(评估模型在测试集上的性能)
score = infer(model, args.testroot, args.n_class)
print(score) # 输出评估结果(如mAP、IoU等)
# 可选:保存最终模型(可能包含CAM生成能力)
torch.save(model.state_dict(), ...)
-
关键细节:
-
模型变体 :
Net_CAM
可能修改了网络结构以输出类别激活图(Class Activation Map)。 -
评估指标 :
infer
函数内部可能计算mAP(平均精度)、像素级IoU等指标。 -
严格模式 :
strict=False
允许加载部分权重(如分类头维度不同)。
-
5. 主函数与参数解析
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# 训练参数
parser.add_argument("--batch_size", default=20, type=int)
parser.add_argument("--max_epoches", default=20, type=int)
parser.add_argument("--network", default="network.resnet38_cls", type=str)
parser.add_argument("--lr", default=0.01, type=float)
parser.add_argument("--num_workers", default=10, type=int)
parser.add_argument("--wt_dec", default=5e-4, type=float) # 权重衰减(L2正则化)
# 实验命名与可视化
parser.add_argument("--session_name", default="Stage 1", type=str) # 实验名称(日志标识)
parser.add_argument("--env_name", default="PDA", type=str) # Visdom环境名
parser.add_argument("--model_name", default='PDA', type=str) # 模型保存名称
# 数据集与模型结构
parser.add_argument("--n_class", default=4, type=int) # 类别数(如BCSS为4类)
parser.add_argument("--weights", default='init_weights/ilsvrc-cls_rna-a1_cls1000_ep-0001.params', type=str)
parser.add_argument("--trainroot", default='datasets/BCSS-WSSS/train/', type=str)
parser.add_argument("--testroot", default='datasets/BCSS-WSSS/test/', type=str)
parser.add_argument("--save_folder", default='checkpoints/', type=str)
# PDA参数
parser.add_argument("--init_gama", default=1, type=float) # 初始注意力强度
# 数据集标识
parser.add_argument("--dataset", default='bcss', type=str) # 数据集缩写(影响保存文件名)
args = parser.parse_args()
train_phase(args) # 执行训练
test_phase(args) # 执行测试
-
关键参数说明:
-
--network
:模型定义文件路径(如network.resnet38_cls
对应network/resnet38_cls.py
)。 -
--init_gama
:PDA的初始强度,影响注意力机制的随机丢弃率。 -
--weights
:预训练权重路径(支持MXNet和PyTorch格式)。
-