PyTorch中常用的工具(3)TensorBoard

文章目录

  • 前言
  • [3 可视化工具](#3 可视化工具)
    • [3.1 TensorBoard](#3.1 TensorBoard)

前言

在训练神经网络的过程中需要用到很多的工具,最重要的是数据处理、可视化和GPU加速。本章主要介绍PyTorch在这些方面常用的工具模块,合理使用这些工具可以极大地提高编程效率。

由于内容较多,本文分成了五篇文章(1)数据处理(2)预训练模型(3)TensorBoard(4)Visdom(5)CUDA与小结。

整体结构如下:

  • 1 数据处理
    • 1.1 Dataset
    • 1.2 DataLoader
  • 2 预训练模型
  • 3 可视化工具
  • 3.1 TensorBoard
  • 3.2 Visdom
  • 4 使用GPU加速:CUDA
  • 5 小结

全文链接:

  1. PyTorch中常用的工具(1)数据处理
  2. PyTorch常用工具(2)预训练模型
  3. PyTorch中常用的工具(3)TensorBoard
  4. PyTorch中常用的工具(4)Visdom
  5. PyTorch中常用的工具(5)使用GPU加速:CUDA

3 可视化工具

在训练神经网络时,通常希望能够更加直观地了解训练情况,例如损失函数曲线、输入图片、输出图片等信息。这些信息可以帮助读者更好地监督网络的训练过程,并为参数优化提供方向和依据。最简单的办法就是打印输出,这种方式只能打印数值信息,不够直观,同时无法查看分布、图片、声音等。本节介绍两个深度学习中常用的可视化工具:TensorBoard和Visdom。

3.1 TensorBoard

最初,TensorBoard是作为TensorFlow的可视化工具迅速流行开来的。作为和TensorFlow深度集成的工具,TensorBoard能够展示TensorFlow的网络计算图,绘制图像生成的定量指标图以及附加数据。同时,TensorBoard是一个相对独立的工具,只要用户保存的数据遵循相应的格式,TensorBoard就能读取这些数据,进行可视化。

在PyTorch 1.1.0版本之后,PyTorch已经内置了TensorBoard的相关接口,用户在手动安装TensorBoard后便可调用相关接口进行数据的可视化,TensorBoard的主界面如下图所示。

![使用add_scalar记录标量]](https://img-blog.csdnimg.cn/direct/864745746f6244e080a0793ae578e5a1.png#pic_center)

TensorBoard的使用非常简单,首先使用以下命令安装TensorBoard:

bash 复制代码
pip install tensorboard

待安装完成后,通过以下命令启动TensorBoard,其中path为log文件的保存路径:

bash 复制代码
tensorboard --logdir=path

TensorBoard的常见操作包括记录标量、显示图像、显示直方图、显示网络结构、可视化embedding等,下面逐一举例说明:

python 复制代码
In: import torch
    import torch.nn as nn
    import numpy as np
    from torchvision import models
    from torch.utils.tensorboard import SummaryWriter
    from torchvision import datasets,transforms
    from torch.utils.data import DataLoader
    # 构建logger对象,log_dir用来指定log文件的保存路径
    logger = SummaryWriter(log_dir='runs')
python 复制代码
In: # 使用add_scalar记录标量
    for n_iter in range(100):
        logger.add_scalar('Loss/train', np.random.random(), n_iter)
        logger.add_scalar('Loss/test', np.random.random(), n_iter)
        logger.add_scalar('Acc/train', np.random.random(), n_iter)
        logger.add_scalar('Acc/test', np.random.random(), n_iter)
python 复制代码
In: transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,),(0.5,))
    ])
    dataset = datasets.MNIST('data/', download=True, train=False, transform=transform)
    dataloader = DataLoader(dataset, shuffle=True, batch_size=16)
    images, labels = next(iter(dataloader))
    grid = torchvision.utils.make_grid(images)
python 复制代码
In: # 使用add_image显示图像
    logger.add_image('images', grid, 0)
python 复制代码
In: # 使用add_graph可视化网络
	class ToyModel(nn.Module):
    	def __init__(self, input_size=28, hidden_size=500, num_classes=10):
        	super().__init__()
        	self.fc1 = nn.Linear(input_size, hidden_size) 
        	self.relu = nn.ReLU()
        	self.fc2 = nn.Linear(hidden_size, num_classes)  
    	def forward(self, x):
        	out = self.fc1(x)
        	out = self.relu(out)
        	out = self.fc2(out)
        	return out
	model = ToyModel()
	logger.add_graph(model, images)
python 复制代码
In: # 使用add_histogram显示直方图
    logger.add_histogram('normal', np.random.normal(0,5,1000), global_step=1)
    logger.add_histogram('normal', np.random.normal(1,2,1000), global_step=10)
python 复制代码
In: # 使用add_embedding进行embedding可视化
    dataset = datasets.MNIST('data/', download=True, train=False)
    images = dataset.data[:100].float()
    label = dataset.targets[:100]
    features = images.view(100, 784)
    logger.add_embedding(features, metadata=label, label_img=images.unsqueeze(1))

打开浏览器输入http://localhost:6006(其中,6006应改成读者TensorBoard所绑定的端口),就可以看到本文之前的可视化结果。

TensorBoard十分容易上手,读者可以根据个人需求灵活地使用上述函数进行可视化。本节介绍了TensorBoard的常见操作,更多详细内容读者可参考官方相关源码。

相关推荐
阿俊仔(摸鱼版)16 分钟前
Python 常用运维模块之OS模块篇
运维·开发语言·python·云服务器
lly_csdn1231 小时前
【Image Captioning】DynRefer
python·深度学习·ai·图像分类·多模态·字幕生成·属性识别
速融云1 小时前
汽车制造行业案例 | 发动机在制造品管理全解析(附解决方案模板)
大数据·人工智能·自动化·汽车·制造
西猫雷婶1 小时前
python学opencv|读取图像(四十一 )使用cv2.add()函数实现各个像素点BGR叠加
开发语言·python·opencv
金融OG1 小时前
99.11 金融难点通俗解释:净资产收益率(ROE)VS投资资本回报率(ROIC)VS总资产收益率(ROA)
大数据·python·算法·机器学习·金融
AI明说2 小时前
什么是稀疏 MoE?Doubao-1.5-pro 如何以少胜多?
人工智能·大模型·moe·豆包
XianxinMao2 小时前
重构开源LLM分类:从二分到三分的转变
人工智能·语言模型·开源
Elastic 中国社区官方博客2 小时前
使用 Elasticsearch 导航检索增强生成图表
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索
小唐C++2 小时前
C++小病毒-1.0勒索
开发语言·c++·vscode·python·算法·c#·编辑器
云天徽上3 小时前
【数据可视化】全国星巴克门店可视化
人工智能·机器学习·信息可视化·数据挖掘·数据分析