PyTorch Lightning Callback 介绍
在 PyTorch 中,callbacks(回调函数)不是原生支持的核心功能,但在深度学习中非常常见,尤其是用来监控训练过程、调整超参数或执行特定的任务。许多高级深度学习框架(如 PyTorch Lightning 和 FastAI)都基于 PyTorch,并内置了 callback 支持。
PyTorch Lightning 提供了一个易于扩展的回调机制,允许用户在训练过程中插入自定义逻辑。回调类继承自 pytorch_lightning.callbacks.Callback
,可以覆盖以下方法:
常用方法
on_fit_start
: 在训练(fit)开始时调用。on_fit_end
: 在训练(fit)结束时调用。on_train_epoch_start
: 在每个训练 epoch 开始时调用。on_train_epoch_end
: 在每个训练 epoch 结束时调用。on_validation_epoch_start
: 在每个验证 epoch 开始时调用。on_validation_epoch_end
: 在每个验证 epoch 结束时调用。on_test_epoch_start
: 在测试 epoch 开始时调用。on_test_epoch_end
: 在测试 epoch 结束时调用。on_train_batch_end
: 在每个训练 batch 结束时调用。on_validation_batch_end
: 在每个验证 batch 结束时调用。on_test_batch_end
: 在每个测试 batch 结束时调用。
示例: 自定义 Callback
以下示例实现了一个打印日志的回调:
from pytorch_lightning.callbacks import Callback
class PrintCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
print(f"Epoch {trainer.current_epoch}: Training ended!")
def on_validation_epoch_end(self, trainer, pl_module):
print(f"Epoch {trainer.current_epoch}: Validation ended!")
使用时将回调传递给 Trainer
:
from pytorch_lightning import Trainer
trainer = Trainer(callbacks=[PrintCallback()])
基于 Hydra 配置实例化 Callback
Hydra 是一个灵活的配置管理工具,常用于深度学习项目中动态管理超参数。通过结合 Hydra 和 PyTorch Lightning,可以动态配置并实例化 Callback。
步骤:
1. 安装 Hydra:
pip install hydra-core --upgrade
2. 定义 Hydra 配置文件 : 创建一个 YAML 配置文件(如 config.yaml
)来管理 Callback 的配置:
callbacks:
model_checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
monitor: "val_loss"
save_top_k: 1
mode: "min"
early_stopping:
_target_: pytorch_lightning.callbacks.EarlyStopping
monitor: "val_loss"
patience: 5
mode: "min"
3. 在代码中动态实例化 : 使用 hydra.utils.instantiate
方法实例化回调对象:
import hydra
from hydra.utils import instantiate
from pytorch_lightning import Trainer
from omegaconf import OmegaConf
@hydra.main(config_path=".", config_name="config")
def main(cfg):
# Instantiate callbacks from config
callbacks = [instantiate(cfg.callbacks[key]) for key in cfg.callbacks]
# Example: Define a simple PyTorch Lightning model
from pytorch_lightning import LightningModule
import torch.nn.functional as F
class SimpleModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 1)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.mse_loss(y_hat, y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
# Instantiate trainer
trainer = Trainer(callbacks=callbacks, max_epochs=10)
# Simulated data loader
from torch.utils.data import DataLoader, TensorDataset
import torch
x = torch.rand(100, 10)
y = torch.rand(100, 1)
train_loader = DataLoader(TensorDataset(x, y), batch_size=32)
model = SimpleModel()
trainer.fit(model, train_loader)
if __name__ == "__main__":
main()
解释:如何通过配置文件动态管理 Callback
- 配置文件中,
_target_
指定回调类的完整路径。 - 使用
hydra.utils.instantiate
根据配置动态实例化对象。 - 将实例化后的回调传递给
Trainer
。
优势
- 动态配置:通过 YAML 文件可以快速更改回调逻辑而无需修改代码。
- 模块化管理:方便管理多个回调类,清晰直观。
- 灵活性:支持自定义 Callback 和 Lightning 内置回调的结合使用。
此方法适用于多种场景,比如动态调整模型保存路径、早停策略等。