ISIC2018数据集训练框架讲解

ISIC2018

ISIC 2018 (International Skin Imaging Collaboration 2018) 数据集主要包含大量经过专业皮肤科医生标注的皮肤镜(dermoscopy)图像,涵盖了多种皮肤病变类型。在分割任务(Task 1)中,提供了约 2594 张图像及其对应的像素级标签,用于训练和验证算法对皮肤病变区域进行精确分割的性能。

数据集通常以原始图像(如 PNG 格式)和对应的二值分割掩码(标记病变区域)形式提供。但在实际研究中,为了方便处理或满足特定框架的需求,常常需要转换成 .npy 或 .nii.gz 格式。

npy: NumPy 的原生二进制数组格式,适合高效存储和直接加载到 Python 深度学习框架中。

nii.gz: 常用于医学影像的 NIfTI 格式压缩文件,能保存图像数据及其空间信息,适用于兼容医学影像处理工具链的场景。

这是其中一个数据的图像以及标注:

代码讲解

train.py

python 复制代码
import os
import torch
import math
import visdom
import torch.utils.data as Data
import argparse
import numpy as np
import sys
from tqdm import tqdm
import torch.nn as nn

from distutils.version import LooseVersion
from Datasets.ISIC2018 import ISIC2018_dataset
from utils.transform import ISIC2018_transform, ISIC2018_transform_320, ISIC2018_transform_newdata

from Models.compare_networks.unet import unet
from Models.compare_networks.AttUnet import AttUNet


from utils.dice_loss import get_soft_label, val_dice_isic, SoftDiceLoss
from utils.dice_loss import Intersection_over_Union_isic
from utils.dice_loss_github import SoftDiceLoss_git, CrossentropyND

from utils.evaluation import AverageMeter
from utils.binary import assd, dc, jc, precision, sensitivity, specificity, F1, ACC
from torch.optim import lr_scheduler

from time import *
from matplotlib import pyplot as plt

Test_Model = {
              "unet":unet,
              "AttUNet":AttUNet
             }
             
             
Test_Dataset = {'ISIC2018': ISIC2018_dataset}

Test_Transform = {'A': ISIC2018_transform, 'B':ISIC2018_transform_320, "C":ISIC2018_transform_newdata}

criterion = "loss_D"  # loss_A-->SoftDiceLoss;  loss_B-->softdice;  loss_C--> CE + softdice;   loss_D--> BCEWithLogitsLoss


