深度学习代码解读——自用

代码来自:GitHub - ChuHan89/WSSS-Tissue

借助了一些人工智能

2_generate_PM.py

功能总结

该代码用于 生成弱监督语义分割(WSSS)所需的伪掩码(Pseudo-Masks),是 Stage2 训练的前置步骤。其核心流程为:

  1. 加载 Stage1 训练好的分类模型(支持 CAM 生成)。

  2. 为不同层次的特征图生成伪掩码 (如 b4_5, b5_2, bn7 对应的不同网络层)。

  3. 保存伪掩码图像,使用调色板将类别标签映射为彩色图像。

代码解析

1. 导入依赖库
复制代码
import os
import torch
import argparse
import importlib
from torch.backends import cudnn
cudnn.enabled = True  # 启用CUDA加速
from tool.infer_fun import create_pseudo_mask  # 自定义函数:生成伪掩码
  • 关键依赖

    • cudnn.enabled = True:启用 cuDNN 加速,优化 GPU 计算性能。

    • create_pseudo_mask:核心函数(用户需参考其实现),负责生成并保存伪掩码。

2. 主函数与参数解析
复制代码
if __name__ == '__main__':
    # 定义命令行参数
    parser = argparse.ArgumentParser()
    parser.add_argument("--weights", default='checkpoints/stage1_checkpoint_trained_on_bcss.pth', type=str)
    parser.add_argument("--network", default="network.resnet38_cls", type=str)
    parser.add_argument("--dataroot", default="datasets/BCSS-WSSS/", type=str)
    parser.add_argument("--dataset", default="bcss", type=str)
    parser.add_argument("--num_workers", default=8, type=int)
    parser.add_argument("--n_class", default=4, type=int)
    args = parser.parse_args()
    print(args)  # 打印参数列表
  • 参数说明

    • --weights:Stage1 训练好的模型权重文件路径(默认指向 BCSS 数据集)。

    • --network:网络结构定义文件(如 network.resnet38_cls)。

    • --dataroot:数据集根目录(包含训练/测试数据)。

    • --dataset:数据集标识(bcssluad)。

    • --n_class:类别数量(BCSS 为 4 类,LUAD 可能不同)。

3. 定义调色板(颜色映射)
复制代码
    if args.dataset == 'luad':
        palette = [0]*15  # 初始化长度为15的列表(每类3个RGB通道)
        palette[0:3] = [205,51,51]    # 类别1:红色
        palette[3:6] = [0,255,0]      # 类别2:绿色
        palette[6:9] = [65,105,225]   # 类别3:蓝色
        palette[9:12] = [255,165,0]   # 类别4:橙色
        palette[12:15] = [255, 255, 255]  # 背景或未标注区域:白色
    elif args.dataset == 'bcss':
        palette = [0]*15
        palette[0:3] = [255, 0, 0]    # 类别1:红色
        palette[3:6] = [0,255,0]      # 类别2:绿色
        palette[6:9] = [0,0,255]      # 类别3:蓝色
        palette[9:12] = [153, 0, 255] # 类别4:紫色
        palette[12:15] = [255, 255, 255]  # 背景:白色
  • 作用:将类别标签映射为 RGB 颜色,用于伪掩码的可视化。

  • 细节

    • 每个类别占 3 个连续位置(RGB 通道)。

    • palette[12:15] 可能表示背景或未标注区域。

    • 不同数据集使用不同的颜色方案(如 BCSS 用紫色表示第4类)。

4. 创建伪掩码保存路径
复制代码
    PMpath = os.path.join(args.dataroot, 'train_PM')  # 路径示例:datasets/BCSS-WSSS/train_PM
    if not os.path.exists(PMpath):
        os.mkdir(PMpath)  # 若目录不存在则创建
  • 目的 :在数据集根目录下创建 train_PM 文件夹,用于保存生成的伪掩码。
5. 加载模型
复制代码
    model = getattr(importlib.import_module("network.resnet38_cls"), 'Net_CAM')(n_class=args.n_class)
    model.load_state_dict(torch.load(args.weights), strict=False)
    model.eval()  # 设置为评估模式(禁用Dropout等随机操作)
    model.cuda()  # 将模型移至GPU
  • 关键步骤

    • 动态加载模型 :从 network.resnet38_cls 模块加载 Net_CAM 类(支持 CAM 生成的变体)。

    • 加载权重 :使用 Stage1 训练好的模型参数(strict=False 允许部分参数不匹配)。

    • 评估模式:关闭 BatchNorm 和 Dropout 的随机性,确保结果一致性。

