PyTorch Lightning Callback介绍

PyTorch Lightning Callback 介绍

在 PyTorch 中,callbacks(回调函数)不是原生支持的核心功能,但在深度学习中非常常见,尤其是用来监控训练过程、调整超参数或执行特定的任务。许多高级深度学习框架(如 PyTorch Lightning 和 FastAI)都基于 PyTorch,并内置了 callback 支持。

PyTorch Lightning 提供了一个易于扩展的回调机制,允许用户在训练过程中插入自定义逻辑。回调类继承自 pytorch_lightning.callbacks.Callback,可以覆盖以下方法:

常用方法
  • on_fit_start: 在训练(fit)开始时调用。
  • on_fit_end: 在训练(fit)结束时调用。
  • on_train_epoch_start: 在每个训练 epoch 开始时调用。
  • on_train_epoch_end: 在每个训练 epoch 结束时调用。
  • on_validation_epoch_start: 在每个验证 epoch 开始时调用。
  • on_validation_epoch_end: 在每个验证 epoch 结束时调用。
  • on_test_epoch_start: 在测试 epoch 开始时调用。
  • on_test_epoch_end: 在测试 epoch 结束时调用。
  • on_train_batch_end: 在每个训练 batch 结束时调用。
  • on_validation_batch_end: 在每个验证 batch 结束时调用。
  • on_test_batch_end: 在每个测试 batch 结束时调用。

示例: 自定义 Callback

以下示例实现了一个打印日志的回调:

from pytorch_lightning.callbacks import Callback

class PrintCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        print(f"Epoch {trainer.current_epoch}: Training ended!")

    def on_validation_epoch_end(self, trainer, pl_module):
        print(f"Epoch {trainer.current_epoch}: Validation ended!")

使用时将回调传递给 Trainer

from pytorch_lightning import Trainer

trainer = Trainer(callbacks=[PrintCallback()])

基于 Hydra 配置实例化 Callback

Hydra 是一个灵活的配置管理工具,常用于深度学习项目中动态管理超参数。通过结合 Hydra 和 PyTorch Lightning,可以动态配置并实例化 Callback。

步骤:

1. 安装 Hydra

pip install hydra-core --upgrade

2. 定义 Hydra 配置文件 : 创建一个 YAML 配置文件(如 config.yaml)来管理 Callback 的配置:

callbacks:
  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: "val_loss"
    save_top_k: 1
    mode: "min"

  early_stopping:
    _target_: pytorch_lightning.callbacks.EarlyStopping
    monitor: "val_loss"
    patience: 5
    mode: "min"

3. 在代码中动态实例化 : 使用 hydra.utils.instantiate 方法实例化回调对象:

import hydra
from hydra.utils import instantiate
from pytorch_lightning import Trainer
from omegaconf import OmegaConf

@hydra.main(config_path=".", config_name="config")
def main(cfg):
    # Instantiate callbacks from config
    callbacks = [instantiate(cfg.callbacks[key]) for key in cfg.callbacks]

    # Example: Define a simple PyTorch Lightning model
    from pytorch_lightning import LightningModule
    import torch.nn.functional as F

    class SimpleModel(LightningModule):
        def __init__(self):
            super().__init__()
            self.layer = torch.nn.Linear(10, 1)

        def forward(self, x):
            return self.layer(x)

        def training_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self(x)
            loss = F.mse_loss(y_hat, y)
            return loss

        def configure_optimizers(self):
            return torch.optim.Adam(self.parameters(), lr=0.001)

    # Instantiate trainer
    trainer = Trainer(callbacks=callbacks, max_epochs=10)

    # Simulated data loader
    from torch.utils.data import DataLoader, TensorDataset
    import torch

    x = torch.rand(100, 10)
    y = torch.rand(100, 1)
    train_loader = DataLoader(TensorDataset(x, y), batch_size=32)

    model = SimpleModel()
    trainer.fit(model, train_loader)

if __name__ == "__main__":
    main()
解释:如何通过配置文件动态管理 Callback
  1. 配置文件中,_target_ 指定回调类的完整路径。
  2. 使用 hydra.utils.instantiate 根据配置动态实例化对象。
  3. 将实例化后的回调传递给 Trainer
优势
  1. 动态配置:通过 YAML 文件可以快速更改回调逻辑而无需修改代码。
  2. 模块化管理:方便管理多个回调类,清晰直观。
  3. 灵活性:支持自定义 Callback 和 Lightning 内置回调的结合使用。

此方法适用于多种场景,比如动态调整模型保存路径、早停策略等。

相关推荐
百家方案32 分钟前
「下载」京东数科-数字孪生智慧园区解决方案:打通园区数据、融合园区业务、集成园区服务、共建园区生态,实现真实与数字孪生园区
人工智能·云计算·智慧园区·数智化园区
MUTA️33 分钟前
专业版pycharm与服务器连接
人工智能·python·深度学习·计算机视觉·pycharm
xuanfengwuxiang1 小时前
安卓帧率获取
android·python·测试工具·adb·性能优化·pycharm
m0_748240441 小时前
《通义千问AI落地—中》:前端实现
前端·人工智能·状态模式
cooldream20091 小时前
RDFS—RDF模型属性扩展解析
人工智能·知识图谱·知识表示
觅远2 小时前
python+PyMuPDF库:(一)创建pdf文件及内容读取和写入
开发语言·python·pdf
MinIO官方账号2 小时前
使用亚马逊针对 PyTorch 和 MinIO 的 S3 连接器实现可迭代式数据集
人工智能·pytorch·python
四口鲸鱼爱吃盐2 小时前
Pytorch | 利用IE-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python·深度学习·计算机视觉
四口鲸鱼爱吃盐2 小时前
Pytorch | 利用EMI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python
章章小鱼3 小时前
LLM预训练recipe — 摘要版
人工智能