PyTorch Lightning Callback介绍

PyTorch Lightning Callback 介绍

在 PyTorch 中,callbacks(回调函数)不是原生支持的核心功能,但在深度学习中非常常见,尤其是用来监控训练过程、调整超参数或执行特定的任务。许多高级深度学习框架(如 PyTorch Lightning 和 FastAI)都基于 PyTorch,并内置了 callback 支持。

PyTorch Lightning 提供了一个易于扩展的回调机制,允许用户在训练过程中插入自定义逻辑。回调类继承自 pytorch_lightning.callbacks.Callback,可以覆盖以下方法:

常用方法
  • on_fit_start: 在训练(fit)开始时调用。
  • on_fit_end: 在训练(fit)结束时调用。
  • on_train_epoch_start: 在每个训练 epoch 开始时调用。
  • on_train_epoch_end: 在每个训练 epoch 结束时调用。
  • on_validation_epoch_start: 在每个验证 epoch 开始时调用。
  • on_validation_epoch_end: 在每个验证 epoch 结束时调用。
  • on_test_epoch_start: 在测试 epoch 开始时调用。
  • on_test_epoch_end: 在测试 epoch 结束时调用。
  • on_train_batch_end: 在每个训练 batch 结束时调用。
  • on_validation_batch_end: 在每个验证 batch 结束时调用。
  • on_test_batch_end: 在每个测试 batch 结束时调用。

示例: 自定义 Callback

以下示例实现了一个打印日志的回调:

复制代码
from pytorch_lightning.callbacks import Callback

class PrintCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        print(f"Epoch {trainer.current_epoch}: Training ended!")

    def on_validation_epoch_end(self, trainer, pl_module):
        print(f"Epoch {trainer.current_epoch}: Validation ended!")

使用时将回调传递给 Trainer

复制代码
from pytorch_lightning import Trainer

trainer = Trainer(callbacks=[PrintCallback()])

基于 Hydra 配置实例化 Callback

Hydra 是一个灵活的配置管理工具,常用于深度学习项目中动态管理超参数。通过结合 Hydra 和 PyTorch Lightning,可以动态配置并实例化 Callback。

步骤:

1. 安装 Hydra

复制代码
pip install hydra-core --upgrade

2. 定义 Hydra 配置文件 : 创建一个 YAML 配置文件(如 config.yaml)来管理 Callback 的配置:

复制代码
callbacks:
  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: "val_loss"
    save_top_k: 1
    mode: "min"

  early_stopping:
    _target_: pytorch_lightning.callbacks.EarlyStopping
    monitor: "val_loss"
    patience: 5
    mode: "min"

3. 在代码中动态实例化 : 使用 hydra.utils.instantiate 方法实例化回调对象:

复制代码
import hydra
from hydra.utils import instantiate
from pytorch_lightning import Trainer
from omegaconf import OmegaConf

@hydra.main(config_path=".", config_name="config")
def main(cfg):
    # Instantiate callbacks from config
    callbacks = [instantiate(cfg.callbacks[key]) for key in cfg.callbacks]

    # Example: Define a simple PyTorch Lightning model
    from pytorch_lightning import LightningModule
    import torch.nn.functional as F

    class SimpleModel(LightningModule):
        def __init__(self):
            super().__init__()
            self.layer = torch.nn.Linear(10, 1)

        def forward(self, x):
            return self.layer(x)

        def training_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self(x)
            loss = F.mse_loss(y_hat, y)
            return loss

        def configure_optimizers(self):
            return torch.optim.Adam(self.parameters(), lr=0.001)

    # Instantiate trainer
    trainer = Trainer(callbacks=callbacks, max_epochs=10)

    # Simulated data loader
    from torch.utils.data import DataLoader, TensorDataset
    import torch

    x = torch.rand(100, 10)
    y = torch.rand(100, 1)
    train_loader = DataLoader(TensorDataset(x, y), batch_size=32)

    model = SimpleModel()
    trainer.fit(model, train_loader)

if __name__ == "__main__":
    main()
解释:如何通过配置文件动态管理 Callback
  1. 配置文件中,_target_ 指定回调类的完整路径。
  2. 使用 hydra.utils.instantiate 根据配置动态实例化对象。
  3. 将实例化后的回调传递给 Trainer
优势
  1. 动态配置:通过 YAML 文件可以快速更改回调逻辑而无需修改代码。
  2. 模块化管理:方便管理多个回调类,清晰直观。
  3. 灵活性:支持自定义 Callback 和 Lightning 内置回调的结合使用。

此方法适用于多种场景,比如动态调整模型保存路径、早停策略等。

相关推荐
爱学习的小鱼gogo11 分钟前
pyhton 螺旋矩阵(指针-矩阵-中等)含源码(二十六)
python·算法·矩阵·指针·经验·二维数组·逆序
文火冰糖的硅基工坊15 分钟前
[嵌入式系统-146]:五次工业革命对应的机器人形态的演进、主要功能的演进以及操作系统的演进
前端·网络·人工智能·嵌入式硬件·机器人
猫头虎21 分钟前
openAI发布的AI浏览器:什么是Atlas?(含 ChatGPT 浏览功能)macOS 离线下载安装Atlas完整教程
人工智能·macos·chatgpt·langchain·prompt·aigc·agi
老六哥_AI助理指南26 分钟前
为什么AI会改变单片机的未来?
人工智能·单片机·嵌入式硬件
SEO_juper37 分钟前
2026 AI可见性:构建未来-proof策略的顶级工具
人工智能·搜索引擎·百度·工具·数字营销
sivdead40 分钟前
当前智能体的几种形式
人工智能·后端·agent
AIGC_北苏41 分钟前
大语言模型,一个巨大的矩阵
人工智能·语言模型·矩阵
算家计算1 小时前
DeepSeek-OCR本地部署教程:DeepSeek突破性开创上下文光学压缩,10倍效率重构文本处理范式
人工智能·开源·deepseek
言之。1 小时前
Andrej Karpathy 演讲【PyTorch at Tesla】
人工智能·pytorch·python
算家计算1 小时前
快手推出“工具+模型+平台”AI编程生态!大厂挤占AI赛道,中小企业如何突围?
人工智能·ai编程·资讯