PyTorch Lightning Callback 完全指南
📑 目录
- 背景与动机
- 核心概念与架构
- [内置 Callback 详解](#内置 Callback 详解)
- [Callback 生命周期钩子方法](#Callback 生命周期钩子方法)
- [自定义 Callback 开发](#自定义 Callback 开发)
- [Callback 搭配使用策略](#Callback 搭配使用策略)
- 高级应用与最佳实践
- 常见问题与调试技巧
- 扩展阅读与进阶方向
1. 背景与动机
1.1 为什么需要 Callback?
在深度学习训练过程中,我们经常需要在特定时刻执行特定操作:
训练过程中的常见需求:
- ✅ 每个 epoch 结束后保存最佳模型
- ✅ 当验证损失不再下降时提前停止训练
- ✅ 记录学习率变化曲线
- ✅ 在训练开始前初始化某些参数
- ✅ 定期验证模型在特定数据集上的表现
- ✅ 动态调整训练策略(如梯度累积)
传统做法的问题:
python
# ❌ 不使用 Callback 的代码(耦合度高、难以维护)
for epoch in range(max_epochs):
# 训练逻辑
train_loss = train_epoch(model, train_loader)
# 验证逻辑
val_loss = validate(model, val_loader)
# 手动保存最佳模型
if val_loss < best_loss:
best_loss = val_loss
torch.save(model.state_dict(), 'best_model.pth')
# 手动早停逻辑
if val_loss > best_loss:
patience_counter += 1
if patience_counter >= patience:
print("Early stopping!")
break
# 手动记录日志
log_metrics(epoch, train_loss, val_loss)
# ... 更多逻辑混杂在一起
使用 Callback 的优势:
python
# ✅ 使用 Callback 的代码(清晰、模块化、可复用)
trainer = pl.Trainer(
max_epochs=100,
callbacks=[
ModelCheckpoint(monitor='val_loss', mode='min'),
EarlyStopping(monitor='val_loss', patience=10),
LearningRateMonitor(logging_interval='epoch'),
]
)
trainer.fit(model, train_loader, val_loader)
1.2 Callback 的核心价值
| 优势 | 说明 |
|---|---|
| 解耦合 | 训练逻辑与辅助功能分离 |
| 模块化 | 每个 Callback 专注单一职责 |
| 可复用 | 同一个 Callback 可用于多个项目 |
| 可组合 | 多个 Callback 自由组合 |
| 易测试 | 独立的 Callback 易于单元测试 |
| 可扩展 | 轻松添加自定义功能 |
2. 核心概念与架构
2.1 什么是 Callback?
定义:Callback 是一个可以在训练循环的特定阶段被调用的对象,用于执行自定义操作。
核心特点:
- 继承自
pytorch_lightning.callbacks.Callback基类 - 通过重写钩子方法(hook methods)来插入自定义逻辑
- 在
Trainer的特定时刻自动被调用
2.2 Callback 的工作原理
训练流程 Callback 钩子触发时机
│
├─ Trainer.fit()
│ │
│ ├─ on_fit_start() ← 训练开始前
│ │
│ ├─ Epoch Loop
│ │ │
│ │ ├─ on_train_epoch_start() ← 每个训练 epoch 开始
│ │ │
│ │ ├─ Training Batch Loop
│ │ │ ├─ on_train_batch_start() ← 每个训练 batch 前
│ │ │ ├─ training_step()
│ │ │ └─ on_train_batch_end() ← 每个训练 batch 后
│ │ │
│ │ ├─ on_train_epoch_end() ← 每个训练 epoch 结束
│ │ │
│ │ ├─ Validation Loop
│ │ │ ├─ on_validation_epoch_start()
│ │ │ ├─ validation_step()
│ │ │ └─ on_validation_epoch_end()
│ │ │
│ │ └─ on_epoch_end() ← 每个完整 epoch 结束
│ │
│ └─ on_fit_end() ← 训练完全结束
│
└─ Trainer.test()
├─ on_test_start()
├─ test_step()
└─ on_test_end()
2.3 Callback 的分类
PyTorch Lightning 的 Callback 可以分为以下几类:
| 类别 | 典型 Callback | 用途 |
|---|---|---|
| 模型管理 | ModelCheckpoint | 保存/加载模型 |
| 训练控制 | EarlyStopping, GradientAccumulationScheduler | 控制训练流程 |
| 监控与日志 | LearningRateMonitor, DeviceStatsMonitor | 记录训练指标 |
| 用户界面 | RichProgressBar, TQDMProgressBar | 显示训练进度 |
| 优化策略 | StochasticWeightAveraging | 高级优化技巧 |
| 调试工具 | ModelSummary, Timer | 辅助调试 |
| 自定义 | 用户自定义 Callback | 特定需求 |
3. 内置 Callback 详解
3.1 ModelCheckpoint - 模型检查点
作用:在训练过程中自动保存模型,支持保存最佳模型或多个检查点。
基础用法
python
from pytorch_lightning.callbacks import ModelCheckpoint
# 示例1:保存验证损失最低的模型
checkpoint = ModelCheckpoint(
monitor='val_loss', # 监控的指标
dirpath='checkpoints/', # 保存目录
filename='best-{epoch:02d}-{val_loss:.4f}', # 文件名模板
save_top_k=1, # 保存最好的 1 个模型
mode='min', # 'min' 表示越小越好,'max' 表示越大越好
save_last=True, # 额外保存最后一个 epoch 的模型
verbose=True, # 打印日志
)
trainer = pl.Trainer(callbacks=[checkpoint])
完整参数说明
python
ModelCheckpoint(
# 核心参数
monitor='val_loss', # 监控的指标名称(必须在 self.log() 中记录)
mode='min', # 'min'/'max'/'auto'
# 保存策略
save_top_k=3, # 保存最好的 k 个模型(-1 表示全部保存)
save_last=True, # 是否额外保存最后一个模型(last.ckpt)
save_weights_only=False, # True: 仅保存权重,False: 保存完整状态
# 文件命名
dirpath='checkpoints/', # 保存目录
filename='epoch={epoch:02d}-val_loss={val_loss:.4f}', # 文件名模板
auto_insert_metric_name=True, # 自动在文件名中插入 monitor 名称
# 触发条件
every_n_epochs=1, # 每 n 个 epoch 检查一次
every_n_train_steps=None, # 每 n 个训练步检查一次
train_time_interval=None, # 按时间间隔检查(如 timedelta(minutes=30))
# 其他
verbose=True, # 是否打印保存信息
save_on_train_epoch_end=None, # 在训练 epoch 结束时保存(默认验证后)
)
高级用法
1. 同时保存多个指标的最佳模型
python
# 保存 val_loss 最低的模型
checkpoint_loss = ModelCheckpoint(
monitor='val_loss',
dirpath='checkpoints/loss/',
filename='best-loss-{epoch:02d}-{val_loss:.4f}',
mode='min',
save_top_k=1,
)
# 保存 val_acc 最高的模型
checkpoint_acc = ModelCheckpoint(
monitor='val_acc',
dirpath='checkpoints/acc/',
filename='best-acc-{epoch:02d}-{val_acc:.4f}',
mode='max',
save_top_k=1,
)
trainer = pl.Trainer(callbacks=[checkpoint_loss, checkpoint_acc])
2. 定期保存检查点(无论性能如何)
python
# 每 5 个 epoch 保存一次
checkpoint_periodic = ModelCheckpoint(
dirpath='checkpoints/periodic/',
filename='epoch={epoch:02d}',
every_n_epochs=5,
save_top_k=-1, # 保存所有
)
3. 按训练步数保存
python
checkpoint_steps = ModelCheckpoint(
dirpath='checkpoints/steps/',
filename='step={step}',
every_n_train_steps=1000, # 每 1000 步保存
save_top_k=-1,
)
4. 按时间间隔保存
python
from datetime import timedelta
checkpoint_time = ModelCheckpoint(
dirpath='checkpoints/timed/',
train_time_interval=timedelta(minutes=30), # 每 30 分钟保存
save_top_k=-1,
)
访问最佳模型路径
python
trainer.fit(model, train_loader, val_loader)
# 获取最佳模型路径
best_model_path = checkpoint.best_model_path
print(f"Best model: {best_model_path}")
# 获取最佳分数
best_score = checkpoint.best_model_score
print(f"Best score: {best_score}")
# 加载最佳模型
best_model = MyModel.load_from_checkpoint(best_model_path)
3.2 EarlyStopping - 早停
作用:当监控指标在一定时间内不再改善时,自动停止训练,防止过拟合。
基础用法
python
from pytorch_lightning.callbacks import EarlyStopping
early_stop = EarlyStopping(
monitor='val_loss', # 监控的指标
patience=10, # 容忍多少个 epoch 不改善
mode='min', # 'min' 或 'max'
verbose=True, # 打印停止信息
min_delta=0.001, # 最小改善量(小于此值不算改善)
)
trainer = pl.Trainer(callbacks=[early_stop])
完整参数说明
python
EarlyStopping(
# 核心参数
monitor='val_loss', # 监控的指标
mode='min', # 'min'/'max'/'auto'
patience=3, # 容忍的 epoch 数
# 判断标准
min_delta=0.0, # 最小改善阈值(绝对值)
strict=True, # 是否严格要求改善(False 允许相等)
# 停止行为
stopping_threshold=None, # 达到此值立即停止(如 val_loss < 0.01)
divergence_threshold=None, # 超过此值立即停止(如 val_loss > 10.0)
check_finite=True, # 检查指标是否为有限值
check_on_train_epoch_end=None, # 在训练 epoch 结束时检查(默认验证后)
# 日志
verbose=True,
log_rank_zero_only=False, # 仅在主进程打印
)
实用场景
1. 基础早停(验证损失不下降)
python
early_stop = EarlyStopping(
monitor='val_loss',
patience=15,
mode='min',
verbose=True,
)
2. 准确率不提升时停止
python
early_stop = EarlyStopping(
monitor='val_acc',
patience=10,
mode='max',
min_delta=0.005, # 提升小于 0.5% 不算改善
)
3. 达到目标后立即停止
python
early_stop = EarlyStopping(
monitor='val_acc',
stopping_threshold=0.95, # 准确率达到 95% 立即停止
mode='max',
)
4. 检测发散(loss 爆炸)
python
early_stop = EarlyStopping(
monitor='train_loss',
divergence_threshold=10.0, # 训练损失超过 10 立即停止
mode='min',
)
3.3 LearningRateMonitor - 学习率监控
作用:自动记录学习率变化,用于可视化学习率调度策略。
基础用法
python
from pytorch_lightning.callbacks import LearningRateMonitor
lr_monitor = LearningRateMonitor(
logging_interval='epoch', # 'step' 或 'epoch'
log_momentum=False, # 是否记录 momentum(SGD 优化器)
)
trainer = pl.Trainer(callbacks=[lr_monitor])
使用场景
1. 监控学习率调度器
python
class MyModel(pl.LightningModule):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=5
)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'monitor': 'val_loss',
}
}
# 在 TensorBoard 中自动记录 lr 曲线
trainer = pl.Trainer(
callbacks=[LearningRateMonitor(logging_interval='epoch')],
logger=TensorBoardLogger('logs/')
)
2. 每步记录(用于 OneCycleLR 等)
python
lr_monitor = LearningRateMonitor(logging_interval='step')
3.4 RichProgressBar / TQDMProgressBar - 进度条
作用:显示训练进度和实时指标。
RichProgressBar(推荐)
python
from pytorch_lightning.callbacks import RichProgressBar
# 默认配置
progress_bar = RichProgressBar()
# 自定义配置
progress_bar = RichProgressBar(
refresh_rate=1, # 刷新频率(步数)
leave=True, # 训练结束后保留进度条
theme=RichProgressBarTheme( # 自定义主题
description="green_yellow",
progress_bar="green1",
progress_bar_finished="green1",
batch_progress="green_yellow",
time="grey82",
processing_speed="grey82",
metrics="grey82",
),
)
自定义进度条显示
python
class CustomProgressBar(RichProgressBar):
def get_metrics(self, trainer, model):
# 获取父类的指标
items = super().get_metrics(trainer, model)
# 自定义显示格式(如显示更多小数位)
items = {
k: f"{v:.6f}" if isinstance(v, (int, float)) else v
for k, v in items.items()
}
return items
3.5 GradientAccumulationScheduler - 梯度累积调度
作用:动态调整梯度累积步数,实现变 batch size 训练。
基础用法
python
from pytorch_lightning.callbacks import GradientAccumulationScheduler
# 在不同 epoch 使用不同的累积步数
accumulator = GradientAccumulationScheduler(
scheduling={
0: 8, # epoch 0-4: 累积 8 步
5: 4, # epoch 5-9: 累积 4 步
10: 2, # epoch 10+: 累积 2 步
}
)
trainer = pl.Trainer(callbacks=[accumulator])
实用场景
场景:GPU 显存有限,初期用小 batch,后期逐步增大。
python
# 等效 batch size 变化:
# epoch 0-4: batch_size=16 × accumulate=8 = 128
# epoch 5-9: batch_size=16 × accumulate=4 = 64
# epoch 10+: batch_size=16 × accumulate=2 = 32
accumulator = GradientAccumulationScheduler(
scheduling={0: 8, 5: 4, 10: 2}
)
3.6 StochasticWeightAveraging (SWA) - 随机权重平均
作用:对训练后期的模型权重进行平均,提升泛化性能。
基础用法
python
from pytorch_lightning.callbacks import StochasticWeightAveraging
swa = StochasticWeightAveraging(
swa_lrs=1e-2, # SWA 阶段的学习率
swa_epoch_start=0.8, # 从 80% epoch 开始 SWA(0.8 × max_epochs)
annealing_epochs=10, # 退火 epoch 数
annealing_strategy='cos', # 'cos' 或 'linear'
)
trainer = pl.Trainer(
max_epochs=100,
callbacks=[swa]
)
原理与效果
正常训练: 模型权重在最优点附近震荡
SWA: 对后期权重求平均,得到更平滑的模型
训练曲线:
╱╲ ╱╲ ╱╲
Loss ╱ ╲╱ ╲╱ ╲ ← 正常训练
╱____________╲ ← SWA 平均后(更稳定)
↑
SWA Start
3.7 ModelSummary - 模型摘要
作用:在训练开始前打印模型结构和参数统计。
python
from pytorch_lightning.callbacks import ModelSummary
summary = ModelSummary(
max_depth=2, # 显示的最大层级深度(-1 表示全部)
)
trainer = pl.Trainer(callbacks=[summary])
输出示例:
| Name | Type | Params
------------------------------------
0 | layer1 | Linear | 320
1 | layer2 | Linear | 640
2 | layer3 | Linear | 10
------------------------------------
970 Trainable params
0 Non-trainable params
970 Total params
3.8 Timer - 训练时间监控
作用:监控训练耗时,可设置最大训练时间。
python
from pytorch_lightning.callbacks import Timer
from datetime import timedelta
timer = Timer(
duration=timedelta(hours=2), # 最大训练时间 2 小时
interval='epoch', # 检查间隔('step' 或 'epoch')
verbose=True,
)
trainer = pl.Trainer(callbacks=[timer])
3.9 DeviceStatsMonitor - 设备状态监控
作用:监控 GPU/CPU 使用情况。
python
from pytorch_lightning.callbacks import DeviceStatsMonitor
device_stats = DeviceStatsMonitor()
trainer = pl.Trainer(callbacks=[device_stats])
记录的指标:
- GPU 利用率
- GPU 内存使用
- CPU 内存使用
3.10 BaseFinetuning - 微调辅助
作用:辅助实现冻结-解冻训练策略。
python
from pytorch_lightning.callbacks import BaseFinetuning
class FeatureExtractorFreezeUnfreeze(BaseFinetuning):
def __init__(self, unfreeze_at_epoch=10):
super().__init__()
self._unfreeze_at_epoch = unfreeze_at_epoch
def freeze_before_training(self, pl_module):
# 初始冻结骨干网络
self.freeze(pl_module.feature_extractor)
def finetune_function(self, pl_module, current_epoch, optimizer):
# 在指定 epoch 解冻
if current_epoch == self._unfreeze_at_epoch:
self.unfreeze_and_add_param_group(
modules=pl_module.feature_extractor,
optimizer=optimizer,
lr=1e-5, # 使用更小的学习率
)
trainer = pl.Trainer(callbacks=[FeatureExtractorFreezeUnfreeze(unfreeze_at_epoch=10)])
4. Callback 生命周期钩子方法
4.1 完整钩子方法列表
PyTorch Lightning 提供了丰富的钩子方法,覆盖训练的各个阶段:
训练流程钩子
| 钩子方法 | 触发时机 | 常用场景 |
|---|---|---|
on_fit_start(trainer, pl_module) |
fit() 开始前 |
初始化全局状态 |
on_fit_end(trainer, pl_module) |
fit() 结束后 |
清理资源、保存最终结果 |
on_train_start(trainer, pl_module) |
训练开始前 | 打印训练配置 |
on_train_end(trainer, pl_module) |
训练结束后 | 生成训练报告 |
on_train_epoch_start(trainer, pl_module) |
每个训练 epoch 开始前 | 重置 epoch 级别的统计 |
on_train_epoch_end(trainer, pl_module) |
每个训练 epoch 结束后 | 计算 epoch 级别的指标 |
on_train_batch_start(trainer, pl_module, batch, batch_idx) |
每个训练 batch 前 | 数据预处理 |
on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) |
每个训练 batch 后 | 记录 batch 级别的指标 |
验证流程钩子
| 钩子方法 | 触发时机 | 常用场景 |
|---|---|---|
on_validation_start(trainer, pl_module) |
验证开始前 | 切换到评估模式 |
on_validation_end(trainer, pl_module) |
验证结束后 | 计算验证集总体指标 |
on_validation_epoch_start(trainer, pl_module) |
验证 epoch 开始前 | 重置验证统计 |
on_validation_epoch_end(trainer, pl_module) |
验证 epoch 结束后 | 计算混淆矩阵等 |
on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx) |
每个验证 batch 前 | - |
on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) |
每个验证 batch 后 | - |
测试流程钩子
| 钩子方法 | 触发时机 |
|---|---|
on_test_start(trainer, pl_module) |
测试开始前 |
on_test_end(trainer, pl_module) |
测试结束后 |
on_test_epoch_start(trainer, pl_module) |
测试 epoch 开始前 |
on_test_epoch_end(trainer, pl_module) |
测试 epoch 结束后 |
on_test_batch_start(...) |
每个测试 batch 前 |
on_test_batch_end(...) |
每个测试 batch 后 |
预测流程钩子
| 钩子方法 | 触发时机 |
|---|---|
on_predict_start(trainer, pl_module) |
预测开始前 |
on_predict_end(trainer, pl_module) |
预测结束后 |
on_predict_epoch_start(trainer, pl_module) |
预测 epoch 开始前 |
on_predict_epoch_end(trainer, pl_module) |
预测 epoch 结束后 |
on_predict_batch_start(...) |
每个预测 batch 前 |
on_predict_batch_end(...) |
每个预测 batch 后 |
其他重要钩子
| 钩子方法 | 触发时机 | 常用场景 |
|---|---|---|
on_epoch_start(trainer, pl_module) |
每个完整 epoch 开始前(训练+验证) | - |
on_epoch_end(trainer, pl_module) |
每个完整 epoch 结束后 | 保存中间结果 |
on_save_checkpoint(trainer, pl_module, checkpoint) |
保存检查点时 | 添加自定义数据到检查点 |
on_load_checkpoint(trainer, pl_module, checkpoint) |
加载检查点时 | 恢复自定义状态 |
on_before_backward(trainer, pl_module, loss) |
反向传播前 | 梯度预处理 |
on_after_backward(trainer, pl_module) |
反向传播后 | 梯度裁剪、检查 |
on_before_optimizer_step(trainer, pl_module, optimizer) |
优化器更新前 | - |
on_before_zero_grad(trainer, pl_module, optimizer) |
梯度清零前 | - |
4.2 钩子方法参数说明
通用参数:
trainer:pl.Trainer实例,可访问训练器的状态pl_module:pl.LightningModule实例,即你的模型batch: 当前批次的数据batch_idx: 批次索引dataloader_idx: 数据加载器索引(多数据集时)outputs: 模型输出(如training_step的返回值)
访问训练状态:
python
def on_train_epoch_end(self, trainer, pl_module):
# 访问当前 epoch
current_epoch = trainer.current_epoch
# 访问全局步数
global_step = trainer.global_step
# 访问日志记录的指标
logged_metrics = trainer.logged_metrics
# 访问回调指标(用于 ModelCheckpoint 等)
callback_metrics = trainer.callback_metrics
# 访问模型参数
for name, param in pl_module.named_parameters():
print(f"{name}: {param.shape}")
4.3 钩子方法调用顺序示例
python
# 完整训练流程的钩子调用顺序
trainer.fit(model, train_loader, val_loader)
│
├─ on_fit_start()
│ ├─ on_train_start()
│ │
│ ├─ Epoch 0
│ │ ├─ on_epoch_start()
│ │ ├─ on_train_epoch_start()
│ │ │
│ │ ├─ Training Batches
│ │ │ ├─ on_train_batch_start(batch_idx=0)
│ │ │ ├─ on_before_backward()
│ │ │ ├─ on_after_backward()
│ │ │ ├─ on_before_optimizer_step()
│ │ │ ├─ on_before_zero_grad()
│ │ │ ├─ on_train_batch_end(batch_idx=0)
│ │ │ │
│ │ │ ├─ on_train_batch_start(batch_idx=1)
│ │ │ └─ ...
│ │ │
│ │ ├─ on_train_epoch_end()
│ │ │
│ │ ├─ Validation (如果启用)
│ │ │ ├─ on_validation_epoch_start()
│ │ │ ├─ on_validation_batch_start(batch_idx=0)
│ │ │ ├─ on_validation_batch_end(batch_idx=0)
│ │ │ └─ on_validation_epoch_end()
│ │ │
│ │ └─ on_epoch_end()
│ │
│ ├─ Epoch 1
│ │ └─ ... (同上)
│ │
│ └─ on_train_end()
│
└─ on_fit_end()
5. 自定义 Callback 开发
5.1 基础模板
python
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
class MyCustomCallback(Callback):
"""自定义 Callback 模板"""
def __init__(self, custom_param):
super().__init__()
self.custom_param = custom_param
# 初始化自定义状态
self.state = {}
def on_train_start(self, trainer, pl_module):
"""训练开始时调用"""
print(f"训练开始,参数: {self.custom_param}")
def on_train_epoch_end(self, trainer, pl_module):
"""每个训练 epoch 结束时调用"""
# 访问训练指标
metrics = trainer.callback_metrics
print(f"Epoch {trainer.current_epoch} 结束")
def on_validation_epoch_end(self, trainer, pl_module):
"""每个验证 epoch 结束时调用"""
pass
5.2 实用自定义 Callback 示例
示例1:打印训练进度报告
python
class TrainingReportCallback(Callback):
"""每个 epoch 结束后打印详细报告"""
def on_train_epoch_end(self, trainer, pl_module):
metrics = trainer.callback_metrics
print("\n" + "="*60)
print(f"Epoch {trainer.current_epoch} 训练报告")
print("="*60)
for key, value in metrics.items():
if isinstance(value, torch.Tensor):
value = value.item()
print(f"{key:30s}: {value:.6f}")
print("="*60 + "\n")
示例2:保存验证集预测结果
python
class SaveValidationPredictionsCallback(Callback):
"""保存每个 epoch 的验证集预测结果"""
def __init__(self, save_dir='predictions/'):
super().__init__()
self.save_dir = save_dir
self.predictions = []
self.targets = []
def on_validation_epoch_start(self, trainer, pl_module):
# 重置存储
self.predictions = []
self.targets = []
def on_validation_batch_end(self, trainer, pl_module, outputs,
batch, batch_idx, dataloader_idx=0):
# 收集预测结果
if isinstance(outputs, dict) and 'preds' in outputs:
self.predictions.append(outputs['preds'].cpu())
self.targets.append(outputs['targets'].cpu())
def on_validation_epoch_end(self, trainer, pl_module):
# 合并并保存
if self.predictions:
all_preds = torch.cat(self.predictions)
all_targets = torch.cat(self.targets)
save_path = f"{self.save_dir}/epoch_{trainer.current_epoch}.pt"
torch.save({
'predictions': all_preds,
'targets': all_targets,
'epoch': trainer.current_epoch
}, save_path)
print(f"验证集预测已保存: {save_path}")
示例3:动态学习率调整
python
class CustomLRScheduler(Callback):
"""基于验证损失的自定义学习率调整"""
def __init__(self, patience=5, factor=0.5, min_lr=1e-6):
super().__init__()
self.patience = patience
self.factor = factor
self.min_lr = min_lr
self.best_loss = float('inf')
self.wait = 0
def on_validation_epoch_end(self, trainer, pl_module):
# 获取当前验证损失
val_loss = trainer.callback_metrics.get('val_loss')
if val_loss is None:
return
val_loss = val_loss.item()
# 检查是否改善
if val_loss < self.best_loss:
self.best_loss = val_loss
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
# 降低学习率
for optimizer in trainer.optimizers:
for param_group in optimizer.param_groups:
old_lr = param_group['lr']
new_lr = max(old_lr * self.factor, self.min_lr)
param_group['lr'] = new_lr
print(f"\n学习率调整: {old_lr:.6f} → {new_lr:.6f}")
self.wait = 0
示例4:梯度监控
python
class GradientLoggingCallback(Callback):
"""记录梯度统计信息"""
def __init__(self, log_every_n_steps=100):
super().__init__()
self.log_every_n_steps = log_every_n_steps
def on_after_backward(self, trainer, pl_module):
if trainer.global_step % self.log_every_n_steps != 0:
return
# 计算梯度统计
grad_norms = []
for name, param in pl_module.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
grad_norms.append(grad_norm)
# 记录每层梯度
pl_module.log(f'grad_norm/{name}', grad_norm)
# 记录平均梯度范数
if grad_norms:
avg_grad_norm = sum(grad_norms) / len(grad_norms)
pl_module.log('grad_norm/average', avg_grad_norm)
示例5:检查点管理(清理旧文件)
python
import os
import glob
class CheckpointCleanupCallback(Callback):
"""自动清理旧的检查点文件,仅保留最新的 N 个"""
def __init__(self, checkpoint_dir='checkpoints/', keep_last_n=3):
super().__init__()
self.checkpoint_dir = checkpoint_dir
self.keep_last_n = keep_last_n
def on_train_epoch_end(self, trainer, pl_module):
# 获取所有检查点文件
ckpt_files = glob.glob(f"{self.checkpoint_dir}/*.ckpt")
# 按修改时间排序
ckpt_files.sort(key=os.path.getmtime, reverse=True)
# 删除旧文件
for ckpt_file in ckpt_files[self.keep_last_n:]:
try:
os.remove(ckpt_file)
print(f"删除旧检查点: {ckpt_file}")
except Exception as e:
print(f"删除失败: {e}")
示例6:邮件通知
python
import smtplib
from email.mime.text import MIMEText
class EmailNotificationCallback(Callback):
"""训练完成或异常时发送邮件通知"""
def __init__(self, recipient_email, smtp_config):
super().__init__()
self.recipient_email = recipient_email
self.smtp_config = smtp_config
def send_email(self, subject, message):
"""发送邮件"""
msg = MIMEText(message)
msg['Subject'] = subject
msg['From'] = self.smtp_config['from']
msg['To'] = self.recipient_email
try:
with smtplib.SMTP(self.smtp_config['server'],
self.smtp_config['port']) as server:
server.login(self.smtp_config['username'],
self.smtp_config['password'])
server.send_message(msg)
except Exception as e:
print(f"邮件发送失败: {e}")
def on_train_end(self, trainer, pl_module):
"""训练结束时发送通知"""
metrics = trainer.callback_metrics
message = f"""
训练已完成!
最终指标:
{metrics}
总 Epoch: {trainer.current_epoch}
总步数: {trainer.global_step}
"""
self.send_email("训练完成通知", message)
def on_exception(self, trainer, pl_module, exception):
"""发生异常时发送通知"""
message = f"训练发生异常: {exception}"
self.send_email("训练异常通知", message)
示例7:实时可视化(Matplotlib)
python
import matplotlib.pyplot as plt
class RealTimePlotCallback(Callback):
"""实时绘制训练曲线"""
def __init__(self):
super().__init__()
self.train_losses = []
self.val_losses = []
self.epochs = []
# 创建图形
plt.ion() # 交互模式
self.fig, self.ax = plt.subplots()
def on_train_epoch_end(self, trainer, pl_module):
# 记录数据
metrics = trainer.callback_metrics
self.epochs.append(trainer.current_epoch)
if 'train_loss' in metrics:
self.train_losses.append(metrics['train_loss'].item())
def on_validation_epoch_end(self, trainer, pl_module):
metrics = trainer.callback_metrics
if 'val_loss' in metrics:
self.val_losses.append(metrics['val_loss'].item())
# 更新图形
self.ax.clear()
self.ax.plot(self.epochs, self.train_losses, label='Train Loss')
if len(self.val_losses) > 0:
self.ax.plot(self.epochs, self.val_losses, label='Val Loss')
self.ax.legend()
self.ax.set_xlabel('Epoch')
self.ax.set_ylabel('Loss')
self.fig.canvas.draw()
self.fig.canvas.flush_events()
def on_train_end(self, trainer, pl_module):
# 保存最终图形
plt.ioff()
self.fig.savefig('training_curve.png')
print("训练曲线已保存: training_curve.png")
5.3 访问模型和数据
在自定义 Callback 中,可以访问:
python
class DataInspectionCallback(Callback):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
# 访问模型
model = pl_module
# 访问批次数据
x, y = batch # 根据实际数据结构解包
# 访问模型输出
predictions = outputs['preds'] # 根据 training_step 返回值
# 访问优化器
optimizer = trainer.optimizers[0]
current_lr = optimizer.param_groups[0]['lr']
# 访问日志器
logger = trainer.logger
logger.log_metrics({'custom_metric': 1.0}, step=trainer.global_step)
6. Callback 搭配使用策略
6.1 基础训练配置
场景:标准的分类/回归任务
python
callbacks = [
# 保存最佳模型
ModelCheckpoint(
monitor='val_loss',
mode='min',
save_top_k=1,
filename='best-{epoch:02d}-{val_loss:.4f}',
),
# 早停
EarlyStopping(
monitor='val_loss',
patience=15,
mode='min',
),
# 学习率监控
LearningRateMonitor(logging_interval='epoch'),
# 进度条
RichProgressBar(),
]
trainer = pl.Trainer(
max_epochs=100,
callbacks=callbacks,
logger=TensorBoardLogger('logs/'),
)
6.2 高性能训练配置
场景:大模型、长时间训练,需要多重保护
python
callbacks = [
# 1. 多重模型保存策略
ModelCheckpoint(
monitor='val_loss',
mode='min',
save_top_k=3,
filename='best-loss-{epoch:02d}-{val_loss:.4f}',
),
ModelCheckpoint(
monitor='val_acc',
mode='max',
save_top_k=1,
filename='best-acc-{epoch:02d}-{val_acc:.4f}',
),
ModelCheckpoint(
every_n_epochs=10,
filename='periodic-{epoch:02d}',
save_top_k=-1, # 保存所有
),
# 2. 早停 + 发散检测
EarlyStopping(
monitor='val_loss',
patience=20,
mode='min',
min_delta=0.001,
),
EarlyStopping(
monitor='train_loss',
divergence_threshold=10.0, # 检测 loss 爆炸
mode='min',
),
# 3. 学习率监控
LearningRateMonitor(logging_interval='step'),
# 4. 设备状态监控
DeviceStatsMonitor(),
# 5. 时间限制(如云服务器按时计费)
Timer(duration=timedelta(hours=10)),
# 6. 自定义训练报告
TrainingReportCallback(),
]
6.3 研究实验配置
场景:科研项目,需要详细记录和复现
python
callbacks = [
# 1. 模型保存
ModelCheckpoint(
monitor='val_loss',
mode='min',
save_top_k=5,
save_last=True,
),
# 2. 早停
EarlyStopping(monitor='val_loss', patience=30),
# 3. 学习率监控
LearningRateMonitor(logging_interval='step'),
# 4. 梯度监控(检测梯度消失/爆炸)
GradientLoggingCallback(log_every_n_steps=50),
# 5. 保存验证集预测(用于后续分析)
SaveValidationPredictionsCallback(save_dir='predictions/'),
# 6. 实时可视化
RealTimePlotCallback(),
# 7. 模型摘要
ModelSummary(max_depth=3),
]
# 同时使用 TensorBoard 和 WandB
trainer = pl.Trainer(
callbacks=callbacks,
logger=[
TensorBoardLogger('logs/tensorboard/'),
WandbLogger(project='my_research', name='exp_001'),
],
)
6.4 生产部署配置
场景:模型训练后需要部署到生产环境
python
callbacks = [
# 1. 仅保存权重(减小文件体积)
ModelCheckpoint(
monitor='val_loss',
mode='min',
save_top_k=1,
save_weights_only=True, # 仅保存权重
filename='production-best',
),
# 2. 早停
EarlyStopping(
monitor='val_loss',
patience=10,
stopping_threshold=0.05, # 达到目标即停止
),
# 3. SWA 提升泛化性能
StochasticWeightAveraging(swa_lrs=1e-2),
# 4. 检查点清理(节省存储)
CheckpointCleanupCallback(keep_last_n=2),
# 5. 训练完成通知
EmailNotificationCallback(
recipient_email='team@company.com',
smtp_config={...}
),
]
6.5 调试配置
场景:快速调试代码,检测 Bug
python
# 使用 Trainer 的快速开发标志
trainer = pl.Trainer(
max_epochs=2, # 少量 epoch
limit_train_batches=10, # 仅训练 10 个 batch
limit_val_batches=5, # 仅验证 5 个 batch
callbacks=[
RichProgressBar(),
ModelSummary(max_depth=-1), # 查看完整模型结构
],
logger=False, # 不记录日志
enable_checkpointing=False, # 不保存检查点
)
6.6 超参数搜索配置
场景:使用 Ray Tune / Optuna 进行超参数优化
python
from ray import tune
from ray.tune.integration.pytorch_lightning import TuneReportCallback
def train_func(config):
model = MyModel(
lr=config['lr'],
hidden_dim=config['hidden_dim'],
)
trainer = pl.Trainer(
max_epochs=20,
callbacks=[
# Ray Tune 回调(报告指标)
TuneReportCallback(
metrics={'val_loss': 'val_loss'},
on='validation_end',
),
EarlyStopping(monitor='val_loss', patience=5),
],
enable_progress_bar=False, # 禁用进度条(避免输出混乱)
enable_model_summary=False,
)
trainer.fit(model, train_loader, val_loader)
# 启动超参数搜索
analysis = tune.run(
train_func,
config={
'lr': tune.loguniform(1e-4, 1e-1),
'hidden_dim': tune.choice([64, 128, 256]),
},
num_samples=20,
)
7. 高级应用与最佳实践
7.1 Callback 之间的通信
场景:不同 Callback 需要共享状态
python
class SharedStateCallback(Callback):
"""使用 Trainer 的自定义属性共享状态"""
def on_train_start(self, trainer, pl_module):
# 初始化共享状态
trainer.my_shared_state = {'counter': 0}
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
# 更新共享状态
trainer.my_shared_state['counter'] += 1
class AnotherCallback(Callback):
def on_validation_epoch_end(self, trainer, pl_module):
# 读取共享状态
counter = trainer.my_shared_state['counter']
print(f"已训练 {counter} 个 batch")
7.2 条件执行 Callback
python
class ConditionalCallback(Callback):
"""仅在特定条件下执行"""
def __init__(self, execute_after_epoch=10):
super().__init__()
self.execute_after_epoch = execute_after_epoch
def on_validation_epoch_end(self, trainer, pl_module):
# 仅在第 10 个 epoch 后执行
if trainer.current_epoch >= self.execute_after_epoch:
print("执行特殊操作...")
7.3 Callback 优先级
Callback 的执行顺序由添加顺序决定:
python
callbacks = [
CallbackA(), # 第一个执行
CallbackB(), # 第二个执行
CallbackC(), # 第三个执行
]
# 注意:ModelCheckpoint 和 EarlyStopping 的顺序很重要!
callbacks = [
ModelCheckpoint(...), # 先保存模型
EarlyStopping(...), # 再判断是否停止
]
7.4 在 Callback 中使用日志器
python
class CustomLoggingCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
# 方式1:通过 pl_module 记录
pl_module.log('custom_metric', 1.0)
# 方式2:直接使用 logger
if trainer.logger:
trainer.logger.log_metrics(
{'another_metric': 2.0},
step=trainer.global_step
)
# 如果是 TensorBoard
if isinstance(trainer.logger, TensorBoardLogger):
trainer.logger.experiment.add_scalar(
'special_metric', 3.0, trainer.global_step
)
7.5 处理分布式训练
python
class DistributedAwareCallback(Callback):
"""在分布式训练中正确处理"""
def on_validation_epoch_end(self, trainer, pl_module):
# 仅在主进程执行(避免重复)
if trainer.is_global_zero:
print("这只在主进程打印一次")
# 所有进程都执行
local_rank = trainer.local_rank
print(f"进程 {local_rank} 执行")
7.6 Callback 的测试
python
import unittest
class TestMyCallback(unittest.TestCase):
def test_callback_logic(self):
# 创建模拟的 trainer 和 model
trainer = MockTrainer()
model = MockModel()
# 测试 callback
callback = MyCustomCallback()
callback.on_train_start(trainer, model)
# 验证行为
self.assertEqual(callback.state['initialized'], True)
8. 常见问题与调试技巧
8.1 常见错误
错误1:在 on_train_epoch_end 中访问不存在的指标
python
# ❌ 错误示例
def on_train_epoch_end(self, trainer, pl_module):
val_loss = trainer.callback_metrics['val_loss'] # KeyError!
原因 :on_train_epoch_end 在验证之前调用,此时 val_loss 还未计算。
解决:
python
# ✅ 正确示例
def on_validation_epoch_end(self, trainer, pl_module):
# 在验证后访问
val_loss = trainer.callback_metrics.get('val_loss')
if val_loss is not None:
print(f"验证损失: {val_loss}")
错误2:Callback 修改了模型状态但未恢复
python
# ❌ 错误示例
def on_validation_start(self, trainer, pl_module):
pl_module.train() # 错误地切换到训练模式
解决:
python
# ✅ 正确示例
def on_validation_start(self, trainer, pl_module):
# Lightning 会自动处理模式切换,无需手动干预
pass
错误3:在错误的钩子中执行耗时操作
python
# ❌ 错误示例(会严重拖慢训练)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
# 每个 batch 都执行复杂计算
expensive_operation()
解决:
python
# ✅ 正确示例
def on_train_epoch_end(self, trainer, pl_module):
# 每个 epoch 执行一次
expensive_operation()
8.2 调试技巧
技巧1:打印所有可用指标
python
class DebugCallback(Callback):
def on_validation_epoch_end(self, trainer, pl_module):
print("\n可用指标:")
for key, value in trainer.callback_metrics.items():
print(f" {key}: {value}")
技巧2:检查 Callback 是否被调用
python
class TestCallback(Callback):
def __init__(self):
super().__init__()
self.call_count = {}
def _log_call(self, method_name):
self.call_count[method_name] = self.call_count.get(method_name, 0) + 1
print(f"[{method_name}] 被调用 {self.call_count[method_name]} 次")
def on_train_start(self, trainer, pl_module):
self._log_call('on_train_start')
def on_train_epoch_end(self, trainer, pl_module):
self._log_call('on_train_epoch_end')
技巧3:使用断点调试
python
class DebugCallback(Callback):
def on_validation_epoch_end(self, trainer, pl_module):
# 在特定条件下触发断点
if trainer.current_epoch == 5:
import pdb; pdb.set_trace()
8.3 性能优化
优化1:避免频繁的 I/O 操作
python
# ❌ 低效
class BadCallback(Callback):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
# 每个 batch 都写文件
with open('log.txt', 'a') as f:
f.write(f"Batch {batch_idx} done\n")
# ✅ 高效
class GoodCallback(Callback):
def __init__(self):
super().__init__()
self.buffer = []
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.buffer.append(f"Batch {batch_idx} done\n")
def on_train_epoch_end(self, trainer, pl_module):
# 每个 epoch 写一次
with open('log.txt', 'a') as f:
f.writelines(self.buffer)
self.buffer = []
优化2:使用条件判断减少计算
python
class OptimizedCallback(Callback):
def __init__(self, log_every_n_epochs=5):
super().__init__()
self.log_every_n_epochs = log_every_n_epochs
def on_validation_epoch_end(self, trainer, pl_module):
# 仅每 5 个 epoch 执行一次
if trainer.current_epoch % self.log_every_n_epochs == 0:
expensive_visualization()
9. 扩展阅读与进阶方向
9.1 官方文档
-
PyTorch Lightning Callbacks 文档 :
https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html
-
内置 Callback API 参考 :
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.html
9.2 高级主题
9.2.1 与其他框架集成
- Ray Tune 集成:分布式超参数优化
- Optuna 集成:贝叶斯超参数优化
- MLflow 集成:实验追踪与模型管理
9.2.2 自定义训练循环
python
class CustomTrainLoop(Callback):
"""完全自定义训练循环"""
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
# 自定义数据预处理
pass
def on_before_backward(self, trainer, pl_module, loss):
# 自定义损失缩放
pass
def on_before_optimizer_step(self, trainer, pl_module, optimizer):
# 自定义梯度处理
pass
9.2.3 高级模型管理
- 模型版本控制:使用 DVC 或 Git LFS
- A/B 测试:保存多个候选模型进行对比
- 模型蒸馏:在 Callback 中实现教师-学生训练
9.3 实战案例学习
推荐阅读以下开源项目的 Callback 实现:
-
Transformers (Hugging Face) :
查看
transformers.TrainerCallback的设计 -
Lightning-Hydra-Template :
完整的 PyTorch Lightning 项目模板
-
PyTorch Lightning Bolts :
高级 Callback 示例集合
9.4 社区资源
-
PyTorch Lightning GitHub Discussions :
-
PyTorch Lightning Slack :
加入社区讨论
📌 总结
核心要点回顾
-
Callback 是什么:
- 在训练循环特定阶段执行的可插拔模块
- 通过钩子方法(hook)实现自定义逻辑
-
常用内置 Callback:
ModelCheckpoint:保存模型EarlyStopping:早停LearningRateMonitor:学习率监控RichProgressBar:进度条StochasticWeightAveraging:SWA 优化
-
生命周期钩子:
- 训练阶段:
on_train_start,on_train_epoch_end,on_train_batch_end - 验证阶段:
on_validation_epoch_end - 其他:
on_save_checkpoint,on_load_checkpoint
- 训练阶段:
-
自定义 Callback:
- 继承
Callback基类 - 重写所需的钩子方法
- 在
Trainer中注册使用
- 继承
-
搭配使用策略:
- 基础训练:
ModelCheckpoint+EarlyStopping+LearningRateMonitor - 研究实验:增加梯度监控、预测保存等
- 生产部署:增加 SWA、检查点清理等
- 基础训练:
最佳实践建议
- ✅ 模块化:每个 Callback 专注单一职责
- ✅ 可配置:通过参数控制行为
- ✅ 高效:避免在高频钩子中执行耗时操作
- ✅ 鲁棒:处理边界情况(如指标不存在)
- ✅ 可测试:编写单元测试验证逻辑
- ✅ 文档化:为自定义 Callback 添加详细注释
Callback 使用清单
训练前检查:
- 确认监控的指标在
self.log()中记录 - 检查
mode参数('min' 或 'max') - 验证文件保存路径存在且有写权限
调试阶段:
- 使用
verbose=True查看详细日志 - 添加
DebugCallback检查调用顺序 - 使用小数据集快速验证
生产环境:
- 启用
ModelCheckpoint和EarlyStopping - 配置合理的
patience和save_top_k - 添加异常处理和通知机制
附录:快速参考
Callback 常用参数速查
| Callback | 关键参数 | 说明 |
|---|---|---|
| ModelCheckpoint | monitor, mode, save_top_k |
保存最佳模型 |
| EarlyStopping | monitor, patience, mode |
防止过拟合 |
| LearningRateMonitor | logging_interval |
记录学习率 |
| GradientAccumulationScheduler | scheduling |
动态调整累积步数 |
| StochasticWeightAveraging | swa_lrs, swa_epoch_start |
权重平均优化 |
钩子方法速查
| 钩子 | 触发时机 | 常用场景 |
|---|---|---|
on_train_start |
训练开始前 | 初始化状态 |
on_train_epoch_end |
训练 epoch 结束 | 计算 epoch 指标 |
on_validation_epoch_end |
验证 epoch 结束 | 保存验证结果 |
on_save_checkpoint |
保存检查点时 | 添加自定义数据 |
on_after_backward |
反向传播后 | 梯度监控 |
常用代码片段
基础配置:
python
callbacks = [
ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=1),
EarlyStopping(monitor='val_loss', patience=10),
LearningRateMonitor(),
]
自定义 Callback 模板:
python
class MyCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
metrics = trainer.callback_metrics
# 自定义逻辑