深度学习代码分析——自用

代码来自:https://github.com/ChuHan89/WSSS-Tissue?tab=readme-ov-file

借助了一些人工智能

1_train_stage1.py

代码功能总览

该代码是弱监督语义分割(WSSS)流程的 Stage1 训练与测试脚本 ,核心任务是通过 多标签分类模型 生成图像级标签,为后续生成伪掩码(Pseudo-Masks)提供基础。代码分为 train_phasetest_phase 两个阶段,支持 渐进式Dropout注意力(PDA)Visdom可视化监控

1. 依赖库导入

复制代码
import os
import numpy as np
import argparse
import importlib
from visdom import Visdom  # 可视化工具

import torch
import torch.nn.functional as F
from torch.backends import cudnn  # CUDA加速
from torch.utils.data import DataLoader
from torchvision import transforms  # 数据预处理
from tool import pyutils, torchutils  # 自定义工具包
from tool.GenDataset import Stage1_TrainDataset  # 自定义数据集类
from tool.infer_fun import infer  # 测试阶段推理函数

cudnn.enabled = True  # 启用CUDA加速(自动优化卷积算法)
  • 关键细节

    • cudnn.enabled=True:启用cuDNN加速,自动选择最优卷积实现。

    • pyutilstorchutils:项目自定义工具模块(包含优化器、计时器等)。

    • Visdom:用于实时可视化训练过程中的损失和准确率曲线。

2. 辅助函数 compute_acc

复制代码
def compute_acc(pred_labels, gt_labels):
    pred_correct_count = 0
    for pred_label in pred_labels:  # 遍历预测标签
        if pred_label in gt_labels:  # 判断是否在真实标签中
            pred_correct_count += 1
    union = len(gt_labels) + len(pred_labels) - pred_correct_count  # 并集大小
    acc = round(pred_correct_count/union, 4)  # 交并比(IoU)式准确率
    return acc
  • 功能 :计算预测标签与真实标签的 交并比准确率(IoU-like Accuracy)。

  • 数学公式

    Acc=预测正确的标签数预测标签数+真实标签数−预测正确的标签数Acc=预测标签数+真实标签数−预测正确的标签数预测正确的标签数

  • 示例

    • 预测标签:[0, 2],真实标签:[2, 3]

    • 正确数:1(标签2),并集:2 + 2 - 1 = 3 → Acc = 1/3 ≈ 0.333

3. 训练阶段 train_phase

3.1 初始化与模型加载
复制代码
def train_phase(args):
    viz = Visdom(env=args.env_name)  # 创建Visdom环境(用于可视化)
    model = getattr(importlib.import_module(args.network), 'Net')(args.init_gama, n_class=args.n_class)
    print(vars(args))  # 打印所有输入参数
  • 关键细节

    • 动态模型加载 :通过 importlib 从字符串 args.network(如 "network.resnet38_cls")动态加载模型类 Net

    • PDA参数args.init_gama 控制渐进式Dropout注意力的初始强度(值越大,注意力区域越集中)。

    • Visdom环境 :通过 env=args.env_name 隔离不同实验的可视化结果。

3.2 数据增强与加载
复制代码
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),  # 50%概率水平翻转
        transforms.RandomVerticalFlip(p=0.5),    # 50%概率垂直翻转
        transforms.ToTensor()                    # 转为Tensor(范围[0,1])
    ]) 
    train_dataset = Stage1_TrainDataset(
        data_path=args.trainroot,  # 训练集路径(如'datasets/BCSS-WSSS/train/')
        transform=transform_train, 
        dataset=args.dataset       # 数据集标识(如'bcss')
    )
    train_data_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,  # 批大小(默认20)
        shuffle=True,                # 打乱数据顺序
        num_workers=args.num_workers,  # 数据加载子进程数(默认10)
        pin_memory=False,             # 不锁页内存(适用于小批量数据)
        drop_last=True                # 丢弃最后不足一个batch的数据
    )
  • 关键细节

    • 数据增强策略:仅使用翻转操作,避免复杂变换干扰分类模型的学习。

    • 自定义数据集类Stage1_TrainDataset 需实现图像和标签的加载逻辑(如解析XML或CSV文件)。