class Logger(object):
    def __init__(self,logfile):
        self.terminal = sys.stdout
        self.log = open(logfile, "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)  

    def flush(self):
        pass 
        
        
def train(train_loader, model, criterion, scheduler, optimizer, args, epoch):
    losses = AverageMeter()
    # current_loss_f = "CE_softdice"       # softdice or CE_softdice
    
    model.train()
    for step, (x, y) in tqdm(enumerate(train_loader), total=len(train_loader)):
        image = x.float().cuda()                                   
        target = y.float().cuda()                                  

        output = model(image)                                     
        target_soft_a = get_soft_label(target, args.num_classes)   
        target_soft = target_soft_a.permute(0, 3, 1, 2)           

        ca_soft_dice_loss = SoftDiceLoss()
        soft_dice_loss = SoftDiceLoss_git(batch_dice=True, dc_log=False)
        soft_dice_loss2 = SoftDiceLoss_git(batch_dice=False, dc_log=False)
        soft_dice_loss3 = SoftDiceLoss_git(batch_dice=True, dc_log=True)
        CE_loss_F = CrossentropyND()
        
        if criterion == "loss_A":
            loss_ave, loss_lesion = ca_soft_dice_loss(output, target_soft_a, args.num_classes)     
            loss = loss_ave
        
        if criterion == "loss_B":
            dice_loss = soft_dice_loss(output, target_soft)      
            loss = dice_loss
        
        if criterion == "loss_C":
            dice_loss = soft_dice_loss2(output, target_soft)    
            ce_loss = CE_loss_F(output, target)
            loss = dice_loss + ce_loss  
            
        if criterion == "loss_D":
            dice_loss = nn.BCEWithLogitsLoss()
            loss = dice_loss(output, target_soft)
                 
        loss = loss
        losses.update(loss.data, image.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if step % (math.ceil(float(len(train_loader.dataset))/args.batch_size)) == 0:
                   print('current lr: {} | Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {losses.avg:.6f}'.format(
                   optimizer.state_dict()['param_groups'][0]['lr'],
                   epoch, step * len(image), len(train_loader.dataset),
                   100. * step / len(train_loader), losses=losses))
                
    print('The average loss:{losses.avg:.4f}'.format(losses=losses))
    return losses.avg


def valid_isic(valid_loader, model, criterion, optimizer, args, epoch, best_score, val_acc_log):
    isic_Jaccard = []
    isic_dc = []

    model.eval()
    for step, (t, k) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
        image = t.float().cuda()
        target = k.float().cuda()

        output = model(image)                                             # model output
        output_dis = torch.max(output, 1)[1].unsqueeze(dim=1)
        output_dis_test = output_dis.permute(0, 2, 3, 1).float()
        target_test = target.permute(0, 2, 3, 1).float()
        isic_b_Jaccard = jc(output_dis_test.cpu().numpy(), target_test.cpu().numpy())
        isic_b_dc = dc(output_dis_test.cpu().numpy(), target_test.cpu().numpy())
        isic_Jaccard.append(isic_b_Jaccard)
        isic_dc.append(isic_b_dc)

    isic_Jaccard_mean = np.average(isic_Jaccard)
    isic_dc_mean = np.average(isic_dc)
    net_score = isic_Jaccard_mean + isic_dc_mean
        
    print('The ISIC Dice score: {dice: .4f}; '
          'The ISIC JC score: {jc: .4f}'.format(
           dice=isic_dc_mean, jc=isic_Jaccard_mean))
           
    with open(val_acc_log, 'a') as vlog_file:
        line = "{} | {dice: .4f} | {jc: .4f}".format(epoch, dice=isic_dc_mean, jc=isic_Jaccard_mean)
        vlog_file.write(line+'\n')

    if net_score > max(best_score):
        best_score.append(net_score)
        print(best_score)
        modelname = args.ckpt + '/' + 'best_score' + '_' + args.data + '_checkpoint.pth.tar'
        print('the best model will be saved at {}'.format(modelname))
        state = {'epoch': epoch, 'state_dict': model.state_dict(), 'opt_dict': optimizer.state_dict()}
        torch.save(state, modelname)

    return isic_Jaccard_mean, isic_dc_mean, net_score


def test_isic(test_loader, model, num_para, args, test_acc_log):
    isic_dice = []
    isic_iou = []
   # isic_assd = []
    isic_acc = []
    isic_sensitive = []
    isic_specificy = []
    isic_precision = []
    isic_f1_score = []
    isic_Jaccard_M = []
    isic_Jaccard_N = []
    isic_Jaccard = []
    isic_dc = []
    infer_time = []
    
    modelname = args.ckpt + '/' + 'best_score' + '_' + args.data + '_checkpoint.pth.tar'
    if os.path.isfile(modelname):
        print("=> Loading checkpoint '{}'".format(modelname))
        checkpoint = torch.load(modelname)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> Loaded saved the best model at (epoch {})".format(checkpoint['epoch']))
    else:
        print("=> No checkpoint found at '{}'".format(modelname))

    model.eval()
    for step, (name, img, lab) in tqdm(enumerate(test_loader), total=len(test_loader)):
        image = img.float().cuda()
        target = lab.float().cuda() # [batch, 1, 224, 320]
        
        begin_time = time()
        output = model(image)
        end_time = time()
        pred_time = end_time - begin_time
        infer_time.append(pred_time)
        
        output_dis = torch.max(output, 1)[1].unsqueeze(dim=1)
        output_dis_test = output_dis.permute(0, 2, 3, 1).float()
        target_test = target.permute(0, 2, 3, 1).float()
        output_soft = get_soft_label(output_dis, 2) 
        target_soft = get_soft_label(target, 2)
        
        msk = target_test.squeeze(0).cpu().detach().numpy()
        out = output_dis_test.squeeze(0).cpu().detach().numpy()
        save_imgs(img, msk, out, step, '/xujiheng/ISICdemo/UNet-ISIC2018/outputflod1/fold1/')

        label_arr = np.squeeze(target_soft.cpu().numpy()).astype(np.uint8)
        output_arr = np.squeeze(output_soft.cpu().byte().numpy()).astype(np.uint8)

        isic_b_dice = val_dice_isic(output_soft, target_soft, 2)                                         # the dice
        isic_b_iou = Intersection_over_Union_isic(output_dis_test, target_test, 1)                       # the iou
        # isic_b_asd = assd(output_arr[:, :, 1], label_arr[:, :, 1])                                     # the assd
        isic_b_acc = ACC(output_dis_test.cpu().numpy(), target_test.cpu().numpy())                       # the accuracy
        isic_b_sensitive = sensitivity(output_dis_test.cpu().numpy(), target_test.cpu().numpy())         # the sensitivity
        isic_b_specificy = specificity(output_dis_test.cpu().numpy(), target_test.cpu().numpy())         # the specificity
        isic_b_precision = precision(output_dis_test.cpu().numpy(), target_test.cpu().numpy())           # the precision
        isic_b_f1_score = F1(output_dis_test.cpu().numpy(), target_test.cpu().numpy())                   # the F1
        isic_b_Jaccard_m = jc(output_arr[:, :, 1], label_arr[:, :, 1])                                   # the Jaccard melanoma
        isic_b_Jaccard_n = jc(output_arr[:, :, 0], label_arr[:, :, 0])                                   # the Jaccard no-melanoma
        isic_b_Jaccard = jc(output_dis_test.cpu().numpy(), target_test.cpu().numpy())
        isic_b_dc = dc(output_dis_test.cpu().numpy(), target_test.cpu().numpy())
        
        dice_np = isic_b_dice.data.cpu().numpy()
        iou_np = isic_b_iou.data.cpu().numpy()
       
        isic_dice.append(dice_np)
        isic_iou.append(iou_np)
       # isic_assd.append(isic_b_asd)
        isic_acc.append(isic_b_acc)
        isic_sensitive.append(isic_b_sensitive)
        isic_specificy.append(isic_b_specificy)
        isic_precision.append(isic_b_precision)
        isic_f1_score.append(isic_b_f1_score)
        isic_Jaccard_M.append(isic_b_Jaccard_m)
        isic_Jaccard_N.append(isic_b_Jaccard_n)
        isic_Jaccard.append(isic_b_Jaccard)
        isic_dc.append(isic_b_dc)
        

    all_time = np.sum(infer_time)
    isic_dice_mean = np.average(isic_dice)
    isic_dice_std = np.std(isic_dice)

    isic_iou_mean = np.average(isic_iou)
    isic_iou_std = np.std(isic_iou)

   # isic_assd_mean = np.average(isic_assd)
   # isic_assd_std = np.std(isic_assd)
      
    isic_acc_mean = np.average(isic_acc)
    isic_acc_std = np.std(isic_acc)
    
    isic_sensitive_mean = np.average(isic_sensitive)
    isic_sensitive_std = np.std(isic_sensitive)
    
    isic_specificy_mean = np.average(isic_specificy)
    isic_specificy_std = np.std(isic_specificy)
    
    isic_precision_mean = np.average(isic_precision)
    isic_precision_std = np.std(isic_precision)
    
    isic_f1_score_mean = np.average(isic_f1_score)
    iisic_f1_score_std = np.std(isic_f1_score)
    
    isic_Jaccard_M_mean = np.average(isic_Jaccard_M)
    isic_Jaccard_M_std = np.std(isic_Jaccard_M)
    
    isic_Jaccard_N_mean = np.average(isic_Jaccard_N)
    isic_Jaccard_N_std = np.std(isic_Jaccard_N)
    
    isic_Jaccard_mean = np.average(isic_Jaccard)
    isic_Jaccard_std = np.std(isic_Jaccard)
    
    isic_dc_mean = np.average(isic_dc)
    isic_dc_std = np.std(isic_dc)

    print('The ISIC mean dice: {isic_dice_mean: .4f}; The ISIC dice std: {isic_dice_std: .4f}'.format(
           isic_dice_mean=isic_dice_mean, isic_dice_std=isic_dice_std))
    print('The ISIC mean IoU: {isic_iou_mean: .4f}; The ISIC IoU std: {isic_iou_std: .4f}'.format(
           isic_iou_mean=isic_iou_mean, isic_iou_std=isic_iou_std))
   # print('The ISIC mean assd: {isic_assd_mean: .4f}; The ISIC assd std: {isic_assd_std: .4f}'.format(
   #        isic_assd_mean=isic_assd_mean, isic_assd_std=isic_assd_std))
    print('The ISIC mean ACC: {isic_acc_mean: .4f}; The ISIC ACC std: {isic_acc_std: .4f}'.format(
           isic_acc_mean=isic_acc_mean, isic_acc_std=isic_acc_std))
    print('The ISIC mean sensitive: {isic_sensitive_mean: .4f}; The ISIC sensitive std: {isic_sensitive_std: .4f}'.format(
           isic_sensitive_mean=isic_sensitive_mean, isic_sensitive_std=isic_sensitive_std)) 
    print('The ISIC mean specificy: {isic_specificy_mean: .4f}; The ISIC specificy std: {isic_specificy_std: .4f}'.format(
           isic_specificy_mean=isic_specificy_mean, isic_specificy_std=isic_specificy_std))
    print('The ISIC mean precision: {isic_precision_mean: .4f}; The ISIC precision std: {isic_precision_std: .4f}'.format(
           isic_precision_mean=isic_precision_mean, isic_precision_std=isic_precision_std))
    print('The ISIC mean f1_score: {isic_f1_score_mean: .4f}; The ISIC f1_score std: {iisic_f1_score_std: .4f}'.format(
           isic_f1_score_mean=isic_f1_score_mean, iisic_f1_score_std=iisic_f1_score_std))
    print('The ISIC mean Jaccard_M: {isic_Jaccard_M_mean: .4f}; The ISIC Jaccard_M std: {isic_Jaccard_M_std: .4f}'.format(
           isic_Jaccard_M_mean=isic_Jaccard_M_mean, isic_Jaccard_M_std=isic_Jaccard_M_std))
    print('The ISIC mean Jaccard_N: {isic_Jaccard_N_mean: .4f}; The ISIC Jaccard_N std: {isic_Jaccard_N_std: .4f}'.format(
           isic_Jaccard_N_mean=isic_Jaccard_N_mean, isic_Jaccard_N_std=isic_Jaccard_N_std))
    print('The ISIC mean Jaccard: {isic_Jaccard_mean: .4f}; The ISIC Jaccard std: {isic_Jaccard_std: .4f}'.format(
           isic_Jaccard_mean=isic_Jaccard_mean, isic_Jaccard_std=isic_Jaccard_std))
    print('The ISIC mean dc: {isic_dc_mean: .4f}; The ISIC dc std: {isic_dc_std: .4f}'.format(
           isic_dc_mean=isic_dc_mean, isic_dc_std=isic_dc_std))
    print('The inference time: {time: .4f}'.format(time=all_time))
    print("Number of trainable parameters {0} in Model {1}".format(num_para, args.id))
    
    
def save_imgs(img, msk, msk_pred, i, save_path,  threshold=0.5, test_data_name=None):
    img = img.squeeze(0).permute(1,2,0).detach().cpu().numpy()
    img = img / 255. if img.max() > 1.1 else img

    plt.figure(figsize=(7,15))

    plt.subplot(3,1,1)
    plt.imshow(img)
    plt.axis('off')

    plt.subplot(3,1,2)
    plt.imshow(msk.squeeze(2), cmap= 'gray')
    plt.axis('off')

    plt.subplot(3,1,3)
    plt.imshow(msk_pred.squeeze(2), cmap = 'gray')
    plt.axis('off')

    if test_data_name is not None:
        save_path = save_path + test_data_name + '_'
    plt.savefig(save_path + str(i) +'.png')
    plt.close()    
    

def main(args, val_acc_log, test_acc_log):
    best_score = [0]
    start_epoch = args.start_epoch
    print('loading the {0},{1},{2} dataset ...'.format('train', 'validation', 'test'))
    trainset = Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='train', 
                                       with_name=False, transform=Test_Transform[args.transform])
    validset = Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='validation',
                                       with_name=False, transform=Test_Transform[args.transform])
    testset =  Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='test',
                                       with_name=True, transform=Test_Transform[args.transform])

    trainloader = Data.DataLoader(dataset=trainset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=6)
    validloader = Data.DataLoader(dataset=validset, batch_size=1, shuffle=False, pin_memory=True, num_workers=6)
    testloader = Data.DataLoader(dataset=testset, batch_size=1, shuffle=False, pin_memory=True, num_workers=6)
    print('Loading is done\n')

    args.num_input = 3
    args.num_classes = 2  #二分类(前景、背景)
    args.out_size = (224, 320)

    
    if args.id =="AttUNet":
        model = AttUNet(in_channel=3, out_channel=2)
    else:
        model = Test_Model[args.id](classes=2, channels=3)

    model = model.cuda()

    print("------------------------------------------")
    print("Network Architecture of Model {}:".format(args.id))
    num_para = 0
    for name, param in model.named_parameters():
        num_mul = 1
        for x in param.size():
            num_mul *= x
        num_para += num_mul
    print(model)
    print("Number of trainable parameters {0} in Model {1}".format(num_para, args.id))
    print("------------------------------------------")

    # Define optimizers and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr_rate, weight_decay=args.weight_decay)    
   # scheduler = lr_scheduler.StepLR(optimizer, step_size=70, gamma=0.5)                                    # lr_1
   # scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min = 0.0001)     # lr_2
    scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min = 0.00001)    # lr_3
   # scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=1, eta_min = 0.00001)  # lr_4

    # resume
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['opt_dict'])
            print("=> Loaded checkpoint (epoch {})".format(checkpoint['epoch']))
        else:
            print("=> No checkpoint found at '{}'".format(args.resume))

    print("Start training ...")
    for epoch in range(start_epoch+1, args.epochs + 1):
        scheduler.step()
        train_avg_loss = train(trainloader, model, criterion, scheduler, optimizer, args, epoch)
        isic_Jaccard_mean, isic_dc_mean, net_score = valid_isic(validloader, model, criterion, optimizer, args, epoch, best_score, val_acc_log)
        if epoch > args.particular_epoch:
            if epoch % args.save_epochs_steps == 0:
                filename = args.ckpt + '/' + str(epoch) + '_' + args.data + '_checkpoint.pth.tar'
                print('the model will be saved at {}'.format(filename))
                state = {'epoch': epoch, 'state_dict': model.state_dict(), 'opt_dict': optimizer.state_dict()}
                torch.save(state, filename)

    print('Training Done! Start testing')
    if args.data == 'ISIC2018':
        test_isic(testloader, model, num_para, args, test_acc_log)
    print('Testing Done!')
    
    
