在上一个关于3D
目标的任务,是基于普通CNN
网络的3D
分类任务。在这个任务中,分类数据采用的是CT
结节的LIDC-IDRI
数据集,其中对结节的良恶性、毛刺、分叶征等等特征进行了各自的等级分类。感兴趣的可以直接点击下方的链接,直达学习:
在开始本次关于3D
目标的分割任务前呢,我还是建议先去看看上述较为简单的分类任务,毕竟大多数是相似的,有很高的借鉴意义。
一、导言
准备一个训练,需要下面这些内容组成:
- 准备数据
- 准备网络
- 搭建训练主模型
train one epoch
valid one epoch
- 存储模型
- 存储指标
loss
函数dice coeff
评估指标optimizer
优化方式
其中,在本项目中:
- 网络采用
vnet 3d
模型 - 数据采用
patch
裁剪大小 loss
函数未dice loss
- 评价指标是
dice coeff
optimizer
优化方式是SGD
二、搭建主结构
训练的主体结构(骨架),总数包括几个部分:
config
:可调参数定义,包括数据路径、图像大小、类别数量、学习率、batch size
等等;main
:主函数,包括:- 构建模型
- 构建数据
- 优化器
- 学习率变化方式
- 损失函数
- 评估指标
- 训练
batch
循环 - 验证
batch
循环
- 后处理:包括模型参数存储,指标走势绘图等等。
上面这些个内容,基本上是囊括了深度学习模型训练的整体结构了,后面的工作就是对每一部分进行补充。就犹如已经有了骨架,后续就是补充肉身了。
后面给出的这个pytorch
骨架案例,也是后面再构建训练任务,一个可以参考的依据,可收藏。
2.1、导入库和配置参数
python
import os
import matplotlib.pyplot as plt
import torch.utils.data
import torch.optim as optim
from datasets.datasets import myDataset
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3" # 使用gpu0
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 没gpu就用cpu
print(DEVICE)
############################################################
# Configuration
############################################################
class Configuration(object):
train_path = r"./database/sk_output/train"
valid_path = r"./database/sk_output/valid"
model_path = r'./checkpoints'
Crop_Size = (48, 96, 96)
num_outs = 2
Batch_Train = 32
Batch_Test = 16
Max_epoch = 220
Num_Workers = 8
Dice_Best = 0
LR = 0.0003
momentum = 0.99
weight_decay = 1e-8
def display(self):
"""Display Configuration values."""
print("\nConfigurations:")
print("")
for a in dir(self):
if not a.startswith("__") and not callable(getattr(self, a)):
print("{:30} {}".format(a, getattr(self, a)))
print("\n")
2.2、构建main主函数
python
def main():
Config = Configuration()
Config.display()
train_loader, valid_loader = get_Dataloader(Config)
model = get_model(Config).to(DEVICE)
# ---- OPTIMIZER ----
optimizer = optim.SGD(model.parameters(), lr=Config.LR, momentum=Config.momentum, weight_decay=Config.weight_decay)
train_loss_list = [] # 用来记录训练损失
valid_loss_list = [] # 用来记录验证损失
valid_dice_list = []
epoch_list = []
for epoch in range(1, Config.Max_epoch + 1):
epoch_list.append(epoch)
train_loss = train_model(model, DEVICE, train_loader, optimizer, epoch) # 训练
valid_loss, valid_dice = valid_model(model, DEVICE, valid_loader, epoch) # 验证
train_loss_list.append(train_loss)
valid_loss_list.append(valid_loss)
valid_dice_list.append(valid_dice)
draw_plot(epoch_list, valid_dice_list, 'valid_dice')
draw_plot(epoch_list, valid_loss_list, 'valid_loss')
draw_plot(epoch_list, train_loss_list, 'train_loss')
if valid_dice > Config.Dice_Best:
path_ckpt = os.path.join(Config.model_path, 'best_model.pth')
save_model(path_ckpt, model)
Config.Dice_Best = valid_dice
else:
path_ckpt = os.path.join(Config.model_path, 'last_model.pth')
save_model(path_ckpt, model)
print('best val Dice is ', Config.Dice_Best)
if __name__ == '__main__':
main()
2.3、构建获取模型和数据的函数
python
def get_model(config):
from models.vnet3d import VNet3D
model = VNet3D(num_outs=config.num_outs, channels=16)
model = model.to(DEVICE) # 模型部署到gpu或cpu里
model = torch.nn.DataParallel(model).to(DEVICE)
return model
def get_Dataloader(config):
# get train data
dataset_train = myDataset(config.train_path, config.Crop_Size, isTrain=True)
print(len(dataset_train))
train_loader = torch.utils.data.DataLoader(dataset_train,
batch_size=config.Batch_Train, shuffle=True,
num_workers=config.Num_Workers, drop_last=False)
# get valid data
dataset_valid = myDataset(config.valid_path, config.Crop_Size, isTrain=False)
valid_loader = torch.utils.data.DataLoader(dataset_valid,
batch_size=config.Batch_Test, shuffle=False,
num_workers=config.Num_Workers, drop_last=False)
return train_loader, valid_loader
2.4、构建训练循环和验证循环
python
def train_model(model, device, train_loader, optimizer, epoch):
config = Configuration()
model.train()
for batch_index, (data, target) in enumerate(train_loader): # 取batch索引,(data,target),也就是图和标签
data, target = data.to(device), target.to(device)
output = model(data)
loss = Loss(output, target)
optimizer.zero_grad() # 梯度归零
loss.backward() # 反向传播
optimizer.step() # 优化器走一步
return losses.avg # 返回平均损失,损失列表
def valid_model(model, device, test_loader, epoch):
config = Configuration()
model.eval()
with torch.no_grad(): # 不进行 梯度计算(反向传播)
for batch_index, (data, target) in enumerate(test_loader): # 枚举batch索引,(图,标签)
data, target = data.to(device), target.to(device)
output = model(data)
loss = Loss(output, target) # 计算损失
return losses.avg, multi_dices.avg
2.5、后处理
保存模型的参数,和绘制训练过程中train loss、valid loss
,以及valid dice
走势图,如下:
python
def draw_plot(x_list, y_list, title_name):
plt.plot(x_list, y_list, label=title_name)
plt.xlabel('x', fontsize=15)
plt.ylabel('y', fontsize=15)
plt.title(title_name, fontsize=15)
plt.savefig('./logs/cure.png')
def save_model(path, model):
if isinstance(model, torch.nn.DataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(state_dict, path)
至此,每一个模块都有了对应的归宿,后面就是如何将缺漏的地方,补全过程了。反倒是这部分的代码相对较少,两大需要单独验证的数据和模型是大头,其他就好办了。
三、总结
本文是关于Pytorch
的 VNet 3D
图像分割的第一篇,也就是一个综述篇,主要是对这个项目的任务目的,以及其中的一个流程进行了梳理。
上述的骨干代码还不能够作为训练使用,还需要补充进去骨肉,才能够适应不同的任务,这一块的内容将会在后面的几个篇章中,一一陈述。
如果你也在做类似的事情,欢迎点赞、收藏,mark住。对于这部分的内容可以一起交流,欢迎多多评论。