Pytorch中Tensorboard的学习

1、Tensorboard介绍

TensorBoard 是 TensorFlow 开发的一个可视化工具,用于帮助用户理解和调试机器学习模型的训练过程。尽管它最初是为 TensorFlow 设计的,但通过 PyTorch 的 torch.utils.tensorboard 模块,PyTorch 用户也可以方便地使用 TensorBoard 来记录和可视化模型训练中的各种数据(记得先安装tensorboard包,pytorch不自带)。

SummaryWriter 是 PyTorch 中与 TensorBoard 交互的核心类,用于将数据写入日志文件,供 TensorBoard 解析和展示。

1.1 TensorBoard 的核心功能

  1. 训练指标可视化:

    记录损失(Loss)、准确率(Accuracy)、学习率(Learning Rate)等标量数据,并绘制曲线。

  2. 模型结构可视化:

    展示神经网络的计算图(模型结构)。

  3. 直方图和分布:

    可视化权重、梯度等张量的分布变化。

  4. 图像和音频:

    记录输入图像、生成样本或中间特征图。

  5. 嵌入向量可视化:

    对高维数据进行降维(如 PCA 或 t-SNE),展示在 2D/3D 空间中的分布。

1.2 SummaryWriter 的基本用法

  1. 初始化 SummaryWriter
python 复制代码
from torch.utils.tensorboard import SummaryWriter

# 创建 SummaryWriter 对象
writer = SummaryWriter(log_dir="runs/experiment_1")

参数:

  • log_dir:日志文件的保存路径(默认生成 runs/ 目录下的时间戳文件夹)。
  1. 记录标量数据(Scalars)
python 复制代码
for epoch in range(100):
    loss = train_model()
    accuracy = calculate_accuracy()
    
    # 记录损失和准确率
    writer.add_scalar("Loss/train", loss, epoch)
    writer.add_scalar("Accuracy/train", accuracy, epoch)
  1. 记录图像(Images)
python 复制代码
images, _ = next(iter(dataloader))
writer.add_images("Input_images", images, epoch)
  1. 记录直方图(Histograms)
python 复制代码
weights = model.layer.weight.data
writer.add_histogram("Weights/layer", weights, epoch)
  1. 记录模型结构(Graph)
python 复制代码
dummy_input = torch.randn(1, 3, 224, 224)  # 输入示例
writer.add_graph(model, dummy_input)
  1. 关闭 Writer
    writer.close()

1.3 启动 TensorBoard

在命令行中运行以下命令启动 TensorBoard:

python 复制代码
tensorboard --logdir=runs/

然后通过浏览器访问 http://localhost:6006 查看可视化结果。

1.4 关键方法详解

常用方法

方法 功能
add_scalar(tag, scalar_value, global_step) 记录单个标量(如 Loss)
add_scalars(main_tag, tag_scalar_dict, global_step) 记录多个标量(如 Loss 和 Accuracy)
add_image(tag, img_tensor, global_step) 记录单张图像(格式需为 CxHxW)
add_images(tag, img_tensor, global_step) 记录多张图像(格式为 NxCxHxW)
add_histogram(tag, values, global_step) 记录张量分布直方图
add_graph(model, input_to_model) 记录模型计算图

1.5 使用示例

python 复制代码
import torch
from torch.utils.tensorboard import SummaryWriter

# 初始化
writer = SummaryWriter("runs/demo")

# 模拟训练过程
for epoch in range(100):
    # 假设训练逻辑
    loss = 1.0 / (epoch + 1)
    accuracy = 0.8 - 0.002 * epoch
    
    # 记录标量
    writer.add_scalar("Loss/train", loss, epoch)
    writer.add_scalar("Accuracy/train", accuracy, epoch)
    
    # 记录随机直方图
    weights = torch.randn(100)
    writer.add_histogram("Random_weights", weights, epoch)

# 记录模型结构(假设 model 已定义)
dummy_input = torch.randn(1, 3, 224, 224)
writer.add_graph(model, dummy_input)