if __name__ == '__main__':

    
    os.environ['CUDA_VISIBLE_DEVICES']= '0'                                                 # gpu-id
    
    assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), 'PyTorch>=0.4.0 is required'
    parser = argparse.ArgumentParser(description='Comprehensive attention network for biomedical Dataset')
    
    parser.add_argument('--id', default="unet",
                        help='')                                                   # Select a loaded model name

    # Path related arguments
    parser.add_argument('--root_path', default='/xujiheng/ISICdemo/ISIC2018_npy_all_224_320',
                        help='root directory of data')                                      # The folder where the numpy data set is stored
    parser.add_argument('--ckpt', default='/xujiheng/ISICdemo/UNet-ISIC2018/saved_models/',
                        help='folder to output checkpoints')                                # The folder in which the trained model is saved
    parser.add_argument('--transform', default='C', type=str,
                        help='which ISIC2018_transform to choose')                         
    parser.add_argument('--data', default='ISIC2018', help='choose the dataset')            
    parser.add_argument('--out_size', default=(224, 320), help='the output image size')
    parser.add_argument('--val_folder', default='folder3', type=str,
                        help='folder1、folder2、folder3、folder4、folder5')                 # five-fold cross-validation

    # optimization related arguments
    parser.add_argument('--epochs', type=int, default=15, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--start_epoch', default=0, type=int,
                        help='epoch to start training. useful if continue from a checkpoint')
    parser.add_argument('--batch_size', type=int, default=10, metavar='N',              
                        help='input batch size for training (default: 12)')                 # batch_size
    parser.add_argument('--lr_rate', type=float, default=1e-4, metavar='LR',
                        help='learning rate (default: 0.001)')                              
    parser.add_argument('--num_classes', default=2, type=int,
                        help='number of classes')
    parser.add_argument('--num_input', default=3, type=int,
                        help='number of input image for each patient')
    parser.add_argument('--weight_decay', default=1e-8, type=float, help='weights regularizer')
    parser.add_argument('--particular_epoch', default=30, type=int,
                        help='after this number, we will save models more frequently')
    parser.add_argument('--save_epochs_steps', default=300, type=int,
                        help='frequency to save models after a particular number of epochs')
    parser.add_argument('--resume', default='',
                        help='the checkpoint that resumes from')

    args = parser.parse_args()
    
    args.ckpt = os.path.join(args.ckpt, args.data, args.val_folder, args.id + "_{}".format(criterion))    
    if not os.path.isdir(args.ckpt):
        os.makedirs(args.ckpt)
    logfile = os.path.join(args.ckpt,'{}_{}_{}.txt'.format(args.val_folder, args.id, criterion))          # path of the training log
    sys.stdout = Logger(logfile)  
    
    val_acc_log = os.path.join(args.ckpt,'val_acc_{}_{}_{}.txt'.format(args.val_folder, args.id, criterion))   
    test_acc_log = os.path.join(args.ckpt,'test_acc_{}_{}_{}.txt'.format(args.val_folder, args.id, criterion))    
    
    print('Models are saved at %s' % (args.ckpt))
    print("Input arguments:")
    for key, value in vars(args).items():
        print("{:16} {}".format(key, value))

    if args.start_epoch > 1:
        args.resume = args.ckpt + '/' + str(args.start_epoch) + '_' + args.data + '_checkpoint.pth.tar'

    main(args, val_acc_log, test_acc_log)

下面我将分模块进行讲解。

官方 / 第三方库

python 复制代码
import os                          # Python 标准库
import torch                       # PyTorch 官方
import math                        # Python 标准库
import visdom                      # 第三方可视化工具(Facebook 开源)
import torch.utils.data as Data    # PyTorch 官方
import argparse                    # Python 标准库
import numpy as np                 # 第三方科学计算库
import sys                         # Python 标准库
from tqdm import tqdm              # 第三方进度条库
import torch.nn as nn              # PyTorch 官方

from distutils.version import LooseVersion  # Python 标准库(用于版本比较)

from torch.optim import lr_scheduler       # PyTorch 官方

from time import *                # Python 标准库
from matplotlib import pyplot as plt       # 第三方绘图库

项目中的 .py 文件

位于 Datasets/、Models/、utils/下。

python 复制代码
# 数据集类(你自己定义的 ISIC2018 数据加载器)
from Datasets.ISIC2018 import ISIC2018_dataset

# 数据增强/预处理变换(你自己写的 transform 函数)
from utils.transform import ISIC2018_transform, ISIC2018_transform_320, ISIC2018_transform_newdata

# 模型结构(你自己实现或集成的 U-Net 及 Attention U-Net)
from Models.compare_networks.unet import unet
from Models.compare_networks.AttUnet import AttUNet

# 自定义损失函数和评估指标
from utils.dice_loss import get_soft_label, val_dice_isic, SoftDiceLoss
from utils.dice_loss import Intersection_over_Union_isic
from utils.dice_loss_github import SoftDiceLoss_git, CrossentropyND

# 工具类:平均值记录器
from utils.evaluation import AverageMeter

# 二值分割评价指标(如 Dice、Jaccard、ASSD 等)
from utils.binary import assd, dc, jc, precision, sensitivity, specificity, F1, ACC

下面就这些来逐一讲解代码。

UNet-ISIC2018/Datasets/ISIC2018.py

python 复制代码
import os
import PIL
import torch
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

from os import listdir
from os.path import join
from PIL import Image
from utils.transform import itensity_normalize
from torch.utils.data.dataset import Dataset


class ISIC2018_dataset(Dataset):
    def __init__(self, dataset_folder='/ISIC2018_Task1_npy_all',
                 folder='folder0', train_type='train', with_name=False, transform=None):
        self.transform = transform
        self.train_type = train_type
        self.with_name = with_name
        self.folder_file = './Datasets/' + folder

        if self.train_type in ['train', 'validation', 'test']:
            # this is for cross validation
            with open(join(self.folder_file, self.folder_file.split('/')[-1] + '_' + self.train_type + '.list'),
                      'r') as f:
                self.image_list = f.readlines()
            self.image_list = [item.replace('\n', '') for item in self.image_list]
            self.folder = [join(dataset_folder, 'image', x) for x in self.image_list]
            self.mask = [join(dataset_folder, 'label', x.split('.')[0] + '_segmentation.npy') for x in self.image_list]
            # self.folder = sorted([join(dataset_folder, self.train_type, 'image', x) for x in
            #                       listdir(join(dataset_folder, self.train_type, 'image'))])
            # self.mask = sorted([join(dataset_folder, self.train_type, 'label', x) for x in
            #                     listdir(join(dataset_folder, self.train_type, 'label'))])
        else:
            print("Choosing type error, You have to choose the loading data type including: train, validation, test")

        assert len(self.folder) == len(self.mask)

    def __getitem__(self, item: int):
        image = np.load(self.folder[item])
        label = np.load(self.mask[item])
        name = self.folder[item].split('/')[-1]

        sample = {'image': image, 'label': label}

        if self.transform is not None:
            # TODO: transformation to argument datasets
            sample = self.transform(sample, self.train_type)
            
        if self.with_name:
            return name, sample['image'], sample['label']    
        else:
            return sample['image'], sample['label']

    def __len__(self):
        return len(self.folder)
        

主数据集类 ISIC2018_dataset

类定义与初始化 init

python 复制代码
class ISIC2018_dataset(Dataset):
    def __init__(self, dataset_folder='/ISIC2018_Task1_npy_all',
                 folder='folder0', train_type='train', with_name=False, transform=None):
        self.transform = transform
        self.train_type = train_type
        self.with_name = with_name
        self.folder_file = './Datasets/' + folder
参数 含义
dataset_folder 存放 .npy 图像和标签的根目录
folder 交叉验证划分文件夹名,如 'folder3'
train_type 数据类型:'train' / 'validation' / 'test'
with_name 是否返回图像文件名(测试时需要)
transform 数据增强/预处理函数
python 复制代码
if self.train_type in ['train', 'validation', 'test']:
            # this is for cross validation
            with open(join(self.folder_file, self.folder_file.split('/')[-1] + '_' + self.train_type + '.list'),
                      'r') as f:
                self.image_list = f.readlines()
            self.image_list = [item.replace('\n', '') for item in self.image_list]
            self.folder = [join(dataset_folder, 'image', x) for x in self.image_list]
            self.mask = [join(dataset_folder, 'label', x.split('.')[0] + '_segmentation.npy') for x in self.image_list]
            # self.folder = sorted([join(dataset_folder, self.train_type, 'image', x) for x in
            #                       listdir(join(dataset_folder, self.train_type, 'image'))])
            # self.mask = sorted([join(dataset_folder, self.train_type, 'label', x) for x in
            #                     listdir(join(dataset_folder, self.train_type, 'label'))])
        else:
            print("Choosing type error, You have to choose the loading data type including: train, validation, test")

        assert len(self.folder) == len(self.mask)

