ISIC2018
ISIC 2018 (International Skin Imaging Collaboration 2018) 数据集主要包含大量经过专业皮肤科医生标注的皮肤镜(dermoscopy)图像,涵盖了多种皮肤病变类型。在分割任务(Task 1)中,提供了约 2594 张图像及其对应的像素级标签,用于训练和验证算法对皮肤病变区域进行精确分割的性能。
数据集通常以原始图像(如 PNG 格式)和对应的二值分割掩码(标记病变区域)形式提供。但在实际研究中,为了方便处理或满足特定框架的需求,常常需要转换成 .npy 或 .nii.gz 格式。
npy: NumPy 的原生二进制数组格式,适合高效存储和直接加载到 Python 深度学习框架中。
nii.gz: 常用于医学影像的 NIfTI 格式压缩文件,能保存图像数据及其空间信息,适用于兼容医学影像处理工具链的场景。
这是其中一个数据的图像以及标注:

代码讲解
train.py
python
import os
import torch
import math
import visdom
import torch.utils.data as Data
import argparse
import numpy as np
import sys
from tqdm import tqdm
import torch.nn as nn
from distutils.version import LooseVersion
from Datasets.ISIC2018 import ISIC2018_dataset
from utils.transform import ISIC2018_transform, ISIC2018_transform_320, ISIC2018_transform_newdata
from Models.compare_networks.unet import unet
from Models.compare_networks.AttUnet import AttUNet
from utils.dice_loss import get_soft_label, val_dice_isic, SoftDiceLoss
from utils.dice_loss import Intersection_over_Union_isic
from utils.dice_loss_github import SoftDiceLoss_git, CrossentropyND
from utils.evaluation import AverageMeter
from utils.binary import assd, dc, jc, precision, sensitivity, specificity, F1, ACC
from torch.optim import lr_scheduler
from time import *
from matplotlib import pyplot as plt
Test_Model = {
"unet":unet,
"AttUNet":AttUNet
}
Test_Dataset = {'ISIC2018': ISIC2018_dataset}
Test_Transform = {'A': ISIC2018_transform, 'B':ISIC2018_transform_320, "C":ISIC2018_transform_newdata}
criterion = "loss_D" # loss_A-->SoftDiceLoss; loss_B-->softdice; loss_C--> CE + softdice; loss_D--> BCEWithLogitsLoss
class Logger(object):
def __init__(self,logfile):
self.terminal = sys.stdout
self.log = open(logfile, "a")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
pass
def train(train_loader, model, criterion, scheduler, optimizer, args, epoch):
losses = AverageMeter()
# current_loss_f = "CE_softdice" # softdice or CE_softdice
model.train()
for step, (x, y) in tqdm(enumerate(train_loader), total=len(train_loader)):
image = x.float().cuda()
target = y.float().cuda()
output = model(image)
target_soft_a = get_soft_label(target, args.num_classes)
target_soft = target_soft_a.permute(0, 3, 1, 2)
ca_soft_dice_loss = SoftDiceLoss()
soft_dice_loss = SoftDiceLoss_git(batch_dice=True, dc_log=False)
soft_dice_loss2 = SoftDiceLoss_git(batch_dice=False, dc_log=False)
soft_dice_loss3 = SoftDiceLoss_git(batch_dice=True, dc_log=True)
CE_loss_F = CrossentropyND()
if criterion == "loss_A":
loss_ave, loss_lesion = ca_soft_dice_loss(output, target_soft_a, args.num_classes)
loss = loss_ave
if criterion == "loss_B":
dice_loss = soft_dice_loss(output, target_soft)
loss = dice_loss
if criterion == "loss_C":
dice_loss = soft_dice_loss2(output, target_soft)
ce_loss = CE_loss_F(output, target)
loss = dice_loss + ce_loss
if criterion == "loss_D":
dice_loss = nn.BCEWithLogitsLoss()
loss = dice_loss(output, target_soft)
loss = loss
losses.update(loss.data, image.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % (math.ceil(float(len(train_loader.dataset))/args.batch_size)) == 0:
print('current lr: {} | Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {losses.avg:.6f}'.format(
optimizer.state_dict()['param_groups'][0]['lr'],
epoch, step * len(image), len(train_loader.dataset),
100. * step / len(train_loader), losses=losses))
print('The average loss:{losses.avg:.4f}'.format(losses=losses))
return losses.avg
def valid_isic(valid_loader, model, criterion, optimizer, args, epoch, best_score, val_acc_log):
isic_Jaccard = []
isic_dc = []
model.eval()
for step, (t, k) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
image = t.float().cuda()
target = k.float().cuda()
output = model(image) # model output
output_dis = torch.max(output, 1)[1].unsqueeze(dim=1)
output_dis_test = output_dis.permute(0, 2, 3, 1).float()
target_test = target.permute(0, 2, 3, 1).float()
isic_b_Jaccard = jc(output_dis_test.cpu().numpy(), target_test.cpu().numpy())
isic_b_dc = dc(output_dis_test.cpu().numpy(), target_test.cpu().numpy())
isic_Jaccard.append(isic_b_Jaccard)
isic_dc.append(isic_b_dc)
isic_Jaccard_mean = np.average(isic_Jaccard)
isic_dc_mean = np.average(isic_dc)
net_score = isic_Jaccard_mean + isic_dc_mean
print('The ISIC Dice score: {dice: .4f}; '
'The ISIC JC score: {jc: .4f}'.format(
dice=isic_dc_mean, jc=isic_Jaccard_mean))
with open(val_acc_log, 'a') as vlog_file:
line = "{} | {dice: .4f} | {jc: .4f}".format(epoch, dice=isic_dc_mean, jc=isic_Jaccard_mean)
vlog_file.write(line+'\n')
if net_score > max(best_score):
best_score.append(net_score)
print(best_score)
modelname = args.ckpt + '/' + 'best_score' + '_' + args.data + '_checkpoint.pth.tar'
print('the best model will be saved at {}'.format(modelname))
state = {'epoch': epoch, 'state_dict': model.state_dict(), 'opt_dict': optimizer.state_dict()}
torch.save(state, modelname)
return isic_Jaccard_mean, isic_dc_mean, net_score
def test_isic(test_loader, model, num_para, args, test_acc_log):
isic_dice = []
isic_iou = []
# isic_assd = []
isic_acc = []
isic_sensitive = []
isic_specificy = []
isic_precision = []
isic_f1_score = []
isic_Jaccard_M = []
isic_Jaccard_N = []
isic_Jaccard = []
isic_dc = []
infer_time = []
modelname = args.ckpt + '/' + 'best_score' + '_' + args.data + '_checkpoint.pth.tar'
if os.path.isfile(modelname):
print("=> Loading checkpoint '{}'".format(modelname))
checkpoint = torch.load(modelname)
model.load_state_dict(checkpoint['state_dict'])
print("=> Loaded saved the best model at (epoch {})".format(checkpoint['epoch']))
else:
print("=> No checkpoint found at '{}'".format(modelname))
model.eval()
for step, (name, img, lab) in tqdm(enumerate(test_loader), total=len(test_loader)):
image = img.float().cuda()
target = lab.float().cuda() # [batch, 1, 224, 320]
begin_time = time()
output = model(image)
end_time = time()
pred_time = end_time - begin_time
infer_time.append(pred_time)
output_dis = torch.max(output, 1)[1].unsqueeze(dim=1)
output_dis_test = output_dis.permute(0, 2, 3, 1).float()
target_test = target.permute(0, 2, 3, 1).float()
output_soft = get_soft_label(output_dis, 2)
target_soft = get_soft_label(target, 2)
msk = target_test.squeeze(0).cpu().detach().numpy()
out = output_dis_test.squeeze(0).cpu().detach().numpy()
save_imgs(img, msk, out, step, '/xujiheng/ISICdemo/UNet-ISIC2018/outputflod1/fold1/')
label_arr = np.squeeze(target_soft.cpu().numpy()).astype(np.uint8)
output_arr = np.squeeze(output_soft.cpu().byte().numpy()).astype(np.uint8)
isic_b_dice = val_dice_isic(output_soft, target_soft, 2) # the dice
isic_b_iou = Intersection_over_Union_isic(output_dis_test, target_test, 1) # the iou
# isic_b_asd = assd(output_arr[:, :, 1], label_arr[:, :, 1]) # the assd
isic_b_acc = ACC(output_dis_test.cpu().numpy(), target_test.cpu().numpy()) # the accuracy
isic_b_sensitive = sensitivity(output_dis_test.cpu().numpy(), target_test.cpu().numpy()) # the sensitivity
isic_b_specificy = specificity(output_dis_test.cpu().numpy(), target_test.cpu().numpy()) # the specificity
isic_b_precision = precision(output_dis_test.cpu().numpy(), target_test.cpu().numpy()) # the precision
isic_b_f1_score = F1(output_dis_test.cpu().numpy(), target_test.cpu().numpy()) # the F1
isic_b_Jaccard_m = jc(output_arr[:, :, 1], label_arr[:, :, 1]) # the Jaccard melanoma
isic_b_Jaccard_n = jc(output_arr[:, :, 0], label_arr[:, :, 0]) # the Jaccard no-melanoma
isic_b_Jaccard = jc(output_dis_test.cpu().numpy(), target_test.cpu().numpy())
isic_b_dc = dc(output_dis_test.cpu().numpy(), target_test.cpu().numpy())
dice_np = isic_b_dice.data.cpu().numpy()
iou_np = isic_b_iou.data.cpu().numpy()
isic_dice.append(dice_np)
isic_iou.append(iou_np)
# isic_assd.append(isic_b_asd)
isic_acc.append(isic_b_acc)
isic_sensitive.append(isic_b_sensitive)
isic_specificy.append(isic_b_specificy)
isic_precision.append(isic_b_precision)
isic_f1_score.append(isic_b_f1_score)
isic_Jaccard_M.append(isic_b_Jaccard_m)
isic_Jaccard_N.append(isic_b_Jaccard_n)
isic_Jaccard.append(isic_b_Jaccard)
isic_dc.append(isic_b_dc)
all_time = np.sum(infer_time)
isic_dice_mean = np.average(isic_dice)
isic_dice_std = np.std(isic_dice)
isic_iou_mean = np.average(isic_iou)
isic_iou_std = np.std(isic_iou)
# isic_assd_mean = np.average(isic_assd)
# isic_assd_std = np.std(isic_assd)
isic_acc_mean = np.average(isic_acc)
isic_acc_std = np.std(isic_acc)
isic_sensitive_mean = np.average(isic_sensitive)
isic_sensitive_std = np.std(isic_sensitive)
isic_specificy_mean = np.average(isic_specificy)
isic_specificy_std = np.std(isic_specificy)
isic_precision_mean = np.average(isic_precision)
isic_precision_std = np.std(isic_precision)
isic_f1_score_mean = np.average(isic_f1_score)
iisic_f1_score_std = np.std(isic_f1_score)
isic_Jaccard_M_mean = np.average(isic_Jaccard_M)
isic_Jaccard_M_std = np.std(isic_Jaccard_M)
isic_Jaccard_N_mean = np.average(isic_Jaccard_N)
isic_Jaccard_N_std = np.std(isic_Jaccard_N)
isic_Jaccard_mean = np.average(isic_Jaccard)
isic_Jaccard_std = np.std(isic_Jaccard)
isic_dc_mean = np.average(isic_dc)
isic_dc_std = np.std(isic_dc)
print('The ISIC mean dice: {isic_dice_mean: .4f}; The ISIC dice std: {isic_dice_std: .4f}'.format(
isic_dice_mean=isic_dice_mean, isic_dice_std=isic_dice_std))
print('The ISIC mean IoU: {isic_iou_mean: .4f}; The ISIC IoU std: {isic_iou_std: .4f}'.format(
isic_iou_mean=isic_iou_mean, isic_iou_std=isic_iou_std))
# print('The ISIC mean assd: {isic_assd_mean: .4f}; The ISIC assd std: {isic_assd_std: .4f}'.format(
# isic_assd_mean=isic_assd_mean, isic_assd_std=isic_assd_std))
print('The ISIC mean ACC: {isic_acc_mean: .4f}; The ISIC ACC std: {isic_acc_std: .4f}'.format(
isic_acc_mean=isic_acc_mean, isic_acc_std=isic_acc_std))
print('The ISIC mean sensitive: {isic_sensitive_mean: .4f}; The ISIC sensitive std: {isic_sensitive_std: .4f}'.format(
isic_sensitive_mean=isic_sensitive_mean, isic_sensitive_std=isic_sensitive_std))
print('The ISIC mean specificy: {isic_specificy_mean: .4f}; The ISIC specificy std: {isic_specificy_std: .4f}'.format(
isic_specificy_mean=isic_specificy_mean, isic_specificy_std=isic_specificy_std))
print('The ISIC mean precision: {isic_precision_mean: .4f}; The ISIC precision std: {isic_precision_std: .4f}'.format(
isic_precision_mean=isic_precision_mean, isic_precision_std=isic_precision_std))
print('The ISIC mean f1_score: {isic_f1_score_mean: .4f}; The ISIC f1_score std: {iisic_f1_score_std: .4f}'.format(
isic_f1_score_mean=isic_f1_score_mean, iisic_f1_score_std=iisic_f1_score_std))
print('The ISIC mean Jaccard_M: {isic_Jaccard_M_mean: .4f}; The ISIC Jaccard_M std: {isic_Jaccard_M_std: .4f}'.format(
isic_Jaccard_M_mean=isic_Jaccard_M_mean, isic_Jaccard_M_std=isic_Jaccard_M_std))
print('The ISIC mean Jaccard_N: {isic_Jaccard_N_mean: .4f}; The ISIC Jaccard_N std: {isic_Jaccard_N_std: .4f}'.format(
isic_Jaccard_N_mean=isic_Jaccard_N_mean, isic_Jaccard_N_std=isic_Jaccard_N_std))
print('The ISIC mean Jaccard: {isic_Jaccard_mean: .4f}; The ISIC Jaccard std: {isic_Jaccard_std: .4f}'.format(
isic_Jaccard_mean=isic_Jaccard_mean, isic_Jaccard_std=isic_Jaccard_std))
print('The ISIC mean dc: {isic_dc_mean: .4f}; The ISIC dc std: {isic_dc_std: .4f}'.format(
isic_dc_mean=isic_dc_mean, isic_dc_std=isic_dc_std))
print('The inference time: {time: .4f}'.format(time=all_time))
print("Number of trainable parameters {0} in Model {1}".format(num_para, args.id))
def save_imgs(img, msk, msk_pred, i, save_path, threshold=0.5, test_data_name=None):
img = img.squeeze(0).permute(1,2,0).detach().cpu().numpy()
img = img / 255. if img.max() > 1.1 else img
plt.figure(figsize=(7,15))
plt.subplot(3,1,1)
plt.imshow(img)
plt.axis('off')
plt.subplot(3,1,2)
plt.imshow(msk.squeeze(2), cmap= 'gray')
plt.axis('off')
plt.subplot(3,1,3)
plt.imshow(msk_pred.squeeze(2), cmap = 'gray')
plt.axis('off')
if test_data_name is not None:
save_path = save_path + test_data_name + '_'
plt.savefig(save_path + str(i) +'.png')
plt.close()
def main(args, val_acc_log, test_acc_log):
best_score = [0]
start_epoch = args.start_epoch
print('loading the {0},{1},{2} dataset ...'.format('train', 'validation', 'test'))
trainset = Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='train',
with_name=False, transform=Test_Transform[args.transform])
validset = Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='validation',
with_name=False, transform=Test_Transform[args.transform])
testset = Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='test',
with_name=True, transform=Test_Transform[args.transform])
trainloader = Data.DataLoader(dataset=trainset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=6)
validloader = Data.DataLoader(dataset=validset, batch_size=1, shuffle=False, pin_memory=True, num_workers=6)
testloader = Data.DataLoader(dataset=testset, batch_size=1, shuffle=False, pin_memory=True, num_workers=6)
print('Loading is done\n')
args.num_input = 3
args.num_classes = 2 #二分类(前景、背景)
args.out_size = (224, 320)
if args.id =="AttUNet":
model = AttUNet(in_channel=3, out_channel=2)
else:
model = Test_Model[args.id](classes=2, channels=3)
model = model.cuda()
print("------------------------------------------")
print("Network Architecture of Model {}:".format(args.id))
num_para = 0
for name, param in model.named_parameters():
num_mul = 1
for x in param.size():
num_mul *= x
num_para += num_mul
print(model)
print("Number of trainable parameters {0} in Model {1}".format(num_para, args.id))
print("------------------------------------------")
# Define optimizers and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr_rate, weight_decay=args.weight_decay)
# scheduler = lr_scheduler.StepLR(optimizer, step_size=70, gamma=0.5) # lr_1
# scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min = 0.0001) # lr_2
scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min = 0.00001) # lr_3
# scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=1, eta_min = 0.00001) # lr_4
# resume
if args.resume:
if os.path.isfile(args.resume):
print("=> Loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['opt_dict'])
print("=> Loaded checkpoint (epoch {})".format(checkpoint['epoch']))
else:
print("=> No checkpoint found at '{}'".format(args.resume))
print("Start training ...")
for epoch in range(start_epoch+1, args.epochs + 1):
scheduler.step()
train_avg_loss = train(trainloader, model, criterion, scheduler, optimizer, args, epoch)
isic_Jaccard_mean, isic_dc_mean, net_score = valid_isic(validloader, model, criterion, optimizer, args, epoch, best_score, val_acc_log)
if epoch > args.particular_epoch:
if epoch % args.save_epochs_steps == 0:
filename = args.ckpt + '/' + str(epoch) + '_' + args.data + '_checkpoint.pth.tar'
print('the model will be saved at {}'.format(filename))
state = {'epoch': epoch, 'state_dict': model.state_dict(), 'opt_dict': optimizer.state_dict()}
torch.save(state, filename)
print('Training Done! Start testing')
if args.data == 'ISIC2018':
test_isic(testloader, model, num_para, args, test_acc_log)
print('Testing Done!')
if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES']= '0' # gpu-id
assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), 'PyTorch>=0.4.0 is required'
parser = argparse.ArgumentParser(description='Comprehensive attention network for biomedical Dataset')
parser.add_argument('--id', default="unet",
help='') # Select a loaded model name
# Path related arguments
parser.add_argument('--root_path', default='/xujiheng/ISICdemo/ISIC2018_npy_all_224_320',
help='root directory of data') # The folder where the numpy data set is stored
parser.add_argument('--ckpt', default='/xujiheng/ISICdemo/UNet-ISIC2018/saved_models/',
help='folder to output checkpoints') # The folder in which the trained model is saved
parser.add_argument('--transform', default='C', type=str,
help='which ISIC2018_transform to choose')
parser.add_argument('--data', default='ISIC2018', help='choose the dataset')
parser.add_argument('--out_size', default=(224, 320), help='the output image size')
parser.add_argument('--val_folder', default='folder3', type=str,
help='folder1、folder2、folder3、folder4、folder5') # five-fold cross-validation
# optimization related arguments
parser.add_argument('--epochs', type=int, default=15, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--start_epoch', default=0, type=int,
help='epoch to start training. useful if continue from a checkpoint')
parser.add_argument('--batch_size', type=int, default=10, metavar='N',
help='input batch size for training (default: 12)') # batch_size
parser.add_argument('--lr_rate', type=float, default=1e-4, metavar='LR',
help='learning rate (default: 0.001)')
parser.add_argument('--num_classes', default=2, type=int,
help='number of classes')
parser.add_argument('--num_input', default=3, type=int,
help='number of input image for each patient')
parser.add_argument('--weight_decay', default=1e-8, type=float, help='weights regularizer')
parser.add_argument('--particular_epoch', default=30, type=int,
help='after this number, we will save models more frequently')
parser.add_argument('--save_epochs_steps', default=300, type=int,
help='frequency to save models after a particular number of epochs')
parser.add_argument('--resume', default='',
help='the checkpoint that resumes from')
args = parser.parse_args()
args.ckpt = os.path.join(args.ckpt, args.data, args.val_folder, args.id + "_{}".format(criterion))
if not os.path.isdir(args.ckpt):
os.makedirs(args.ckpt)
logfile = os.path.join(args.ckpt,'{}_{}_{}.txt'.format(args.val_folder, args.id, criterion)) # path of the training log
sys.stdout = Logger(logfile)
val_acc_log = os.path.join(args.ckpt,'val_acc_{}_{}_{}.txt'.format(args.val_folder, args.id, criterion))
test_acc_log = os.path.join(args.ckpt,'test_acc_{}_{}_{}.txt'.format(args.val_folder, args.id, criterion))
print('Models are saved at %s' % (args.ckpt))
print("Input arguments:")
for key, value in vars(args).items():
print("{:16} {}".format(key, value))
if args.start_epoch > 1:
args.resume = args.ckpt + '/' + str(args.start_epoch) + '_' + args.data + '_checkpoint.pth.tar'
main(args, val_acc_log, test_acc_log)
下面我将分模块进行讲解。
官方 / 第三方库
python
import os # Python 标准库
import torch # PyTorch 官方
import math # Python 标准库
import visdom # 第三方可视化工具(Facebook 开源)
import torch.utils.data as Data # PyTorch 官方
import argparse # Python 标准库
import numpy as np # 第三方科学计算库
import sys # Python 标准库
from tqdm import tqdm # 第三方进度条库
import torch.nn as nn # PyTorch 官方
from distutils.version import LooseVersion # Python 标准库(用于版本比较)
from torch.optim import lr_scheduler # PyTorch 官方
from time import * # Python 标准库
from matplotlib import pyplot as plt # 第三方绘图库
项目中的 .py 文件
位于 Datasets/、Models/、utils/下。
python
# 数据集类(你自己定义的 ISIC2018 数据加载器)
from Datasets.ISIC2018 import ISIC2018_dataset
# 数据增强/预处理变换(你自己写的 transform 函数)
from utils.transform import ISIC2018_transform, ISIC2018_transform_320, ISIC2018_transform_newdata
# 模型结构(你自己实现或集成的 U-Net 及 Attention U-Net)
from Models.compare_networks.unet import unet
from Models.compare_networks.AttUnet import AttUNet
# 自定义损失函数和评估指标
from utils.dice_loss import get_soft_label, val_dice_isic, SoftDiceLoss
from utils.dice_loss import Intersection_over_Union_isic
from utils.dice_loss_github import SoftDiceLoss_git, CrossentropyND
# 工具类:平均值记录器
from utils.evaluation import AverageMeter
# 二值分割评价指标(如 Dice、Jaccard、ASSD 等)
from utils.binary import assd, dc, jc, precision, sensitivity, specificity, F1, ACC
下面就这些来逐一讲解代码。
UNet-ISIC2018/Datasets/ISIC2018.py
python
import os
import PIL
import torch
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from os import listdir
from os.path import join
from PIL import Image
from utils.transform import itensity_normalize
from torch.utils.data.dataset import Dataset
class ISIC2018_dataset(Dataset):
def __init__(self, dataset_folder='/ISIC2018_Task1_npy_all',
folder='folder0', train_type='train', with_name=False, transform=None):
self.transform = transform
self.train_type = train_type
self.with_name = with_name
self.folder_file = './Datasets/' + folder
if self.train_type in ['train', 'validation', 'test']:
# this is for cross validation
with open(join(self.folder_file, self.folder_file.split('/')[-1] + '_' + self.train_type + '.list'),
'r') as f:
self.image_list = f.readlines()
self.image_list = [item.replace('\n', '') for item in self.image_list]
self.folder = [join(dataset_folder, 'image', x) for x in self.image_list]
self.mask = [join(dataset_folder, 'label', x.split('.')[0] + '_segmentation.npy') for x in self.image_list]
# self.folder = sorted([join(dataset_folder, self.train_type, 'image', x) for x in
# listdir(join(dataset_folder, self.train_type, 'image'))])
# self.mask = sorted([join(dataset_folder, self.train_type, 'label', x) for x in
# listdir(join(dataset_folder, self.train_type, 'label'))])
else:
print("Choosing type error, You have to choose the loading data type including: train, validation, test")
assert len(self.folder) == len(self.mask)
def __getitem__(self, item: int):
image = np.load(self.folder[item])
label = np.load(self.mask[item])
name = self.folder[item].split('/')[-1]
sample = {'image': image, 'label': label}
if self.transform is not None:
# TODO: transformation to argument datasets
sample = self.transform(sample, self.train_type)
if self.with_name:
return name, sample['image'], sample['label']
else:
return sample['image'], sample['label']
def __len__(self):
return len(self.folder)
主数据集类 ISIC2018_dataset
类定义与初始化 init:
python
class ISIC2018_dataset(Dataset):
def __init__(self, dataset_folder='/ISIC2018_Task1_npy_all',
folder='folder0', train_type='train', with_name=False, transform=None):
self.transform = transform
self.train_type = train_type
self.with_name = with_name
self.folder_file = './Datasets/' + folder
| 参数 | 含义 |
|---|---|
dataset_folder |
存放 .npy 图像和标签的根目录 |
folder |
交叉验证划分文件夹名,如 'folder3' |
train_type |
数据类型:'train' / 'validation' / 'test' |
with_name |
是否返回图像文件名(测试时需要) |
transform |
数据增强/预处理函数 |
python
if self.train_type in ['train', 'validation', 'test']:
# this is for cross validation
with open(join(self.folder_file, self.folder_file.split('/')[-1] + '_' + self.train_type + '.list'),
'r') as f:
self.image_list = f.readlines()
self.image_list = [item.replace('\n', '') for item in self.image_list]
self.folder = [join(dataset_folder, 'image', x) for x in self.image_list]
self.mask = [join(dataset_folder, 'label', x.split('.')[0] + '_segmentation.npy') for x in self.image_list]
# self.folder = sorted([join(dataset_folder, self.train_type, 'image', x) for x in
# listdir(join(dataset_folder, self.train_type, 'image'))])
# self.mask = sorted([join(dataset_folder, self.train_type, 'label', x) for x in
# listdir(join(dataset_folder, self.train_type, 'label'))])
else:
print("Choosing type error, You have to choose the loading data type including: train, validation, test")
assert len(self.folder) == len(self.mask)
指向划分列表 list 所在目录,例如 ./Datasets/folder3/读取如 ./Datasets/folder3/folder3_train.list 文件,获取图像文件名列表。构建完整路径。
图像路径:/ISIC2018_Task1_npy_all/image/ISIC_xxx.npy
标签路径:/ISIC2018_Task1_npy_all/label/ISIC_xxx_segmentation.npy
assert len(self.folder) == len(self.mask)确保图像和标签数量一致,防止错位。
获取单个样本 getitem
python
def __getitem__(self, item: int):
image = np.load(self.folder[item]) # 加载图像 (H, W, 3)
label = np.load(self.mask[item]) # 加载标签 (H, W),值为 0/1
name = self.folder[item].split('/')[-1] # 如 'ISIC_0000000.npy'
sample = {'image': image, 'label': label}
if self.transform is not None:
sample = self.transform(sample, self.train_type)
if self.with_name:
return name, sample['image'], sample['label']
else:
return sample['image'], sample['label']
返回格式由 with_name 控制:训练/验证(with_name=False) → (image, label);测试 → (name, image, label)(便于保存结果)。
transform 接收整个 sample 字典和 train_type,可实现"训练时增强,测试时不增强"。
在训练主体代码中:
python
from Datasets.ISIC2018 import ISIC2018_dataset
Test_Dataset = {'ISIC2018': ISIC2018_dataset}
# 在 main() 中:
trainset = Test_Dataset[args.data](
dataset_folder=args.root_path, # '/xujiheng/ISICdemo/ISIC2018_npy_all_224_320'
folder=args.val_folder, # 如 'folder3'(五折交叉验证)
train_type='train',
with_name=False,
transform=Test_Transform[args.transform] # 如 ISIC2018_transform_newdata
)
validset = Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='validation',
with_name=False, transform=Test_Transform[args.transform])
testset = Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='test',
with_name=True, transform=Test_Transform[args.transform])
ISIC2018_dataset 是连接文件夹中.npy 数据与训练 pipeline 的桥梁,负责按交叉验证划分加载样本,并应用指定的图像/标签预处理,为模型训练/验证/测试提供标准化输入。
trainset是一个 PyTorch Dataset 对象,当调用 trainset[i] 时,它会返回一个由一张图像及其对应的标签组成的元组 (image_tensor, label_tensor)。
image_tensor: 形状为 (3, H, W) 的张量,表示输入图像(RGB三通道)。
label_tensor: 形状为 (1, H, W) 的张量,表示分割标签(通常是二值化的掩码)。
UNet-ISIC2018/utils/transform.py
python
import torch
import random
import PIL
import numbers
import numpy as np
import torch.nn as nn
import collections
import matplotlib.pyplot as plt
import torchvision.transforms as ts
import torchvision.transforms.functional as TF
from PIL import Image, ImageDraw
_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Image.BILINEAR: 'PIL.Image.BILINEAR',
Image.BICUBIC: 'PIL.Image.BICUBIC',
Image.LANCZOS: 'PIL.Image.LANCZOS',
}
def ISIC2018_transform(sample, train_type):
image, label = Image.fromarray(np.uint8(sample['image']), mode='RGB'),\
Image.fromarray(np.uint8(sample['label']), mode='L')
if train_type == 'train':
image, label = randomcrop(size=(224, 300))(image, label)
image, label = randomflip_rotate(image, label, p=0.5, degrees=30)
else:
image, label = resize(size=(224, 300))(image, label)
image = ts.Compose([ts.ToTensor(),
ts.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])(image)
label = ts.ToTensor()(label)
return {'image': image, 'label': label}
def ISIC2018_transform_320(sample, train_type):
image, label = Image.fromarray(np.uint8(sample['image']), mode='RGB'),\
Image.fromarray(np.uint8(sample['label']), mode='L')
if train_type == 'train':
image, label = randomcrop(size=(224, 320))(image, label)
image, label = randomflip_rotate(image, label, p=0.5, degrees=30)
else:
image, label = resize(size=(224, 320))(image, label)
image = ts.Compose([ts.ToTensor(),
ts.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])(image)
label = ts.ToTensor()(label)
return {'image': image, 'label': label}
def ISIC2018_transform_newdata(sample, train_type):
image, label = Image.fromarray(np.uint8(sample['image']), mode='RGB'),\
Image.fromarray(np.uint8(sample['label']), mode='L')
if train_type == 'train':
# image, label = randomcrop(size=(224, 320))(image, label)
image, label = randomflip_rotate(image, label, p=0.5, degrees=30)
else:
image = image
label = label
image = ts.Compose([ts.ToTensor(),
ts.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])(image)
label = ts.ToTensor()(label)
return {'image': image, 'label': label}
# these are founctional function for transform
def randomflip_rotate(img, lab, p=0.5, degrees=0):
if random.random() < p:
img = TF.hflip(img)
lab = TF.hflip(lab)
if random.random() < p:
img = TF.vflip(img)
lab = TF.vflip(lab)
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
degrees = (-degrees, degrees)
else:
if len(degrees) != 2:
raise ValueError("If degrees is a sequence, it must be of len 2.")
degrees = degrees
angle = random.uniform(degrees[0], degrees[1])
img = TF.rotate(img, angle)
lab = TF.rotate(lab, angle)
return img, lab
class randomcrop(object):
"""Crop the given PIL Image and mask at a random location.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
padding (int or sequence, optional): Optional padding on each border
of the image. Default is 0, i.e no padding. If a sequence of length
4 is provided, it is used to pad left, top, right, bottom borders
respectively.
pad_if_needed (boolean): It will pad the image if smaller than the
desired size to avoid raising an exception.
"""
def __init__(self, size, padding=0, pad_if_needed=False):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
self.padding = padding
self.pad_if_needed = pad_if_needed
@staticmethod
def get_params(img, output_size):
"""Get parameters for ``crop`` for a random crop.
Args:
img (PIL Image): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
w, h = img.size
th, tw = output_size
if w == tw and h == th:
return 0, 0, h, w
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
return i, j, th, tw
def __call__(self, img, lab):
"""
Args:
img (PIL Image): Image to be cropped.
lab (PIL Image): Image to be cropped.
Returns:
PIL Image: Cropped image and mask.
"""
if self.padding > 0:
img = TF.pad(img, self.padding)
lab = TF.pad(lab, self.padding)
# pad the width if needed
if self.pad_if_needed and img.size[0] < self.size[1]:
img = TF.pad(img, (int((1 + self.size[1] - img.size[0]) / 2), 0))
lab = TF.pad(lab, (int((1 + self.size[1] - lab.size[0]) / 2), 0))
# pad the height if needed
if self.pad_if_needed and img.size[1] < self.size[0]:
img = TF.pad(img, (0, int((1 + self.size[0] - img.size[1]) / 2)))
lab = TF.pad(lab, (0, int((1 + self.size[0] - lab.size[1]) / 2)))
i, j, h, w = self.get_params(img, self.size)
return TF.crop(img, i, j, h, w), TF.crop(lab, i, j, h, w)
def __repr__(self):
return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
class resize(object):
"""Resize the input PIL Image and mask to the given size.
Args:
size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
"""
def __init__(self, size, interpolation=Image.BILINEAR):
assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
self.size = size
self.interpolation = interpolation
def __call__(self, img, lab):
"""
Args:
img (PIL Image): Image to be scaled.
lab (PIL Image): Image to be scaled.
Returns:
PIL Image: Rescaled image and mask.
"""
return TF.resize(img, self.size, self.interpolation), TF.resize(lab, self.size, self.interpolation)
def __repr__(self):
interpolate_str = _pil_interpolation_to_str[self.interpolation]
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
def itensity_normalize(volume):
"""
normalize the itensity of an nd volume based on the mean and std of nonzeor region
inputs:
volume: the input nd volume
outputs:
out: the normalized n d volume
"""
# pixels = volume[volume > 0]
mean = volume.mean()
std = volume.std()
out = (volume - mean) / std
out_random = np.random.normal(0, 1, size=volume.shape)
out[volume == 0] = out_random[volume == 0]
return out
三个主变换函数(核心接口)
ISIC2018_transform(sample, train_type)
python
def ISIC2018_transform(sample, train_type):
image, label = Image.fromarray(np.uint8(sample['image']), mode='RGB'),\
Image.fromarray(np.uint8(sample['label']), mode='L')
if train_type == 'train':
image, label = randomcrop(size=(224, 300))(image, label)
image, label = randomflip_rotate(image, label, p=0.5, degrees=30)
else:
image, label = resize(size=(224, 300))(image, label)
image = ts.Compose([ts.ToTensor(),
ts.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])(image)
label = ts.ToTensor()(label)
return {'image': image, 'label': label}
输入:sample 是一个字典,包含 'image'(H×W×3 numpy array)和 'label'(H×W 单通道 mask)。
训练模式 (train_type == 'train'):先随机裁剪 到 (224, 300)再随机水平/垂直翻转 + 随机旋转 ±30°。
验证/测试模式:直接 resize 到 (224, 300)。
最后统一:图像:转为 Tensor 并归一化到 [-1, 1](因为 mean=std=0.5);标签:仅转为 Tensor(值域 [0, 1]);输出尺寸是 (224, 300),不是正方形。
ISIC2018_transform_320(sample, train_type)
所有尺寸改为 (224, 320)(更宽),其他逻辑完全一致。
ISIC2018_transform_newdata(sample, train_type)
数据已处理成(224,320)前提下,可以选择在训练时候使用这个。不进行 resize 或 crop!保留原始尺寸。仅在训练时做翻转+旋转增强。
python
def ISIC2018_transform_newdata(sample, train_type):
image, label = Image.fromarray(np.uint8(sample['image']), mode='RGB'),\
Image.fromarray(np.uint8(sample['label']), mode='L')
if train_type == 'train':
# image, label = randomcrop(size=(224, 320))(image, label)
image, label = randomflip_rotate(image, label, p=0.5, degrees=30)
else:
image = image
label = label
image = ts.Compose([ts.ToTensor(),
ts.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])(image)
label = ts.ToTensor()(label)
return {'image': image, 'label': label}
功能性增强函数randomflip_rotate(img, lab, p=0.5, degrees=0)
python
def randomflip_rotate(img, lab, p=0.5, degrees=0):
# 水平翻转(概率 p)
if random.random() < p:
img = TF.hflip(img); lab = TF.hflip(lab)
# 垂直翻转(概率 p)
if random.random() < p:
img = TF.vflip(img); lab = TF.vflip(lab)
# 处理旋转角度范围
if isinstance(degrees, numbers.Number):
degrees = (-degrees, degrees)
angle = random.uniform(degrees[0], degrees[1])
img = TF.rotate(img, angle, fill=0) # 注意:默认 fill=0(黑色)
lab = TF.rotate(lab, angle, fill=0) # 标签也用 0 填充(背景)
return img, lab
在训练主脚本中:
python
from utils.transform import ISIC2018_transform, ISIC2018_transform_320, ISIC2018_transform_newdata
Test_Transform = {
'A': ISIC2018_transform,
'B': ISIC2018_transform_320,
'C': ISIC2018_transform_newdata
}
# 在 main() 中:
trainset = Test_Dataset[args.data](..., transform=Test_Transform[args.transform])
通过命令行参数 --transform C(默认)选择使用哪一种预处理策略。parser.add_argument('--transform', default='C', type=str, ...)。通过使用 ISIC2018_transform_newdata,在保留原始图像分辨率(实际为预处理后的统一尺寸 224×320)的前提下,仅对训练数据施加翻转和旋转增强,实现了有效防止过拟合的同时确保测试阶段分割结果的精细度。
UNet-ISIC2018/Models/compare_networks/unet.py
代码:
python
import torch
import torch.nn as nn
class conv_block(nn.Module):
def __init__(self,ch_in,ch_out):
super(conv_block,self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
x = self.conv(x)
return x
class up_conv(nn.Module):
def __init__(self,ch_in,ch_out):
super(up_conv,self).__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
x = self.up(x)
return x
class unet(nn.Module): #添加了空间注意力和通道注意力
def __init__(self,classes=2,channels=3):
super(unet,self).__init__()
self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
self.Conv1 = conv_block(ch_in=channels,ch_out=64) #64
self.Conv2 = conv_block(ch_in=64,ch_out=128) #64 128
self.Conv3 = conv_block(ch_in=128,ch_out=256) #128 256
self.Conv4 = conv_block(ch_in=256,ch_out=512) #256 512
self.Conv5 = conv_block(ch_in=512,ch_out=1024) #512 1024
self.Up5 = up_conv(ch_in=1024,ch_out=512) #1024 512
self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
self.Up4 = up_conv(ch_in=512,ch_out=256) #512 256
self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
self.Up3 = up_conv(ch_in=256,ch_out=128) #256 128
self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
self.Up2 = up_conv(ch_in=128,ch_out=64) #128 64
self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
self.Conv_1x1 = nn.Conv2d(64,classes,kernel_size=1,stride=1,padding=0) #64
def forward(self,x):
# encoding path
x1 = self.Conv1(x)
x2 = self.Maxpool(x1)
x2 = self.Conv2(x2)
x3 = self.Maxpool(x2)
x3 = self.Conv3(x3)
x4 = self.Maxpool(x3)
x4 = self.Conv4(x4)
x5 = self.Maxpool(x4)
x5 = self.Conv5(x5)
# decoding + concat path
d5 = self.Up5(x5)
d5 = torch.cat((x4,d5),dim=1)
d5 = self.Up_conv5(d5)
d4 = self.Up4(d5)
d4 = torch.cat((x3,d4),dim=1)
d4 = self.Up_conv4(d4)
d3 = self.Up3(d4)
d3 = torch.cat((x2,d3),dim=1)
d3 = self.Up_conv3(d3)
d2 = self.Up2(d3)
d2 = torch.cat((x1,d2),dim=1)
d2 = self.Up_conv2(d2)
d1 = self.Conv_1x1(d2)
return d1
| 参数 | 含义 | 你任务中的正确值 |
|---|---|---|
channels |
输入图像的通道数 | 3(RGB 彩色图) |
classes |
模型最终输出的类别数(通道数) | 1(二值分割) |
注:
U-Net 能支持 (224, 320) 是因为它是全卷积网络,没有全连接层(Fully Connected Layer)。
例子:
输入 224×224 → 最后卷积输出 512×7×7 → 展平为 25088 维
输入 320×320 → 最后卷积输出 512×10×10 → 展平为 51200 维
但 Linear(25088, 1000) 无法接收 51200 维 → 崩溃!
同时该尺寸能被下采样因子 16 整除,确保编码器与解码器特征图尺寸对齐;若输入尺寸不能被 16 整除(如 225×321),则上采样后无法恢复原尺寸,导致跳跃连接时张量形状不匹配而报错。
在主训练脚本中:
python
# ==============================
# 1. 导入模型
# ==============================
from Models.compare_networks.unet import unet
Test_Model = {
'UNet': unet,
# 其他模型...
}
# ==============================
# 2. 初始化模型
# ==============================
model = Test_Model[args.model](classes=1, channels=3).cuda()
model.train()
# ==============================
# 3. 定义优化器 & 损失函数
# ==============================
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.BCEWithLogitsLoss() # 或自定义 Dice+BCE
# ==============================
# 4. 训练循环(核心)
# ==============================
for epoch in range(args.epochs):
for step, (x, y) in enumerate(train_loader):
# x: (B, 3, 224, 320), y: (B, 1, 224, 320)
image = x.float().cuda() # 输入图像
target = y.float().cuda() # 标签(0/1)
# 前向传播
pred = model(image) # 输出: (B, 1, 224, 320),logits
# 计算损失
loss = criterion(pred, target)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# ==============================
# 5. 验证/测试(推理)
# ==============================
model.eval()
with torch.no_grad():
for name, x, y in test_loader:
image = x.float().cuda()
target = y.float().cuda()
pred = model(image) # logits
pred_sigmoid = torch.sigmoid(pred) # 转为概率 [0,1]
pred_binary = (pred_sigmoid > 0.5).float() # 二值 mask
# 保存或评估 pred_binary
**(B, C, H, W):**B 表示 batch size(批大小),即 一次前向/反向传播中同时处理的样本数量。
| 阶段 | 操作 | 输入张量形状 (B, C, H, W) | 输出张量形状 (B, C, H, W) | 说明 |
|---|---|---|---|---|
| Input | - | (B, 3, 224, 320) |
- | RGB 图像,已归一化到 [-1, 1] |
| Encoder ↓ | ||||
| Conv1 | conv_block(3→64) |
(B, 3, 224, 320) |
(B, 64, 224, 320) |
双卷积 + ReLU + BN |
| MaxPool1 | MaxPool2d(2) |
(B, 64, 224, 320) |
(B, 64, 112, 160) |
下采样 |
| Conv2 | conv_block(64→128) |
(B, 64, 112, 160) |
(B, 128, 112, 160) |
|
| MaxPool2 | MaxPool2d(2) |
(B, 128, 112, 160) |
(B, 128, 56, 80) |
|
| Conv3 | conv_block(128→256) |
(B, 128, 56, 80) |
(B, 256, 56, 80) |
|
| MaxPool3 | MaxPool2d(2) |
(B, 256, 56, 80) |
(B, 256, 28, 40) |
|
| Conv4 | conv_block(256→512) |
(B, 256, 28, 40) |
(B, 512, 28, 40) |
|
| MaxPool4 | MaxPool2d(2) |
(B, 512, 28, 40) |
(B, 512, 14, 20) |
|
| Conv5 | conv_block(512→1024) |
(B, 512, 14, 20) |
(B, 1024, 14, 20) |
编码器最深层 |
| Decoder ↑ | ||||
| Up5 | up_conv(1024→512) (Upsample×2 + Conv) |
(B, 1024, 14, 20) |
(B, 512, 28, 40) |
上采样 |
| Concat x4 | torch.cat([x4, d5], dim=1) |
x4: (B, 512, 28, 40) d5: (B, 512, 28, 40) |
(B, 1024, 28, 40) |
跳跃连接 |
| Up_conv5 | conv_block(1024→512) |
(B, 1024, 28, 40) |
(B, 512, 28, 40) |
融合特征 |
| Up4 | up_conv(512→256) |
(B, 512, 28, 40) |
(B, 256, 56, 80) |
|
| Concat x3 | torch.cat([x3, d4], dim=1) |
x3: (B, 256, 56, 80) d4: (B, 256, 56, 80) |
(B, 512, 56, 80) |
|
| Up_conv4 | conv_block(512→256) |
(B, 512, 56, 80) |
(B, 256, 56, 80) |
|
| Up3 | up_conv(256→128) |
(B, 256, 56, 80) |
(B, 128, 112, 160) |
|
| Concat x2 | torch.cat([x2, d3], dim=1) |
x2: (B, 128, 112, 160) d3: (B, 128, 112, 160) |
(B, 256, 112, 160) |
|
| Up_conv3 | conv_block(256→128) |
(B, 256, 112, 160) |
(B, 128, 112, 160) |
|
| Up2 | up_conv(128→64) |
(B, 128, 112, 160) |
(B, 64, 224, 320) |
|
| Concat x1 | torch.cat([x1, d2], dim=1) |
x1: (B, 64, 224, 320) d2: (B, 64, 224, 320) |
(B, 128, 224, 320) |
|
| Up_conv2 | conv_block(128→64) |
(B, 128, 224, 320) |
(B, 64, 224, 320) |
|
| Final Conv | Conv2d(64→1, 1×1) |
(B, 64, 224, 320) |
(B, 1, 224, 320) |
输出 logits |
模型最终返回的是一个单通道的"灰度图"形式的分割结果。
UNet-ISIC2018/utils/dice_loss.py
python
import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
class SoftDiceLoss(_Loss):
'''
Soft_Dice = 2*|dot(A, B)| / (|dot(A, A)| + |dot(B, B)| + eps)
eps is a small constant to avoid zero division,
'''
def __init__(self, *args, **kwargs):
super(SoftDiceLoss, self).__init__()
def forward(self, prediction, soft_ground_truth, num_class=3, weight_map=None, eps=1e-8):
dice_loss_ave, dice_score_lesion = soft_dice_loss(prediction, soft_ground_truth, num_class, weight_map)
return dice_loss_ave, dice_score_lesion
def get_soft_label(input_tensor, num_class):
"""
convert a label tensor to soft label
input_tensor: tensor with shape [N, C, H, W]
output_tensor: shape [N, H, W, num_class]
"""
tensor_list = []
input_tensor = input_tensor.permute(0, 2, 3, 1)
for i in range(num_class):
temp_prob = torch.eq(input_tensor, i * torch.ones_like(input_tensor))
tensor_list.append(temp_prob)
output_tensor = torch.cat(tensor_list, dim=-1)
output_tensor = output_tensor.float()
return output_tensor
def soft_dice_loss(prediction, soft_ground_truth, num_class, weight_map=None):
predict = prediction.permute(0, 2, 3, 1)
pred = predict.contiguous().view(-1, num_class)
# pred = F.softmax(pred, dim=1)
ground = soft_ground_truth.view(-1, num_class)
n_voxels = ground.size(0)
if weight_map is not None:
weight_map = weight_map.view(-1)
weight_map_nclass = weight_map.repeat(num_class).view_as(pred)
ref_vol = torch.sum(weight_map_nclass * ground, 0)
intersect = torch.sum(weight_map_nclass * ground * pred, 0)
seg_vol = torch.sum(weight_map_nclass * pred, 0)
else:
ref_vol = torch.sum(ground, 0)
intersect = torch.sum(ground * pred, 0)
seg_vol = torch.sum(pred, 0)
dice_score = (2.0 * intersect + 1e-5) / (ref_vol + seg_vol + 1.0 + 1e-5)
# dice_loss = 1.0 - torch.mean(dice_score)
# return dice_loss
dice_loss = -torch.log(dice_score)
dice_loss_ave = torch.mean(dice_loss)
dice_score_lesion = dice_loss[1]
return dice_loss_ave, dice_score_lesion
def IOU_loss(prediction, soft_ground_truth, num_class):
predict = prediction.permute(0, 2, 3, 1)
pred = prediction.contiguous().view(-1, num_class)
# pred = F.softmax(pred, dim=1)
ground = soft_ground_truth.view(-1, num_class)
ref_vol = torch.sum(ground, 0)
intersect = torch.sum(ground * pred, 0)
seg_vol = torch.sum(pred, 0)
iou_score = intersect / (ref_vol + seg_vol - intersect + 1.0)
iou_loss = torch.mean(-torch.log(iou_score))
return iou_loss
def jc_loss(prediction, soft_ground_truth, num_class):
predict = prediction.permute(0, 2, 3, 1)
pred = predict[:,:,:,1].contiguous().view(-1, num_class)
# pred = prediction[:,:,:,1].view(-1, num_class)
# pred = F.softmax(pred, dim=1)
ground = soft_ground_truth[:,:,:,1].view(-1, num_class)
ref_vol = torch.sum(ground, 0)
intersect = torch.sum(ground * pred, 0)
seg_vol = torch.sum(pred, 0)
iou_score = intersect / (ref_vol + seg_vol - intersect + 1.0)
#jc = 10*(1-iou_score)
jc = 20*torch.mean(-torch.log(iou_score))
return jc
def val_dice_fetus(prediction, soft_ground_truth, num_class):
# predict = prediction.permute(0, 2, 3, 1)
pred = prediction.contiguous().view(-1, num_class)
# pred = F.softmax(pred, dim=1)
ground = soft_ground_truth.view(-1, num_class)
ref_vol = torch.sum(ground, 0)
intersect = torch.sum(ground * pred, 0)
seg_vol = torch.sum(pred, 0)
dice_score = 2.0 * intersect / (ref_vol + seg_vol + 1.0)
dice_mean_score = torch.mean(dice_score)
placenta_dice = dice_score[1]
brain_dice = dice_score[2]
return placenta_dice, brain_dice
def Intersection_over_Union_fetus(prediction, soft_ground_truth, num_class):
# predict = prediction.permute(0, 2, 3, 1)
pred = prediction.contiguous().view(-1, num_class)
# pred = F.softmax(pred, dim=1)
ground = soft_ground_truth.view(-1, num_class)
ref_vol = torch.sum(ground, 0)
intersect = torch.sum(ground * pred, 0)
seg_vol = torch.sum(pred, 0)
iou_score = intersect / (ref_vol + seg_vol - intersect + 1.0)
dice_mean_score = torch.mean(iou_score)
placenta_iou = iou_score[1]
brain_iou = iou_score[2]
return placenta_iou, brain_iou
def val_dice_isic(prediction, soft_ground_truth, num_class):
# predict = prediction.permute(0, 2, 3, 1)
# pred = prediction.contiguous().view(-1, num_class)
pred = prediction.view(-1, num_class)
# pred = F.softmax(pred, dim=1)
ground = soft_ground_truth.view(-1, num_class)
ref_vol = torch.sum(ground, 0)
intersect = torch.sum(ground * pred, 0)
seg_vol = torch.sum(pred, 0)
# dice_score = 2.0 * intersect / (ref_vol + seg_vol + 1.0)
dice_score = 2.0 * intersect / (ref_vol + seg_vol + 1e-6)
# dice_mean_score = torch.mean(dice_score)
return dice_score
def val_dice_isic_raw(prediction, soft_ground_truth, num_class):
# predict = prediction.permute(0, 2, 3, 1)
pred = prediction.contiguous().view(-1, num_class)
# pred = F.softmax(pred, dim=1)
ground = soft_ground_truth.view(-1, num_class)
ref_vol = torch.sum(ground, 0)
intersect = torch.sum(ground * pred, 0)
seg_vol = torch.sum(pred, 0)
# dice_score = 2.0 * intersect / (ref_vol + seg_vol + 1.0)
dice_score = 2.0 * intersect / (ref_vol + seg_vol + 1e-6)
dice_mean_score = torch.mean(dice_score)
return dice_mean_score
def Intersection_over_Union_isic(prediction, soft_ground_truth, num_class):
# predict = prediction.permute(0, 2, 3, 1)
pred = prediction.contiguous().view(-1, num_class)
# pred = F.softmax(pred, dim=1)
ground = soft_ground_truth.view(-1, num_class)
ref_vol = torch.sum(ground, 0)
intersect = torch.sum(ground * pred, 0)
seg_vol = torch.sum(pred, 0)
iou_score = intersect / (ref_vol + seg_vol - intersect + 1.0)
iou_mean_score = torch.mean(iou_score)
return iou_mean_score
SoftDiceLoss 类(主损失类)
python
class SoftDiceLoss(_Loss):
'''
Soft_Dice = 2*|dot(A, B)| / (|dot(A, A)| + |dot(B, B)| + eps)
eps is a small constant to avoid zero division,
'''
def __init__(self, *args, **kwargs):
super(SoftDiceLoss, self).__init__()
def forward(self, prediction, soft_ground_truth, num_class=3, weight_map=None, eps=1e-8):
dice_loss_ave, dice_score_lesion = soft_dice_loss(prediction, soft_ground_truth, num_class, weight_map)
return dice_loss_ave, dice_score_lesion
继承 PyTorch 的 _Loss,可直接用于训练。注:默认 num_class=3 是为多类任务设计的,但 ISIC2018 是二值任务,需传 num_class=2
get_soft_label:硬标签 → 软标签(one-hot)
python
def get_soft_label(input_tensor, num_class):
"""
convert a label tensor to soft label
input_tensor: tensor with shape [N, C, H, W]
output_tensor: shape [N, H, W, num_class]
"""
tensor_list = []
input_tensor = input_tensor.permute(0, 2, 3, 1)
for i in range(num_class):
temp_prob = torch.eq(input_tensor, i * torch.ones_like(input_tensor))
tensor_list.append(temp_prob)
output_tensor = torch.cat(tensor_list, dim=-1)
output_tensor = output_tensor.float()
return output_tensor
将整数标签(如 0/1)转为 one-hot 编码。但在 ISIC2018 中通常不用,因为标签已是 float [0,1] 单通道,且模型输出也是单通道 logits。
soft_dice_loss:核心 Dice 损失计算
python
def soft_dice_loss(prediction, soft_ground_truth, num_class, weight_map=None):
predict = prediction.permute(0, 2, 3, 1)
pred = predict.contiguous().view(-1, num_class)
# pred = F.softmax(pred, dim=1)
ground = soft_ground_truth.view(-1, num_class)
n_voxels = ground.size(0)
if weight_map is not None:
weight_map = weight_map.view(-1)
weight_map_nclass = weight_map.repeat(num_class).view_as(pred)
ref_vol = torch.sum(weight_map_nclass * ground, 0)
intersect = torch.sum(weight_map_nclass * ground * pred, 0)
seg_vol = torch.sum(weight_map_nclass * pred, 0)
else:
ref_vol = torch.sum(ground, 0)
intersect = torch.sum(ground * pred, 0)
seg_vol = torch.sum(pred, 0)
dice_score = (2.0 * intersect + 1e-5) / (ref_vol + seg_vol + 1.0 + 1e-5)
# dice_loss = 1.0 - torch.mean(dice_score)
# return dice_loss
dice_loss = -torch.log(dice_score)
dice_loss_ave = torch.mean(dice_loss)
dice_score_lesion = dice_loss[1]
return dice_loss_ave, dice_score_lesion
在主训练代码中:
python
from utils.dice_loss import get_soft_label, val_dice_isic, SoftDiceLoss
from utils.dice_loss import Intersection_over_Union_isic
| 函数/类 | 用途 | 调用阶段 |
|---|---|---|
get_soft_label |
将硬标签(0/1)转为 one-hot 软标签(如 [0,1] → [1,0], [0,1]) |
训练 + 测试 |
val_dice_isic |
计算 ISIC2018 的 Dice 系数(用于测试评估) | 测试阶段 |
SoftDiceLoss |
自定义的 Dice 损失函数 | 训练阶段(理论上) |
Intersection_over_Union_isic |
计算 IoU(Jaccard Index) | 测试阶段 |
get_soft_label:
python
# train() 函数内
target = y.float().cuda() # [B, 1, H, W]
target_soft_a = get_soft_label(target, args.num_classes) # → [B, H, W, 2]
target_soft = target_soft_a.permute(0, 3, 1, 2) # → [B, 2, H, W]
因为模型输出是 2 通道(classes=2),所以将单通道标签转为 one-hot 双通道,以匹配损失函数输入。
python
# test_isic() 函数内
output_dis = torch.max(output, 1)[1].unsqueeze(dim=1) # [B,1,H,W], 整数预测 (0/1)
target = lab.float().cuda() # [B,1,H,W], float 标签 (0.0/1.0)
output_soft = get_soft_label(output_dis, 2) # → [B, H, W, 2]
target_soft = get_soft_label(target, 2) # → [B, H, W, 2]
为 val_dice_isic 提供符合其接口的 one-hot 输入。
val_dice_isic
python
# test_isic() 函数内
isic_b_dice = val_dice_isic(output_soft, target_soft, 2) # num_class=2
dice_np = isic_b_dice.data.cpu().numpy() # shape: (2,)
isic_dice.append(dice_np)
dice_np[0]:背景 Dice
dice_np[1]:病灶(前景)Dice ← 这才是你要的指标
Intersection_over_Union_isic
python
# test_isic() 函数内
output_dis_test = output_dis.permute(0, 2, 3, 1).float() # [B, H, W, 1]
target_test = target.permute(0, 2, 3, 1).float() # [B, H, W, 1]
isic_b_iou = Intersection_over_Union_isic(output_dis_test, target_test, 1)
iou_np = isic_b_iou.data.cpu().numpy()
isic_iou.append(iou_np)
其他的指标库
| 函数 | 全称 | 用途 | 输入格式 | 输出 |
|---|---|---|---|---|
dc |
Dice Coefficient | 衡量预测与真值重叠度(最常用) | (H,W) 或 (B,H,W) 的 0/1 array |
float ∈ [0,1] |
jc |
Jaccard Index (IoU) | 交并比 | 同上 | float ∈ [0,1] |
precision |
Precision (Positive Predictive Value) | 预测为病灶的像素中有多少是真的 | 同上 | float |
sensitivity |
Sensitivity (Recall / True Positive Rate) | 真实病灶中有多少被找出来 | 同上 | float |
specificity |
Specificity (True Negative Rate) | 真实背景中有多少被正确识别 | 同上 | float |
F1 |
F1-Score | Precision 和 Recall 的调和平均 | 同上 | float |
ACC |
Accuracy | 整体像素分类准确率 | 同上 | float |
assd |
Average Symmetric Surface Distance | 边界距离误差(更精细的几何指标) | 同上 | float (单位:像素) |
最后再来逐句的理解一下训练代码:
导入语句不多解释,先看全局配置。
python
Test_Model = {
"unet":unet,
"AttUNet":AttUNet
}
Test_Dataset = {'ISIC2018': ISIC2018_dataset}
Test_Transform = {'A': ISIC2018_transform, 'B':ISIC2018_transform_320, "C":ISIC2018_transform_newdata}
criterion = "loss_D" # loss_A-->SoftDiceLoss; loss_B-->softdice; loss_C--> CE + softdice; loss_D--> BCEWithLogitsLoss
日志重定向,将 print() 同时输出到控制台和日志文件。
python
class Logger(object):
def __init__(self,logfile):
self.terminal = sys.stdout
self.log = open(logfile, "a")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
pass
训练函数
python
def train(train_loader, model, criterion, scheduler, optimizer, args, epoch):
初始化,每个 epoch 重置 loss 记录器,设为训练模式。
losses = AverageMeter()
model.train()
数据加载,图像转 float,标签也转 float(因为 BCE 要求 float)。
for step, (x, y) in tqdm(enumerate(train_loader)):
image = x.float().cuda() # [B, 3, H, W]
target = y.float().cuda() # [B, 1, H, W], 值为 0.0/1.0
前向传播,ISIC2018 是二值分割,理想输出应为 [B,1,H,W],但这里用了 2 通道(背景+前景),导致后续必须用 one-hot 标签。
output = model(image) # [B, 2, H, W] (因为 classes=2)
构造 one-hot 标签,为了匹配 2 通道输出,把 [B,1,H,W] 标签转成 one-hot。
target_soft_a = get_soft_label(target, args.num_classes) # [B, H, W, 2]
target_soft = target_soft_a.permute(0, 3, 1, 2) # [B, 2, H, W]
损失函数选择(执行 loss_D)
if criterion == "loss_D":
dice_loss = nn.BCEWithLogitsLoss()
loss = dice_loss(output, target_soft) # 注意:target_soft 是 [B,2,H,W]
反向传播 & 优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.update(loss.data, image.size(0))
日志打印,每个 epoch 打印一次(因为 step % total_steps == 0 只在最后一步成立?其实逻辑有点怪,建议用 step == len(loader)-1)。
if step % (math.ceil(len(dataset)/batch_size)) == 0:
print('Epoch: {} Loss: {:.6f}'.format(...))
验证函数 valid_isic()
推理模式
model.eval()
with no_grad implicitly via eval()
预测 & 二值化,用 argmax 得到预测类别。
output = model(image) # [B,2,H,W]
output_dis = torch.max(output, 1)[1].unsqueeze(1) # [B,1,H,W], int64 (0 or 1)
转换为 NumPy 用于评估
output_dis_test = output_dis.permute(0,2,3,1).float() # [B,H,W,1]
target_test = target.permute(0,2,3,1).float() # [B,H,W,1]
计算指标,使用 utils.binary 中的函数,输入是 .cpu().numpy() 的二值 mask。
isic_b_Jaccard = jc(pred, target) # Jaccard = IoU
isic_b_dc = dc(pred, target) # Dice
保存最佳模型
net_score = dice + iou
if net_score > best: save model
测试函数 test_isic()
加载最佳模型,逐样本推理 & 评估。
checkpoint = torch.load('best_score_...pth.tar')
model.load_state_dict(checkpoint['state_dict'])
对每个样本:
推理得到 output ([B,2,H,W])
torch.max → output_dis ([B,1,H,W])
转 one-hot:get_soft_label(..., 2) → 用于 val_dice_isic
转二值 mask:.permute → [B,H,W,1] → 用于 dc, jc, ACC 等
多种指标计算
# 方法1:用 val_dice_isic(返回 [bg_dice, fg_dice])
isic_b_dice = val_dice_isic(output_soft, target_soft, 2)
# 方法2:用 binary 模块(直接传二值 mask)
isic_b_dc = dc(output_dis_test.numpy(), target_test.numpy())
isic_b_acc = ACC(...)
...
特殊指标:Jaccard_M / Jaccard_N,分别计算病灶和背景的 IoU。
isic_b_Jaccard_m = jc(output_arr[:, :, 1], label_arr[:, :, 1]) # 前景 IoU
isic_b_Jaccard_n = jc(output_arr[:, :, 0], label_arr[:, :, 0]) # 背景 IoU
保存可视化结果,保存原图、真值、预测三联图。
save_imgs(img, msk, msk_pred, ...)
主函数 main()
数据集加载,五折交叉验证支持(通过 --val_folder folder3 指定哪一折做验证)。
trainset = ISIC2018_dataset(..., train_type='train', transform=Test_Transform[args.transform])
validset = ... 'validation'
testset = ... 'test', with_name=True
模型初始化
if args.id == "AttUNet":
model = AttUNet(in_channel=3, out_channel=2)
else:
model = unet(classes=2, channels=3)
优化器 & 学习率调度,使用余弦退火带重启,适合小数据集。
optimizer = Adam(..., weight_decay=1e-8)
scheduler = CosineAnnealingWarmRestarts(T_0=10, T_mult=2, eta_min=1e-5)
断点续训
if args.resume: load checkpoint
训练循环
for epoch in 1..epochs:
train()
valid() → save best if improved
if epoch > 30 and epoch % 300 == 0: save checkpoint
测试,训练结束后自动测试最佳模型。
test_isic(testloader, ...)
命令行参数(if __name__ == '__main__')
python
parser = argparse.ArgumentParser(description='Comprehensive attention network for biomedical Dataset')
parser.add_argument('--id', default="unet",
help='') # Select a loaded model name
# Path related arguments
parser.add_argument('--root_path', default='/xujiheng/ISICdemo/ISIC2018_npy_all_224_320',
help='root directory of data') # The folder where the numpy data set is stored
parser.add_argument('--ckpt', default='/xujiheng/ISICdemo/UNet-ISIC2018/saved_models/',
help='folder to output checkpoints') # The folder in which the trained model is saved
parser.add_argument('--transform', default='C', type=str,
help='which ISIC2018_transform to choose')
parser.add_argument('--data', default='ISIC2018', help='choose the dataset')
parser.add_argument('--out_size', default=(224, 320), help='the output image size')
parser.add_argument('--val_folder', default='folder3', type=str,
help='folder1、folder2、folder3、folder4、folder5') # five-fold cross-validation
# optimization related arguments
parser.add_argument('--epochs', type=int, default=15, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--start_epoch', default=0, type=int,
help='epoch to start training. useful if continue from a checkpoint')
parser.add_argument('--batch_size', type=int, default=10, metavar='N',
help='input batch size for training (default: 12)') # batch_size
parser.add_argument('--lr_rate', type=float, default=1e-4, metavar='LR',
help='learning rate (default: 0.001)')
parser.add_argument('--num_classes', default=2, type=int,
help='number of classes')
parser.add_argument('--num_input', default=3, type=int,
help='number of input image for each patient')
parser.add_argument('--weight_decay', default=1e-8, type=float, help='weights regularizer')
parser.add_argument('--particular_epoch', default=30, type=int,
help='after this number, we will save models more frequently')
parser.add_argument('--save_epochs_steps', default=300, type=int,
help='frequency to save models after a particular number of epochs')
parser.add_argument('--resume', default='',
help='the checkpoint that resumes from')
args = parser.parse_args()
--id: 模型名(unet / AttUNet)
--root_path: 数据路径(npy 格式)
--val_folder: 五折中的哪一折(folder1~5)
--transform: 预处理方式(A/B/C)
--epochs, --batch_size, --lr_rate: 训练超参
--resume: 断点路径
总结
整体训练流程总览(按执行顺序)
| 阶段 | 调用位置 / 模块 | 实现功能 | 关键作用说明 |
|---|---|---|---|
| 1. 初始化与参数解析 | if __name__ == '__main__' + argparse |
解析命令行参数,设置 GPU、路径、模型、数据等 | 控制实验配置(如模型名、数据路径、五折验证 folder、损失类型等) |
| 2. 日志系统设置 | Logger 类 + sys.stdout = Logger(...) |
将所有 print() 输出同时写入控制台和日志文件 |
自动记录训练过程,便于复现实验 |
| 3. 数据集加载 | main() → ISIC2018_dataset(...) |
加载 train / valid / test 三部分数据 | 支持五折交叉验证(通过 --val_folder 指定) |
| 4. 数据预处理 | Test_Transform[args.transform] |
应用图像增强(如归一化、裁剪、翻转等) | 提升泛化能力;transform='C' 表示使用 ISIC2018_transform_newdata |
| 5. 模型构建 | main() → Test_Model[args.id] 或 AttUNet(...) |
实例化 UNet 或 Attention UNet | 输出通道数为 2(背景+前景),用于多类 Dice/BCE 计算 |
| 6. 优化器 & 学习率调度 | torch.optim.Adam + CosineAnnealingWarmRestarts |
设置优化器和动态学习率 | 使用余弦退火带重启,适合小数据集微调 |
| 7. 断点续训(可选) | if args.resume: torch.load(...) |
从 checkpoint 恢复模型和优化器状态 | 支持中断后继续训练 |
| 8. 训练循环(每 epoch) | for epoch in range(...): |
执行训练 → 验证 → 保存模型 | 主训练流程 |
| 9. 训练步骤(train 函数) | train(train_loader, ...) |
前向传播 + 损失计算 + 反向传播 | 使用 criterion="loss_D" → BCEWithLogitsLoss |
| 10. 标签转 one-hot | get_soft_label(target, 2) |
将 [B,1,H,W] 标签转为 [B,2,H,W] one-hot |
匹配 2 通道模型输出,供 BCE/Dice 使用 |
| 11. 损失计算 | nn.BCEWithLogitsLoss()(output, target_soft) |
计算二值交叉熵损失 | 当前实际使用的损失函数 |
| 12. 损失跟踪 | AverageMeter() |
动态计算 batch loss 的平均值 | 用于日志打印和监控收敛 |
| 13. 验证(valid_isic) | valid_isic(valid_loader, ...) |
在验证集上评估 Dice 和 IoU | 决定是否保存"最佳模型" |
| 14. 预测二值化 | torch.max(output, 1)[1].unsqueeze(1) |
将 logits 转为 0/1 预测 mask | 用于指标计算 |
| 15. 评估指标计算 | dc(), jc() from utils.binary |
计算 Dice 系数和 Jaccard (IoU) | 核心分割性能指标 |
| 16. 最佳模型保存 | if net_score > best: torch.save(...) |
保存验证集上 Dice+IoU 最高的模型 | 文件名:best_score_ISIC2018_checkpoint.pth.tar |
| 17. 定期模型保存 | if epoch % save_epochs_steps == 0 |
每 N 轮保存一次 checkpoint | 用于后续分析或断点恢复 |
| 18. 测试(test_isic) | test_isic(test_loader, ...) |
加载最佳模型,在测试集全面评估 | 报告 10+ 种指标 |
| 19. 多指标评估 | ACC, sensitivity, specificity, F1, precision, dc, jc 等 |
计算全面的二值分割性能 | 包括病灶(M)和背景(N)分别的 IoU |
| 20. 推理时间统计 | time() before/after model(image) |
记录单样本推理耗时 | 评估模型效率 |
| 21. 可视化结果保存 | save_imgs(...) |
保存原图、真值、预测三联图 | 用于人工检查分割质量 |
| 22. 参数量统计 | for param in model.parameters(): num_para += ... |
计算模型可训练参数总数 | 用于模型复杂度分析 |