6. 生成多级伪掩码
复制代码
    ##
    fm = 'b4_5'  # 特征模块名称(可能对应网络中的某个中间层)
    savepath = os.path.join(PMpath, 'PM_' + fm)  # 保存路径:train_PM/PM_b4_5
    if not os.path.exists(savepath):
        os.mkdir(savepath)
    create_pseudo_mask(model, args.dataroot, fm, savepath, args.n_class, palette, args.dataset)

    ## 重复相同流程生成其他层级的伪掩码
    fm = 'b5_2'
    savepath = os.path.join(PMpath, 'PM_' + fm)
    if not os.path.exists(savepath):
        os.mkdir(savepath)
    create_pseudo_mask(model, args.dataroot, fm, savepath, args.n_class, palette, args.dataset)

    ##
    fm = 'bn7'
    savepath = os.path.join(PMpath, 'PM_' + fm)
    if not os.path.exists(savepath):
        os.mkdir(savepath)
    create_pseudo_mask(model, args.dataroot, fm, savepath, args.n_class, palette, args.dataset)
  • 功能 :针对不同特征模块(fm)生成伪掩码,保存到对应子目录。

  • 关键参数

    • fm:特征模块标识,可能对应网络中的不同层(如 ResNet 的 block4block5bottleneck)。

    • create_pseudo_mask:核心函数,推测其功能为:

      1. 加载训练集图像。

      2. 使用模型提取指定层的特征图。

      3. 生成类别激活图(CAM)。

      4. 根据阈值将 CAM 转换为二值伪掩码。

      5. 应用调色板将掩码保存为彩色 PNG 图像。

代码执行示例

复制代码
python generate_pseudo_masks.py \
    --dataset bcss \
    --dataroot datasets/BCSS-WSSS/ \
    --weights checkpoints/stage1_checkpoint_trained_on_bcss.pth
  • 输出 :在 datasets/BCSS-WSSS/train_PM/ 下生成三个子目录:

    • PM_b4_5:基于 b4_5 层特征的伪掩码。

    • PM_b5_2:基于 b5_2 层特征的伪掩码。

    • PM_bn7:基于 bn7 层特征的伪掩码。

总结

该代码是弱监督语义分割流程中 生成多级伪掩码的关键步骤,利用 Stage1 训练的分类模型提取不同层级的特征,生成伪标签供 Stage2 的分割模型训练。通过多级伪掩码的融合,可以提升最终分割结果的精度和鲁棒性。

3_train_stage2.py

功能总结

该代码是弱监督语义分割(WSSS)的 Stage2 训练与测试脚本,核心功能为:

  1. 训练分割模型:基于 DeepLab v3+ 架构,使用 Stage1 生成的伪掩码(Pseudo-Masks)进行监督训练。

  2. 验证与测试:评估模型在验证集和测试集上的性能(如 mIoU、像素准确率等)。

  3. 门控机制(Gate Mechanism):在测试阶段结合 Stage1 的分类结果过滤分割预测,提升精度。

  4. 多任务损失:融合不同层次伪掩码的损失(主伪掩码 + 两种增强版本)。

代码结构

复制代码
# 1. 依赖库导入
import argparse, os, numpy as np
from tqdm import tqdm
import torch
from tool.GenDataset import make_data_loader
from network.sync_batchnorm.replicate import patch_replication_callback
from network.deeplab import *
from tool.loss import SegmentationLosses
from tool.lr_scheduler import LR_Scheduler
from tool.saver import Saver
from tool.summaries import TensorboardSummary
from tool.metrics import Evaluator

# 2. 定义训练器类
class Trainer(object):
    def __init__(self, args): ...  # 初始化模型、数据、优化器等
    def training(self, epoch): ...  # 训练一个epoch
    def validation(self, epoch): ...  # 验证集评估
    def test(self, epoch, Is_GM): ...  # 测试集评估(支持门控机制)
    def load_the_best_checkpoint(self): ...  # 加载最佳模型

# 3. 主函数
def main(): ...  # 解析参数、启动训练

if __name__ == "__main__":
    main()

关键代码解析