指向划分列表 list 所在目录,例如 ./Datasets/folder3/读取如 ./Datasets/folder3/folder3_train.list 文件,获取图像文件名列表。构建完整路径。

图像路径:/ISIC2018_Task1_npy_all/image/ISIC_xxx.npy

标签路径:/ISIC2018_Task1_npy_all/label/ISIC_xxx_segmentation.npy

assert len(self.folder) == len(self.mask)确保图像和标签数量一致,防止错位。

获取单个样本 getitem

python 复制代码
def __getitem__(self, item: int):
    image = np.load(self.folder[item])      # 加载图像 (H, W, 3)
    label = np.load(self.mask[item])        # 加载标签 (H, W),值为 0/1
    name = self.folder[item].split('/')[-1] # 如 'ISIC_0000000.npy'

    sample = {'image': image, 'label': label}

    if self.transform is not None:
        sample = self.transform(sample, self.train_type)

    if self.with_name:
        return name, sample['image'], sample['label']
    else:
        return sample['image'], sample['label']

返回格式由 with_name 控制:训练/验证(with_name=False) → (image, label);测试 → (name, image, label)(便于保存结果)。

transform 接收整个 sample 字典和 train_type,可实现"训练时增强,测试时不增强"。

在训练主体代码中:

python 复制代码
from Datasets.ISIC2018 import ISIC2018_dataset

Test_Dataset = {'ISIC2018': ISIC2018_dataset}

# 在 main() 中:
trainset = Test_Dataset[args.data](
    dataset_folder=args.root_path,      # '/xujiheng/ISICdemo/ISIC2018_npy_all_224_320'
    folder=args.val_folder,             # 如 'folder3'(五折交叉验证)
    train_type='train',
    with_name=False,
    transform=Test_Transform[args.transform]  # 如 ISIC2018_transform_newdata
)

validset = Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='validation',
                                       with_name=False, transform=Test_Transform[args.transform])

testset =  Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='test',
                                       with_name=True, transform=Test_Transform[args.transform])

ISIC2018_dataset 是连接文件夹中.npy 数据与训练 pipeline 的桥梁,负责按交叉验证划分加载样本,并应用指定的图像/标签预处理,为模型训练/验证/测试提供标准化输入。

trainset是一个 PyTorch Dataset 对象,当调用 trainset[i] 时,它会返回一个由一张图像及其对应的标签组成的元组 (image_tensor, label_tensor)。

image_tensor: 形状为 (3, H, W) 的张量,表示输入图像(RGB三通道)。

label_tensor: 形状为 (1, H, W) 的张量,表示分割标签(通常是二值化的掩码)。

UNet-ISIC2018/utils/transform.py

python 复制代码
import torch
import random
import PIL
import numbers
import numpy as np
import torch.nn as nn
import collections
import matplotlib.pyplot as plt
import torchvision.transforms as ts
import torchvision.transforms.functional as TF
from PIL import Image, ImageDraw


_pil_interpolation_to_str = {
    Image.NEAREST: 'PIL.Image.NEAREST',
    Image.BILINEAR: 'PIL.Image.BILINEAR',
    Image.BICUBIC: 'PIL.Image.BICUBIC',
    Image.LANCZOS: 'PIL.Image.LANCZOS',
}


def ISIC2018_transform(sample, train_type):
    image, label = Image.fromarray(np.uint8(sample['image']), mode='RGB'),\
                   Image.fromarray(np.uint8(sample['label']), mode='L')

    if train_type == 'train':
        image, label = randomcrop(size=(224, 300))(image, label)
        image, label = randomflip_rotate(image, label, p=0.5, degrees=30)
    else:
        image, label = resize(size=(224, 300))(image, label)

    image = ts.Compose([ts.ToTensor(),
                        ts.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])(image)
    label = ts.ToTensor()(label)

    return {'image': image, 'label': label}


def ISIC2018_transform_320(sample, train_type):
    image, label = Image.fromarray(np.uint8(sample['image']), mode='RGB'),\
                   Image.fromarray(np.uint8(sample['label']), mode='L')

    if train_type == 'train':
        image, label = randomcrop(size=(224, 320))(image, label)
        image, label = randomflip_rotate(image, label, p=0.5, degrees=30)
    else:
        image, label = resize(size=(224, 320))(image, label)

    image = ts.Compose([ts.ToTensor(),
                        ts.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])(image)
    label = ts.ToTensor()(label)

    return {'image': image, 'label': label}
    
    
def ISIC2018_transform_newdata(sample, train_type):
    image, label = Image.fromarray(np.uint8(sample['image']), mode='RGB'),\
                   Image.fromarray(np.uint8(sample['label']), mode='L')

    if train_type == 'train':
        # image, label = randomcrop(size=(224, 320))(image, label)
        image, label = randomflip_rotate(image, label, p=0.5, degrees=30)
    else:
        image = image
        label = label
        
    image = ts.Compose([ts.ToTensor(),
                        ts.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])(image)
    label = ts.ToTensor()(label)

    return {'image': image, 'label': label}
    

# these are founctional function for transform
def randomflip_rotate(img, lab, p=0.5, degrees=0):
    if random.random() < p:
        img = TF.hflip(img)
        lab = TF.hflip(lab)
    if random.random() < p:
        img = TF.vflip(img)
        lab = TF.vflip(lab)

    if isinstance(degrees, numbers.Number):
        if degrees < 0:
            raise ValueError("If degrees is a single number, it must be positive.")
        degrees = (-degrees, degrees)
    else:
        if len(degrees) != 2:
            raise ValueError("If degrees is a sequence, it must be of len 2.")
        degrees = degrees
    angle = random.uniform(degrees[0], degrees[1])
    img = TF.rotate(img, angle)
    lab = TF.rotate(lab, angle)

    return img, lab


class randomcrop(object):
    """Crop the given PIL Image and mask at a random location.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
        pad_if_needed (boolean): It will pad the image if smaller than the
            desired size to avoid raising an exception.
    """

    def __init__(self, size, padding=0, pad_if_needed=False):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
        self.padding = padding
        self.pad_if_needed = pad_if_needed

    @staticmethod
    def get_params(img, output_size):
        """Get parameters for ``crop`` for a random crop.

        Args:
            img (PIL Image): Image to be cropped.
            output_size (tuple): Expected output size of the crop.

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
        """
        w, h = img.size
        th, tw = output_size
        if w == tw and h == th:
            return 0, 0, h, w

        i = random.randint(0, h - th)
        j = random.randint(0, w - tw)
        return i, j, th, tw

    def __call__(self, img, lab):
        """
        Args:
            img (PIL Image): Image to be cropped.
            lab (PIL Image): Image to be cropped.

        Returns:
            PIL Image: Cropped image and mask.
        """
        if self.padding > 0:
            img = TF.pad(img, self.padding)
            lab = TF.pad(lab, self.padding)

        # pad the width if needed
        if self.pad_if_needed and img.size[0] < self.size[1]:
            img = TF.pad(img, (int((1 + self.size[1] - img.size[0]) / 2), 0))
            lab = TF.pad(lab, (int((1 + self.size[1] - lab.size[0]) / 2), 0))
        # pad the height if needed
        if self.pad_if_needed and img.size[1] < self.size[0]:
            img = TF.pad(img, (0, int((1 + self.size[0] - img.size[1]) / 2)))
            lab = TF.pad(lab, (0, int((1 + self.size[0] - lab.size[1]) / 2)))

        i, j, h, w = self.get_params(img, self.size)

        return TF.crop(img, i, j, h, w), TF.crop(lab, i, j, h, w)

    def __repr__(self):
        return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)


class resize(object):
    """Resize the input PIL Image and mask to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
    """

    def __init__(self, size, interpolation=Image.BILINEAR):
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img, lab):
        """
        Args:
            img (PIL Image): Image to be scaled.
            lab (PIL Image): Image to be scaled.

        Returns:
            PIL Image: Rescaled image and mask.
        """
        return TF.resize(img, self.size, self.interpolation), TF.resize(lab, self.size, self.interpolation)

    def __repr__(self):
        interpolate_str = _pil_interpolation_to_str[self.interpolation]
        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)


def itensity_normalize(volume):
    """
    normalize the itensity of an nd volume based on the mean and std of nonzeor region
    inputs:
        volume: the input nd volume
    outputs:
        out: the normalized n                                                                                                                                                                 d volume
    """

    # pixels = volume[volume > 0]
    mean = volume.mean()
    std = volume.std()
    out = (volume - mean) / std
    out_random = np.random.normal(0, 1, size=volume.shape)
    out[volume == 0] = out_random[volume == 0]

    return out

三个主变换函数(核心接口)

ISIC2018_transform(sample, train_type)

