pytorch 利用Tensorboar记录训练过程loss变化

文章目录

    • [1. LossHistory日志类定义](#1. LossHistory日志类定义)
    • [2. LossHistory类的使用](#2. LossHistory类的使用)
      • [2.1 实例化LossHistory](#2.1 实例化LossHistory)
      • [2.2 记录每个epoch的loss](#2.2 记录每个epoch的loss)
      • [2.3 训练结束close掉SummaryWriter](#2.3 训练结束close掉SummaryWriter)
    • [3. 利用Tensorboard 可视化](#3. 利用Tensorboard 可视化)
      • [3.1 显示可视化效果](#3.1 显示可视化效果)
    • 参考

利用Tensorboard记录训练过程中每个epoch的训练loss以及验证loss,便于及时了解网络的训练进展。

代码参考自 B导github仓库: https://github.com/bubbliiiing/deeplabv3-plus-pytorch

1. LossHistory日志类定义

py 复制代码
import os
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import scipy.signal

import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
#from tensorboardX import SummaryWriter
class LossHistory():
    def __init__(self, log_dir, model, input_shape):
        self.log_dir    = log_dir
        self.losses     = []
        self.val_loss   = []
        
        os.makedirs(self.log_dir)
        self.writer     = SummaryWriter(self.log_dir)
        try:
            dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1])
            self.writer.add_graph(model, dummy_input)
        except:
            pass

    def append_loss(self, epoch, loss, val_loss):
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        self.losses.append(loss)
        self.val_loss.append(val_loss)

        with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
            f.write(str(loss))
            f.write("\n")
        with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
            f.write(str(val_loss))
            f.write("\n")

        self.writer.add_scalar('loss', loss, epoch)
        self.writer.add_scalar('val_loss', val_loss, epoch)
        self.loss_plot()

    def loss_plot(self):
        iters = range(len(self.losses))

        plt.figure()
        plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
        plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
        try:
            if len(self.losses) < 25:
                num = 5
            else:
                num = 15
            
            plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
            plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
        except:
            pass

        plt.grid(True)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend(loc="upper right")

        plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))

        plt.cla()
        plt.close("all")
  • (1) 首先利用LossHistory类的构造函数__init__, 实例化TensorboardSummaryWriter对象self.writer,并将网络结构图添加到self.writer中。其中__init__方法接收的参数包括,保存log的路径log_dir以及模型model和输入的shape
py 复制代码
def __init__(self, log_dir, model, input_shape):
        self.log_dir    = log_dir
        self.losses     = []
        self.val_loss   = []
        
        os.makedirs(self.log_dir)
        self.writer     = SummaryWriter(self.log_dir)
        try:
            dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1])
            self.writer.add_graph(model, dummy_input)
        except:
            pass
  • (2) 记录每个epoch的训练损失loss以及验证val_loss,并保存到tensorboar中显示
py 复制代码
self.writer.add_scalar('loss', loss, epoch)
self.writer.add_scalar('val_loss', val_loss, epoch)

同时将训练的loss以及验证val_loss逐行保存到.txt文件中

py 复制代码
 with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
      f.write(str(loss))
      f.write("\n")

并且在每个epoch时,调用loss_plot绘制历史的loss曲线,并保存为epoch_loss.png, 由于每个epoch保存的图片都是重名的,因此在训练结束时,会保存最新的所有epoch绘制的loss曲线

2. LossHistory类的使用

2.1 实例化LossHistory

在训练开始前,实例化LossHistory类,调用__init__实例化时,会创建SummaryWriter对象,用于记录训练的过程中的数据,比如loss, graph以及图片信息等

py 复制代码
local_rank  = int(os.environ["LOCAL_RANK"]) 
model   = DeepLab(num_classes=num_classes, backbone=backbone, downsample_factor=downsample_factor, pretrained=pretrained)
input_shape     = [512, 512]

if local_rank == 0:
      time_str        = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
      log_dir         = os.path.join(save_dir, "loss_" + str(time_str))
      loss_history    = LossHistory(log_dir, model, input_shape=input_shape)
  else:
      loss_history    = None
  • 对于多GPU训练时,只在主进程(local_rank == 0)记录训练的日志信息
  • log 保存的路径log_dir,利用loss_ + 当前时间的形式记录
py 复制代码
log_dir         = os.path.join(save_dir, "loss_" + str(time_str))

2.2 记录每个epoch的loss

在每个epoch中,利用loss_history的append_loss方法,利用SummaryWriter对象保存loss:

py 复制代码
for epoch in range(start_epoch, total_epoch):
	...
	loss_history.append_loss(epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)
  • 记录了每个epoch的训练loss以及验证val_loss
  • 同时将最新的loss曲线,保存到本地epoch_loss.png
  • 并将历史的训练loss和val_loss保存为txt文件,方便查看

2.3 训练结束close掉SummaryWriter

py 复制代码
loss_history.writer.close()

3. 利用Tensorboard 可视化

  • Tensorboard最早是在Tensorflow中开发和应用的,pytorch 中也同样支持Tensorboard的使用,pytorch中的Tensorboard工具叫TensorboardX, 它需要依赖于tensorflow库中的一些组件支持。因此在安装Tensorboardx之前,需要先安装TensorFlow, 否则直接安装Tensorboardx运行会报错。
shell 复制代码
pip install tensorflow
pip install tensorboardX

3.1 显示可视化效果

训练结束后,cd到SummaryWriter中定义好日志保存目录log_dir下,执行如下指令

py 复制代码
cd log_dir # log_dir为定义的日志保存目录
tensorboard  --logdir=./     --port 6006 

然后会显示出访问的链接地址,点击链接就可以查看Tensorboard可视化效果

  • Scalar模块展示训练过程中,每个epoch的train_loss、Accuracy、Learn_Rating的数值变化
  • GRAPH模块展示的是模型的网络结构
  • HISTOGRAMS模块展示添加到tensorboard中各层的权重分布情况

参考

相关推荐
用户6915811416537 分钟前
Ascend Extension for PyTorch的源码解析
人工智能
努力的家伙是不讨厌的41 分钟前
解析json导出csv或者直接入库
开发语言·python·json
用户6915811416544 分钟前
Ascend C的编程模型
人工智能
云空1 小时前
《Python 与 SQLite:强大的数据库组合》
数据库·python·sqlite
成富2 小时前
文本转SQL(Text-to-SQL),场景介绍与 Spring AI 实现
数据库·人工智能·sql·spring·oracle
凤枭香2 小时前
Python OpenCV 傅里叶变换
开发语言·图像处理·python·opencv
CSDN云计算2 小时前
如何以开源加速AI企业落地,红帽带来新解法
人工智能·开源·openshift·红帽·instructlab
测试杂货铺2 小时前
外包干了2年,快要废了。。
自动化测试·软件测试·python·功能测试·测试工具·面试·职场和发展
艾派森2 小时前
大数据分析案例-基于随机森林算法的智能手机价格预测模型
人工智能·python·随机森林·机器学习·数据挖掘
hairenjing11232 小时前
在 Android 手机上从SD 卡恢复数据的 6 个有效应用程序
android·人工智能·windows·macos·智能手机