使用 Pytorch Lightning 时追踪指标和可视化指标

【PL 基础】追踪指标和可视化指标

  • 摘要
  • [1. 跟踪指标](#1. 跟踪指标)
  • [2. 在命令行中查看](#2. 在命令行中查看)
  • [3. 在浏览器中查看](#3. 在浏览器中查看)
  • [4. 配置保存目录](#4. 配置保存目录)

摘要

本文介绍了PyTorch Lightning中指标追踪和可视化的方法。主要内容包括:1)使用self.log和self.log_dict记录训练指标;2)通过设置prog_bar=True在命令行进度条显示指标;3)使用TensorBoard可视化指标变化曲线;4)验证集和测试集指标的自动聚合功能,支持mean、min、max等聚合方式;5)通过default_root_dir参数自定义日志保存路径。这些功能为模型训练过程提供了便捷的性能监控和分析手段。

1. 跟踪指标

指标可视化是了解模型在整个模型开发过程中表现的最基本但最强大的方法。

要跟踪指标,只需使用 LightningModule 中提供的 self.log 方法

python 复制代码
class LitModel(L.LightningModule):
    def training_step(self, batch, batch_idx):
        value = ...
        self.log("some_value", value)

要一次记录多个指标,请使用 self.log_dict

python 复制代码
values = {"loss": loss, "acc": acc, "metric_n": metric_n}  # add more items if needed
self.log_dict(values)

2. 在命令行中查看

要在命令行进度条中查看指标,请将 prog_bar 参数设置为 True

python 复制代码
self.log(..., prog_bar=True)
python 复制代码
Epoch 3:  33%|███▉        | 307/938 [00:01<00:02, 289.04it/s, loss=0.198, v_num=51, acc=0.211, metric_n=0.937]

3. 在浏览器中查看

要在浏览器中查看指标,您需要使用具有这些功能的实验管理器。

默认情况下,Lightning 使用 Tensorboard(如果可用)和简单的 CSV 记录器。

python 复制代码
# every trainer already has tensorboard enabled by default (if the dependency is available)
trainer = Trainer()

要启动 tensorboard 控制面板,请在命令行上运行以下命令。

python 复制代码
tensorboard --logdir=lightning_logs/

如果您使用的是 colab、kaggle 或 jupyter 等笔记本环境,请使用此命令启动 Tensorboard

python 复制代码
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/

training_step 中调用 self.log 时,它会生成一个时间序列,显示指标随时间的变化。

但是,对于验证集和测试集,通常对绘制每批数据的量度值不感兴趣。相反,希望计算整个数据拆分的汇总统计数据(例如 averageminmax)。

当您在 validation_steptest_step 中调用 self.log 时,Lightning 会自动累积指标,并在它经历整个拆分(epoch)后对其进行平均。

python 复制代码
def validation_step(self, batch, batch_idx):
    value = batch_idx + 1
    self.log("average_value", value)

如果您不想平均,也可以通过传递 reduce_fx 参数{min,max,sum}来选择。

python 复制代码
# default function
self.log(..., reduce_fx="mean")

4. 配置保存目录

默认情况下,记录的任何内容都将保存到当前工作目录中。要使用其他目录,请在 Trainer 中设置 default_root_dir 参数。

python 复制代码
Trainer(default_root_dir="/your/custom/path")
相关推荐
William.csj16 小时前
Pytorch/CUDA——flash-attn 库编译的 gcc 版本问题
pytorch·cuda
Green1Leaves1 天前
pytorch学习-9.多分类问题
人工智能·pytorch·学习
摸爬滚打李上进2 天前
重生学AI第十六集:线性层nn.Linear
人工智能·pytorch·python·神经网络·机器学习
HuashuiMu花水木2 天前
PyTorch笔记1----------Tensor(张量):基本概念、创建、属性、算数运算
人工智能·pytorch·笔记
喝过期的拉菲2 天前
如何使用 Pytorch Lightning 启用早停机制
pytorch·lightning·早停机制
kk爱闹2 天前
【挑战14天学完python和pytorch】- day01
android·pytorch·python
Yo_Becky2 天前
【PyTorch】PyTorch预训练模型缓存位置迁移,也可拓展应用于其他文件的迁移
人工智能·pytorch·经验分享·笔记·python·程序人生·其他
xinxiangwangzhi_2 天前
pytorch底层原理学习--PyTorch 架构梳理
人工智能·pytorch·架构
FF-Studio3 天前
【硬核数学 · LLM篇】3.1 Transformer之心:自注意力机制的线性代数解构《从零构建机器学习、深度学习到LLM的数学认知》
人工智能·pytorch·深度学习·线性代数·机器学习·数学建模·transformer
盼小辉丶3 天前
PyTorch实战(14)——条件生成对抗网络(conditional GAN,cGAN)
人工智能·pytorch·生成对抗网络