python 复制代码
def ISIC2018_transform(sample, train_type):
    image, label = Image.fromarray(np.uint8(sample['image']), mode='RGB'),\
                   Image.fromarray(np.uint8(sample['label']), mode='L')

    if train_type == 'train':
        image, label = randomcrop(size=(224, 300))(image, label)
        image, label = randomflip_rotate(image, label, p=0.5, degrees=30)
    else:
        image, label = resize(size=(224, 300))(image, label)

    image = ts.Compose([ts.ToTensor(),
                        ts.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])(image)
    label = ts.ToTensor()(label)

    return {'image': image, 'label': label}

输入:sample 是一个字典,包含 'image'(H×W×3 numpy array)和 'label'(H×W 单通道 mask)。

训练模式 (train_type == 'train'):先随机裁剪 到 (224, 300)再随机水平/垂直翻转 + 随机旋转 ±30°。

验证/测试模式:直接 resize 到 (224, 300)。

最后统一:图像:转为 Tensor 并归一化到 [-1, 1](因为 mean=std=0.5);标签:仅转为 Tensor(值域 [0, 1]);输出尺寸是 (224, 300),不是正方形。

ISIC2018_transform_320(sample, train_type)

所有尺寸改为 (224, 320)(更宽),其他逻辑完全一致。

ISIC2018_transform_newdata(sample, train_type)

数据已处理成(224,320)前提下,可以选择在训练时候使用这个。不进行 resize 或 crop!保留原始尺寸。仅在训练时做翻转+旋转增强。

python 复制代码
def ISIC2018_transform_newdata(sample, train_type):
    image, label = Image.fromarray(np.uint8(sample['image']), mode='RGB'),\
                   Image.fromarray(np.uint8(sample['label']), mode='L')

    if train_type == 'train':
        # image, label = randomcrop(size=(224, 320))(image, label)
        image, label = randomflip_rotate(image, label, p=0.5, degrees=30)
    else:
        image = image
        label = label
        
    image = ts.Compose([ts.ToTensor(),
                        ts.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])(image)
    label = ts.ToTensor()(label)

    return {'image': image, 'label': label}

功能性增强函数randomflip_rotate(img, lab, p=0.5, degrees=0)

python 复制代码
def randomflip_rotate(img, lab, p=0.5, degrees=0):
    # 水平翻转(概率 p)
    if random.random() < p:
        img = TF.hflip(img); lab = TF.hflip(lab)
    # 垂直翻转(概率 p)
    if random.random() < p:
        img = TF.vflip(img); lab = TF.vflip(lab)

    # 处理旋转角度范围
    if isinstance(degrees, numbers.Number):
        degrees = (-degrees, degrees)
    angle = random.uniform(degrees[0], degrees[1])
    img = TF.rotate(img, angle, fill=0)      # 注意:默认 fill=0(黑色)
    lab = TF.rotate(lab, angle, fill=0)      # 标签也用 0 填充(背景)

    return img, lab

在训练主脚本中:

python 复制代码
from utils.transform import ISIC2018_transform, ISIC2018_transform_320, ISIC2018_transform_newdata

Test_Transform = {
    'A': ISIC2018_transform,
    'B': ISIC2018_transform_320,
    'C': ISIC2018_transform_newdata
}

# 在 main() 中:
trainset = Test_Dataset[args.data](..., transform=Test_Transform[args.transform])

通过命令行参数 --transform C(默认)选择使用哪一种预处理策略。parser.add_argument('--transform', default='C', type=str, ...)。通过使用 ISIC2018_transform_newdata,在保留原始图像分辨率(实际为预处理后的统一尺寸 224×320)的前提下,仅对训练数据施加翻转和旋转增强,实现了有效防止过拟合的同时确保测试阶段分割结果的精细度。

UNet-ISIC2018/Models/compare_networks/unet.py

代码:

python 复制代码
import torch
import torch.nn as nn


class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.conv(x)
        return x

class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		    nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x

class unet(nn.Module):   #添加了空间注意力和通道注意力
    def __init__(self,classes=2,channels=3):
        super(unet,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)

        self.Conv1 = conv_block(ch_in=channels,ch_out=64) #64
        self.Conv2 = conv_block(ch_in=64,ch_out=128)  #64 128
        self.Conv3 = conv_block(ch_in=128,ch_out=256) #128 256
        self.Conv4 = conv_block(ch_in=256,ch_out=512) #256 512
        self.Conv5 = conv_block(ch_in=512,ch_out=1024) #512 1024

        self.Up5 = up_conv(ch_in=1024,ch_out=512)  #1024 512
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)  

        self.Up4 = up_conv(ch_in=512,ch_out=256)  #512 256
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)  
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)  #256 128
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128) 
        
        self.Up2 = up_conv(ch_in=128,ch_out=64) #128 64
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)  

        self.Conv_1x1 = nn.Conv2d(64,classes,kernel_size=1,stride=1,padding=0)  #64


    def forward(self,x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4,d5),dim=1)
        
        d5 = self.Up_conv5(d5)
        
        d4 = self.Up4(d5)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)

        return d1
参数 含义 你任务中的正确值
channels 输入图像的通道数 3(RGB 彩色图)
classes 模型最终输出的类别数(通道数) 1(二值分割)

注:

U-Net 能支持 (224, 320) 是因为它是全卷积网络,没有全连接层(Fully Connected Layer)。

例子:

输入 224×224 → 最后卷积输出 512×7×7 → 展平为 25088 维

输入 320×320 → 最后卷积输出 512×10×10 → 展平为 51200 维

但 Linear(25088, 1000) 无法接收 51200 维 → 崩溃!

同时该尺寸能被下采样因子 16 整除,确保编码器与解码器特征图尺寸对齐;若输入尺寸不能被 16 整除(如 225×321),则上采样后无法恢复原尺寸,导致跳跃连接时张量形状不匹配而报错。

在主训练脚本中:

python 复制代码
# ==============================
# 1. 导入模型
# ==============================
from Models.compare_networks.unet import unet

Test_Model = {
    'UNet': unet,
    # 其他模型...
}

# ==============================
# 2. 初始化模型
# ==============================
model = Test_Model[args.model](classes=1, channels=3).cuda()
model.train()

# ==============================
# 3. 定义优化器 & 损失函数
# ==============================
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.BCEWithLogitsLoss()  # 或自定义 Dice+BCE

# ==============================
# 4. 训练循环(核心)
# ==============================
for epoch in range(args.epochs):
    for step, (x, y) in enumerate(train_loader):
        # x: (B, 3, 224, 320), y: (B, 1, 224, 320)
        image = x.float().cuda()      # 输入图像
        target = y.float().cuda()     # 标签(0/1)

        # 前向传播
        pred = model(image)           # 输出: (B, 1, 224, 320),logits

        # 计算损失
        loss = criterion(pred, target)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# ==============================
# 5. 验证/测试(推理)
# ==============================
model.eval()
with torch.no_grad():
    for name, x, y in test_loader:
        image = x.float().cuda()
        target = y.float().cuda()
        pred = model(image)                     # logits
        pred_sigmoid = torch.sigmoid(pred)      # 转为概率 [0,1]
        pred_binary = (pred_sigmoid > 0.5).float()  # 二值 mask
        # 保存或评估 pred_binary

**(B, C, H, W):**B 表示 batch size(批大小),即 一次前向/反向传播中同时处理的样本数量。

阶段 操作 输入张量形状 (B, C, H, W) 输出张量形状 (B, C, H, W) 说明
Input - (B, 3, 224, 320) - RGB 图像,已归一化到 [-1, 1]
Encoder ↓
Conv1 conv_block(3→64) (B, 3, 224, 320) (B, 64, 224, 320) 双卷积 + ReLU + BN
MaxPool1 MaxPool2d(2) (B, 64, 224, 320) (B, 64, 112, 160) 下采样
Conv2 conv_block(64→128) (B, 64, 112, 160) (B, 128, 112, 160)
MaxPool2 MaxPool2d(2) (B, 128, 112, 160) (B, 128, 56, 80)
Conv3 conv_block(128→256) (B, 128, 56, 80) (B, 256, 56, 80)
MaxPool3 MaxPool2d(2) (B, 256, 56, 80) (B, 256, 28, 40)
Conv4 conv_block(256→512) (B, 256, 28, 40) (B, 512, 28, 40)
MaxPool4 MaxPool2d(2) (B, 512, 28, 40) (B, 512, 14, 20)
Conv5 conv_block(512→1024) (B, 512, 14, 20) (B, 1024, 14, 20) 编码器最深层
Decoder ↑
Up5 up_conv(1024→512) (Upsample×2 + Conv) (B, 1024, 14, 20) (B, 512, 28, 40) 上采样
Concat x4 torch.cat([x4, d5], dim=1) x4: (B, 512, 28, 40) d5: (B, 512, 28, 40) (B, 1024, 28, 40) 跳跃连接
Up_conv5 conv_block(1024→512) (B, 1024, 28, 40) (B, 512, 28, 40) 融合特征
Up4 up_conv(512→256) (B, 512, 28, 40) (B, 256, 56, 80)
Concat x3 torch.cat([x3, d4], dim=1) x3: (B, 256, 56, 80) d4: (B, 256, 56, 80) (B, 512, 56, 80)
Up_conv4 conv_block(512→256) (B, 512, 56, 80) (B, 256, 56, 80)
Up3 up_conv(256→128) (B, 256, 56, 80) (B, 128, 112, 160)
Concat x2 torch.cat([x2, d3], dim=1) x2: (B, 128, 112, 160) d3: (B, 128, 112, 160) (B, 256, 112, 160)
Up_conv3 conv_block(256→128) (B, 256, 112, 160) (B, 128, 112, 160)
Up2 up_conv(128→64) (B, 128, 112, 160) (B, 64, 224, 320)
Concat x1 torch.cat([x1, d2], dim=1) x1: (B, 64, 224, 320) d2: (B, 64, 224, 320) (B, 128, 224, 320)
Up_conv2 conv_block(128→64) (B, 128, 224, 320) (B, 64, 224, 320)
Final Conv Conv2d(64→1, 1×1) (B, 64, 224, 320) (B, 1, 224, 320) 输出 logits