# 关闭 Writer
writer.close()

1.6 注意事项

  1. 数据频率:

    避免在每个训练步(step)都记录大量数据,否则日志文件会过大。

  2. 实验管理:

    使用不同的 log_dir 区分不同实验(如 runs/exp1, runs/exp2)。

  3. 张量格式:

    图像数据需符合 CxHxW 格式(通道优先)。

  4. 性能影响:

    高频记录可能影响训练速度,需权衡监控需求与效率。

1.7 总结

  • TensorBoard 是模型训练可视化的标准工具,支持标量、图像、直方图等多种数据。
  • SummaryWriter 是 PyTorch 与 TensorBoard 的桥梁,通过简单的方法将数据写入日志。
  • 典型工作流程:
    记录数据 → 启动 TensorBoard → 浏览器查看 → 分析优化模型。

通过结合 TensorBoard 和 SummaryWriter,可以直观地监控模型训练过程,快速定位问题(如过拟合、梯度消失等),是深度学习开发中的必备技能。

2、Tensorboard实操

2.1 测试Tensorboard

2.1.1 add_scalar 函数介绍

将单个标量值(如损失、准确率、学习率)记录到 TensorBoard 日志中,生成随时间/步骤变化的曲线图,便于可视化分析。

基本语法

python 复制代码
add_scalar(tag, scalar_value, global_step=None, walltime=None)

参数说明

tag (字符串)

  • 标识数据的标签,决定在 TensorBoard 中的图表标题和分组。
  • 支持层级命名(如 "train/loss" 和 "val/loss"),TensorBoard 会自动按斜杠分组显示。

scalar_value (浮点数或标量张量)

  • 要记录的具体数值(如 loss.item() 或 accuracy)。

global_step (整数, 可选)

  • 当前记录的"步骤",通常是迭代次数、epoch 数或自定义的计数器。
  • 作为曲线的横坐标,需确保递增以正确显示变化趋势。

walltime (浮点数, 可选)

  • 覆盖默认的时间戳(记录时刻),一般无需手动设置。

示例代码

python 复制代码
from torch.utils.tensorboard import SummaryWriter

# 创建 SummaryWriter,日志保存到 runs/test 目录
writer = SummaryWriter("runs/test")

for i in range(100):
    # 记录 y=3x 的标量值:tag为"y=3x",值为3*i,步骤为i
    writer.add_scalar("y=3x", 3 * i, i)

writer.close()  # 关闭写入器,确保数据保存
  • tag="y=3x":在 TensorBoard 中生成名为 y=3x 的曲线图。
  • scalar_value=3*i:每次迭代记录的值(模拟 y=3x 函数)。
  • global_step=i:横轴表示迭代步骤,从 0 到 99 递增。

使用建议

命名规范

  • 使用清晰的层级标签(如 "train/loss"、"val/accuracy")以便在 TensorBoard 中分类查看。

步骤连续性

  • global_step 应单调递增(如 0,1,2,...),避免跳跃或重复,否则图表可能混乱。

高效记录

  • 避免高频调用(如每个 batch 都记录),可每隔若干步记录一次以减少开销。

查看结果

运行后,使用以下命令启动 TensorBoard:

python 复制代码
tensorboard --logdir=runs/test

在浏览器中打开提示的链接,即可看到 y=3x 的线性增长曲线。

注意事项

  • 路径引号问题:确保路径字符串使用英文引号(如 "runs/test"),而非中文引号""或''。
  • 及时关闭 Writer:调用 writer.close() 或在 with 块中使用,防止日志未保存。

相关方法

add_scalars:同时记录多个标量(如训练和验证损失对比):

python 复制代码
writer.add_scalars("loss", {"train": train_loss, "val": val_loss}, epoch)

其他记录方法:如 add_image(记录图像)、add_histogram(权重分布)等,用于多维数据可视化。

2.1.2 实例

python 复制代码
from torch.utils.tensorboard import SummaryWriter

