第12章 PyTorch图像分割代码框架-2

模型模块

本书的第5-9章重点介绍了各种2D和3D的语义分割和实例分割网络模型,所以在模型模块中,我们需要做的事情就是将要实验的分割网络写在该目录下。有时候我们可能想尝试不同的分割网络结构,所以在该目录下可以存在多个想要实验的网络模型定义文件。对于PASCAL VOC这样的自然数据集,我们可能想实验Deeplab v3+、PSPNet、RefineNet等网络的训练效果。代码11-3给出了Deeplab v3+网络封装后的主体部分,完整网络搭建代码可参考本书配套代码对应章节。

代码11-3 Deeplab v3+网络的主体部分

ruby 复制代码
# 定义Deeplab V3+类
class DeepLabHeadV3Plus(nn.Module):
    def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]):
        super(DeepLabHeadV3Plus, self).__init__()


        self.project = nn.Sequential(
            nn.Conv2d(low_level_channels, 48, 1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True),
        )
    # ASPP
        self.aspp = ASPP(in_channels, aspp_dilate)
    # classifier head
        self.classifier = nn.Sequential(
            nn.Conv2d(304, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, 1)
        )


        self._init_weight()
  # forward method
    def forward(self, feature):
        # print(feature['low_level'].shape)
        # print(feature['out'].shape)
        low_level_feature = self.project(feature['low_level'])
        output_feature = self.aspp(feature['out'])
        output_feature = F.interpolate(
            output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False)
        return self.classifier(torch.cat([low_level_feature, output_feature], dim=1))
  # weight initilize
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

对于复杂网络搭建,一般都是采用自下而上的搭建方法,先搭建底层组件,再逐步向上封装,对于本例中的Deeplab v3+,可以先分别搭建backbone骨干网络、ASPP和编解码结构,最后再进行封装。

工具函数模块

工具函数是为项目完成各项功能所自定义的辅助函数,可以统一定义在utils文件夹下,根据实际项目的不同,工具函数也各不相同。常用的工具函数包括各种损失函数的定义loss.py训练可视化函数的定义visualize.py、用于记录训练日志的log.py等。代码11-4给出了一个关于Focal loss损失函数的定义,该损失函数作为工具函数可放在loss.py文件中。

代码11-4 工具函数示例:定义一个Focal loss

ruby 复制代码
# 导入相关库
import torch
import torch.nn as nn
import torch.nn.functional as F
# 定义一个Focal loss类
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma


    def forward(self, inputs, targets):
        # Compute cross-entropy loss
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')


        # Compute the focal loss
        pt = torch.exp(-ce_loss)  
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return focal_loss.mean()

配置模块

配置模块是为项目模型训练传入各种参数而进行设置的模块,比如训练数据所在目录、训练所需要的各种参数、训练过程是否需要可视化等。一般来说,我们有两种方式来对项目执行参数进行配置管理,一种是直接在主函数main.py中使用argparse库对参数进行配置,然后再命令行中进行传入;另一种则是单独定义一个config.py或者config.yaml文件来对所有参数进行统一配置。基于argparse库的参数配置管理简单示例如代码11-5所示。

代码11-5 argparser参数配置管理

sql 复制代码
# 导入argparse库
import argparse
# 创建参数管理器
parser = argparse.ArgumentParser()
# 涉及数据相关的参数管理
parser.add_argument("--data_root", type=str, default='./dataset',
                     help="path to Dataset")
parser.add_argument("--save_root", type=str, default='./',
                     help="path to save result")
parser.add_argument("--dataset", type=str, default='voc',
                     choices=['voc', 'cityscapes', 'ade'], help='Name of dataset')
parser.add_argument("--num_classes", type=int, default=None,
                     help="num classes (default: None)")

在上述代码中,我们基于argparse给出了一小部分参数配置管理代码,涉及训练数据相关的部分参数,包括数据读取路径、存放路径、训练所用数据集、分割类别数量等。

主函数模块

主函数模块main.py是项目的启动模块,该模块将定义好的数据和模型模块进行组装,并结合损失函数、优化器、评估方法和可视化等组件,将config.py中配置好的项目参数传入,根据训练-验证的模式,执行图像分割项目模型训练和验证。代码11-6是VOC数据集训练验证部分代码。

