使用 Pytorch Lightning 时追踪指标和可视化指标

【PL 基础】追踪指标和可视化指标

  • 摘要
  • [1. 跟踪指标](#1. 跟踪指标)
  • [2. 在命令行中查看](#2. 在命令行中查看)
  • [3. 在浏览器中查看](#3. 在浏览器中查看)
  • [4. 配置保存目录](#4. 配置保存目录)

摘要

本文介绍了PyTorch Lightning中指标追踪和可视化的方法。主要内容包括:1)使用self.log和self.log_dict记录训练指标;2)通过设置prog_bar=True在命令行进度条显示指标;3)使用TensorBoard可视化指标变化曲线;4)验证集和测试集指标的自动聚合功能,支持mean、min、max等聚合方式;5)通过default_root_dir参数自定义日志保存路径。这些功能为模型训练过程提供了便捷的性能监控和分析手段。

1. 跟踪指标

指标可视化是了解模型在整个模型开发过程中表现的最基本但最强大的方法。

要跟踪指标,只需使用 LightningModule 中提供的 self.log 方法

python 复制代码
class LitModel(L.LightningModule):
    def training_step(self, batch, batch_idx):
        value = ...
        self.log("some_value", value)

要一次记录多个指标,请使用 self.log_dict

python 复制代码
values = {"loss": loss, "acc": acc, "metric_n": metric_n}  # add more items if needed
self.log_dict(values)

2. 在命令行中查看

要在命令行进度条中查看指标,请将 prog_bar 参数设置为 True

python 复制代码
self.log(..., prog_bar=True)
python 复制代码
Epoch 3:  33%|███▉        | 307/938 [00:01<00:02, 289.04it/s, loss=0.198, v_num=51, acc=0.211, metric_n=0.937]

3. 在浏览器中查看

要在浏览器中查看指标,您需要使用具有这些功能的实验管理器。

默认情况下,Lightning 使用 Tensorboard(如果可用)和简单的 CSV 记录器。

python 复制代码
# every trainer already has tensorboard enabled by default (if the dependency is available)
trainer = Trainer()

要启动 tensorboard 控制面板,请在命令行上运行以下命令。

python 复制代码
tensorboard --logdir=lightning_logs/

如果您使用的是 colab、kaggle 或 jupyter 等笔记本环境,请使用此命令启动 Tensorboard

python 复制代码
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/

training_step 中调用 self.log 时,它会生成一个时间序列,显示指标随时间的变化。

但是,对于验证集和测试集,通常对绘制每批数据的量度值不感兴趣。相反,希望计算整个数据拆分的汇总统计数据(例如 averageminmax)。

当您在 validation_steptest_step 中调用 self.log 时,Lightning 会自动累积指标,并在它经历整个拆分(epoch)后对其进行平均。

python 复制代码
def validation_step(self, batch, batch_idx):
    value = batch_idx + 1
    self.log("average_value", value)

如果您不想平均,也可以通过传递 reduce_fx 参数{min,max,sum}来选择。

python 复制代码
# default function
self.log(..., reduce_fx="mean")

4. 配置保存目录

默认情况下,记录的任何内容都将保存到当前工作目录中。要使用其他目录,请在 Trainer 中设置 default_root_dir 参数。

python 复制代码
Trainer(default_root_dir="/your/custom/path")
相关推荐
AndrewHZ3 小时前
【图像处理基石】如何对遥感图像进行目标检测?
图像处理·人工智能·pytorch·目标检测·遥感图像·小目标检测·旋转目标检测
墨染点香4 小时前
第七章 Pytorch构建模型详解【构建CIFAR10模型结构】
人工智能·pytorch·python
兮℡檬,7 小时前
房价预测|Pytorch
人工智能·pytorch·python
贝塔西塔19 小时前
PytorchLightning最佳实践基础篇
pytorch·深度学习·lightning·编程框架
小猪和纸箱20 小时前
通过Python交互式控制台理解Conv1d的输入输出
pytorch
墨染枫1 天前
pytorch学习笔记-使用DataLoader加载固有Datasets(CIFAR10),使用tensorboard进行可视化
pytorch·笔记·学习
九章云极AladdinEdu2 天前
GitHub新手生存指南:AI项目版本控制与协作实战
人工智能·pytorch·opencv·机器学习·github·gpu算力
z are2 天前
PyTorch 模型开发全栈指南:从定义、修改到保存的完整闭环
人工智能·pytorch·python
点云SLAM2 天前
Pytorch中cuda相关操作详见和代码示例
人工智能·pytorch·python·深度学习·3d·cuda·多gpu训练
cwn_2 天前
Sequential 损失函数 反向传播 优化器 模型的使用修改保存加载
人工智能·pytorch·python·深度学习·机器学习