1. Trainer 类初始化
复制代码
class Trainer(object):
    def __init__(self, args):
        self.args = args
        # 初始化日志记录与模型保存工具
        self.saver = Saver(args)  # 保存模型检查点
        self.summary = TensorboardSummary('logs')  # TensorBoard日志
        self.writer = self.summary.create_summary()
        # 数据加载
        kwargs = {'num_workers': args.workers, 'pin_memory': False}
        self.train_loader, self.val_loader, self.test_loader = make_data_loader(args, **kwargs)
        # 模型定义(DeepLab v3+)
        self.nclass = args.n_class
        model = DeepLab(
            num_classes=self.nclass,
            backbone=args.backbone,  # 骨干网络(如ResNet)
            output_stride=args.out_stride,  # 输出步长(控制特征图分辨率)
            sync_bn=args.sync_bn,  # 多GPU同步BatchNorm
            freeze_bn=args.freeze_bn  # 冻结BN层参数
        )
        # 优化器配置(分层学习率)
        train_params = [
            {'params': model.get_1x_lr_params(), 'lr': args.lr},  # 骨干网络低学习率
            {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}  # 分类头高学习率
        ]
        optimizer = torch.optim.SGD(
            train_params, 
            momentum=args.momentum,
            weight_decay=args.weight_decay, 
            nesterov=args.nesterov
        )
        # 损失函数(交叉熵或Focal Loss)
        self.criterion = SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        # 评估工具(计算mIoU等指标)
        self.evaluator = Evaluator(self.nclass)
        # 学习率调度(Poly策略)
        self.scheduler = LR_Scheduler(
            args.lr_scheduler, 
            args.lr, 
            args.epochs, 
            len(self.train_loader)
        )
        # 加载Stage1的分类模型(用于门控机制)
        model_stage1 = getattr(importlib.import_module('network.resnet38_cls'), 'Net_CAM')(n_class=4)
        resume_stage1 = 'checkpoints/stage1_checkpoint_trained_on_'+str(args.dataset)+'.pth'
        weights_dict = torch.load(resume_stage1)
        model_stage1.load_state_dict(weights_dict)
        self.model_stage1 = model_stage1.cuda()
        self.model_stage1.eval()  # 固定Stage1模型参数
        # GPU并行化
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)  # 修复多GPU BatchNorm同步问题
            self.model = self.model.cuda()
        # 加载预训练权重(如DeepLab预训练模型)
        if args.resume is not None:
            checkpoint = torch.load(args.resume)
            # 处理分类头权重(微调时保留,否则删除)
            if args.ft:
                self.model.load_state_dict(checkpoint['state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            else:
                del checkpoint['state_dict']['decoder.last_conv.8.weight']
                del checkpoint['state_dict']['decoder.last_conv.8.bias']
                self.model.load_state_dict(checkpoint['state_dict'], strict=False)
        # 初始化最佳mIoU
        self.best_pred = 0.0
2. 训练阶段 training
复制代码
    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)  # 进度条
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            # 加载数据(图像 + 三个伪掩码)
            image, target, target_a, target_b = sample['image'], sample['label'], sample['label_a'], sample['label_b']
            if self.args.cuda:
                image, target, target_a, target_b = image.cuda(), target.cuda(), target_a.cuda(), target_b.cuda()
            # 调整学习率
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            # 前向传播
            output = self.model(image)
            # 添加额外通道处理类别4(背景或忽略类)
            one = torch.ones((output.shape[0],1,224,224)).cuda()
            output = torch.cat([output, (100 * one * (target==4).unsqueeze(dim=1)], dim=1)
            # 计算多任务损失(主伪掩码 + 两种增强版本)
            loss_o = self.criterion(output, target)
            loss_a = self.criterion(output, target_a)
            loss_b = self.criterion(output, target_b)
            loss = 0.6*loss_o + 0.2*loss_a + 0.2*loss_b
            # 反向传播
            loss.backward()
            self.optimizer.step()
            # 统计损失
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            # 记录TensorBoard日志
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)
        # 输出epoch总结
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)
3. 验证阶段 validation
复制代码
    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample[0]['image'], sample[0]['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            # 转换为CPU numpy数组
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # 处理类别4(设为忽略类)
            pred[target==4] = 4
            # 更新评估指标
            self.evaluator.add_batch(target, pred)
        # 计算并记录指标
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        ious = self.evaluator.Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        # 输出结果
        print('Validation:')
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        # 保存最佳模型
        if mIoU > self.best_pred:
            self.best_pred = mIoU
            self.saver.save_checkpoint({
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict()
            }, 'stage2_checkpoint_trained_on_'+self.args.dataset+'.pth')
4. 测试阶段 test(含门控机制)
复制代码
    def test(self, epoch, Is_GM):
        self.load_the_best_checkpoint()  # 加载最佳模型
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.test_loader, desc='\r')
        for i, sample in enumerate(tbar):
            image, target = sample[0]['image'], sample[0]['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
                # 门控机制:利用Stage1的分类结果过滤分割预测
                if Is_GM:
                    _, y_cls = self.model_stage1.forward_cam(image)  # Stage1的分类输出
                    y_cls = y_cls.cpu().data
                    pred_cls = (y_cls > 0.1)  # 类别存在性判断(阈值0.1)
            # 应用门控机制
            pred = output.data.cpu().numpy()
            if Is_GM:
                pred = pred * pred_cls.unsqueeze(dim=2).unsqueeze(dim=3).numpy()
            # 处理类别4
            pred = np.argmax(pred, axis=1)
            pred[target==4] = 4
            self.evaluator.add_batch(target, pred)
        # 计算并输出指标
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        print('Test:')
        print("Acc:{}, Acc_class:{}, mIoU:{}".format(Acc, Acc_class, mIoU))
5. 主函数 main
复制代码
def main():
    # 解析命令行参数
    parser = argparse.ArgumentParser(description="WSSS Stage2")
    # 模型结构参数
    parser.add_argument('--backbone', default='resnet', choices=['resnet', 'xception', 'drn', 'mobilenet'])
    parser.add_argument('--out-stride', type=int, default=16)  # 输出步长(控制特征图下采样率)
    parser.add_argument('--Is_GM', type=bool, default=True)  # 是否启用门控机制
    # 数据集参数
    parser.add_argument('--dataroot', default='datasets/BCSS-WSSS/')
    parser.add_argument('--dataset', default='bcss')
    parser.add_argument('--n_class', type=int, default=4)
    # 训练超参数
    parser.add_argument('--epochs', type=int, default=30)
    parser.add_argument('--batch-size', type=int, default=20)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--lr-scheduler', default='poly', choices=['poly', 'step', 'cos'])
    # 其他配置
    parser.add_argument('--gpu-ids', default='0')  # 指定使用的GPU
    parser.add_argument('--resume', default='init_weights/deeplab-resnet.pth.tar')  # 预训练权重
    args = parser.parse_args()
    
    # 配置CUDA
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
    # 自动设置SyncBN
    if args.sync_bn is None:
        args.sync_bn = True if args.cuda and len(args.gpu_ids) > 1 else False
    
    # 初始化训练器并启动训练
    trainer = Trainer(args)
    for epoch in range(trainer.args.epochs):
        trainer.training(epoch)
        if epoch % args.eval_interval == 0:
            trainer.validation(epoch)
    # 最终测试
    trainer.test(epoch, args.Is_GM)
    trainer.writer.close()

关键设计解析

  1. 多任务损失

    • 目标 :同时优化主伪掩码(target)及其两种增强版本(target_a, target_b),提升模型对不同噪声伪标签的鲁棒性。

    • 权重分配 :主损失占60%,增强损失各占20%(0.6*loss_o + 0.2*loss_a + 0.2*loss_b)。

  2. 门控机制(Gate Mechanism)

    • 作用:在测试阶段,利用 Stage1 的分类结果过滤分割预测,仅保留分类模型认为存在的类别。

    • 实现:若 Stage1 对某类别的预测概率 > 0.1,则保留该类的分割结果,否则置零。

  3. 类别4处理

    • 背景或忽略类 :在标签中,类别4可能表示背景或未标注区域,预测时直接继承真实标签的值(pred[target==4] = 4),避免错误优化。
  4. 模型初始化

    • 预训练权重 :加载 DeepLab 在 ImageNet 上的预训练权重(init_weights/deeplab-resnet.pth.tar),加速收敛。

    • 分层学习率 :骨干网络使用较低学习率(args.lr),分类头使用更高学习率(args.lr * 10)。

运行示例

复制代码
python train_stage2.py \
    --dataset bcss \
    --dataroot datasets/BCSS-WSSS/ \
    --backbone resnet \
    --Is_GM True \
    --batch-size 20 \
    --epochs 30

总结

该代码实现了弱监督语义分割的第二阶段训练,通过多任务损失融合多级伪标签,结合门控机制提升测试精度,最终生成高精度分割模型。训练过程支持多GPU加速、Poly学习率调度及多种评估指标监控,适用于医学图像(如BCSS)或自然场景图像的分割任务。

相关推荐
萧鼎13 分钟前
深度探索 Py2neo:用 Python 玩转图数据库 Neo4j
数据库·python·neo4j
华子w90892585929 分钟前
基于 Python Django 和 Spark 的电力能耗数据分析系统设计与实现7000字论文实现
python·spark·django
kikikidult39 分钟前
(2025.07)解决——ubuntu20.04系统开机黑屏,左上角光标闪烁
笔记·ubuntu
风铃喵游41 分钟前
让大模型调用MCP服务变得超级简单
前端·人工智能
旷世奇才李先生1 小时前
Pillow 安装使用教程
深度学习·microsoft·pillow
Rockson1 小时前
使用Ruby接入实时行情API教程
javascript·python
booooooty1 小时前
基于Spring AI Alibaba的多智能体RAG应用
java·人工智能·spring·多智能体·rag·spring ai·ai alibaba
近津薪荼1 小时前
初学者关于数据在内存中的储存的笔记
笔记
PyAIExplorer1 小时前
基于 OpenCV 的图像 ROI 切割实现
人工智能·opencv·计算机视觉
风口猪炒股指标1 小时前
技术分析、超短线打板模式与情绪周期理论,在市场共识的形成、分歧、瓦解过程中缘起性空的理解
人工智能·博弈论·群体博弈·人生哲学·自我引导觉醒