代码11-6 主函数模块中的训练迭代部分

python 复制代码
# 初始化区间损失
interval_loss = 0
while True:  
  # 执行训练
  model.train()
  cur_epochs += 1
  for (images, labels) in train_loader:
    cur_itrs += 1
    images = images.to(device, dtype=torch.float32)
    labels = labels.to(device, dtype=torch.long)
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()


    np_loss = loss.detach().cpu().numpy()
    interval_loss += np_loss


    if vis is not None:
      vis.vis_scalar('Loss', cur_itrs, np_loss)
    # 打印训练信息
    if (cur_itrs) % opts.print_interval == 0:
      pass
    # 保存模型
    if (cur_itrs) % opts.val_interval == 0:
      pass
      # 日志记录
      logger.info("Save the latest model to %s" % save_path_checkpoints)
      # 模型验证
      print("validation...")
      model.eval()
      val_score, ret_samples = validate(
        opts=opts, model=model, loader=val_loader, device=device, metrics=metrics,
        ret_samples_ids=vis_sample_id)
      logger.info("Validation performance: %s", val_score)
      
      # 保存最优模型
      if val_score['mean_dice'] > best_score:  
        best_score = val_score['mean_dice']
        save_ckpt(os.path.join(save_path_checkpoints, 'best_%s_%s_os%d.pth' %
                     (opts.model, opts.dataset, opts.output_stride)))
        logger.info("Save best-performance model so far to %s" % save_path_checkpoints)


      # 训练过程可视化
      if vis is not None:  
        vis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])
        vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])
        vis.vis_table("[Val] Class IoU", val_score['Class IoU'])


        for k, (img, target, lbl) in enumerate(ret_samples):
          img = (denorm(img) * 255).astype(np.uint8)
          target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8)
          lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)
          concat_img = np.concatenate((img, target, lbl), axis=2)  
          vis.vis_image('Sample %d' % k, concat_img)
          
    scheduler.step()

在代码11-6中,我们展示了一个图像分割项目主函数模块中最核心的训练和验证部分。在训练时,按照指定迭代次数保存模型和对训练过程进行可视化展示。图11-2为训练打印的部分信息。

图11-2 VOC训练过程信息

图11-3为基于visdom的训练过程可视化展示,包括当前训练配置参数信息,训练损失函数变化曲线、验证集全局准确率、mIoU和类别IoU等指标变化曲线图。

图11-3 Deeplab v3+训练过程可视化

图11-4展示了两组训练过程中验证集的输入图像、标签图像和模型预测图像的对比图。可以看到,基于Deeplab v3+的分割模型在PASCAL VOC 2012上表现还不错。

图11-4 验证集模型效果图

后续全书内容和代码将在github上开源,请关注仓库:

https://github.com/luwill/Deep-Learning-Image-Segmentation

(未完待续)

相关推荐
Elastic 中国社区官方博客2 分钟前
使用 Elastic AI Assistant for Search 和 Azure OpenAI 实现从 0 到 60 的转变
大数据·人工智能·elasticsearch·microsoft·搜索引擎·ai·azure
江_小_白1 小时前
自动驾驶之激光雷达
人工智能·机器学习·自动驾驶
yusaisai大鱼2 小时前
TensorFlow如何调用GPU?
人工智能·tensorflow
湫ccc3 小时前
《Python基础》之字符串格式化输出
开发语言·python
mqiqe4 小时前
Python MySQL通过Binlog 获取变更记录 恢复数据
开发语言·python·mysql
AttackingLin4 小时前
2024强网杯--babyheap house of apple2解法
linux·开发语言·python
哭泣的眼泪4084 小时前
解析粗糙度仪在工业制造及材料科学和建筑工程领域的重要性
python·算法·django·virtualenv·pygame
珠海新立电子科技有限公司5 小时前
FPC柔性线路板与智能生活的融合
人工智能·生活·制造
湫ccc5 小时前
《Python基础》之基本数据类型
开发语言·python
IT古董5 小时前
【机器学习】机器学习中用到的高等数学知识-8. 图论 (Graph Theory)
人工智能·机器学习·图论