PyTorch Lightning 指南
本文档是 PyTorch Lightning 的完整使用指南,涵盖 LightningModule、Trainer、DataModule 三大核心组件的详细说明。
目录
- [1. PyTorch Lightning 简介](#1. PyTorch Lightning 简介)
- [2. LightningModule 核心组件](#2. LightningModule 核心组件)
- [2.1 模板结构](#2.1 模板结构)
- [2.2 初始化方法](#2.2 初始化方法)
- [2.3 训练相关方法](#2.3 训练相关方法)
- [2.4 验证与测试方法](#2.4 验证与测试方法)
- [2.5 优化器配置](#2.5 优化器配置)
- [3. Trainer 训练器](#3. Trainer 训练器)
- [3.1 检查点管理](#3.1 检查点管理)
- [3.2 回调机制](#3.2 回调机制)
- [3.3 日志与可视化](#3.3 日志与可视化)
- [3.4 命令行参数](#3.4 命令行参数)
- [3.5 模型预测](#3.5 模型预测)
- [3.6 GPU 训练](#3.6 GPU 训练)
- [3.7 模型调试](#3.7 模型调试)
- [3.8 性能优化技巧](#3.8 性能优化技巧)
- [4. DataModule 数据模块](#4. DataModule 数据模块)
- [4.1 DataModule 介绍](#4.1 DataModule 介绍)
- [4.2 核心方法](#4.2 核心方法)
- [4.3 使用方式](#4.3 使用方式)
1. PyTorch Lightning 简介
PyTorch Lightning 是一个轻量级的 PyTorch 封装框架,旨在组织 PyTorch 代码,使研究代码更具可读性和可复现性。它将研究代码从工程代码中分离,让你专注于模型开发。
核心优势:
- 自动化训练流程(梯度计算、优化器步骤、日志记录等)
- 代码组织结构化,易于维护和复用
- 支持多 GPU、TPU 等分布式训练
- 内置丰富的回调和日志系统
2. LightningModule 核心组件
2.1 模板结构
LightningModule 是 PyTorch Lightning 的核心,它将 PyTorch 的 nn.Module 进行了扩展,提供了标准化的训练流程。
基本组成部分:
- 初始化(
__init__和setup()) - 训练循环(
training_step()) - 验证循环(
validation_step()) - 测试循环(
test_step()) - 预测循环(
predict_step()) - 优化器配置(
configure_optimizers())
完整模板代码:
python
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
class LitModel(pl.LightningModule):
"""PyTorch Lightning 模型模板"""
def __init__(self, input_dim, hidden_dim, output_dim, learning_rate=1e-3):
super().__init__()
# 保存超参数
self.save_hyperparameters()
# 定义模型层
self.layer_1 = nn.Linear(input_dim, hidden_dim)
self.layer_2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
"""前向传播,用于推理"""
x = F.relu(self.layer_1(x))
x = self.layer_2(x)
return x
def training_step(self, batch, batch_idx):
"""训练步骤"""
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
"""验证步骤"""
x, y = batch
y_hat = self(x)
val_loss = F.cross_entropy(y_hat, y)
self.log('val_loss', val_loss)
def test_step(self, batch, batch_idx):
"""测试步骤"""
x, y = batch
y_hat = self(x)
test_loss = F.cross_entropy(y_hat, y)
self.log('test_loss', test_loss)
def predict_step(self, batch, batch_idx):
"""预测步骤"""
x, y = batch
return self(x)
def configure_optimizers(self):
"""配置优化器"""
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer
2.2 初始化方法
__init__() 方法
在 __init__ 方法中进行模型初始化,包括定义网络层、损失函数等。
超参数管理:
(1) 保存超参数
使用 self.save_hyperparameters() 自动保存所有传入 __init__ 的参数:
python
class MyLightningModule(pl.LightningModule):
def __init__(self, learning_rate, layer_1_dim, dropout_rate):
super().__init__()
# 自动保存所有参数到 self.hparams
self.save_hyperparameters()
# 定义网络
self.net = nn.Sequential(
nn.Linear(784, self.hparams.layer_1_dim),
nn.Dropout(self.hparams.dropout_rate),
nn.ReLU(),
nn.Linear(self.hparams.layer_1_dim, 10)
)
(2) 访问超参数
超参数被保存后,可以通过 self.hparams 访问:
python
# 在任何方法中访问超参数
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(),
lr=self.hparams.learning_rate
)
return optimizer
超参数也会自动保存到检查点:
python
# 加载检查点时,超参数会自动恢复
checkpoint = torch.load("checkpoint.ckpt")
print(checkpoint["hyper_parameters"])
# 输出: {"learning_rate": 0.001, "layer_1_dim": 128, "dropout_rate": 0.5}
# 直接从检查点加载模型(包括超参数)
model = MyLightningModule.load_from_checkpoint("checkpoint.ckpt")
print(model.hparams.learning_rate) # 0.001
(3) 使用不同参数初始化
加载检查点时可以覆盖原有超参数:
python
# 使用原始超参数
model = LitModel.load_from_checkpoint("best_model.ckpt")
# 覆盖部分超参数
model = LitModel.load_from_checkpoint(
"best_model.ckpt",
learning_rate=0.0001, # 新的学习率
layer_1_dim=256 # 新的隐藏层维度
)
补充超参数(针对未保存的参数):
如果初始化时某些参数没有通过 save_hyperparameters() 保存,需要在加载时手动传入:
python
class LitAutoencoder(pl.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
# 没有调用 save_hyperparameters()
self.encoder = encoder
self.decoder = decoder
# 加载时必须提供缺失的参数
encoder = MyEncoder()
decoder = MyDecoder()
model = LitAutoencoder.load_from_checkpoint(
"checkpoint.ckpt",
encoder=encoder,
decoder=decoder
)
2.3 训练相关方法
forward() 方法
forward() 方法定义了模型的前向传播逻辑,主要用于推理:
python
def forward(self, x):
"""
前向传播方法
Args:
x: 输入张量
Returns:
输出张量
"""
return self.model(x)
# 在代码中调用
output = model(input_data) # 自动调用 forward()
training_step() 方法
定义单个训练批次的逻辑:
python
def training_step(self, batch, batch_idx):
"""
训练步骤
Args:
batch: 当前批次数据
batch_idx: 批次索引
Returns:
loss: 训练损失(必须返回)
"""
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
# 记录指标到日志
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
self.log('train_acc', self.accuracy(y_hat, y))
return loss
日志记录说明:
on_step=True: 每个 step 记录一次on_epoch=True: 每个 epoch 结束时记录平均值prog_bar=True: 在进度条中显示
training_step_end() 方法(高级)
仅在多 GPU 训练且需要对所有 GPU 的输出进行联合计算时使用:
python
def training_step_end(self, batch_parts):
"""
在分布式训练中,对所有设备的输出进行汇总
Args:
batch_parts: 所有设备返回的 training_step 输出列表
Returns:
汇总后的结果
"""
# 例如:对所有 GPU 的 logits 进行 softmax
gpu_0_prediction = batch_parts[0]['pred']
gpu_1_prediction = batch_parts[1]['pred']
# 合并预测
all_predictions = torch.cat([gpu_0_prediction, gpu_1_prediction])
return {'loss': loss}
training_epoch_end() 方法
在每个训练 epoch 结束时调用:
python
def training_epoch_end(self, outputs):
"""
训练 epoch 结束时的处理
Args:
outputs: 包含所有 training_step 返回值的列表
"""
# 计算整个 epoch 的平均损失
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
self.log('train_epoch_loss', avg_loss)
# 可以在这里进行学习率调整、模型保存等操作
print(f"Epoch {self.current_epoch} finished with avg loss: {avg_loss:.4f}")
2.4 验证与测试方法
validation_step() 方法
定义验证步骤,用于在训练过程中评估模型:
python
def validation_step(self, batch, batch_idx):
"""
验证步骤
Args:
batch: 验证批次数据
batch_idx: 批次索引
Returns:
可以返回任意内容(通常返回损失或指标)
"""
x, y = batch
y_hat = self(x)
val_loss = F.cross_entropy(y_hat, y)
# 记录验证指标
self.log('val_loss', val_loss, prog_bar=True)
self.log('val_acc', self.accuracy(y_hat, y))
return val_loss
验证集划分:
通常使用训练集的 20% 作为验证集:
python
from torch.utils.data import random_split
# 划分训练集和验证集
train_size = int(len(dataset) * 0.8)
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(
dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(42)
)
使用验证集训练:
python
from torch.utils.data import DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
trainer = pl.Trainer()
trainer.fit(model, train_loader, val_loader)
validation_step_end() 和 validation_epoch_end() 方法
用法与训练阶段的对应方法相同:
python
def validation_epoch_end(self, outputs):
"""验证 epoch 结束时的处理"""
avg_val_loss = torch.stack([x for x in outputs]).mean()
self.log('val_epoch_loss', avg_val_loss)
test_step() 方法
测试步骤用于评估模型的最终性能:
python
def test_step(self, batch, batch_idx):
"""
测试步骤
Args:
batch: 测试批次数据
batch_idx: 批次索引
"""
x, y = batch
y_hat = self(x)
test_loss = F.cross_entropy(y_hat, y)
# 记录测试指标
self.log('test_loss', test_loss)
self.log('test_acc', self.accuracy(y_hat, y))
测试集使用:
python
# 加载测试数据
test_dataset = MNIST(root="data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32)
# 训练完成后进行测试
trainer = pl.Trainer()
trainer.fit(model, train_loader, val_loader)
trainer.test(model, test_loader)
注意事项:
- 测试集不应在训练过程中使用
- 测试集仅用于评估训练完成后的模型性能
- 确保测试集与训练集完全独立
2.5 优化器配置
configure_optimizers() 方法
配置模型的优化器和学习率调度器:
python
def configure_optimizers(self):
"""
配置优化器和学习率调度器
Returns:
优化器或包含优化器和调度器的字典
"""
# 基础用法:仅返回优化器
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer
配置学习率调度器:
python
def configure_optimizers(self):
"""配置优化器和学习率调度器"""
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'interval': 'epoch', # 'epoch' 或 'step'
'frequency': 1, # 每多少个 interval 调用一次
'monitor': 'val_loss' # 用于 ReduceLROnPlateau
}
}
多优化器配置(高级):
python
def configure_optimizers(self):
"""配置多个优化器(如 GAN 训练)"""
opt_g = torch.optim.Adam(self.generator.parameters(), lr=0.0002)
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=0.0002)
return [opt_g, opt_d], []
3. Trainer 训练器
Trainer 是 PyTorch Lightning 的训练控制中心,负责管理整个训练流程。
3.1 检查点管理
检查点的组成
PyTorch Lightning 的检查点包含以下内容:
- 16-bit 精度缩放因子(如使用混合精度训练)
- 当前 epoch
- 全局 step
- LightningModule 的 state_dict
- 所有优化器的状态
- 所有学习率调度器的状态
- 所有回调的状态
- DataModule 的状态(如果使用)
- 超参数(模型和 DataModule 的初始化参数)
- 训练循环的状态
自动保存检查点
Lightning 会自动保存最后一个 epoch 的检查点:
python
# 默认保存到当前工作目录
trainer = pl.Trainer()
# 指定保存路径
trainer = pl.Trainer(default_root_dir="my_checkpoints/")
从检查点加载模型
方法 1:仅加载权重进行推理
python
# 加载模型权重和超参数
model = MyLightningModule.load_from_checkpoint("path/to/checkpoint.ckpt")
# 设置为评估模式
model.eval()
# 进行预测
with torch.no_grad():
y_hat = model(x)
方法 2:恢复完整训练状态
python
# 从检查点恢复并继续训练
model = LitModel()
trainer = pl.Trainer()
# 自动恢复模型、epoch、step、优化器等所有状态
trainer.fit(model, ckpt_path="path/to/checkpoint.ckpt")
禁用检查点
python
# 完全禁用自动检查点保存
trainer = pl.Trainer(enable_checkpointing=False)
自定义检查点行为(ModelCheckpoint 回调)
使用 ModelCheckpoint 回调实现更精细的控制:
python
from pytorch_lightning.callbacks import ModelCheckpoint
# 保存验证损失最好的前 3 个模型
checkpoint_callback = ModelCheckpoint(
dirpath='my/path/', # 保存目录
filename='model-{epoch:02d}-{val_loss:.2f}', # 文件名模板
save_top_k=3, # 保存最好的 k 个模型
monitor='val_loss', # 监控的指标
mode='min', # 'min' 或 'max'
save_last=True, # 额外保存最后一个 epoch
every_n_epochs=1, # 每 n 个 epoch 检查一次
save_weights_only=False, # 是否仅保存权重
)
trainer = pl.Trainer(callbacks=[checkpoint_callback])
trainer.fit(model)
# 获取最佳模型路径
best_model_path = checkpoint_callback.best_model_path
print(f"Best model saved at: {best_model_path}")
高级配置:
python
checkpoint_callback = ModelCheckpoint(
# 何时保存(When)
every_n_train_steps=100, # 每 N 个训练步骤保存
every_n_epochs=5, # 每 N 个 epoch 保存
train_time_interval=timedelta(minutes=10), # 每隔一定时间保存
# 保存哪些(Which)
save_top_k=5, # 保存最好的 k 个
save_last=True, # 保存最后一个
monitor='val_accuracy', # 监控的指标
mode='max', # 'min' 表示越小越好,'max' 表示越大越好
# 保存什么(What)
save_weights_only=True, # 仅保存权重(节省空间)
# 保存到哪里(Where)
dirpath='checkpoints/',
filename='{epoch}-{val_loss:.2f}-{val_acc:.2f}',
)
监控自定义指标:
python
class LitModel(pl.LightningModule):
def training_step(self, batch, batch_idx):
# 记录自定义指标
self.log('my_custom_metric', some_value)
# 监控自定义指标
checkpoint_callback = ModelCheckpoint(monitor='my_custom_metric', mode='max')
根据条件保存检查点
python
# 保存最后 K 个检查点(基于 global_step)
checkpoint_callback = ModelCheckpoint(
save_top_k=10,
monitor="global_step",
mode="max",
dirpath="my/path/",
filename="model-{epoch:02d}-{global_step}",
)
注意事项:
- 在文件名中包含监控指标,避免文件名冲突
- 不要依赖自动版本号来检索 top-k 模型
- 文件名示例:
model-epoch=02-val_loss=0.32.ckpt
手动保存检查点
python
# 训练过程中手动保存
trainer.save_checkpoint("manual_checkpoint.ckpt")
# 稍后加载
model = MyLightningModule.load_from_checkpoint("manual_checkpoint.ckpt")
3.2 回调机制
回调(Callback)是 Lightning 提供的一种扩展机制,可以在训练的特定阶段执行自定义操作。
回调的作用
- 将辅助功能从核心研究代码中分离
- 提供数十个可插拔的钩子函数
- 可重用且易于组合
- 不污染主要的模型逻辑
常用回调
1. EarlyStopping(早停)
在验证指标不再改善时提前停止训练:
python
from pytorch_lightning.callbacks import EarlyStopping
class LitModel(pl.LightningModule):
def validation_step(self, batch, batch_idx):
loss = ...
self.log("val_loss", loss)
# 配置早停回调
early_stop_callback = EarlyStopping(
monitor='val_loss', # 监控的指标
min_delta=0.00, # 最小改善幅度
patience=3, # 容忍的 epoch 数
verbose=False, # 是否打印信息
mode='min', # 'min' 或 'max'
strict=True, # 是否在找不到指标时报错
stopping_threshold=0.01, # 达到此阈值立即停止
divergence_threshold=5.0,# 超过此阈值立即停止(防止发散)
check_finite=True, # 检查 NaN 或 Inf
check_on_train_epoch_end=False, # 是否在训练 epoch 结束时检查
)
trainer = pl.Trainer(callbacks=[early_stop_callback])
trainer.fit(model)
自定义早停逻辑:
python
class MyEarlyStopping(EarlyStopping):
def on_validation_end(self, trainer, pl_module):
# 禁用验证结束时的早停
pass
def on_train_end(self, trainer, pl_module):
# 在训练结束时执行早停检查
self._run_early_stopping_check(trainer)
注意事项:
patience计数的是验证检查次数,而非 epoch 数- 如果
check_val_every_n_epoch=10且patience=3,则至少需要 40 个 epoch
2. ModelCheckpoint(检查点)
详见 [3.1 检查点管理](#3.1 检查点管理) 部分。
3. LearningRateMonitor(学习率监控)
自动记录学习率变化:
python
from pytorch_lightning.callbacks import LearningRateMonitor
# 创建学习率监控器
lr_monitor = LearningRateMonitor(
logging_interval='step', # 'step' 或 'epoch',None 表示按调度器的 interval
log_momentum=False, # 是否记录动量
log_weight_decay=False, # 是否记录权重衰减
)
trainer = pl.Trainer(callbacks=[lr_monitor])
自定义日志名称:
python
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
scheduler = {
'scheduler': torch.optim.lr_scheduler.StepLR(optimizer, step_size=10),
'name': 'my_lr_scheduler_name' # 自定义名称
}
return [optimizer], [scheduler]
4. ModelSummary(模型摘要)
打印模型结构信息:
python
from pytorch_lightning.callbacks import ModelSummary
# 打印完整的模型层次结构
trainer = pl.Trainer(callbacks=[ModelSummary(max_depth=-1)])
输出示例:
| Name | Type | Params | In sizes | Out sizes
----------------------------------------------------------------
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
1 | net.0 | Linear | 131 K | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1.0 K | [10, 512] | [10, 512]
5. DeviceStatsMonitor(设备统计)
监控 GPU/CPU 使用情况:
python
from pytorch_lightning.callbacks import DeviceStatsMonitor
trainer = pl.Trainer(callbacks=[DeviceStatsMonitor(cpu_stats=True)])
6. GradientAccumulationScheduler(梯度累积调度)
详见 [3.8 性能优化技巧](#3.8 性能优化技巧) 部分。
7. StochasticWeightAveraging(随机权重平均)
详见 [3.8 性能优化技巧](#3.8 性能优化技巧) 部分。
3.3 日志与可视化
基础日志记录
使用 self.log() 记录指标:
python
class LitModel(pl.LightningModule):
def training_step(self, batch, batch_idx):
loss = ...
# 记录单个指标
self.log("train_loss", loss)
# 记录多个指标
self.log_dict({
"loss": loss,
"acc": accuracy,
"metric_n": some_metric
})
return loss
在命令行中显示
python
# 在进度条中显示指标
self.log("train_loss", loss, prog_bar=True)
输出示例:
Epoch 3: 33%|███▉ | 307/938 [00:01<00:02, 289.04it/s, loss=0.198, acc=0.211]
使用 TensorBoard
启动 TensorBoard:
python
# Lightning 默认使用 TensorBoard(如果已安装)
trainer = pl.Trainer()
# 显式指定 TensorBoard
from pytorch_lightning.loggers import TensorBoardLogger
tensorboard = TensorBoardLogger(save_dir="logs/")
trainer = pl.Trainer(logger=tensorboard)
在命令行中启动 TensorBoard:
bash
tensorboard --logdir=lightning_logs/
在 Jupyter Notebook 中:
python
%load_ext tensorboard
%tensorboard --logdir=lightning_logs/
记录非标量内容
记录图像、直方图、文本等:
python
def training_step(self, batch, batch_idx):
# 获取 TensorBoard 实验对象
tensorboard = self.logger.experiment
# 记录图像
tensorboard.add_image('input_images', images, self.current_epoch)
# 记录直方图
tensorboard.add_histogram('layer1_weights', self.layer1.weight, self.current_epoch)
# 记录梯度直方图
for name, param in self.named_parameters():
if param.grad is not None:
tensorboard.add_histogram(
f"gradients/{name}",
param.grad.detach().cpu(),
self.current_epoch
)
return loss
自定义日志行为
配置日志频率:
python
# 每 10 步记录一次
trainer = pl.Trainer(log_every_n_steps=10)
配置 TensorBoard 刷新频率:
python
logger = TensorBoardLogger(..., max_queue=100, flush_secs=120)
self.log 的详细配置
python
self.log(
name="metric_name",
value=metric_value,
# 时间维度
on_step=True, # 是否在每个 step 记录
on_epoch=True, # 是否在 epoch 结束时记录
# 聚合方式
reduce_fx=torch.mean, # 聚合函数:mean, max, min, sum
# 显示位置
prog_bar=True, # 是否显示在进度条
logger=True, # 是否发送到 logger
# 分布式训练
sync_dist=False, # 是否在设备间同步
sync_dist_group=None, # DDP 同步组
rank_zero_only=False, # 是否仅在 rank 0 记录
# 其他
batch_size=32, # 批次大小(用于正确累积)
enable_graph=True, # 是否保持计算图
)
默认值(根据调用位置不同):
python
def training_step(self, batch, batch_idx):
# 默认: on_step=True, on_epoch=False
self.log("train_loss", loss)
def validation_step(self, batch, batch_idx):
# 默认: on_step=False, on_epoch=True
self.log("val_loss", loss)
def test_step(self, batch, batch_idx):
# 默认: on_step=False, on_epoch=True
self.log("test_loss", loss)
使用多个 Logger
python
from pytorch_lightning import loggers as pl_loggers
# 同时使用 TensorBoard 和 CSV logger
tensorboard = pl_loggers.TensorBoardLogger('logs/')
csv_logger = pl_loggers.CSVLogger('logs/')
trainer = pl.Trainer(logger=[tensorboard, csv_logger])
累积指标
在验证和测试阶段,Lightning 自动累积指标并计算平均值:
python
def validation_step(self, batch, batch_idx):
value = ...
# 自动累积并在 epoch 结束时计算平均值
self.log("average_value", value)
如需其他聚合方式:
python
self.log("max_value", value, reduce_fx='max')
self.log("min_value", value, reduce_fx='min')
self.log("sum_value", value, reduce_fx='sum')
3.4 命令行参数
使用 ArgumentParser 或 Lightning CLI 管理超参数。
使用 ArgumentParser
python
from argparse import ArgumentParser
# 创建参数解析器
parser = ArgumentParser()
# 添加 Trainer 参数
parser.add_argument("--devices", type=int, default=2)
parser.add_argument("--max_epochs", type=int, default=10)
# 添加模型超参数
parser.add_argument("--layer_1_dim", type=int, default=128)
parser.add_argument("--learning_rate", type=float, default=1e-3)
# 解析参数
args = parser.parse_args()
# 使用解析的参数
model = MyModel(
layer_1_dim=args.layer_1_dim,
learning_rate=args.learning_rate
)
trainer = pl.Trainer(devices=args.devices, max_epochs=args.max_epochs)
trainer.fit(model)
命令行调用:
bash
python train.py --layer_1_dim 256 --learning_rate 0.001 --devices 4
使用 Lightning CLI(推荐)
Lightning CLI 提供了更强大的命令行配置功能:
python
from pytorch_lightning.cli import LightningCLI
# 简单用法
cli = LightningCLI(MyModel, MyDataModule)
命令行调用:
bash
# 查看所有可用参数
python train.py --help
# 使用命令行参数
python train.py --model.learning_rate 0.001 --model.layer_1_dim 256 --trainer.max_epochs 50
# 使用配置文件
python train.py --config config.yaml
配置文件示例(config.yaml):
yaml
model:
learning_rate: 0.001
layer_1_dim: 256
dropout: 0.5
trainer:
max_epochs: 50
devices: 4
accelerator: gpu
data:
batch_size: 32
num_workers: 4
3.5 模型预测
从检查点加载并预测
python
# 加载训练好的模型
model = LitModel.load_from_checkpoint("best_model.ckpt")
model.eval()
# 准备输入数据
x = torch.randn(1, 64)
# 进行预测
with torch.no_grad():
y_hat = model(x)
使用 predict_step
定义 predict_step 方法来处理预测逻辑:
python
class MyModel(pl.LightningModule):
def predict_step(self, batch, batch_idx, dataloader_idx=0):
x, _ = batch
return self(x)
使用 Trainer 进行预测:
python
# 准备数据
predict_loader = DataLoader(predict_dataset, batch_size=32)
# 加载模型
model = MyModel.load_from_checkpoint("checkpoint.ckpt")
# 进行预测
trainer = pl.Trainer()
predictions = trainer.predict(model, predict_loader)
复杂预测逻辑(Monte Carlo Dropout)
python
class LitMCdropoutModel(pl.LightningModule):
def __init__(self, model, mc_iteration=10):
super().__init__()
self.model = model
self.dropout = nn.Dropout()
self.mc_iteration = mc_iteration
def predict_step(self, batch, batch_idx):
# 启用 Dropout(Monte Carlo Dropout)
self.dropout.train()
# 进行多次预测并取平均
predictions = []
for _ in range(self.mc_iteration):
pred = self.dropout(self.model(batch))
predictions.append(pred.unsqueeze(0))
# 计算平均预测
pred = torch.vstack(predictions).mean(dim=0)
return pred
分布式预测并保存结果
使用 BasePredictionWriter 在分布式环境中保存预测结果:
python
from pytorch_lightning.callbacks import BasePredictionWriter
class CustomWriter(BasePredictionWriter):
def __init__(self, output_dir, write_interval='epoch'):
super().__init__(write_interval)
self.output_dir = output_dir
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
# 每个进程保存自己的预测结果
torch.save(
predictions,
os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pt")
)
# 可选:保存批次索引
torch.save(
batch_indices,
os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt")
)
# 使用自定义写入器
pred_writer = CustomWriter(output_dir="predictions/", write_interval="epoch")
trainer = pl.Trainer(
accelerator="gpu",
strategy="ddp",
devices=8,
callbacks=[pred_writer]
)
model = MyModel()
trainer.predict(model, return_predictions=False)
3.6 GPU 训练
基础 GPU 训练
python
# 自动使用所有可用 GPU
trainer = pl.Trainer(accelerator="auto", devices="auto")
# 等价于
trainer = pl.Trainer()
# 使用单个 GPU
trainer = pl.Trainer(accelerator="gpu", devices=1)
# 使用多个 GPU
trainer = pl.Trainer(accelerator="gpu", devices=8)
注意:
accelerator="gpu"也会自动选择 Apple Silicon 的 MPS 设备- 如要避免使用 MPS,可设置
accelerator="cuda"
选择特定 GPU
python
# 使用前 k 个 GPU
trainer = pl.Trainer(accelerator="gpu", devices=k)
# 等价于
trainer = pl.Trainer(accelerator="gpu", devices=list(range(k)))
# 指定特定 GPU(不推荐在集群上使用)
trainer = pl.Trainer(accelerator="gpu", devices=[0, 1])
# 使用字符串形式
trainer = pl.Trainer(accelerator="gpu", devices="0, 1")
# 使用所有 GPU
trainer = pl.Trainer(accelerator="gpu", devices=-1)
devices 参数解释:
| devices 值 | 类型 | 解析结果 | 含义 |
|---|---|---|---|
| 3 | int | [0, 1, 2] | 前 3 个 GPU |
| -1 | int | [0, 1, 2, ...] | 所有可用 GPU |
| [0] | list | [0] | GPU 0 |
| [1, 3] | list | [1, 3] | GPU 索引 1 和 3 |
| "3" | str | [0, 1, 2] | 前 3 个 GPU |
| "1, 3" | str | [1, 3] | GPU 索引 1 和 3 |
| "-1" | str | [0, 1, 2, ...] | 所有可用 GPU |
自动查找可用 GPU
在多任务场景下(如超参数搜索),自动查找未被占用的 GPU:
python
from pytorch_lightning.accelerators import find_usable_cuda_devices
# 查找 2 个未被占用的 GPU
trainer = pl.Trainer(
accelerator="cuda",
devices=find_usable_cuda_devices(2)
)
这在 GPU 设置为"独占计算模式"时特别有用。
3.7 模型调试
设置断点
python
def function_to_debug():
x = 2
# 设置断点
import pdb
pdb.set_trace()
y = x ** 2 # 代码将在此处暂停
快速运行模式(fast_dev_run)
快速运行少量批次以检查代码是否有错误:
python
# 运行 5 个批次的训练、验证、测试
trainer = pl.Trainer(fast_dev_run=True)
# 自定义批次数量
trainer = pl.Trainer(fast_dev_run=7)
注意:
- 该模式会禁用 checkpoint、early stopping、logger 等功能
- 适合快速验证代码逻辑
缩短 epoch 长度
使用部分数据进行调试:
python
# 使用 10% 的训练数据和 1% 的验证数据
trainer = pl.Trainer(limit_train_batches=0.1, limit_val_batches=0.01)
# 使用固定批次数
trainer = pl.Trainer(limit_train_batches=10, limit_val_batches=5)
Sanity Check(健全性检查)
训练开始前运行少量验证步骤,避免深度训练后才发现验证错误:
python
# 默认运行 2 步验证
trainer = pl.Trainer(num_sanity_val_steps=2)
# 禁用健全性检查
trainer = pl.Trainer(num_sanity_val_steps=0)
打印模型摘要
python
# 训练时自动打印模型摘要
trainer.fit(model)
# 输出示例:
# | Name | Type | Params | Mode
# -------------------------------------------
# 0 | net | Sequential | 132 K | train
# 1 | net.0 | Linear | 131 K | train
# 2 | net.1 | BatchNorm1d | 1.0 K | train
打印更详细的摘要:
python
from pytorch_lightning.callbacks import ModelSummary
trainer = pl.Trainer(callbacks=[ModelSummary(max_depth=-1)])
手动打印摘要:
python
from pytorch_lightning.utilities.model_summary import ModelSummary
model = LitModel()
summary = ModelSummary(model, max_depth=-1)
print(summary)
禁用自动摘要:
python
trainer = pl.Trainer(enable_model_summary=False)
显示输入输出维度
设置 example_input_array 以显示层的输入输出维度:
python
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.example_input_array = torch.Tensor(32, 1, 28, 28)
# ...
输出示例:
| Name | Type | Params | In sizes | Out sizes
----------------------------------------------------------------------
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
1 | net.0 | Linear | 131 K | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1.0 K | [10, 512] | [10, 512]
3.8 性能优化技巧
混合精度训练(N-bit Precision)
使用低精度浮点数可以加速训练并减少内存使用:
python
# 使用 16 位混合精度
trainer = pl.Trainer(precision='16-mixed')
# 使用 bf16(bfloat16)混合精度
trainer = pl.Trainer(precision='bf16-mixed')
# 使用 64 位精度(更高精度)
trainer = pl.Trainer(precision=64)
优势:
- 减少内存占用,可训练更大的模型
- 加快训练速度
- 降低对硬件的要求
梯度累积(Accumulate Gradients)
通过累积多个小批次的梯度来模拟大批次训练:
python
# 默认不累积(每个批次更新一次)
trainer = pl.Trainer(accumulate_grad_batches=1)
# 累积 7 个批次的梯度
trainer = pl.Trainer(accumulate_grad_batches=7)
效果:
- 有效批次大小 =
batch_size × accumulate_grad_batches - 例如:
batch_size=32,accumulate_grad_batches=7,有效批次大小 = 224
动态调整累积批次:
python
from pytorch_lightning.callbacks import GradientAccumulationScheduler
# 前 5 个 epoch 累积 8 个批次
# 第 5-9 epoch 累积 4 个批次
# 第 9 epoch 后不累积
accumulator = GradientAccumulationScheduler(scheduling={0: 8, 4: 4, 8: 1})
trainer = pl.Trainer(callbacks=[accumulator])
注意事项:
- 在 DDP 下,每个设备独立累积梯度
- 有效批次大小 =
num_devices × batch_size × accumulate_grad_batches
梯度裁剪(Gradient Clipping)
防止梯度爆炸:
python
# 默认不裁剪
trainer = pl.Trainer(gradient_clip_val=0)
# 裁剪梯度范数到 <= 0.5
trainer = pl.Trainer(gradient_clip_val=0.5)
# 裁剪梯度值到 <= 0.5(而非范数)
trainer = pl.Trainer(gradient_clip_val=0.5, gradient_clip_algorithm="value")
注意:
- 使用混合精度时,梯度会在裁剪前自动缩放回 fp32
随机权重平均(Stochastic Weight Averaging)
通过平均多个训练步骤的权重来提高模型泛化能力:
python
from pytorch_lightning.callbacks import StochasticWeightAveraging
trainer = pl.Trainer(callbacks=[StochasticWeightAveraging(swa_lrs=1e-2)])
优势:
- 几乎无额外成本
- 提高模型泛化性能
- 平滑损失函数
自动批次大小查找
自动找到能放入内存的最大批次大小:
python
from pytorch_lightning.tuner import Tuner
trainer = pl.Trainer()
tuner = Tuner(trainer)
# 指数增长搜索(默认)
tuner.scale_batch_size(model, mode="power")
# 二分搜索
tuner.scale_batch_size(model, mode="binsearch")
# 然后正常训练
trainer.fit(model)
前提条件:
- 模型需要有
batch_size属性或在hparams中定义 train_dataloader()方法需要使用该属性
python
class LitModel(pl.LightningModule):
def __init__(self, batch_size=32):
super().__init__()
self.save_hyperparameters()
def train_dataloader(self):
return DataLoader(dataset, batch_size=self.hparams.batch_size)
自动学习率查找
自动找到最优初始学习率:
python
from pytorch_lightning.tuner import Tuner
model = MyModel()
trainer = pl.Trainer()
tuner = Tuner(trainer)
# 运行学习率查找
lr_finder = tuner.lr_find(model)
# 查看结果
print(lr_finder.results)
# 绘制曲线
fig = lr_finder.plot(suggest=True)
fig.show()
# 获取建议的学习率
new_lr = lr_finder.suggestion()
# 更新模型
model.hparams.lr = new_lr
# 开始训练
trainer.fit(model)
学习率查找原理:
- 从小学习率开始,逐步增加
- 记录每个学习率对应的损失
- 找到损失下降最快的区域
- 建议选择该区域的中点(而非最低点)
自定义学习率查找:
python
from pytorch_lightning.callbacks import LearningRateFinder
class FineTuneLearningRateFinder(LearningRateFinder):
def __init__(self, milestones, *args, **kwargs):
super().__init__(*args, **kwargs)
self.milestones = milestones
def on_fit_start(self, *args, **kwargs):
return
def on_train_epoch_start(self, trainer, pl_module):
if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
self.lr_find(trainer, pl_module)
# 在第 0、5、10 epoch 时运行学习率查找
trainer = pl.Trainer(callbacks=[FineTuneLearningRateFinder(milestones=(5, 10))])
4. DataModule 数据模块
4.1 DataModule 介绍
DataModule 是 PyTorch Lightning 提供的数据管理抽象,封装了数据处理的五个步骤:
- 下载/标记/处理数据
- 清理数据并保存到磁盘
- 加载 数据到
Dataset - 应用数据变换(旋转、标记化等)
- 包装 成
DataLoader
为什么使用 DataModule?
解决的问题:
- 数据准备代码通常分散在多个文件中
- 难以共享和复用数据集的划分和变换
- 无法确保数据处理的一致性
DataModule 的优势:
- 将所有数据相关逻辑封装在一起
- 易于在不同项目间共享
- 便于切换不同数据集
- 确保数据处理的可复现性
基础示例对比
传统 PyTorch 代码:
python
# 数据准备代码分散
test_data = MNIST(path, train=False, download=True)
predict_data = MNIST(path, train=False, download=True)
train_data = MNIST(path, train=True, download=True)
train_data, val_data = random_split(train_data, [55000, 5000])
train_loader = DataLoader(train_data, batch_size=32)
val_loader = DataLoader(val_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)
predict_loader = DataLoader(predict_data, batch_size=32)
使用 DataModule:
python
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "data/", batch_size: int = 32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def setup(self, stage: str):
# 加载和划分数据
self.mnist_test = MNIST(self.data_dir, train=False)
self.mnist_predict = MNIST(self.data_dir, train=False)
mnist_full = MNIST(self.data_dir, train=True)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000],
generator=torch.Generator().manual_seed(42)
)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=self.batch_size)
使用方式:
python
# 创建 DataModule 和模型
mnist = MNISTDataModule(data_dir="data/", batch_size=64)
model = LitClassifier()
# 训练
trainer = pl.Trainer()
trainer.fit(model, datamodule=mnist)
# DataModule 可复用于不同数据集
cifar10 = CIFAR10DataModule()
trainer.fit(model, datamodule=cifar10)
4.2 核心方法
prepare_data()
作用: 下载、标记化等一次性操作(通常只在单个进程中执行)
特点:
- 仅在单个进程上调用(避免多进程下载冲突)
- 在 CPU 上运行
- 不应在此分配状态(
self.x = y)
使用场景:
- 下载数据集
- 标记化文本
- 生成词汇表
- 预处理并保存到磁盘
python
class MNISTDataModule(pl.LightningDataModule):
def prepare_data(self):
# 仅下载一次,避免多进程重复下载
MNIST(self.data_dir, train=True, download=True, transform=transforms.ToTensor())
MNIST(self.data_dir, train=False, download=True, transform=transforms.ToTensor())
多节点训练:
- 默认情况下,每个节点的 rank 0 进程都会调用
prepare_data() - 可通过
prepare_data_per_node控制行为:
python
class LitDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
# True: 每个节点的 LOCAL_RANK=0 调用
# False: 仅 NODE_RANK=0, LOCAL_RANK=0 调用
self.prepare_data_per_node = True
setup(stage)
作用: 在每个进程上执行的数据准备操作
特点:
- 在所有进程上调用
- 可以分配状态(
self.x = y) - 接收
stage参数区分训练/验证/测试阶段
使用场景:
- 统计类别数量
- 构建词汇表
- 划分训练/验证/测试集
- 创建 Dataset
- 应用数据变换
python
class MNISTDataModule(pl.LightningDataModule):
def setup(self, stage: str):
# stage 可以是 'fit', 'validate', 'test', 'predict'
if stage == "fit":
# 训练和验证阶段
mnist_full = MNIST(
self.data_dir,
train=True,
transform=self.transform
)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000],
generator=torch.Generator().manual_seed(42)
)
# 可以在这里统计信息
self.num_classes = 10
self.dims = mnist_full[0][0].shape
if stage == "test":
# 测试阶段
self.mnist_test = MNIST(
self.data_dir,
train=False,
transform=self.transform
)
if stage == "predict":
# 预测阶段
self.mnist_predict = MNIST(
self.data_dir,
train=False,
transform=self.transform
)
NLP 示例(标记化):
python
class TextDataModule(pl.LightningDataModule):
def prepare_data(self):
# 下载和标记化(仅一次)
dataset = load_dataset(...)
tokenized = tokenize(dataset)
save_to_disk(tokenized, "processed/")
def setup(self, stage):
# 每个进程加载预处理的数据
self.dataset = load_dataset_from_disk("processed/")
if stage == "fit":
self.train_data = self.dataset['train']
self.val_data = self.dataset['validation']
train_dataloader()
返回训练数据加载器:
python
def train_dataloader(self):
return DataLoader(
self.mnist_train,
batch_size=self.batch_size,
shuffle=True,
num_workers=4,
pin_memory=True
)
val_dataloader()
返回验证数据加载器:
python
def val_dataloader(self):
return DataLoader(
self.mnist_val,
batch_size=self.batch_size,
shuffle=False,
num_workers=4
)
test_dataloader()
返回测试数据加载器:
python
def test_dataloader(self):
return DataLoader(
self.mnist_test,
batch_size=self.batch_size,
shuffle=False
)
predict_dataloader()
返回预测数据加载器:
python
def predict_dataloader(self):
return DataLoader(
self.mnist_predict,
batch_size=self.batch_size
)
teardown(stage)
作用: 在训练/测试结束时执行清理工作
python
def teardown(self, stage: str):
# 清理资源
if stage == "fit":
# 训练结束时的清理
del self.train_data
del self.val_data
if stage == "test":
# 测试结束时的清理
del self.test_data
state_dict() 和 load_state_dict()
保存和恢复 DataModule 状态:
python
class LitDataModule(pl.LightningDataModule):
def state_dict(self):
# 保存需要持久化的状态
state = {
"current_train_batch_index": self.current_train_batch_index,
"random_state": self.rng.getstate()
}
return state
def load_state_dict(self, state_dict):
# 从检查点恢复状态
self.current_train_batch_index = state_dict["current_train_batch_index"]
self.rng.setstate(state_dict["random_state"])
数据传输钩子(高级)
transfer_batch_to_device()
自定义如何将批次数据传输到设备:
python
def transfer_batch_to_device(self, batch, device, dataloader_idx):
if isinstance(batch, CustomBatch):
# 移动自定义数据结构中的张量
batch.samples = batch.samples.to(device)
batch.targets = batch.targets.to(device)
return batch
else:
return super().transfer_batch_to_device(batch, device, dataloader_idx)
on_before_batch_transfer()
在数据传输到设备之前进行数据增强:
python
def on_before_batch_transfer(self, batch, dataloader_idx):
# 在 CPU 上应用数据增强
batch['x'] = cpu_transforms(batch['x'])
return batch
on_after_batch_transfer()
在数据传输到设备之后进行数据增强:
python
def on_after_batch_transfer(self, batch, dataloader_idx):
# 在 GPU 上应用数据增强
batch['x'] = gpu_transforms(batch['x'])
return batch
4.3 使用方式
基本使用
python
# 创建 DataModule
dm = MNISTDataModule(data_dir="data/", batch_size=32)
# 创建模型
model = MyModel()
# 训练
trainer = pl.Trainer()
trainer.fit(model, datamodule=dm)
# 测试
trainer.test(datamodule=dm)
# 验证
trainer.validate(datamodule=dm)
# 预测
trainer.predict(datamodule=dm)
手动调用 DataModule 方法
如果需要在构建模型前获取数据集信息:
python
# 手动准备数据
dm = MNISTDataModule()
dm.prepare_data()
dm.setup(stage="fit")
# 使用数据集信息构建模型
model = MyModel(
num_classes=dm.num_classes,
input_shape=dm.dims
)
# 训练
trainer.fit(model, datamodule=dm)
# 测试
dm.setup(stage="test")
trainer.test(datamodule=dm)
在纯 PyTorch 中使用
DataModule 也可以在纯 PyTorch 代码中使用:
python
# 准备数据
dm = MNISTDataModule()
dm.prepare_data()
# 设置训练数据
dm.setup(stage="fit")
for batch in dm.train_dataloader():
# 训练代码
pass
for batch in dm.val_dataloader():
# 验证代码
pass
dm.teardown(stage="fit")
# 设置测试数据
dm.setup(stage="test")
for batch in dm.test_dataloader():
# 测试代码
pass
dm.teardown(stage="test")
保存和恢复超参数
与 LightningModule 类似,DataModule 也支持超参数管理:
python
class CustomDataModule(pl.LightningDataModule):
def __init__(self, batch_size=32, num_workers=4):
super().__init__()
# 保存所有超参数
self.save_hyperparameters()
def train_dataloader(self):
# 使用保存的超参数
return DataLoader(
self.train_data,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers
)
访问当前使用的 DataModule
python
# 在训练过程中访问 DataModule
current_dm = trainer.datamodule
# 访问当前的 DataLoader
train_loader = trainer.train_dataloader
val_loaders = trainer.val_dataloaders
test_loaders = trainer.test_dataloaders
predict_loaders = trainer.predict_dataloaders
完整示例
端到端训练示例
python
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
# 1. 定义 DataModule
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir="data/", batch_size=32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
def prepare_data(self):
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: str):
if stage == "fit":
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000],
generator=torch.Generator().manual_seed(42)
)
if stage == "test":
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True, num_workers=4)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)
# 2. 定义 LightningModule
class LitMNIST(pl.LightningModule):
def __init__(self, hidden_dim=128, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
self.layer_1 = nn.Linear(28 * 28, hidden_dim)
self.layer_2 = nn.Linear(hidden_dim, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = F.relu(self.layer_1(x))
x = self.layer_2(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
val_loss = F.cross_entropy(y_hat, y)
acc = (y_hat.argmax(dim=1) == y).float().mean()
self.log('val_loss', val_loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
test_loss = F.cross_entropy(y_hat, y)
acc = (y_hat.argmax(dim=1) == y).float().mean()
self.log('test_loss', test_loss)
self.log('test_acc', acc)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'interval': 'epoch'
}
}
# 3. 配置训练
if __name__ == "__main__":
# 创建 DataModule 和模型
dm = MNISTDataModule(batch_size=64)
model = LitMNIST(hidden_dim=256, learning_rate=1e-3)
# 配置回调
checkpoint_callback = ModelCheckpoint(
dirpath='checkpoints/',
filename='mnist-{epoch:02d}-{val_loss:.2f}',
save_top_k=3,
monitor='val_loss',
mode='min'
)
early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
mode='min'
)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
# 创建 Trainer
trainer = pl.Trainer(
max_epochs=50,
accelerator='gpu',
devices=1,
callbacks=[checkpoint_callback, early_stop_callback, lr_monitor],
log_every_n_steps=10,
)
# 训练
trainer.fit(model, datamodule=dm)
# 测试
trainer.test(model, datamodule=dm, ckpt_path='best')
总结
PyTorch Lightning 通过三大核心组件提供了完整的深度学习训练解决方案:
- LightningModule:封装模型定义、训练逻辑和优化器配置
- Trainer:自动化训练流程,提供丰富的回调和优化选项
- DataModule:组织数据处理流程,确保数据处理的可复现性