lightning的hook顺序

结果

setup: 训练循环开始前设置数据加载器和模型。

configure_optimizers: 设置优化器和学习率调度器。

on_fit_start: 训练过程开始。

on_train_start: 训练开始。

on_train_epoch_start: 每个训练周期开始。

on_train_batch_start: 每个训练批次开始。

on_before_backward: 反向传播之前。

on_after_backward: 反向传播之后。

on_before_zero_grad: 清空梯度之前。

on_after_zero_grad: 清空梯度之后。

on_before_optimizer_step: 优化器步骤之前。

on_train_batch_end: 每个训练批次结束。

on_train_epoch_end: 每个训练周期结束。

on_train_end: 训练结束。

on_fit_end: 训练过程结束。

测试代码

py 复制代码
import torch
from torch.utils.data import DataLoader, TensorDataset
from pytorch_lightning import LightningModule, Trainer, Callback
from pytorch_lightning.callbacks import Callback

# 定义一个简单的线性回归模型
class LinearRegression(LightningModule):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.mse_loss(y_hat, y)
        return loss
    
    def on_after_backward(self, *args, **kwargs):
        print("After backward is called!", args, kwargs)
        return super().on_after_backward(*args, **kwargs)
    
    def on_before_zero_grad(self, *args, **kwargs):
        print("Before zero grad is called!", args, kwargs)
        return super().on_before_zero_grad(*args, **kwargs)
    
    def on_after_zero_grad(self, *args, **kwargs):
        print("After zero grad is called!", args, kwargs)
        return super().on_after_zero_grad(*args, **kwargs)
    
    def on_before_backward(self, *args, **kwargs):
        print("Before backward is called!", args, kwargs)
        return super().on_before_backward(*args, **kwargs)
    
    def on_before_optimizer_step(self, *args, **kwargs):
        print("Before optimizer step is called!", args, kwargs)
        return super().on_before_optimizer_step(*args, **kwargs)
    
    def on_after_optimizer_step(self, *args, **kwargs):
        print("After optimizer step is called!", args, kwargs)
        return super().on_after_optimizer_step(*args, **kwargs)
    
    def on_fit_start(self, *args, **kwargs):
        print("Fit is starting!", args, kwargs)
        return super().on_fit_start(*args, **kwargs)
    
    def on_fit_end(self, *args, **kwargs):
        print("Fit is ending!", args, kwargs)
        return super().on_fit_end(*args, **kwargs)
    
    def setup(self, *args, **kwargs):
        print("Setup is called!", args, kwargs)
        return super().setup(*args, **kwargs)
    
    def configure_optimizers(self, *args, **kwargs):
        print("Configure Optimizers is called!", args, kwargs)
        return super().configure_optimizers(*args, **kwargs)
    
    def on_train_start(self, *args, **kwargs):
        print("Training is starting!", args, kwargs)
        return super().on_train_start(*args, **kwargs)
    
    def on_train_end(self, *args, **kwargs):
        print("Training is ending!", args, kwargs)
        return super().on_train_end(*args, **kwargs)
    
    def on_train_batch_start(self, *args, **kwargs):
        print(f"Training batch is starting!", args, kwargs)
        return super().on_train_batch_start(*args, **kwargs)
    
    def on_train_batch_end(self, *args, **kwargs):
        print(f"Training batch is ending!", args, kwargs)
        return super().on_train_batch_end(*args, **kwargs)
    
    def on_train_epoch_start(self, *args, **kwargs):
        print(f"Training epoch is starting!", args, kwargs)
        return super().on_train_epoch_start(*args, **kwargs)
    
    def on_train_epoch_end(self, *args, **kwargs):
        print(f"Training epoch is ending!", args, kwargs)
        return super().on_train_epoch_end(*args, **kwargs)
    
    
# 创建数据集
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]], dtype=torch.float)
y_train = torch.tensor([[2.0], [4.0], [6.0], [8.0]], dtype=torch.float)
train_dataset = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=2)

# 创建模型和训练器
model = LinearRegression()
trainer = Trainer(max_epochs=2)

# 开始训练
trainer.fit(model, train_loader)
相关推荐
zhaotiannuo_19982 分钟前
Python之2.7.9-3.9.1-3.14.2共存
开发语言·python
Keep_Trying_Go7 分钟前
基于GAN的文生图算法详解ControlGAN(Controllable Text-to-Image Generation)
人工智能·python·深度学习·神经网络·机器学习·生成对抗网络·文生图
Spey_Events10 分钟前
星箭聚力启盛会,2026第二届商业航天产业发展大会暨商业航天展即将开幕!
大数据·人工智能
JoySSLLian13 分钟前
IP SSL证书:一键解锁IP通信安全,高效抵御网络威胁!
网络·人工智能·网络协议·tcp/ip·ssl
AC赳赳老秦25 分钟前
专利附图说明:DeepSeek生成的专业技术描述与权利要求书细化
大数据·人工智能·kafka·区块链·数据库开发·数据库架构·deepseek
LostSpeed29 分钟前
openpnp - python2.7 script - 中文显示乱码,只能显示英文
python·openpnp
小雨青年36 分钟前
鸿蒙 HarmonyOS 6 | AI Kit 集成 Core Speech Kit 语音服务
人工智能·华为·harmonyos
懒羊羊吃辣条37 分钟前
电力负荷预测怎么做才不翻车
人工智能·深度学习·机器学习·时间序列
hhy_smile43 分钟前
Class in Python
java·前端·python
前进的程序员1 小时前
2026年IT行业技术发展前瞻性见解
人工智能