# 创建一个 SummaryWriter 对象
writer = SummaryWriter("runs/test")
for i in range(100):
    writer.add_scalar("y=3x", 3*i, i)
writer.close()

运行上述程序之后,会发现程序所在目录中多了个文件夹runs,文件夹runs中有一个文件夹test:

如果想要查看可视化结果,可以在终端执行命令:

python 复制代码
tensorboard --logdir=runs/

在这里,logdir=事件文件所在的文件夹,在本例中为logdir=test(注意不是logdir = runs/test,只能是一级目录):

python 复制代码
tensorboard --logdir=test     

运行之后,返回结果:

python 复制代码
TensorFlow installation not found - running with reduced feature set.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)

点击网址,即可查看可视化结果。

如果打开网址看不到你的数据结果,尝试把相对路径(test)换为绝对路径再重新运行并打开网址即可。

另外,如果想要更改端口号(即上面网址中的6006),可以在终端执行以下命令:

python 复制代码
tensorboard --logdir=test --port=6007 

2.2 读取图像数据

2.2.1 add_jmage函数介绍

SummaryWriter 中的 .add_image 方法用于将图像数据记录到 TensorBoard 中,帮助用户可视化训练过程中的图像变化。

基本语法

python 复制代码
add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW')

参数说明:

  • tag (str): 图像的名称标识,用于在 TensorBoard 中分类显示。
  • img_tensor (Tensor): 图像张量,支持多种格式(需通过 dataformats 指定)。
  • global_step (int, 可选): 记录的步骤(如训练步数或 epoch),用于滑动查看图像变化。
  • walltime (float, 可选): 自定义时间戳,默认使用当前时间。
  • dataformats (str, 可选): 图像张量的格式,默认为 'CHW'(通道数 × 高度 × 宽度)。

图像张量格式

常见格式:

  • 单图像:(C, H, W)(默认)或 (H, W, C)(需设置 dataformats='HWC')。
  • 灰度图:(1, H, W) 或 (H, W)(需设置 dataformats='HW')。
  • 批量图像:使用 .add_images 或 torchvision.utils.make_grid 合并为网格后记录。

数据类型与范围:

  • 浮点型:数值范围应为 [0.0, 1.0]。
  • 整型(uint8):数值范围应为 [0, 255]。

使用步骤

  1. 导入库并创建 SummaryWriter:
python 复制代码
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
  1. 准备图像张量:

从随机数据生成:

python 复制代码
img = torch.rand(3, 64, 64)  # 3通道,64x64,范围[0,1]

从文件加载并转换:

python 复制代码
from PIL import Image
import numpy as np
import torch

# 读取图像并转为张量
image = Image.open("example.jpg")
img_array = np.array(image)
if img_array.ndim == 3:
    img_tensor = torch.from_numpy(img_array.transpose(2, 0, 1))  # HWC → CHW
else:
    img_tensor = torch.from_numpy(img_array).unsqueeze(0)       # 灰度图添加通道维度
img_tensor = img_tensor.float() / 255.0  # 归一化到[0,1]
  1. 记录图像:
python 复制代码
writer.add_image('my_image', img_tensor, global_step=0)
  1. 处理多张图像(网格):
python 复制代码
import torchvision

images = torch.randn(16, 3, 64, 64)        # 16张图像
grid = torchvision.utils.make_grid(images, nrow=4)  # 合并为4x4网格
writer.add_image('16_images_grid', grid, 0)

注意事项

  • 数据范围:若张量值超出 [0, 1] 或 [0, 255],需手动归一化:
python 复制代码
img = (img - img.min()) / (img.max() - img.min())  # 归一化到[0,1]
  • 格式匹配:若张量形状与 dataformats 不匹配,TensorBoard 可能无法正确显示。
    例如:(H, W, C) 需设置 dataformats='HWC'。单通道灰度图 (H, W) 需设置 dataformats='HW'。
  • GPU 张量:若张量在 GPU 上,需先移至 CPU:
python 复制代码
img_tensor = img_tensor.cpu()

