PyTorch Lightning

PyTorch Lightning 指南

本文档是 PyTorch Lightning 的完整使用指南,涵盖 LightningModule、Trainer、DataModule 三大核心组件的详细说明。


目录

  • [1. PyTorch Lightning 简介](#1. PyTorch Lightning 简介)
  • [2. LightningModule 核心组件](#2. LightningModule 核心组件)
    • [2.1 模板结构](#2.1 模板结构)
    • [2.2 初始化方法](#2.2 初始化方法)
    • [2.3 训练相关方法](#2.3 训练相关方法)
    • [2.4 验证与测试方法](#2.4 验证与测试方法)
    • [2.5 优化器配置](#2.5 优化器配置)
  • [3. Trainer 训练器](#3. Trainer 训练器)
    • [3.1 检查点管理](#3.1 检查点管理)
    • [3.2 回调机制](#3.2 回调机制)
    • [3.3 日志与可视化](#3.3 日志与可视化)
    • [3.4 命令行参数](#3.4 命令行参数)
    • [3.5 模型预测](#3.5 模型预测)
    • [3.6 GPU 训练](#3.6 GPU 训练)
    • [3.7 模型调试](#3.7 模型调试)
    • [3.8 性能优化技巧](#3.8 性能优化技巧)
  • [4. DataModule 数据模块](#4. DataModule 数据模块)
    • [4.1 DataModule 介绍](#4.1 DataModule 介绍)
    • [4.2 核心方法](#4.2 核心方法)
    • [4.3 使用方式](#4.3 使用方式)

1. PyTorch Lightning 简介

PyTorch Lightning 是一个轻量级的 PyTorch 封装框架,旨在组织 PyTorch 代码,使研究代码更具可读性和可复现性。它将研究代码从工程代码中分离,让你专注于模型开发。

核心优势:

  • 自动化训练流程(梯度计算、优化器步骤、日志记录等)
  • 代码组织结构化,易于维护和复用
  • 支持多 GPU、TPU 等分布式训练
  • 内置丰富的回调和日志系统

2. LightningModule 核心组件

2.1 模板结构

LightningModule 是 PyTorch Lightning 的核心,它将 PyTorch 的 nn.Module 进行了扩展,提供了标准化的训练流程。

基本组成部分:

  • 初始化(__init__setup()
  • 训练循环(training_step()
  • 验证循环(validation_step()
  • 测试循环(test_step()
  • 预测循环(predict_step()
  • 优化器配置(configure_optimizers()

完整模板代码:

python 复制代码
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F

class LitModel(pl.LightningModule):
    """PyTorch Lightning 模型模板"""
    
    def __init__(self, input_dim, hidden_dim, output_dim, learning_rate=1e-3):
        super().__init__()
        # 保存超参数
        self.save_hyperparameters()
        
        # 定义模型层
        self.layer_1 = nn.Linear(input_dim, hidden_dim)
        self.layer_2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        """前向传播,用于推理"""
        x = F.relu(self.layer_1(x))
        x = self.layer_2(x)
        return x
    
    def training_step(self, batch, batch_idx):
        """训练步骤"""
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        """验证步骤"""
        x, y = batch
        y_hat = self(x)
        val_loss = F.cross_entropy(y_hat, y)
        self.log('val_loss', val_loss)
    
    def test_step(self, batch, batch_idx):
        """测试步骤"""
        x, y = batch
        y_hat = self(x)
        test_loss = F.cross_entropy(y_hat, y)
        self.log('test_loss', test_loss)
    
    def predict_step(self, batch, batch_idx):
        """预测步骤"""
        x, y = batch
        return self(x)
    
    def configure_optimizers(self):
        """配置优化器"""
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer

2.2 初始化方法

__init__() 方法

__init__ 方法中进行模型初始化,包括定义网络层、损失函数等。

超参数管理:

(1) 保存超参数

使用 self.save_hyperparameters() 自动保存所有传入 __init__ 的参数:

python 复制代码
class MyLightningModule(pl.LightningModule):
    def __init__(self, learning_rate, layer_1_dim, dropout_rate):
        super().__init__()
        # 自动保存所有参数到 self.hparams
        self.save_hyperparameters()
        
        # 定义网络
        self.net = nn.Sequential(
            nn.Linear(784, self.hparams.layer_1_dim),
            nn.Dropout(self.hparams.dropout_rate),
            nn.ReLU(),
            nn.Linear(self.hparams.layer_1_dim, 10)
        )
(2) 访问超参数

超参数被保存后,可以通过 self.hparams 访问:

python 复制代码
# 在任何方法中访问超参数
def configure_optimizers(self):
    optimizer = torch.optim.Adam(
        self.parameters(), 
        lr=self.hparams.learning_rate
    )
    return optimizer

超参数也会自动保存到检查点:

python 复制代码
# 加载检查点时,超参数会自动恢复
checkpoint = torch.load("checkpoint.ckpt")
print(checkpoint["hyper_parameters"])
# 输出: {"learning_rate": 0.001, "layer_1_dim": 128, "dropout_rate": 0.5}

# 直接从检查点加载模型(包括超参数)
model = MyLightningModule.load_from_checkpoint("checkpoint.ckpt")
print(model.hparams.learning_rate)  # 0.001
(3) 使用不同参数初始化

加载检查点时可以覆盖原有超参数:

python 复制代码
# 使用原始超参数
model = LitModel.load_from_checkpoint("best_model.ckpt")

# 覆盖部分超参数
model = LitModel.load_from_checkpoint(
    "best_model.ckpt", 
    learning_rate=0.0001,  # 新的学习率
    layer_1_dim=256        # 新的隐藏层维度
)

补充超参数(针对未保存的参数):

如果初始化时某些参数没有通过 save_hyperparameters() 保存,需要在加载时手动传入:

python 复制代码
class LitAutoencoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        # 没有调用 save_hyperparameters()
        self.encoder = encoder
        self.decoder = decoder

# 加载时必须提供缺失的参数
encoder = MyEncoder()
decoder = MyDecoder()
model = LitAutoencoder.load_from_checkpoint(
    "checkpoint.ckpt",
    encoder=encoder,
    decoder=decoder
)

2.3 训练相关方法

forward() 方法

forward() 方法定义了模型的前向传播逻辑,主要用于推理

python 复制代码
def forward(self, x):
    """
    前向传播方法
    
    Args:
        x: 输入张量
        
    Returns:
        输出张量
    """
    return self.model(x)

# 在代码中调用
output = model(input_data)  # 自动调用 forward()
training_step() 方法

定义单个训练批次的逻辑:

python 复制代码
def training_step(self, batch, batch_idx):
    """
    训练步骤
    
    Args:
        batch: 当前批次数据
        batch_idx: 批次索引
        
    Returns:
        loss: 训练损失(必须返回)
    """
    x, y = batch
    y_hat = self(x)
    loss = F.cross_entropy(y_hat, y)
    
    # 记录指标到日志
    self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
    self.log('train_acc', self.accuracy(y_hat, y))
    
    return loss

日志记录说明:

  • on_step=True: 每个 step 记录一次
  • on_epoch=True: 每个 epoch 结束时记录平均值
  • prog_bar=True: 在进度条中显示
training_step_end() 方法(高级)

仅在多 GPU 训练且需要对所有 GPU 的输出进行联合计算时使用:

python 复制代码
def training_step_end(self, batch_parts):
    """
    在分布式训练中,对所有设备的输出进行汇总
    
    Args:
        batch_parts: 所有设备返回的 training_step 输出列表
        
    Returns:
        汇总后的结果
    """
    # 例如:对所有 GPU 的 logits 进行 softmax
    gpu_0_prediction = batch_parts[0]['pred']
    gpu_1_prediction = batch_parts[1]['pred']
    
    # 合并预测
    all_predictions = torch.cat([gpu_0_prediction, gpu_1_prediction])
    return {'loss': loss}
training_epoch_end() 方法

在每个训练 epoch 结束时调用:

python 复制代码
def training_epoch_end(self, outputs):
    """
    训练 epoch 结束时的处理
    
    Args:
        outputs: 包含所有 training_step 返回值的列表
    """
    # 计算整个 epoch 的平均损失
    avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
    self.log('train_epoch_loss', avg_loss)
    
    # 可以在这里进行学习率调整、模型保存等操作
    print(f"Epoch {self.current_epoch} finished with avg loss: {avg_loss:.4f}")

2.4 验证与测试方法

validation_step() 方法

定义验证步骤,用于在训练过程中评估模型:

python 复制代码
def validation_step(self, batch, batch_idx):
    """
    验证步骤
    
    Args:
        batch: 验证批次数据
        batch_idx: 批次索引
        
    Returns:
        可以返回任意内容(通常返回损失或指标)
    """
    x, y = batch
    y_hat = self(x)
    val_loss = F.cross_entropy(y_hat, y)
    
    # 记录验证指标
    self.log('val_loss', val_loss, prog_bar=True)
    self.log('val_acc', self.accuracy(y_hat, y))
    
    return val_loss

验证集划分:

通常使用训练集的 20% 作为验证集:

python 复制代码
from torch.utils.data import random_split

# 划分训练集和验证集
train_size = int(len(dataset) * 0.8)
val_size = len(dataset) - train_size

train_dataset, val_dataset = random_split(
    dataset, 
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

使用验证集训练:

python 复制代码
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

trainer = pl.Trainer()
trainer.fit(model, train_loader, val_loader)
validation_step_end()validation_epoch_end() 方法

用法与训练阶段的对应方法相同:

python 复制代码
def validation_epoch_end(self, outputs):
    """验证 epoch 结束时的处理"""
    avg_val_loss = torch.stack([x for x in outputs]).mean()
    self.log('val_epoch_loss', avg_val_loss)
test_step() 方法

测试步骤用于评估模型的最终性能:

python 复制代码
def test_step(self, batch, batch_idx):
    """
    测试步骤
    
    Args:
        batch: 测试批次数据
        batch_idx: 批次索引
    """
    x, y = batch
    y_hat = self(x)
    test_loss = F.cross_entropy(y_hat, y)
    
    # 记录测试指标
    self.log('test_loss', test_loss)
    self.log('test_acc', self.accuracy(y_hat, y))

测试集使用:

python 复制代码
# 加载测试数据
test_dataset = MNIST(root="data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32)

# 训练完成后进行测试
trainer = pl.Trainer()
trainer.fit(model, train_loader, val_loader)
trainer.test(model, test_loader)

注意事项:

  • 测试集不应在训练过程中使用
  • 测试集仅用于评估训练完成后的模型性能
  • 确保测试集与训练集完全独立

2.5 优化器配置

configure_optimizers() 方法

配置模型的优化器和学习率调度器:

python 复制代码
def configure_optimizers(self):
    """
    配置优化器和学习率调度器
    
    Returns:
        优化器或包含优化器和调度器的字典
    """
    # 基础用法:仅返回优化器
    optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
    return optimizer

配置学习率调度器:

python 复制代码
def configure_optimizers(self):
    """配置优化器和学习率调度器"""
    optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    
    return {
        'optimizer': optimizer,
        'lr_scheduler': {
            'scheduler': scheduler,
            'interval': 'epoch',  # 'epoch' 或 'step'
            'frequency': 1,       # 每多少个 interval 调用一次
            'monitor': 'val_loss' # 用于 ReduceLROnPlateau
        }
    }

多优化器配置(高级):

python 复制代码
def configure_optimizers(self):
    """配置多个优化器(如 GAN 训练)"""
    opt_g = torch.optim.Adam(self.generator.parameters(), lr=0.0002)
    opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=0.0002)
    return [opt_g, opt_d], []

3. Trainer 训练器

Trainer 是 PyTorch Lightning 的训练控制中心,负责管理整个训练流程。

3.1 检查点管理

检查点的组成

PyTorch Lightning 的检查点包含以下内容:

  • 16-bit 精度缩放因子(如使用混合精度训练)
  • 当前 epoch
  • 全局 step
  • LightningModule 的 state_dict
  • 所有优化器的状态
  • 所有学习率调度器的状态
  • 所有回调的状态
  • DataModule 的状态(如果使用)
  • 超参数(模型和 DataModule 的初始化参数)
  • 训练循环的状态
自动保存检查点

Lightning 会自动保存最后一个 epoch 的检查点:

python 复制代码
# 默认保存到当前工作目录
trainer = pl.Trainer()

# 指定保存路径
trainer = pl.Trainer(default_root_dir="my_checkpoints/")
从检查点加载模型

方法 1:仅加载权重进行推理

python 复制代码
# 加载模型权重和超参数
model = MyLightningModule.load_from_checkpoint("path/to/checkpoint.ckpt")

# 设置为评估模式
model.eval()

# 进行预测
with torch.no_grad():
    y_hat = model(x)

方法 2:恢复完整训练状态

python 复制代码
# 从检查点恢复并继续训练
model = LitModel()
trainer = pl.Trainer()

# 自动恢复模型、epoch、step、优化器等所有状态
trainer.fit(model, ckpt_path="path/to/checkpoint.ckpt")
禁用检查点
python 复制代码
# 完全禁用自动检查点保存
trainer = pl.Trainer(enable_checkpointing=False)
自定义检查点行为(ModelCheckpoint 回调)

使用 ModelCheckpoint 回调实现更精细的控制:

python 复制代码
from pytorch_lightning.callbacks import ModelCheckpoint

# 保存验证损失最好的前 3 个模型
checkpoint_callback = ModelCheckpoint(
    dirpath='my/path/',              # 保存目录
    filename='model-{epoch:02d}-{val_loss:.2f}',  # 文件名模板
    save_top_k=3,                    # 保存最好的 k 个模型
    monitor='val_loss',              # 监控的指标
    mode='min',                      # 'min' 或 'max'
    save_last=True,                  # 额外保存最后一个 epoch
    every_n_epochs=1,                # 每 n 个 epoch 检查一次
    save_weights_only=False,         # 是否仅保存权重
)

trainer = pl.Trainer(callbacks=[checkpoint_callback])
trainer.fit(model)

# 获取最佳模型路径
best_model_path = checkpoint_callback.best_model_path
print(f"Best model saved at: {best_model_path}")

高级配置:

python 复制代码
checkpoint_callback = ModelCheckpoint(
    # 何时保存(When)
    every_n_train_steps=100,         # 每 N 个训练步骤保存
    every_n_epochs=5,                # 每 N 个 epoch 保存
    train_time_interval=timedelta(minutes=10),  # 每隔一定时间保存
    
    # 保存哪些(Which)
    save_top_k=5,                    # 保存最好的 k 个
    save_last=True,                  # 保存最后一个
    monitor='val_accuracy',          # 监控的指标
    mode='max',                      # 'min' 表示越小越好,'max' 表示越大越好
    
    # 保存什么(What)
    save_weights_only=True,          # 仅保存权重(节省空间)
    
    # 保存到哪里(Where)
    dirpath='checkpoints/',
    filename='{epoch}-{val_loss:.2f}-{val_acc:.2f}',
)

监控自定义指标:

python 复制代码
class LitModel(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        # 记录自定义指标
        self.log('my_custom_metric', some_value)

# 监控自定义指标
checkpoint_callback = ModelCheckpoint(monitor='my_custom_metric', mode='max')
根据条件保存检查点
python 复制代码
# 保存最后 K 个检查点(基于 global_step)
checkpoint_callback = ModelCheckpoint(
    save_top_k=10,
    monitor="global_step",
    mode="max",
    dirpath="my/path/",
    filename="model-{epoch:02d}-{global_step}",
)

注意事项:

  • 在文件名中包含监控指标,避免文件名冲突
  • 不要依赖自动版本号来检索 top-k 模型
  • 文件名示例:model-epoch=02-val_loss=0.32.ckpt
手动保存检查点
python 复制代码
# 训练过程中手动保存
trainer.save_checkpoint("manual_checkpoint.ckpt")

# 稍后加载
model = MyLightningModule.load_from_checkpoint("manual_checkpoint.ckpt")

3.2 回调机制

回调(Callback)是 Lightning 提供的一种扩展机制,可以在训练的特定阶段执行自定义操作。

回调的作用
  • 将辅助功能从核心研究代码中分离
  • 提供数十个可插拔的钩子函数
  • 可重用且易于组合
  • 不污染主要的模型逻辑
常用回调
1. EarlyStopping(早停)

在验证指标不再改善时提前停止训练:

python 复制代码
from pytorch_lightning.callbacks import EarlyStopping

class LitModel(pl.LightningModule):
    def validation_step(self, batch, batch_idx):
        loss = ...
        self.log("val_loss", loss)

# 配置早停回调
early_stop_callback = EarlyStopping(
    monitor='val_loss',      # 监控的指标
    min_delta=0.00,          # 最小改善幅度
    patience=3,              # 容忍的 epoch 数
    verbose=False,           # 是否打印信息
    mode='min',              # 'min' 或 'max'
    strict=True,             # 是否在找不到指标时报错
    stopping_threshold=0.01, # 达到此阈值立即停止
    divergence_threshold=5.0,# 超过此阈值立即停止(防止发散)
    check_finite=True,       # 检查 NaN 或 Inf
    check_on_train_epoch_end=False,  # 是否在训练 epoch 结束时检查
)

trainer = pl.Trainer(callbacks=[early_stop_callback])
trainer.fit(model)

自定义早停逻辑:

python 复制代码
class MyEarlyStopping(EarlyStopping):
    def on_validation_end(self, trainer, pl_module):
        # 禁用验证结束时的早停
        pass

    def on_train_end(self, trainer, pl_module):
        # 在训练结束时执行早停检查
        self._run_early_stopping_check(trainer)

注意事项:

  • patience 计数的是验证检查次数,而非 epoch 数
  • 如果 check_val_every_n_epoch=10patience=3,则至少需要 40 个 epoch
2. ModelCheckpoint(检查点)

详见 [3.1 检查点管理](#3.1 检查点管理) 部分。

3. LearningRateMonitor(学习率监控)

自动记录学习率变化:

python 复制代码
from pytorch_lightning.callbacks import LearningRateMonitor

# 创建学习率监控器
lr_monitor = LearningRateMonitor(
    logging_interval='step',  # 'step' 或 'epoch',None 表示按调度器的 interval
    log_momentum=False,       # 是否记录动量
    log_weight_decay=False,   # 是否记录权重衰减
)

trainer = pl.Trainer(callbacks=[lr_monitor])

自定义日志名称:

python 复制代码
def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
    scheduler = {
        'scheduler': torch.optim.lr_scheduler.StepLR(optimizer, step_size=10),
        'name': 'my_lr_scheduler_name'  # 自定义名称
    }
    return [optimizer], [scheduler]
4. ModelSummary(模型摘要)

打印模型结构信息:

python 复制代码
from pytorch_lightning.callbacks import ModelSummary

# 打印完整的模型层次结构
trainer = pl.Trainer(callbacks=[ModelSummary(max_depth=-1)])

输出示例:

复制代码
  | Name  | Type        | Params | In sizes  | Out sizes
----------------------------------------------------------------
0 | net   | Sequential  | 132 K  | [10, 256] | [10, 512]
1 | net.0 | Linear      | 131 K  | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1.0 K  | [10, 512] | [10, 512]
5. DeviceStatsMonitor(设备统计)

监控 GPU/CPU 使用情况:

python 复制代码
from pytorch_lightning.callbacks import DeviceStatsMonitor

trainer = pl.Trainer(callbacks=[DeviceStatsMonitor(cpu_stats=True)])
6. GradientAccumulationScheduler(梯度累积调度)

详见 [3.8 性能优化技巧](#3.8 性能优化技巧) 部分。

7. StochasticWeightAveraging(随机权重平均)

详见 [3.8 性能优化技巧](#3.8 性能优化技巧) 部分。


3.3 日志与可视化

基础日志记录

使用 self.log() 记录指标:

python 复制代码
class LitModel(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        loss = ...
        # 记录单个指标
        self.log("train_loss", loss)
        
        # 记录多个指标
        self.log_dict({
            "loss": loss,
            "acc": accuracy,
            "metric_n": some_metric
        })
        
        return loss
在命令行中显示
python 复制代码
# 在进度条中显示指标
self.log("train_loss", loss, prog_bar=True)

输出示例:

复制代码
Epoch 3: 33%|███▉ | 307/938 [00:01<00:02, 289.04it/s, loss=0.198, acc=0.211]
使用 TensorBoard

启动 TensorBoard:

python 复制代码
# Lightning 默认使用 TensorBoard(如果已安装)
trainer = pl.Trainer()

# 显式指定 TensorBoard
from pytorch_lightning.loggers import TensorBoardLogger

tensorboard = TensorBoardLogger(save_dir="logs/")
trainer = pl.Trainer(logger=tensorboard)

在命令行中启动 TensorBoard:

bash 复制代码
tensorboard --logdir=lightning_logs/

在 Jupyter Notebook 中:

python 复制代码
%load_ext tensorboard
%tensorboard --logdir=lightning_logs/
记录非标量内容

记录图像、直方图、文本等:

python 复制代码
def training_step(self, batch, batch_idx):
    # 获取 TensorBoard 实验对象
    tensorboard = self.logger.experiment
    
    # 记录图像
    tensorboard.add_image('input_images', images, self.current_epoch)
    
    # 记录直方图
    tensorboard.add_histogram('layer1_weights', self.layer1.weight, self.current_epoch)
    
    # 记录梯度直方图
    for name, param in self.named_parameters():
        if param.grad is not None:
            tensorboard.add_histogram(
                f"gradients/{name}", 
                param.grad.detach().cpu(), 
                self.current_epoch
            )
    
    return loss
自定义日志行为

配置日志频率:

python 复制代码
# 每 10 步记录一次
trainer = pl.Trainer(log_every_n_steps=10)

配置 TensorBoard 刷新频率:

python 复制代码
logger = TensorBoardLogger(..., max_queue=100, flush_secs=120)
self.log 的详细配置
python 复制代码
self.log(
    name="metric_name",
    value=metric_value,
    
    # 时间维度
    on_step=True,          # 是否在每个 step 记录
    on_epoch=True,         # 是否在 epoch 结束时记录
    
    # 聚合方式
    reduce_fx=torch.mean,  # 聚合函数:mean, max, min, sum
    
    # 显示位置
    prog_bar=True,         # 是否显示在进度条
    logger=True,           # 是否发送到 logger
    
    # 分布式训练
    sync_dist=False,       # 是否在设备间同步
    sync_dist_group=None,  # DDP 同步组
    rank_zero_only=False,  # 是否仅在 rank 0 记录
    
    # 其他
    batch_size=32,         # 批次大小(用于正确累积)
    enable_graph=True,     # 是否保持计算图
)

默认值(根据调用位置不同):

python 复制代码
def training_step(self, batch, batch_idx):
    # 默认: on_step=True, on_epoch=False
    self.log("train_loss", loss)

def validation_step(self, batch, batch_idx):
    # 默认: on_step=False, on_epoch=True
    self.log("val_loss", loss)

def test_step(self, batch, batch_idx):
    # 默认: on_step=False, on_epoch=True
    self.log("test_loss", loss)
使用多个 Logger
python 复制代码
from pytorch_lightning import loggers as pl_loggers

# 同时使用 TensorBoard 和 CSV logger
tensorboard = pl_loggers.TensorBoardLogger('logs/')
csv_logger = pl_loggers.CSVLogger('logs/')

trainer = pl.Trainer(logger=[tensorboard, csv_logger])
累积指标

在验证和测试阶段,Lightning 自动累积指标并计算平均值:

python 复制代码
def validation_step(self, batch, batch_idx):
    value = ...
    # 自动累积并在 epoch 结束时计算平均值
    self.log("average_value", value)

如需其他聚合方式:

python 复制代码
self.log("max_value", value, reduce_fx='max')
self.log("min_value", value, reduce_fx='min')
self.log("sum_value", value, reduce_fx='sum')

3.4 命令行参数

使用 ArgumentParser 或 Lightning CLI 管理超参数。

使用 ArgumentParser
python 复制代码
from argparse import ArgumentParser

# 创建参数解析器
parser = ArgumentParser()

# 添加 Trainer 参数
parser.add_argument("--devices", type=int, default=2)
parser.add_argument("--max_epochs", type=int, default=10)

# 添加模型超参数
parser.add_argument("--layer_1_dim", type=int, default=128)
parser.add_argument("--learning_rate", type=float, default=1e-3)

# 解析参数
args = parser.parse_args()

# 使用解析的参数
model = MyModel(
    layer_1_dim=args.layer_1_dim,
    learning_rate=args.learning_rate
)
trainer = pl.Trainer(devices=args.devices, max_epochs=args.max_epochs)
trainer.fit(model)

命令行调用:

bash 复制代码
python train.py --layer_1_dim 256 --learning_rate 0.001 --devices 4
使用 Lightning CLI(推荐)

Lightning CLI 提供了更强大的命令行配置功能:

python 复制代码
from pytorch_lightning.cli import LightningCLI

# 简单用法
cli = LightningCLI(MyModel, MyDataModule)

命令行调用:

bash 复制代码
# 查看所有可用参数
python train.py --help

# 使用命令行参数
python train.py --model.learning_rate 0.001 --model.layer_1_dim 256 --trainer.max_epochs 50

# 使用配置文件
python train.py --config config.yaml

配置文件示例(config.yaml):

yaml 复制代码
model:
  learning_rate: 0.001
  layer_1_dim: 256
  dropout: 0.5

trainer:
  max_epochs: 50
  devices: 4
  accelerator: gpu

data:
  batch_size: 32
  num_workers: 4

3.5 模型预测

从检查点加载并预测
python 复制代码
# 加载训练好的模型
model = LitModel.load_from_checkpoint("best_model.ckpt")
model.eval()

# 准备输入数据
x = torch.randn(1, 64)

# 进行预测
with torch.no_grad():
    y_hat = model(x)
使用 predict_step

定义 predict_step 方法来处理预测逻辑:

python 复制代码
class MyModel(pl.LightningModule):
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x, _ = batch
        return self(x)

使用 Trainer 进行预测:

python 复制代码
# 准备数据
predict_loader = DataLoader(predict_dataset, batch_size=32)

# 加载模型
model = MyModel.load_from_checkpoint("checkpoint.ckpt")

# 进行预测
trainer = pl.Trainer()
predictions = trainer.predict(model, predict_loader)
复杂预测逻辑(Monte Carlo Dropout)
python 复制代码
class LitMCdropoutModel(pl.LightningModule):
    def __init__(self, model, mc_iteration=10):
        super().__init__()
        self.model = model
        self.dropout = nn.Dropout()
        self.mc_iteration = mc_iteration

    def predict_step(self, batch, batch_idx):
        # 启用 Dropout(Monte Carlo Dropout)
        self.dropout.train()
        
        # 进行多次预测并取平均
        predictions = []
        for _ in range(self.mc_iteration):
            pred = self.dropout(self.model(batch))
            predictions.append(pred.unsqueeze(0))
        
        # 计算平均预测
        pred = torch.vstack(predictions).mean(dim=0)
        return pred
分布式预测并保存结果

使用 BasePredictionWriter 在分布式环境中保存预测结果:

python 复制代码
from pytorch_lightning.callbacks import BasePredictionWriter

class CustomWriter(BasePredictionWriter):
    def __init__(self, output_dir, write_interval='epoch'):
        super().__init__(write_interval)
        self.output_dir = output_dir

    def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
        # 每个进程保存自己的预测结果
        torch.save(
            predictions, 
            os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pt")
        )
        
        # 可选:保存批次索引
        torch.save(
            batch_indices, 
            os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt")
        )

# 使用自定义写入器
pred_writer = CustomWriter(output_dir="predictions/", write_interval="epoch")
trainer = pl.Trainer(
    accelerator="gpu",
    strategy="ddp",
    devices=8,
    callbacks=[pred_writer]
)
model = MyModel()
trainer.predict(model, return_predictions=False)

3.6 GPU 训练

基础 GPU 训练
python 复制代码
# 自动使用所有可用 GPU
trainer = pl.Trainer(accelerator="auto", devices="auto")

# 等价于
trainer = pl.Trainer()

# 使用单个 GPU
trainer = pl.Trainer(accelerator="gpu", devices=1)

# 使用多个 GPU
trainer = pl.Trainer(accelerator="gpu", devices=8)

注意:

  • accelerator="gpu" 也会自动选择 Apple Silicon 的 MPS 设备
  • 如要避免使用 MPS,可设置 accelerator="cuda"
选择特定 GPU
python 复制代码
# 使用前 k 个 GPU
trainer = pl.Trainer(accelerator="gpu", devices=k)

# 等价于
trainer = pl.Trainer(accelerator="gpu", devices=list(range(k)))

# 指定特定 GPU(不推荐在集群上使用)
trainer = pl.Trainer(accelerator="gpu", devices=[0, 1])

# 使用字符串形式
trainer = pl.Trainer(accelerator="gpu", devices="0, 1")

# 使用所有 GPU
trainer = pl.Trainer(accelerator="gpu", devices=-1)

devices 参数解释:

devices 值 类型 解析结果 含义
3 int [0, 1, 2] 前 3 个 GPU
-1 int [0, 1, 2, ...] 所有可用 GPU
[0] list [0] GPU 0
[1, 3] list [1, 3] GPU 索引 1 和 3
"3" str [0, 1, 2] 前 3 个 GPU
"1, 3" str [1, 3] GPU 索引 1 和 3
"-1" str [0, 1, 2, ...] 所有可用 GPU
自动查找可用 GPU

在多任务场景下(如超参数搜索),自动查找未被占用的 GPU:

python 复制代码
from pytorch_lightning.accelerators import find_usable_cuda_devices

# 查找 2 个未被占用的 GPU
trainer = pl.Trainer(
    accelerator="cuda",
    devices=find_usable_cuda_devices(2)
)

这在 GPU 设置为"独占计算模式"时特别有用。


3.7 模型调试

设置断点
python 复制代码
def function_to_debug():
    x = 2
    
    # 设置断点
    import pdb
    pdb.set_trace()
    
    y = x ** 2  # 代码将在此处暂停
快速运行模式(fast_dev_run)

快速运行少量批次以检查代码是否有错误:

python 复制代码
# 运行 5 个批次的训练、验证、测试
trainer = pl.Trainer(fast_dev_run=True)

# 自定义批次数量
trainer = pl.Trainer(fast_dev_run=7)

注意:

  • 该模式会禁用 checkpoint、early stopping、logger 等功能
  • 适合快速验证代码逻辑
缩短 epoch 长度

使用部分数据进行调试:

python 复制代码
# 使用 10% 的训练数据和 1% 的验证数据
trainer = pl.Trainer(limit_train_batches=0.1, limit_val_batches=0.01)

# 使用固定批次数
trainer = pl.Trainer(limit_train_batches=10, limit_val_batches=5)
Sanity Check(健全性检查)

训练开始前运行少量验证步骤,避免深度训练后才发现验证错误:

python 复制代码
# 默认运行 2 步验证
trainer = pl.Trainer(num_sanity_val_steps=2)

# 禁用健全性检查
trainer = pl.Trainer(num_sanity_val_steps=0)
打印模型摘要
python 复制代码
# 训练时自动打印模型摘要
trainer.fit(model)

# 输出示例:
#   | Name  | Type        | Params | Mode
# -------------------------------------------
# 0 | net   | Sequential  | 132 K  | train
# 1 | net.0 | Linear      | 131 K  | train
# 2 | net.1 | BatchNorm1d | 1.0 K  | train

打印更详细的摘要:

python 复制代码
from pytorch_lightning.callbacks import ModelSummary

trainer = pl.Trainer(callbacks=[ModelSummary(max_depth=-1)])

手动打印摘要:

python 复制代码
from pytorch_lightning.utilities.model_summary import ModelSummary

model = LitModel()
summary = ModelSummary(model, max_depth=-1)
print(summary)

禁用自动摘要:

python 复制代码
trainer = pl.Trainer(enable_model_summary=False)
显示输入输出维度

设置 example_input_array 以显示层的输入输出维度:

python 复制代码
class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.example_input_array = torch.Tensor(32, 1, 28, 28)
        # ...

输出示例:

复制代码
  | Name  | Type        | Params | In sizes  | Out sizes
----------------------------------------------------------------------
0 | net   | Sequential  | 132 K  | [10, 256] | [10, 512]
1 | net.0 | Linear      | 131 K  | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1.0 K  | [10, 512] | [10, 512]

3.8 性能优化技巧

混合精度训练(N-bit Precision)

使用低精度浮点数可以加速训练并减少内存使用:

python 复制代码
# 使用 16 位混合精度
trainer = pl.Trainer(precision='16-mixed')

# 使用 bf16(bfloat16)混合精度
trainer = pl.Trainer(precision='bf16-mixed')

# 使用 64 位精度(更高精度)
trainer = pl.Trainer(precision=64)

优势:

  • 减少内存占用,可训练更大的模型
  • 加快训练速度
  • 降低对硬件的要求
梯度累积(Accumulate Gradients)

通过累积多个小批次的梯度来模拟大批次训练:

python 复制代码
# 默认不累积(每个批次更新一次)
trainer = pl.Trainer(accumulate_grad_batches=1)

# 累积 7 个批次的梯度
trainer = pl.Trainer(accumulate_grad_batches=7)

效果:

  • 有效批次大小 = batch_size × accumulate_grad_batches
  • 例如:batch_size=32accumulate_grad_batches=7,有效批次大小 = 224

动态调整累积批次:

python 复制代码
from pytorch_lightning.callbacks import GradientAccumulationScheduler

# 前 5 个 epoch 累积 8 个批次
# 第 5-9 epoch 累积 4 个批次
# 第 9 epoch 后不累积
accumulator = GradientAccumulationScheduler(scheduling={0: 8, 4: 4, 8: 1})
trainer = pl.Trainer(callbacks=[accumulator])

注意事项:

  • 在 DDP 下,每个设备独立累积梯度
  • 有效批次大小 = num_devices × batch_size × accumulate_grad_batches
梯度裁剪(Gradient Clipping)

防止梯度爆炸:

python 复制代码
# 默认不裁剪
trainer = pl.Trainer(gradient_clip_val=0)

# 裁剪梯度范数到 <= 0.5
trainer = pl.Trainer(gradient_clip_val=0.5)

# 裁剪梯度值到 <= 0.5(而非范数)
trainer = pl.Trainer(gradient_clip_val=0.5, gradient_clip_algorithm="value")

注意:

  • 使用混合精度时,梯度会在裁剪前自动缩放回 fp32
随机权重平均(Stochastic Weight Averaging)

通过平均多个训练步骤的权重来提高模型泛化能力:

python 复制代码
from pytorch_lightning.callbacks import StochasticWeightAveraging

trainer = pl.Trainer(callbacks=[StochasticWeightAveraging(swa_lrs=1e-2)])

优势:

  • 几乎无额外成本
  • 提高模型泛化性能
  • 平滑损失函数
自动批次大小查找

自动找到能放入内存的最大批次大小:

python 复制代码
from pytorch_lightning.tuner import Tuner

trainer = pl.Trainer()
tuner = Tuner(trainer)

# 指数增长搜索(默认)
tuner.scale_batch_size(model, mode="power")

# 二分搜索
tuner.scale_batch_size(model, mode="binsearch")

# 然后正常训练
trainer.fit(model)

前提条件:

  • 模型需要有 batch_size 属性或在 hparams 中定义
  • train_dataloader() 方法需要使用该属性
python 复制代码
class LitModel(pl.LightningModule):
    def __init__(self, batch_size=32):
        super().__init__()
        self.save_hyperparameters()
    
    def train_dataloader(self):
        return DataLoader(dataset, batch_size=self.hparams.batch_size)
自动学习率查找

自动找到最优初始学习率:

python 复制代码
from pytorch_lightning.tuner import Tuner

model = MyModel()
trainer = pl.Trainer()
tuner = Tuner(trainer)

# 运行学习率查找
lr_finder = tuner.lr_find(model)

# 查看结果
print(lr_finder.results)

# 绘制曲线
fig = lr_finder.plot(suggest=True)
fig.show()

# 获取建议的学习率
new_lr = lr_finder.suggestion()

# 更新模型
model.hparams.lr = new_lr

# 开始训练
trainer.fit(model)

学习率查找原理:

  • 从小学习率开始,逐步增加
  • 记录每个学习率对应的损失
  • 找到损失下降最快的区域
  • 建议选择该区域的中点(而非最低点)

自定义学习率查找:

python 复制代码
from pytorch_lightning.callbacks import LearningRateFinder

class FineTuneLearningRateFinder(LearningRateFinder):
    def __init__(self, milestones, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.milestones = milestones

    def on_fit_start(self, *args, **kwargs):
        return

    def on_train_epoch_start(self, trainer, pl_module):
        if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
            self.lr_find(trainer, pl_module)

# 在第 0、5、10 epoch 时运行学习率查找
trainer = pl.Trainer(callbacks=[FineTuneLearningRateFinder(milestones=(5, 10))])

4. DataModule 数据模块

4.1 DataModule 介绍

DataModule 是 PyTorch Lightning 提供的数据管理抽象,封装了数据处理的五个步骤:

  1. 下载/标记/处理数据
  2. 清理数据并保存到磁盘
  3. 加载 数据到 Dataset
  4. 应用数据变换(旋转、标记化等)
  5. 包装DataLoader
为什么使用 DataModule?

解决的问题:

  • 数据准备代码通常分散在多个文件中
  • 难以共享和复用数据集的划分和变换
  • 无法确保数据处理的一致性

DataModule 的优势:

  • 将所有数据相关逻辑封装在一起
  • 易于在不同项目间共享
  • 便于切换不同数据集
  • 确保数据处理的可复现性
基础示例对比

传统 PyTorch 代码:

python 复制代码
# 数据准备代码分散
test_data = MNIST(path, train=False, download=True)
predict_data = MNIST(path, train=False, download=True)
train_data = MNIST(path, train=True, download=True)
train_data, val_data = random_split(train_data, [55000, 5000])

train_loader = DataLoader(train_data, batch_size=32)
val_loader = DataLoader(val_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)
predict_loader = DataLoader(predict_data, batch_size=32)

使用 DataModule:

python 复制代码
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "data/", batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def setup(self, stage: str):
        # 加载和划分数据
        self.mnist_test = MNIST(self.data_dir, train=False)
        self.mnist_predict = MNIST(self.data_dir, train=False)
        mnist_full = MNIST(self.data_dir, train=True)
        self.mnist_train, self.mnist_val = random_split(
            mnist_full, [55000, 5000], 
            generator=torch.Generator().manual_seed(42)
        )

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=self.batch_size)

使用方式:

python 复制代码
# 创建 DataModule 和模型
mnist = MNISTDataModule(data_dir="data/", batch_size=64)
model = LitClassifier()

# 训练
trainer = pl.Trainer()
trainer.fit(model, datamodule=mnist)

# DataModule 可复用于不同数据集
cifar10 = CIFAR10DataModule()
trainer.fit(model, datamodule=cifar10)

4.2 核心方法

prepare_data()

作用: 下载、标记化等一次性操作(通常只在单个进程中执行)

特点:

  • 仅在单个进程上调用(避免多进程下载冲突)
  • 在 CPU 上运行
  • 不应在此分配状态(self.x = y

使用场景:

  • 下载数据集
  • 标记化文本
  • 生成词汇表
  • 预处理并保存到磁盘
python 复制代码
class MNISTDataModule(pl.LightningDataModule):
    def prepare_data(self):
        # 仅下载一次,避免多进程重复下载
        MNIST(self.data_dir, train=True, download=True, transform=transforms.ToTensor())
        MNIST(self.data_dir, train=False, download=True, transform=transforms.ToTensor())

多节点训练:

  • 默认情况下,每个节点的 rank 0 进程都会调用 prepare_data()
  • 可通过 prepare_data_per_node 控制行为:
python 复制代码
class LitDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        # True: 每个节点的 LOCAL_RANK=0 调用
        # False: 仅 NODE_RANK=0, LOCAL_RANK=0 调用
        self.prepare_data_per_node = True
setup(stage)

作用: 在每个进程上执行的数据准备操作

特点:

  • 在所有进程上调用
  • 可以分配状态(self.x = y
  • 接收 stage 参数区分训练/验证/测试阶段

使用场景:

  • 统计类别数量
  • 构建词汇表
  • 划分训练/验证/测试集
  • 创建 Dataset
  • 应用数据变换
python 复制代码
class MNISTDataModule(pl.LightningDataModule):
    def setup(self, stage: str):
        # stage 可以是 'fit', 'validate', 'test', 'predict'
        
        if stage == "fit":
            # 训练和验证阶段
            mnist_full = MNIST(
                self.data_dir, 
                train=True, 
                transform=self.transform
            )
            self.mnist_train, self.mnist_val = random_split(
                mnist_full, [55000, 5000],
                generator=torch.Generator().manual_seed(42)
            )
            
            # 可以在这里统计信息
            self.num_classes = 10
            self.dims = mnist_full[0][0].shape
        
        if stage == "test":
            # 测试阶段
            self.mnist_test = MNIST(
                self.data_dir, 
                train=False, 
                transform=self.transform
            )
        
        if stage == "predict":
            # 预测阶段
            self.mnist_predict = MNIST(
                self.data_dir, 
                train=False, 
                transform=self.transform
            )

NLP 示例(标记化):

python 复制代码
class TextDataModule(pl.LightningDataModule):
    def prepare_data(self):
        # 下载和标记化(仅一次)
        dataset = load_dataset(...)
        tokenized = tokenize(dataset)
        save_to_disk(tokenized, "processed/")
    
    def setup(self, stage):
        # 每个进程加载预处理的数据
        self.dataset = load_dataset_from_disk("processed/")
        
        if stage == "fit":
            self.train_data = self.dataset['train']
            self.val_data = self.dataset['validation']
train_dataloader()

返回训练数据加载器:

python 复制代码
def train_dataloader(self):
    return DataLoader(
        self.mnist_train,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
val_dataloader()

返回验证数据加载器:

python 复制代码
def val_dataloader(self):
    return DataLoader(
        self.mnist_val,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=4
    )
test_dataloader()

返回测试数据加载器:

python 复制代码
def test_dataloader(self):
    return DataLoader(
        self.mnist_test,
        batch_size=self.batch_size,
        shuffle=False
    )
predict_dataloader()

返回预测数据加载器:

python 复制代码
def predict_dataloader(self):
    return DataLoader(
        self.mnist_predict,
        batch_size=self.batch_size
    )
teardown(stage)

作用: 在训练/测试结束时执行清理工作

python 复制代码
def teardown(self, stage: str):
    # 清理资源
    if stage == "fit":
        # 训练结束时的清理
        del self.train_data
        del self.val_data
    
    if stage == "test":
        # 测试结束时的清理
        del self.test_data
state_dict() 和 load_state_dict()

保存和恢复 DataModule 状态:

python 复制代码
class LitDataModule(pl.LightningDataModule):
    def state_dict(self):
        # 保存需要持久化的状态
        state = {
            "current_train_batch_index": self.current_train_batch_index,
            "random_state": self.rng.getstate()
        }
        return state

    def load_state_dict(self, state_dict):
        # 从检查点恢复状态
        self.current_train_batch_index = state_dict["current_train_batch_index"]
        self.rng.setstate(state_dict["random_state"])
数据传输钩子(高级)
transfer_batch_to_device()

自定义如何将批次数据传输到设备:

python 复制代码
def transfer_batch_to_device(self, batch, device, dataloader_idx):
    if isinstance(batch, CustomBatch):
        # 移动自定义数据结构中的张量
        batch.samples = batch.samples.to(device)
        batch.targets = batch.targets.to(device)
        return batch
    else:
        return super().transfer_batch_to_device(batch, device, dataloader_idx)
on_before_batch_transfer()

在数据传输到设备之前进行数据增强:

python 复制代码
def on_before_batch_transfer(self, batch, dataloader_idx):
    # 在 CPU 上应用数据增强
    batch['x'] = cpu_transforms(batch['x'])
    return batch
on_after_batch_transfer()

在数据传输到设备之后进行数据增强:

python 复制代码
def on_after_batch_transfer(self, batch, dataloader_idx):
    # 在 GPU 上应用数据增强
    batch['x'] = gpu_transforms(batch['x'])
    return batch

4.3 使用方式

基本使用
python 复制代码
# 创建 DataModule
dm = MNISTDataModule(data_dir="data/", batch_size=32)

# 创建模型
model = MyModel()

# 训练
trainer = pl.Trainer()
trainer.fit(model, datamodule=dm)

# 测试
trainer.test(datamodule=dm)

# 验证
trainer.validate(datamodule=dm)

# 预测
trainer.predict(datamodule=dm)
手动调用 DataModule 方法

如果需要在构建模型前获取数据集信息:

python 复制代码
# 手动准备数据
dm = MNISTDataModule()
dm.prepare_data()
dm.setup(stage="fit")

# 使用数据集信息构建模型
model = MyModel(
    num_classes=dm.num_classes,
    input_shape=dm.dims
)

# 训练
trainer.fit(model, datamodule=dm)

# 测试
dm.setup(stage="test")
trainer.test(datamodule=dm)
在纯 PyTorch 中使用

DataModule 也可以在纯 PyTorch 代码中使用:

python 复制代码
# 准备数据
dm = MNISTDataModule()
dm.prepare_data()

# 设置训练数据
dm.setup(stage="fit")
for batch in dm.train_dataloader():
    # 训练代码
    pass

for batch in dm.val_dataloader():
    # 验证代码
    pass

dm.teardown(stage="fit")

# 设置测试数据
dm.setup(stage="test")
for batch in dm.test_dataloader():
    # 测试代码
    pass

dm.teardown(stage="test")
保存和恢复超参数

与 LightningModule 类似,DataModule 也支持超参数管理:

python 复制代码
class CustomDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32, num_workers=4):
        super().__init__()
        # 保存所有超参数
        self.save_hyperparameters()
    
    def train_dataloader(self):
        # 使用保存的超参数
        return DataLoader(
            self.train_data,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers
        )
访问当前使用的 DataModule
python 复制代码
# 在训练过程中访问 DataModule
current_dm = trainer.datamodule

# 访问当前的 DataLoader
train_loader = trainer.train_dataloader
val_loaders = trainer.val_dataloaders
test_loaders = trainer.test_dataloaders
predict_loaders = trainer.predict_dataloaders

完整示例

端到端训练示例

python 复制代码
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.functional as F

# 1. 定义 DataModule
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir="data/", batch_size=32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str):
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(
                mnist_full, [55000, 5000],
                generator=torch.Generator().manual_seed(42)
            )

        if stage == "test":
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)

# 2. 定义 LightningModule
class LitMNIST(pl.LightningModule):
    def __init__(self, hidden_dim=128, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()
        
        self.layer_1 = nn.Linear(28 * 28, hidden_dim)
        self.layer_2 = nn.Linear(hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.layer_1(x))
        x = self.layer_2(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        val_loss = F.cross_entropy(y_hat, y)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log('val_loss', val_loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        test_loss = F.cross_entropy(y_hat, y)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log('test_loss', test_loss)
        self.log('test_acc', acc)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch'
            }
        }

# 3. 配置训练
if __name__ == "__main__":
    # 创建 DataModule 和模型
    dm = MNISTDataModule(batch_size=64)
    model = LitMNIST(hidden_dim=256, learning_rate=1e-3)
    
    # 配置回调
    checkpoint_callback = ModelCheckpoint(
        dirpath='checkpoints/',
        filename='mnist-{epoch:02d}-{val_loss:.2f}',
        save_top_k=3,
        monitor='val_loss',
        mode='min'
    )
    
    early_stop_callback = EarlyStopping(
        monitor='val_loss',
        patience=3,
        mode='min'
    )
    
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    
    # 创建 Trainer
    trainer = pl.Trainer(
        max_epochs=50,
        accelerator='gpu',
        devices=1,
        callbacks=[checkpoint_callback, early_stop_callback, lr_monitor],
        log_every_n_steps=10,
    )
    
    # 训练
    trainer.fit(model, datamodule=dm)
    
    # 测试
    trainer.test(model, datamodule=dm, ckpt_path='best')

总结

PyTorch Lightning 通过三大核心组件提供了完整的深度学习训练解决方案:

  1. LightningModule:封装模型定义、训练逻辑和优化器配置
  2. Trainer:自动化训练流程,提供丰富的回调和优化选项
  3. DataModule:组织数据处理流程,确保数据处理的可复现性
相关推荐
开开心心_Every2 小时前
安卓做菜APP:家常菜谱详细步骤无广简洁
服务器·前端·python·学习·edge·django·powerpoint
SiYuanFeng2 小时前
pytorch常用张量构造词句表和nn.组件速查表
人工智能·pytorch·python
MistaCloud2 小时前
Pytorch深入浅出(十四)之完整的模型训练测试套路
人工智能·pytorch·python·深度学习
知乎的哥廷根数学学派2 小时前
基于物理信息嵌入与多维度约束的深度学习地基承载力智能预测与可解释性评估算法(以模拟信号为例,Pytorch)
人工智能·pytorch·python·深度学习·算法·机器学习
WLJT1231231232 小时前
电子元器件:智能时代的核心基石
大数据·人工智能·科技·安全·生活
RockHopper20252 小时前
约束的力量:从生物认知到人工智能的跨越
人工智能·具身智能·具身认知
未来之窗软件服务2 小时前
幽冥大陆(九十六)分词服务训练 —东方仙盟练气期
人工智能·仙盟创梦ide·东方仙盟
雪域迷影2 小时前
Python中连接Redis数据库并存储数据
redis·python
rgeshfgreh2 小时前
Python正则与模式匹配实战技巧
大数据·人工智能