模型最终返回的是一个单通道的"灰度图"形式的分割结果。

UNet-ISIC2018/utils/dice_loss.py

python 复制代码
import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss


class SoftDiceLoss(_Loss):
    '''
    Soft_Dice = 2*|dot(A, B)| / (|dot(A, A)| + |dot(B, B)| + eps)
    eps is a small constant to avoid zero division,
    '''
    def __init__(self, *args, **kwargs):
        super(SoftDiceLoss, self).__init__()

    def forward(self, prediction, soft_ground_truth, num_class=3, weight_map=None, eps=1e-8):
        dice_loss_ave, dice_score_lesion = soft_dice_loss(prediction, soft_ground_truth, num_class, weight_map)
        return dice_loss_ave, dice_score_lesion


def get_soft_label(input_tensor, num_class):
    """
        convert a label tensor to soft label
        input_tensor: tensor with shape [N, C, H, W]
        output_tensor: shape [N, H, W, num_class]
    """
    tensor_list = []
    input_tensor = input_tensor.permute(0, 2, 3, 1)
    for i in range(num_class):
        temp_prob = torch.eq(input_tensor, i * torch.ones_like(input_tensor))
        tensor_list.append(temp_prob)
    output_tensor = torch.cat(tensor_list, dim=-1)
    output_tensor = output_tensor.float()
    return output_tensor


def soft_dice_loss(prediction, soft_ground_truth, num_class, weight_map=None):
    predict = prediction.permute(0, 2, 3, 1)
    pred = predict.contiguous().view(-1, num_class)
    # pred = F.softmax(pred, dim=1)
    ground = soft_ground_truth.view(-1, num_class)
    n_voxels = ground.size(0)
    if weight_map is not None:
        weight_map = weight_map.view(-1)
        weight_map_nclass = weight_map.repeat(num_class).view_as(pred)
        ref_vol = torch.sum(weight_map_nclass * ground, 0)
        intersect = torch.sum(weight_map_nclass * ground * pred, 0)
        seg_vol = torch.sum(weight_map_nclass * pred, 0)
    else:
        ref_vol = torch.sum(ground, 0)
        intersect = torch.sum(ground * pred, 0)
        seg_vol = torch.sum(pred, 0)
    dice_score = (2.0 * intersect + 1e-5) / (ref_vol + seg_vol + 1.0 + 1e-5)
    # dice_loss = 1.0 - torch.mean(dice_score)
    # return dice_loss
    dice_loss = -torch.log(dice_score)
    dice_loss_ave = torch.mean(dice_loss)
    dice_score_lesion = dice_loss[1]
    return dice_loss_ave, dice_score_lesion

def IOU_loss(prediction, soft_ground_truth, num_class):
    predict = prediction.permute(0, 2, 3, 1)
    pred = prediction.contiguous().view(-1, num_class)
    # pred = F.softmax(pred, dim=1)
    ground = soft_ground_truth.view(-1, num_class)
    ref_vol = torch.sum(ground, 0)
    intersect = torch.sum(ground * pred, 0)
    seg_vol = torch.sum(pred, 0)
    iou_score = intersect / (ref_vol + seg_vol - intersect + 1.0)
    iou_loss = torch.mean(-torch.log(iou_score))

    return iou_loss

def jc_loss(prediction, soft_ground_truth, num_class):
    predict = prediction.permute(0, 2, 3, 1)
    pred = predict[:,:,:,1].contiguous().view(-1, num_class)
   # pred = prediction[:,:,:,1].view(-1, num_class)
    # pred = F.softmax(pred, dim=1)
    ground = soft_ground_truth[:,:,:,1].view(-1, num_class)
    ref_vol = torch.sum(ground, 0)
    intersect = torch.sum(ground * pred, 0)
    seg_vol = torch.sum(pred, 0)
    iou_score = intersect / (ref_vol + seg_vol - intersect + 1.0)
    #jc = 10*(1-iou_score)
    jc = 20*torch.mean(-torch.log(iou_score))

    return jc


def val_dice_fetus(prediction, soft_ground_truth, num_class):
    # predict = prediction.permute(0, 2, 3, 1)
    pred = prediction.contiguous().view(-1, num_class)
    # pred = F.softmax(pred, dim=1)
    ground = soft_ground_truth.view(-1, num_class)
    ref_vol = torch.sum(ground, 0)
    intersect = torch.sum(ground * pred, 0)
    seg_vol = torch.sum(pred, 0)
    dice_score = 2.0 * intersect / (ref_vol + seg_vol + 1.0)
    dice_mean_score = torch.mean(dice_score)
    placenta_dice = dice_score[1]
    brain_dice = dice_score[2]

    return placenta_dice, brain_dice


def Intersection_over_Union_fetus(prediction, soft_ground_truth, num_class):
    # predict = prediction.permute(0, 2, 3, 1)
    pred = prediction.contiguous().view(-1, num_class)
    # pred = F.softmax(pred, dim=1)
    ground = soft_ground_truth.view(-1, num_class)
    ref_vol = torch.sum(ground, 0)
    intersect = torch.sum(ground * pred, 0)
    seg_vol = torch.sum(pred, 0)
    iou_score = intersect / (ref_vol + seg_vol - intersect + 1.0)
    dice_mean_score = torch.mean(iou_score)
    placenta_iou = iou_score[1]
    brain_iou = iou_score[2]

    return placenta_iou, brain_iou


def val_dice_isic(prediction, soft_ground_truth, num_class):
    # predict = prediction.permute(0, 2, 3, 1)
   # pred = prediction.contiguous().view(-1, num_class)
    pred = prediction.view(-1, num_class)
    # pred = F.softmax(pred, dim=1)
    ground = soft_ground_truth.view(-1, num_class)
    ref_vol = torch.sum(ground, 0)
    intersect = torch.sum(ground * pred, 0)
    seg_vol = torch.sum(pred, 0)
   # dice_score = 2.0 * intersect / (ref_vol + seg_vol + 1.0)
    dice_score = 2.0 * intersect / (ref_vol + seg_vol + 1e-6)
   # dice_mean_score = torch.mean(dice_score)

    return dice_score


def val_dice_isic_raw(prediction, soft_ground_truth, num_class):
    # predict = prediction.permute(0, 2, 3, 1)
    pred = prediction.contiguous().view(-1, num_class)
    # pred = F.softmax(pred, dim=1)
    ground = soft_ground_truth.view(-1, num_class)
    ref_vol = torch.sum(ground, 0)
    intersect = torch.sum(ground * pred, 0)
    seg_vol = torch.sum(pred, 0)
   # dice_score = 2.0 * intersect / (ref_vol + seg_vol + 1.0)
    dice_score = 2.0 * intersect / (ref_vol + seg_vol + 1e-6)
    dice_mean_score = torch.mean(dice_score)

    return dice_mean_score


def Intersection_over_Union_isic(prediction, soft_ground_truth, num_class):
    # predict = prediction.permute(0, 2, 3, 1)
    pred = prediction.contiguous().view(-1, num_class)
    # pred = F.softmax(pred, dim=1)
    ground = soft_ground_truth.view(-1, num_class)
    ref_vol = torch.sum(ground, 0)
    intersect = torch.sum(ground * pred, 0)
    seg_vol = torch.sum(pred, 0)
    iou_score = intersect / (ref_vol + seg_vol - intersect + 1.0)
    iou_mean_score = torch.mean(iou_score)

    return iou_mean_score

SoftDiceLoss 类(主损失类)

python 复制代码
class SoftDiceLoss(_Loss):
    '''
    Soft_Dice = 2*|dot(A, B)| / (|dot(A, A)| + |dot(B, B)| + eps)
    eps is a small constant to avoid zero division,
    '''
    def __init__(self, *args, **kwargs):
        super(SoftDiceLoss, self).__init__()

    def forward(self, prediction, soft_ground_truth, num_class=3, weight_map=None, eps=1e-8):
        dice_loss_ave, dice_score_lesion = soft_dice_loss(prediction, soft_ground_truth, num_class, weight_map)
        return dice_loss_ave, dice_score_lesion

继承 PyTorch 的 _Loss,可直接用于训练。注:默认 num_class=3 是为多类任务设计的,但 ISIC2018 是二值任务,需传 num_class=2

get_soft_label:硬标签 → 软标签(one-hot)

python 复制代码
def get_soft_label(input_tensor, num_class):
    """
        convert a label tensor to soft label
        input_tensor: tensor with shape [N, C, H, W]
        output_tensor: shape [N, H, W, num_class]
    """
    tensor_list = []
    input_tensor = input_tensor.permute(0, 2, 3, 1)
    for i in range(num_class):
        temp_prob = torch.eq(input_tensor, i * torch.ones_like(input_tensor))
        tensor_list.append(temp_prob)
    output_tensor = torch.cat(tensor_list, dim=-1)
    output_tensor = output_tensor.float()
    return output_tensor

