PyTorch Lightning Callback介绍

PyTorch Lightning Callback 介绍

在 PyTorch 中,callbacks(回调函数)不是原生支持的核心功能,但在深度学习中非常常见,尤其是用来监控训练过程、调整超参数或执行特定的任务。许多高级深度学习框架(如 PyTorch Lightning 和 FastAI)都基于 PyTorch,并内置了 callback 支持。

PyTorch Lightning 提供了一个易于扩展的回调机制,允许用户在训练过程中插入自定义逻辑。回调类继承自 pytorch_lightning.callbacks.Callback,可以覆盖以下方法:

常用方法
  • on_fit_start: 在训练(fit)开始时调用。
  • on_fit_end: 在训练(fit)结束时调用。
  • on_train_epoch_start: 在每个训练 epoch 开始时调用。
  • on_train_epoch_end: 在每个训练 epoch 结束时调用。
  • on_validation_epoch_start: 在每个验证 epoch 开始时调用。
  • on_validation_epoch_end: 在每个验证 epoch 结束时调用。
  • on_test_epoch_start: 在测试 epoch 开始时调用。
  • on_test_epoch_end: 在测试 epoch 结束时调用。
  • on_train_batch_end: 在每个训练 batch 结束时调用。
  • on_validation_batch_end: 在每个验证 batch 结束时调用。
  • on_test_batch_end: 在每个测试 batch 结束时调用。

示例: 自定义 Callback

以下示例实现了一个打印日志的回调:

复制代码
from pytorch_lightning.callbacks import Callback

class PrintCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        print(f"Epoch {trainer.current_epoch}: Training ended!")

    def on_validation_epoch_end(self, trainer, pl_module):
        print(f"Epoch {trainer.current_epoch}: Validation ended!")

使用时将回调传递给 Trainer

复制代码
from pytorch_lightning import Trainer

trainer = Trainer(callbacks=[PrintCallback()])

基于 Hydra 配置实例化 Callback

Hydra 是一个灵活的配置管理工具,常用于深度学习项目中动态管理超参数。通过结合 Hydra 和 PyTorch Lightning,可以动态配置并实例化 Callback。

步骤:

1. 安装 Hydra

复制代码
pip install hydra-core --upgrade

2. 定义 Hydra 配置文件 : 创建一个 YAML 配置文件(如 config.yaml)来管理 Callback 的配置:

复制代码
callbacks:
  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: "val_loss"
    save_top_k: 1
    mode: "min"

  early_stopping:
    _target_: pytorch_lightning.callbacks.EarlyStopping
    monitor: "val_loss"
    patience: 5
    mode: "min"

3. 在代码中动态实例化 : 使用 hydra.utils.instantiate 方法实例化回调对象:

复制代码
import hydra
from hydra.utils import instantiate
from pytorch_lightning import Trainer
from omegaconf import OmegaConf

@hydra.main(config_path=".", config_name="config")
def main(cfg):
    # Instantiate callbacks from config
    callbacks = [instantiate(cfg.callbacks[key]) for key in cfg.callbacks]

    # Example: Define a simple PyTorch Lightning model
    from pytorch_lightning import LightningModule
    import torch.nn.functional as F

    class SimpleModel(LightningModule):
        def __init__(self):
            super().__init__()
            self.layer = torch.nn.Linear(10, 1)

        def forward(self, x):
            return self.layer(x)

        def training_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self(x)
            loss = F.mse_loss(y_hat, y)
            return loss

        def configure_optimizers(self):
            return torch.optim.Adam(self.parameters(), lr=0.001)

    # Instantiate trainer
    trainer = Trainer(callbacks=callbacks, max_epochs=10)

    # Simulated data loader
    from torch.utils.data import DataLoader, TensorDataset
    import torch

    x = torch.rand(100, 10)
    y = torch.rand(100, 1)
    train_loader = DataLoader(TensorDataset(x, y), batch_size=32)

    model = SimpleModel()
    trainer.fit(model, train_loader)

if __name__ == "__main__":
    main()
解释:如何通过配置文件动态管理 Callback
  1. 配置文件中,_target_ 指定回调类的完整路径。
  2. 使用 hydra.utils.instantiate 根据配置动态实例化对象。
  3. 将实例化后的回调传递给 Trainer
优势
  1. 动态配置:通过 YAML 文件可以快速更改回调逻辑而无需修改代码。
  2. 模块化管理:方便管理多个回调类,清晰直观。
  3. 灵活性:支持自定义 Callback 和 Lightning 内置回调的结合使用。

此方法适用于多种场景,比如动态调整模型保存路径、早停策略等。

相关推荐
Raink老师5 小时前
【AI面试临阵磨枪-79】实时数据 RAG:订单、商家、物流、天气、动态库存
人工智能·面试·职场和发展
脑极体5 小时前
点亮星河AI+鸿蒙,一座艺术场馆的日神觉醒
人工智能·华为·harmonyos
Cosolar5 小时前
Chroma向量库面试学习指南
数据库·人工智能·面试·职场和发展·数据库架构
BUG指挥官5 小时前
Claude Code的自动化编程
人工智能
意图共鸣6 小时前
意图共鸣科技《认知智能白皮书》——感知与执行分离:认知架构(CA)如何重塑大模型底层结构
人工智能·架构
等一个人的@6 小时前
让数据自己开口:数睿通智库新增智能问数模块
人工智能·自然语言处理
ZGi.ai6 小时前
人工审查节点:让自动化工作流多一步人工把关
运维·人工智能·自动化·人机协同·智能体工作流·人工审查
风吹夏回6 小时前
Python 全局异常处理:从“满屏 try-except”到优雅兜底
开发语言·python
王莎莎-MinerU6 小时前
MinerU 深度技术解析:从架构原理到生产部署的全面指南
css·人工智能·自然语言处理·架构·ocr·个人开发
盘古信息IMS6 小时前
盘古信息IMS V6 8.0重磅发布:以薪火AI数智平台点燃离散制造数智化引擎
大数据·人工智能·制造