示例代码

python 复制代码
from torch.utils.tensorboard import SummaryWriter
import torch

writer = SummaryWriter()

# 生成随机图像(3通道,64x64)
img = torch.rand(3, 64, 64)
writer.add_image('random_image', img, 0)

# 记录多张图像的网格
images = torch.randn(16, 3, 64, 64)
grid = torchvision.utils.make_grid(images, nrow=4)
writer.add_image('image_grid', grid, 0)

writer.close()

运行后,在命令行启动 TensorBoard:

python 复制代码
tensorboard --logdir=runs

通过以上步骤,可有效利用 .add_image 监控模型输入、输出或中间特征图,提升训练过程的可解释性。

2.2.2 示例

接下来使用SummaryWriter类中的.add_image()方法来读取图像数据。使用的图像为.jpg格式文件:

由于PIL中Image.open()打开的图片类型为<class 'PIL.JpegImagePlugin.JpegImageFile'>,而方法add_image()中

img_tensor 参数必须为图像张量: (torch.Tensor, numpy.ndarray, or string/blobname)

python 复制代码
        Args:
            tag (str): Data identifier
            img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data
            global_step (int): Global step value to record
            walltime (float): Optional override default walltime (time.time())
              seconds after epoch of event
            dataformats (str): Image data format specification of the form
              CHW, HWC, HW, WH, etc.

因此,使用img_array = np.array(img_PIL)将PIL类型图像转换为numpy.ndarray类型图像。

python 复制代码
>>> import numpy as np
>>> img_array = np.array(img_PIL)
>>> type(img_array)
<class 'numpy.ndarray'>
>>> img_array.shape
(512, 768, 3)

此外,转换为numpy.ndarray类型图像之后,方法add_image()中参数dataformats (str, 可选): 图像张量的格式,默认为 'CHW'(通道数 × 高度 × 宽度)。而转换之后的格式为(H x W x C),因此需要设置该参数为dataformats='HWC'。

完整代码:

python 复制代码
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image

# 创建一个 SummaryWriter 对象
writer = SummaryWriter("runs")
image_path = "hymenoptera_data/train/ants/0013035.jpg"
img_PIL = Image.open(image_path)
img_array = np.array(img_PIL)
print(type(img_array))
print(img_array.shape)

writer.add_image("train", img_array, 1, dataformats='HWC')
for i in range(100):
    writer.add_scalar("y=3x", 3*i, i)
writer.close()

在终端执行命令:

python 复制代码
tensorboard --logdir=E:\my_pycharm_projects\project1\runs  

运行:

python 复制代码
TensorFlow installation not found - running with reduced feature set.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)

点击网址,即可查看可视化结果:

相关推荐
郝YH是人间理想9 分钟前
Python面向对象
开发语言·python·面向对象
藍海琴泉10 分钟前
蓝桥杯算法精讲:二分查找实战与变种解析
python·算法
默 语10 分钟前
10分钟打造专属AI助手!ToDesk云电脑/顺网云/海马云操作DeepSeek哪家强?
人工智能·电脑·todesk
Donvink2 小时前
【Dive Into Stable Diffusion v3.5】2:Stable Diffusion v3.5原理介绍
人工智能·深度学习·语言模型·stable diffusion·aigc·transformer
宇灵梦2 小时前
大模型金融企业场景落地应用
人工智能
lsrsyx2 小时前
中信银行太原长治路支行赴老年活动服务中心开展专题金融知识宣讲
大数据·人工智能
烟锁池塘柳03 小时前
【深度学习】Self-Attention机制详解:Transformer的核心引擎
人工智能·深度学习·transformer
Matrix_113 小时前
论文阅读:Self-Supervised Video Defocus Deblurring with Atlas Learning
人工智能·计算摄影
云上艺旅4 小时前
K8S学习之基础四十七:k8s中部署fluentd
学习·云原生·容器·kubernetes
mqwguardain5 小时前
python常见反爬思路详解
开发语言·python