将整数标签(如 0/1)转为 one-hot 编码。但在 ISIC2018 中通常不用,因为标签已是 float [0,1] 单通道,且模型输出也是单通道 logits。

soft_dice_loss:核心 Dice 损失计算

python 复制代码
def soft_dice_loss(prediction, soft_ground_truth, num_class, weight_map=None):
    predict = prediction.permute(0, 2, 3, 1)
    pred = predict.contiguous().view(-1, num_class)
    # pred = F.softmax(pred, dim=1)
    ground = soft_ground_truth.view(-1, num_class)
    n_voxels = ground.size(0)
    if weight_map is not None:
        weight_map = weight_map.view(-1)
        weight_map_nclass = weight_map.repeat(num_class).view_as(pred)
        ref_vol = torch.sum(weight_map_nclass * ground, 0)
        intersect = torch.sum(weight_map_nclass * ground * pred, 0)
        seg_vol = torch.sum(weight_map_nclass * pred, 0)
    else:
        ref_vol = torch.sum(ground, 0)
        intersect = torch.sum(ground * pred, 0)
        seg_vol = torch.sum(pred, 0)
    dice_score = (2.0 * intersect + 1e-5) / (ref_vol + seg_vol + 1.0 + 1e-5)
    # dice_loss = 1.0 - torch.mean(dice_score)
    # return dice_loss
    dice_loss = -torch.log(dice_score)
    dice_loss_ave = torch.mean(dice_loss)
    dice_score_lesion = dice_loss[1]
    return dice_loss_ave, dice_score_lesion

在主训练代码中:

python 复制代码
from utils.dice_loss import get_soft_label, val_dice_isic, SoftDiceLoss
from utils.dice_loss import Intersection_over_Union_isic
函数/类 用途 调用阶段
get_soft_label 将硬标签(0/1)转为 one-hot 软标签(如 [0,1][1,0], [0,1] 训练 + 测试
val_dice_isic 计算 ISIC2018 的 Dice 系数(用于测试评估) 测试阶段
SoftDiceLoss 自定义的 Dice 损失函数 训练阶段(理论上)
Intersection_over_Union_isic 计算 IoU(Jaccard Index) 测试阶段

get_soft_label:

python 复制代码
# train() 函数内
target = y.float().cuda()                                   # [B, 1, H, W]
target_soft_a = get_soft_label(target, args.num_classes)    # → [B, H, W, 2]
target_soft = target_soft_a.permute(0, 3, 1, 2)             # → [B, 2, H, W]

因为模型输出是 2 通道(classes=2),所以将单通道标签转为 one-hot 双通道,以匹配损失函数输入。

python 复制代码
# test_isic() 函数内
output_dis = torch.max(output, 1)[1].unsqueeze(dim=1)   # [B,1,H,W], 整数预测 (0/1)
target = lab.float().cuda()                             # [B,1,H,W], float 标签 (0.0/1.0)

output_soft = get_soft_label(output_dis, 2)             # → [B, H, W, 2]
target_soft = get_soft_label(target, 2)                 # → [B, H, W, 2]

为 val_dice_isic 提供符合其接口的 one-hot 输入。

val_dice_isic

python 复制代码
# test_isic() 函数内
isic_b_dice = val_dice_isic(output_soft, target_soft, 2)  # num_class=2
dice_np = isic_b_dice.data.cpu().numpy()                  # shape: (2,)
isic_dice.append(dice_np)                                

dice_np[0]:背景 Dice

dice_np[1]:病灶(前景)Dice ← 这才是你要的指标

Intersection_over_Union_isic

python 复制代码
# test_isic() 函数内
output_dis_test = output_dis.permute(0, 2, 3, 1).float()  # [B, H, W, 1]
target_test = target.permute(0, 2, 3, 1).float()          # [B, H, W, 1]

isic_b_iou = Intersection_over_Union_isic(output_dis_test, target_test, 1)
iou_np = isic_b_iou.data.cpu().numpy()
isic_iou.append(iou_np)

其他的指标库

函数 全称 用途 输入格式 输出
dc Dice Coefficient 衡量预测与真值重叠度(最常用) (H,W)(B,H,W) 的 0/1 array float ∈ [0,1]
jc Jaccard Index (IoU) 交并比 同上 float ∈ [0,1]
precision Precision (Positive Predictive Value) 预测为病灶的像素中有多少是真的 同上 float
sensitivity Sensitivity (Recall / True Positive Rate) 真实病灶中有多少被找出来 同上 float
specificity Specificity (True Negative Rate) 真实背景中有多少被正确识别 同上 float
F1 F1-Score Precision 和 Recall 的调和平均 同上 float
ACC Accuracy 整体像素分类准确率 同上 float
assd Average Symmetric Surface Distance 边界距离误差(更精细的几何指标) 同上 float (单位:像素)

最后再来逐句的理解一下训练代码:

导入语句不多解释,先看全局配置

python 复制代码
Test_Model = {
              "unet":unet,
              "AttUNet":AttUNet
             }
             
             
Test_Dataset = {'ISIC2018': ISIC2018_dataset}

Test_Transform = {'A': ISIC2018_transform, 'B':ISIC2018_transform_320, "C":ISIC2018_transform_newdata}

criterion = "loss_D"  # loss_A-->SoftDiceLoss;  loss_B-->softdice;  loss_C--> CE + softdice;   loss_D--> BCEWithLogitsLoss

日志重定向,将 print() 同时输出到控制台和日志文件。

python 复制代码
class Logger(object):
    def __init__(self,logfile):
        self.terminal = sys.stdout
        self.log = open(logfile, "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)  

    def flush(self):
        pass 

训练函数

python 复制代码
def train(train_loader, model, criterion, scheduler, optimizer, args, epoch):
    

初始化,每个 epoch 重置 loss 记录器,设为训练模式。

复制代码
losses = AverageMeter()
model.train()

数据加载,图像转 float,标签也转 float(因为 BCE 要求 float)。

复制代码
for step, (x, y) in tqdm(enumerate(train_loader)):
    image = x.float().cuda()      # [B, 3, H, W]
    target = y.float().cuda()     # [B, 1, H, W], 值为 0.0/1.0

前向传播,ISIC2018 是二值分割,理想输出应为 [B,1,H,W],但这里用了 2 通道(背景+前景),导致后续必须用 one-hot 标签。

复制代码
output = model(image)  # [B, 2, H, W] (因为 classes=2)

构造 one-hot 标签,为了匹配 2 通道输出,把 [B,1,H,W] 标签转成 one-hot。

复制代码
target_soft_a = get_soft_label(target, args.num_classes)  # [B, H, W, 2]
target_soft = target_soft_a.permute(0, 3, 1, 2)           # [B, 2, H, W]

损失函数选择(执行 loss_D)

复制代码
if criterion == "loss_D":
    dice_loss = nn.BCEWithLogitsLoss()
    loss = dice_loss(output, target_soft)  # 注意:target_soft 是 [B,2,H,W]

反向传播 & 优化

复制代码
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.update(loss.data, image.size(0))

日志打印,每个 epoch 打印一次(因为 step % total_steps == 0 只在最后一步成立?其实逻辑有点怪,建议用 step == len(loader)-1)。

复制代码
if step % (math.ceil(len(dataset)/batch_size)) == 0:
    print('Epoch: {} Loss: {:.6f}'.format(...))

验证函数 valid_isic()

推理模式

复制代码
model.eval()
with no_grad implicitly via eval()

预测 & 二值化,用 argmax 得到预测类别。

复制代码
output = model(image)                          # [B,2,H,W]
output_dis = torch.max(output, 1)[1].unsqueeze(1)  # [B,1,H,W], int64 (0 or 1)

转换为 NumPy 用于评估

复制代码
output_dis_test = output_dis.permute(0,2,3,1).float()  # [B,H,W,1]
target_test = target.permute(0,2,3,1).float()          # [B,H,W,1]

计算指标,使用 utils.binary 中的函数,输入是 .cpu().numpy() 的二值 mask。

复制代码
isic_b_Jaccard = jc(pred, target)  # Jaccard = IoU
isic_b_dc = dc(pred, target)       # Dice

保存最佳模型

复制代码
net_score = dice + iou
if net_score > best: save model

测试函数 test_isic()

加载最佳模型,逐样本推理 & 评估。

复制代码
checkpoint = torch.load('best_score_...pth.tar')
model.load_state_dict(checkpoint['state_dict'])

对每个样本:

推理得到 output ([B,2,H,W])

torch.max → output_dis ([B,1,H,W])

转 one-hot:get_soft_label(..., 2) → 用于 val_dice_isic

转二值 mask:.permute → [B,H,W,1] → 用于 dc, jc, ACC 等

多种指标计算

复制代码
# 方法1:用 val_dice_isic(返回 [bg_dice, fg_dice])
isic_b_dice = val_dice_isic(output_soft, target_soft, 2)

# 方法2:用 binary 模块(直接传二值 mask)
isic_b_dc = dc(output_dis_test.numpy(), target_test.numpy())
isic_b_acc = ACC(...)
...

特殊指标:Jaccard_M / Jaccard_N,分别计算病灶和背景的 IoU。

复制代码
isic_b_Jaccard_m = jc(output_arr[:, :, 1], label_arr[:, :, 1])  # 前景 IoU
isic_b_Jaccard_n = jc(output_arr[:, :, 0], label_arr[:, :, 0])  # 背景 IoU

保存可视化结果,保存原图、真值、预测三联图。

复制代码
save_imgs(img, msk, msk_pred, ...)

主函数 main()

数据集加载,五折交叉验证支持(通过 --val_folder folder3 指定哪一折做验证)。

复制代码
trainset = ISIC2018_dataset(..., train_type='train', transform=Test_Transform[args.transform])
validset = ... 'validation'
testset = ... 'test', with_name=True

模型初始化

复制代码
if args.id == "AttUNet":
    model = AttUNet(in_channel=3, out_channel=2)
else:
    model = unet(classes=2, channels=3)

优化器 & 学习率调度,使用余弦退火带重启,适合小数据集。

复制代码
optimizer = Adam(..., weight_decay=1e-8)
scheduler = CosineAnnealingWarmRestarts(T_0=10, T_mult=2, eta_min=1e-5)

断点续训

复制代码
if args.resume: load checkpoint

训练循环

复制代码
for epoch in 1..epochs:
    train()
    valid() → save best if improved
    if epoch > 30 and epoch % 300 == 0: save checkpoint

测试,训练结束后自动测试最佳模型。

复制代码
test_isic(testloader, ...)

命令行参数(if __name__ == '__main__'

python 复制代码
parser = argparse.ArgumentParser(description='Comprehensive attention network for biomedical Dataset')
    
    parser.add_argument('--id', default="unet",
                        help='')                                                   # Select a loaded model name

    # Path related arguments
    parser.add_argument('--root_path', default='/xujiheng/ISICdemo/ISIC2018_npy_all_224_320',
                        help='root directory of data')                                      # The folder where the numpy data set is stored
    parser.add_argument('--ckpt', default='/xujiheng/ISICdemo/UNet-ISIC2018/saved_models/',
                        help='folder to output checkpoints')                                # The folder in which the trained model is saved
    parser.add_argument('--transform', default='C', type=str,
                        help='which ISIC2018_transform to choose')                         
    parser.add_argument('--data', default='ISIC2018', help='choose the dataset')            
    parser.add_argument('--out_size', default=(224, 320), help='the output image size')
    parser.add_argument('--val_folder', default='folder3', type=str,
                        help='folder1、folder2、folder3、folder4、folder5')                 # five-fold cross-validation

    # optimization related arguments
    parser.add_argument('--epochs', type=int, default=15, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--start_epoch', default=0, type=int,
                        help='epoch to start training. useful if continue from a checkpoint')
    parser.add_argument('--batch_size', type=int, default=10, metavar='N',              
                        help='input batch size for training (default: 12)')                 # batch_size
    parser.add_argument('--lr_rate', type=float, default=1e-4, metavar='LR',
                        help='learning rate (default: 0.001)')                              
    parser.add_argument('--num_classes', default=2, type=int,
                        help='number of classes')
    parser.add_argument('--num_input', default=3, type=int,
                        help='number of input image for each patient')
    parser.add_argument('--weight_decay', default=1e-8, type=float, help='weights regularizer')
    parser.add_argument('--particular_epoch', default=30, type=int,
                        help='after this number, we will save models more frequently')
    parser.add_argument('--save_epochs_steps', default=300, type=int,
                        help='frequency to save models after a particular number of epochs')
    parser.add_argument('--resume', default='',
                        help='the checkpoint that resumes from')

    args = parser.parse_args()

--id: 模型名(unet / AttUNet)

--root_path: 数据路径(npy 格式)

--val_folder: 五折中的哪一折(folder1~5)

--transform: 预处理方式(A/B/C)

--epochs, --batch_size, --lr_rate: 训练超参

--resume: 断点路径

总结

整体训练流程总览(按执行顺序)

阶段 调用位置 / 模块 实现功能 关键作用说明
1. 初始化与参数解析 if __name__ == '__main__' + argparse 解析命令行参数,设置 GPU、路径、模型、数据等 控制实验配置(如模型名、数据路径、五折验证 folder、损失类型等)
2. 日志系统设置 Logger 类 + sys.stdout = Logger(...) 将所有 print() 输出同时写入控制台和日志文件 自动记录训练过程,便于复现实验
3. 数据集加载 main()ISIC2018_dataset(...) 加载 train / valid / test 三部分数据 支持五折交叉验证(通过 --val_folder 指定)
4. 数据预处理 Test_Transform[args.transform] 应用图像增强(如归一化、裁剪、翻转等) 提升泛化能力;transform='C' 表示使用 ISIC2018_transform_newdata
5. 模型构建 main()Test_Model[args.id]AttUNet(...) 实例化 UNet 或 Attention UNet 输出通道数为 2(背景+前景),用于多类 Dice/BCE 计算
6. 优化器 & 学习率调度 torch.optim.Adam + CosineAnnealingWarmRestarts 设置优化器和动态学习率 使用余弦退火带重启,适合小数据集微调
7. 断点续训(可选) if args.resume: torch.load(...) 从 checkpoint 恢复模型和优化器状态 支持中断后继续训练
8. 训练循环(每 epoch) for epoch in range(...): 执行训练 → 验证 → 保存模型 主训练流程
9. 训练步骤(train 函数) train(train_loader, ...) 前向传播 + 损失计算 + 反向传播 使用 criterion="loss_D"BCEWithLogitsLoss
10. 标签转 one-hot get_soft_label(target, 2) [B,1,H,W] 标签转为 [B,2,H,W] one-hot 匹配 2 通道模型输出,供 BCE/Dice 使用
11. 损失计算 nn.BCEWithLogitsLoss()(output, target_soft) 计算二值交叉熵损失 当前实际使用的损失函数
12. 损失跟踪 AverageMeter() 动态计算 batch loss 的平均值 用于日志打印和监控收敛
13. 验证(valid_isic) valid_isic(valid_loader, ...) 在验证集上评估 Dice 和 IoU 决定是否保存"最佳模型"
14. 预测二值化 torch.max(output, 1)[1].unsqueeze(1) 将 logits 转为 0/1 预测 mask 用于指标计算
15. 评估指标计算 dc(), jc() from utils.binary 计算 Dice 系数和 Jaccard (IoU) 核心分割性能指标
16. 最佳模型保存 if net_score > best: torch.save(...) 保存验证集上 Dice+IoU 最高的模型 文件名:best_score_ISIC2018_checkpoint.pth.tar
17. 定期模型保存 if epoch % save_epochs_steps == 0 每 N 轮保存一次 checkpoint 用于后续分析或断点恢复
18. 测试(test_isic) test_isic(test_loader, ...) 加载最佳模型,在测试集全面评估 报告 10+ 种指标
19. 多指标评估 ACC, sensitivity, specificity, F1, precision, dc, jc 计算全面的二值分割性能 包括病灶(M)和背景(N)分别的 IoU
20. 推理时间统计 time() before/after model(image) 记录单样本推理耗时 评估模型效率
21. 可视化结果保存 save_imgs(...) 保存原图、真值、预测三联图 用于人工检查分割质量
22. 参数量统计 for param in model.parameters(): num_para += ... 计算模型可训练参数总数 用于模型复杂度分析
相关推荐
九.九6 小时前
ops-transformer:AI 处理器上的高性能 Transformer 算子库
人工智能·深度学习·transformer
春日见6 小时前
拉取与合并:如何让个人分支既包含你昨天的修改,也包含 develop 最新更新
大数据·人工智能·深度学习·elasticsearch·搜索引擎
偷吃的耗子6 小时前
【CNN算法理解】:三、AlexNet 训练模块(附代码)
深度学习·算法·cnn
Faker66363aaa8 小时前
【深度学习】YOLO11-BiFPN多肉植物检测分类模型,从0到1实现植物识别系统,附完整代码与教程_1
人工智能·深度学习·分类
大江东去浪淘尽千古风流人物10 小时前
【SLAM】Hydra-Foundations 层次化空间感知:机器人如何像人类一样理解3D环境
深度学习·算法·3d·机器人·概率论·slam
小刘的大模型笔记10 小时前
大模型微调参数设置 —— 从入门到精通的调参指南
人工智能·深度学习·机器学习
LaughingZhu11 小时前
Product Hunt 每日热榜 | 2026-02-10
人工智能·经验分享·深度学习·神经网络·产品运营
千里马也想飞11 小时前
公共管理新题解:信息化条件下文化治理类论文,如何用AI把“大空题目”做成“落地案例库”?(附三级提纲+指令包)
人工智能·深度学习·机器学习·论文笔记
软件算法开发11 小时前
基于鲸鱼优化的LSTM深度学习网络模型(WOA-LSTM)的一维时间序列预测算法matlab仿真
深度学习·lstm·鲸鱼优化·一维时间序列预测·woa-lstm
技术传感器12 小时前
大模型从0到精通:对齐之心 —— 人类如何教会AI“好“与“坏“ | RLHF深度解析
人工智能·深度学习·神经网络·架构