[ Pytorch教程 ] TensorBaord类

1、介绍

TensorBoard 是可视化工具,用于跟踪和可视化以下内容:

主要功能模块:

  1. 标量可视化 - 损失、准确率等指标

  2. 图像可视化 - 输入/输出图像、特征图

  3. 图表可视化 - 模型计算图

  4. 直方图分布 - 权重、偏置的分布变化

  5. 嵌入可视化 - 高维数据的降维展示

  6. 文本可视化 - 文本数据

  7. 超参数调优 - HParams 面板

2、 安装

bash 复制代码
pip3 install tensorboard==2.12.0 -i https://pypi.doubanio.com/simple/ --target=/home/qhr/anaconda3/envs/pytorch/lib/python3.8/site-packages

3、运行

bash 复制代码
 tensorboard --logdir=/home/qhr/PythonPorject/hymenoptera/logs --port=6007

如果报错:zsh: command not found: tensorboard

bash 复制代码
第一种解决方案:
  
alias tensorboard='python3 -m tensorboard.main'
  
第二种解决方案:
  
python3 -m  tensorboard.main --logdir=.
  
第三种解决方案:
  
将alias命令放到 .zshrc 里

4、函数详解

4.1. 标量记录 - add_scalar()

python 复制代码
writer.add_scalar(tag, scalar_value, global_step=None, walltime=None, new_style=False)

参数解释:

  • tag (string): 指标的标签名称,如 'train/loss'

  • scalar_value (float): 要记录的标量值

  • global_step (int): 全局步数,通常是训练步数或epoch数

  • walltime (float): 可选,覆盖默认的walltime

4.2. 图像记录 - add_image()

python 复制代码
writer.add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW')

参数解释:

  • tag (string): 图像标签

  • img_tensor (torch.Tensor, numpy.array): 图像数据

  • global_step (int): 全局步数

  • dataformats (string): 数据格式,如 'CHW', 'HWC', 'HW', 'WH'等

以及其他的函数

|-------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 多个图像记录 - add_images() | python writer.add_images(tag, img_tensor, global_step=None, walltime=None, dataformats='NCHW') 参数解释: img_tensor 应该是批量的图像,形状为 [N, C, H, W] 或 [N, H, W, C] |
| 直方图记录 - add_histogram() | python writer.add_histogram(tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None) 参数解释: * tag (string): 直方图标签 * values (torch.Tensor, numpy.array): 要分析的值 * bins (string): 分桶策略 |

5、代码

创建一个Python文件

python 复制代码
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import cv2
from dataset import MyDataSet

if __name__ == "__main__":
    print("tensor board test")
    img_dir_path = "./hymenoptera_data/train/ants_img"
    label_dir_path = "./hymenoptera_data/train/ants_label"
    my_dataset = MyDataSet(img_dir_path, label_dir_path)

    ant_img, ant_label = my_dataset[0]
    ant_img.show()
    print(ant_label)

    print(np.array(ant_img).shape)

    writer = SummaryWriter("logs") # 存储目录

    # 可视化图片
    writer.add_image("test", np.array(ant_img), 1, dataformats='HWC')

    #可视化标量
    for i in range(100):
        writer.add_scalar("y=x", i, i)

    writer.close() # 关闭

代码中我们使用到了前文用到的MyDataSet,同时从

复制代码
torch.utils.tensorboard 导入 SummaryWriter
python 复制代码
writer = SummaryWriter("logs") # 存储目录

创建writer实例,并定义存储目录名称"logs"

示例中显示如何向tensorboard添加图片、标量数据

python 复制代码
  # 可视化图片
    writer.add_image("test", np.array(ant_img), 1, dataformats='HWC')

    #可视化标量
    for i in range(100):
        writer.add_scalar("y=x", i, i)

最后再关闭writer

python 复制代码
  writer.close() # 关闭

6、执行

上面的代码执行过后

在本地会有个logs目录,

然后启动tersorboard

bash 复制代码
 tensorboard --logdir=./logs --port=6007

启动后会有终端输出

复制这个网址在浏览器中打开

就能看到我们传入的图片和标量数据

相关推荐
koharu1231 分钟前
大模型后训练全解:SFT、RLHF/PPO、DPO 的原理、实践与选择
人工智能·llm·后训练
m0_377618236 分钟前
c++如何将双精度浮点数以科学计数法写入文件_scientific标志【详解】
jvm·数据库·python
weixin_424999369 分钟前
如何检测SQL注入风险_利用模糊测试技术发现漏洞
jvm·数据库·python
2301_7751481513 分钟前
如何用正则具名捕获组 (-) 提升复杂数据的提取效率
jvm·数据库·python
2501_9142459317 分钟前
Go语言如何在VSCode中开发_Go语言VSCode配置教程【避坑】.txt
jvm·数据库·python
Kel19 分钟前
LangChain.js 架构设计深度剖析
人工智能·设计模式·架构
百度Geek说20 分钟前
我把 Karpathy 的 AutoResearch 搬到了软件开发领域,效果炸了
人工智能
2301_7826591820 分钟前
MongoDB如果有一个分片完全宕机集群还能用吗_受影响数据的不可读与分片隔离感知
jvm·数据库·python
justjinji22 分钟前
JavaScript中严格模式use-strict对引擎解析的辅助
jvm·数据库·python
Absurd58724 分钟前
CSS如何使用-default获取默认选项样式_通过状态伪类突出预选表单项
jvm·数据库·python