使用 Pytorch Lightning 时追踪指标和可视化指标

【PL 基础】追踪指标和可视化指标

  • 摘要
  • [1. 跟踪指标](#1. 跟踪指标)
  • [2. 在命令行中查看](#2. 在命令行中查看)
  • [3. 在浏览器中查看](#3. 在浏览器中查看)
  • [4. 配置保存目录](#4. 配置保存目录)

摘要

本文介绍了PyTorch Lightning中指标追踪和可视化的方法。主要内容包括:1)使用self.log和self.log_dict记录训练指标;2)通过设置prog_bar=True在命令行进度条显示指标;3)使用TensorBoard可视化指标变化曲线;4)验证集和测试集指标的自动聚合功能,支持mean、min、max等聚合方式;5)通过default_root_dir参数自定义日志保存路径。这些功能为模型训练过程提供了便捷的性能监控和分析手段。

1. 跟踪指标

指标可视化是了解模型在整个模型开发过程中表现的最基本但最强大的方法。

要跟踪指标,只需使用 LightningModule 中提供的 self.log 方法

python 复制代码
class LitModel(L.LightningModule):
    def training_step(self, batch, batch_idx):
        value = ...
        self.log("some_value", value)

要一次记录多个指标,请使用 self.log_dict

python 复制代码
values = {"loss": loss, "acc": acc, "metric_n": metric_n}  # add more items if needed
self.log_dict(values)

2. 在命令行中查看

要在命令行进度条中查看指标,请将 prog_bar 参数设置为 True

python 复制代码
self.log(..., prog_bar=True)
python 复制代码
Epoch 3:  33%|███▉        | 307/938 [00:01<00:02, 289.04it/s, loss=0.198, v_num=51, acc=0.211, metric_n=0.937]

3. 在浏览器中查看

要在浏览器中查看指标,您需要使用具有这些功能的实验管理器。

默认情况下,Lightning 使用 Tensorboard(如果可用)和简单的 CSV 记录器。

python 复制代码
# every trainer already has tensorboard enabled by default (if the dependency is available)
trainer = Trainer()

要启动 tensorboard 控制面板,请在命令行上运行以下命令。

python 复制代码
tensorboard --logdir=lightning_logs/

如果您使用的是 colab、kaggle 或 jupyter 等笔记本环境,请使用此命令启动 Tensorboard

python 复制代码
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/

training_step 中调用 self.log 时,它会生成一个时间序列,显示指标随时间的变化。

但是,对于验证集和测试集,通常对绘制每批数据的量度值不感兴趣。相反,希望计算整个数据拆分的汇总统计数据(例如 averageminmax)。

当您在 validation_steptest_step 中调用 self.log 时,Lightning 会自动累积指标,并在它经历整个拆分(epoch)后对其进行平均。

python 复制代码
def validation_step(self, batch, batch_idx):
    value = batch_idx + 1
    self.log("average_value", value)

如果您不想平均,也可以通过传递 reduce_fx 参数{min,max,sum}来选择。

python 复制代码
# default function
self.log(..., reduce_fx="mean")

4. 配置保存目录

默认情况下,记录的任何内容都将保存到当前工作目录中。要使用其他目录,请在 Trainer 中设置 default_root_dir 参数。

python 复制代码
Trainer(default_root_dir="/your/custom/path")
相关推荐
嗝o゚10 小时前
昇腾CANN ge 仓的图优化 Pass:哪些 Pass 真正影响推理性能
pytorch·python·深度学习·cann·ge-pass
松☆13 小时前
昇腾NPU上的张量操作库,和PyTorch的张量操作有啥不一样?
人工智能·pytorch·python
weixin_5500831513 小时前
PyTorch 实战:从零搭建手写数字识别系统(CNN 卷积神经网络)
人工智能·pytorch·cnn
5201-14 小时前
Cube MatMul:为什么矩阵乘法选了 Cube 而不是 Vector
pytorch·python·矩阵
MediaTea16 小时前
DL:Transformer 的基本原理与 PyTorch 实现
人工智能·pytorch·python·深度学习·transformer
心中有国也有家17 小时前
MindSpore 适配 NPU 的全链路解析——从算子注册到端到端性能调优
人工智能·pytorch·python·学习·numpy
小糖学代码20 小时前
LLM系列:1.python入门:19.Requests(网络库)
人工智能·pytorch·深度学习·自然语言处理
心中有国也有家20 小时前
PyTorch 适配 NPU:从 torch_npu 到 CANN 算子的全链路技术解析
人工智能·pytorch·python
盼小辉丶20 小时前
PyTorch强化学习实战(10)——强化学习高级组件
人工智能·pytorch·python·强化学习
MediaTea20 小时前
DL:生成对抗网络的基本原理与 PyTorch 实现
人工智能·pytorch·深度学习·神经网络·生成对抗网络