3.3 优化器配置
复制代码
    max_step = (len(train_dataset) // args.batch_size) * args.max_epoches  # 总迭代次数
    param_groups = model.get_parameter_groups()  # 获取模型参数分组(通常按网络层分组)
    optimizer = torchutils.PolyOptimizer(
        [
            {'params': param_groups[0], 'lr': args.lr, 'weight_decay': args.wt_dec},  # 主干网络(低学习率)
            {'params': param_groups[1], 'lr': 2*args.lr, 'weight_decay': 0},         # 中间层(较高学习率)
            {'params': param_groups[2], 'lr': 10*args.lr, 'weight_decay': args.wt_dec},  # 分类头(高学习率)
            {'params': param_groups[3], 'lr': 20*args.lr, 'weight_decay': 0}          # 特殊模块(最高学习率)
        ], 
        lr=args.lr, 
        weight_decay=args.wt_dec, 
        max_step=max_step  # 控制学习率衰减
    )
  • 关键细节

    • 参数分组:不同网络层(如ResNet38的卷积层、全连接层)使用不同的学习率,分类头通常需要更高学习率以快速适应新任务。

    • Poly学习率衰减 :学习率按公式 lr=base_lr×(1−stepmax_step)powerlr=base_lr×(1−max_stepstep​)power 衰减,默认 power=0.9

3.4 加载预训练权重
复制代码
    if args.weights[-7:] == '.params':  # MXNet格式权重(如'init_weights/ilsvrc-cls_rna-a1_cls1000_ep-0001.params')
        import network.resnet38d
        weights_dict = network.resnet38d.convert_mxnet_to_torch(args.weights)  # 转换权重格式
        model.load_state_dict(weights_dict, strict=False)  # 非严格加载(允许部分参数不匹配)
    elif args.weights[-4:] == '.pth':   # PyTorch格式权重
        weights_dict = torch.load(args.weights)
        model.load_state_dict(weights_dict, strict=False)
    else:
        print('random init')  # 随机初始化(无预训练)
  • 关键细节

    • MXNet转换:项目可能基于早期MXNet实现,需将预训练权重转换为PyTorch格式。

    • strict=False:允许模型结构与权重文件部分不匹配(如分类头维度不同)。

3.5 训练循环
复制代码
    model = model.cuda()  # 将模型移至GPU
    avg_meter = pyutils.AverageMeter('loss', 'avg_ep_EM', 'avg_ep_acc')  # 统计训练指标
    timer = pyutils.Timer("Session started: ")  # 计时器(计算剩余时间)

    for ep in range(args.max_epoches):  # 遍历每个epoch
        model.train()
        args.ep_index = ep  # 当前epoch索引(可能用于回调)
        ep_count = 0        # 当前epoch累计样本数
        ep_EM = 0           # 完全匹配(Exact Match)次数
        ep_acc = 0           # 累计准确率

        for iter, (filename, data, label) in enumerate(train_data_loader):  # 遍历每个batch
            img = data  # 图像数据(未使用filename)
            label = label.cuda(non_blocking=True)  # 标签移至GPU(异步传输)

            # 控制PDA的启用(前3个epoch禁用)
            enable_PDA = 1 if ep > 2 else 0

            # 前向传播(返回分类输出、特征图、概率)
            x, feature, y = model(img.cuda(), enable_PDA)

            # 转换为CPU numpy数组以计算指标
            prob = y.cpu().data.numpy()  # 预测概率(shape=[batch_size, n_class])
            gt = label.cpu().data.numpy()  # 真实标签(shape=[batch_size, n_class])

            # 遍历batch内每个样本计算指标
            for num, one in enumerate(prob):
                ep_count += 1
                pass_cls = np.where(one > 0.5)[0]  # 预测标签(概率>0.5的类别)
                true_cls = np.where(gt[num] == 1)[0]  # 真实标签(one-hot编码中为1的类别)

                # 统计Exact Match(完全匹配)
                if np.array_equal(pass_cls, true_cls):
                    ep_EM += 1

                # 计算交并比式准确率
                acc = compute_acc(pass_cls, true_cls)
                ep_acc += acc

            # 计算当前batch的平均指标
            avg_ep_EM = round(ep_EM / ep_count, 4)
            avg_ep_acc = round(ep_acc / ep_count, 4)

            # 计算多标签分类损失
            loss = F.multilabel_soft_margin_loss(x, label)  # x为模型原始输出(未经过sigmoid)

            # 更新统计指标
            avg_meter.add({
                'loss': loss.item(),
                'avg_ep_EM': avg_ep_EM,
                'avg_ep_acc': avg_ep_acc
            })

            # 反向传播与优化
            optimizer.zero_grad()  # 清空梯度
            loss.backward()        # 计算梯度
            optimizer.step()       # 更新参数
            torch.cuda.empty_cache()  # 清理GPU缓存(防止内存泄漏)

            # 每100步打印日志并更新Visdom
            if (optimizer.global_step) % 100 == 0 and (optimizer.global_step) != 0:
                timer.update_progress(optimizer.global_step / max_step)  # 更新剩余时间估计
                print(
                    'Epoch:%2d' % (ep),
                    'Iter:%5d/%5d' % (optimizer.global_step, max_step),
                    'Loss:%.4f' % (avg_meter.get('loss')),
                    'avg_ep_EM:%.4f' % (avg_meter.get('avg_ep_EM')),
                    'avg_ep_acc:%.4f' % (avg_meter.get('avg_ep_acc')),
                    'lr: %.4f' % (optimizer.param_groups[0]['lr']), 
                    'Fin:%s' % (timer.str_est_finish()),
                    flush=True
                )
                # 更新Visdom图表
                viz.line(
                    [avg_meter.pop('loss')],
                    [optimizer.global_step],
                    win='loss',
                    update='append',
                    opts=dict(title='loss')
                )
                # 同理更新 'Acc_exact' 和 'Acc' 图表...

        # 每epoch后调整PDA的gama参数
        if model.gama > 0.65:
            model.gama = model.gama * 0.98  # 逐步衰减注意力强度
        print('Gama of progressive dropout attention is: ', model.gama)

    # 保存最终模型
    torch.save(
        model.state_dict(), 
        os.path.join(args.save_folder, 'stage1_checkpoint_trained_on_'+args.dataset+'.pth')
    )
  • 关键细节

    • 渐进式Dropout注意力(PDA)

      • 前3个epoch禁用(enable_PDA=0),让模型初步学习基础特征。

      • gama 初始值为1,逐渐衰减(gama *= 0.98),控制注意力区域的聚焦程度。

    • 损失函数F.multilabel_soft_margin_loss 结合Sigmoid和交叉熵,适用于多标签分类。

    • 指标计算

      • Exact Match (EM):预测标签与真实标签完全一致的样本比例(严格指标)。

      • IoU式准确率:反映预测与真实标签的重合程度(宽松指标)。

    • Visdom集成:实时可视化损失和准确率曲线,便于监控训练状态。

4. 测试阶段 test_phase

复制代码
def test_phase(args):
    # 加载生成CAM的模型变体(Net_CAM)
    model = getattr(importlib.import_module(args.network), 'Net_CAM')(n_class=args.n_class)
    model = model.cuda()

    # 加载训练阶段保存的权重
    args.weights = os.path.join(args.save_folder, 'stage1_checkpoint_trained_on_'+args.dataset+'.pth')
    weights_dict = torch.load(args.weights)
    model.load_state_dict(weights_dict, strict=False)

    model.eval()  # 设置为评估模式(禁用Dropout和BatchNorm的随机性)

    # 调用自定义推理函数(评估模型在测试集上的性能)
    score = infer(model, args.testroot, args.n_class)
    print(score)  # 输出评估结果(如mAP、IoU等)

    # 可选:保存最终模型(可能包含CAM生成能力)
    torch.save(model.state_dict(), ...)
  • 关键细节

    • 模型变体Net_CAM 可能修改了网络结构以输出类别激活图(Class Activation Map)。

    • 评估指标infer 函数内部可能计算mAP(平均精度)、像素级IoU等指标。

    • 严格模式strict=False 允许加载部分权重(如分类头维度不同)。

5. 主函数与参数解析

复制代码
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # 训练参数
    parser.add_argument("--batch_size", default=20, type=int)
    parser.add_argument("--max_epoches", default=20, type=int)
    parser.add_argument("--network", default="network.resnet38_cls", type=str)
    parser.add_argument("--lr", default=0.01, type=float)
    parser.add_argument("--num_workers", default=10, type=int)
    parser.add_argument("--wt_dec", default=5e-4, type=float)  # 权重衰减(L2正则化)
    
    # 实验命名与可视化
    parser.add_argument("--session_name", default="Stage 1", type=str)  # 实验名称(日志标识)
    parser.add_argument("--env_name", default="PDA", type=str)          # Visdom环境名
    parser.add_argument("--model_name", default='PDA', type=str)        # 模型保存名称
    
    # 数据集与模型结构
    parser.add_argument("--n_class", default=4, type=int)               # 类别数(如BCSS为4类)
    parser.add_argument("--weights", default='init_weights/ilsvrc-cls_rna-a1_cls1000_ep-0001.params', type=str)
    parser.add_argument("--trainroot", default='datasets/BCSS-WSSS/train/', type=str)
    parser.add_argument("--testroot", default='datasets/BCSS-WSSS/test/', type=str)
    parser.add_argument("--save_folder", default='checkpoints/', type=str)
    
    # PDA参数
    parser.add_argument("--init_gama", default=1, type=float)  # 初始注意力强度
    
    # 数据集标识
    parser.add_argument("--dataset", default='bcss', type=str)  # 数据集缩写(影响保存文件名)

    args = parser.parse_args()

    train_phase(args)  # 执行训练
    test_phase(args)   # 执行测试
  • 关键参数说明

    • --network:模型定义文件路径(如 network.resnet38_cls 对应 network/resnet38_cls.py)。

    • --init_gama:PDA的初始强度,影响注意力机制的随机丢弃率。

    • --weights:预训练权重路径(支持MXNet和PyTorch格式)。

相关推荐
乱世刀疤12 分钟前
AI绘画软件Stable Diffusion详解教程(6):文生图、提示词细说与绘图案例
人工智能·ai作画·stable diffusion
niu_sama19 分钟前
[杂学笔记] 封装、继承、多态,堆和栈的区别,堆和栈的区别 ,托管与非托管 ,c++的垃圾回收机制 , 实现一个单例模式 注意事项
c++·笔记·单例模式
AAA小肥杨20 分钟前
深度解析 | 2025 AI新突破,物理信息神经网络(PINN):Nature级顶刊的「科研加速器」,70份源码论文速取!
人工智能·深度学习·神经网络·pinn
Data-Miner24 分钟前
112页精品PPT | DeepSeek行业应用实践报告
人工智能·ai·数字化
业余小程序猿31 分钟前
图像处理中注意力机制的解析与代码详解
笔记
Loving_enjoy44 分钟前
DeepSeek、Grok与ChatGPT:AI三巨头的技术博弈与场景革命
人工智能
windwant1 小时前
神经网络:AI的网络神经
网络·人工智能·神经网络
平凡而伟大(心之所向)1 小时前
一文讲清楚自我学习和深度学习
人工智能·深度学习·机器学习
挣扎与觉醒中的技术人1 小时前
如何本地部署大模型及性能优化指南(附避坑要点)
人工智能·opencv·算法·yolo·性能优化·audiolm