【PyTorch Lightning】.ckpt 是什么?里面有什么?

  1. 什么是检查点(checkpoint, ckpt)?

当模型在训练过程中时,随着其不断接收更多数据,其性能也会发生变化。在训练过程中保存模型的状态是一种最佳实践。这样可以在开发模型的过程中,在每个关键点上获得模型的一个版本,即一个检查点。一旦训练完成,您可以使用在训练过程中找到的性能最佳的检查点。

检查点还使得训练在中断的情况下可以从中断的地方恢复。

PyTorch Lightning 检查点在普通的 PyTorch 中完全可用。

  1. .ckpt 检查点文件里面有什么?

一个 Lightning 检查点包含了模型的整个内部状态的转储。与普通的 PyTorch 不同,Lightning 保存了你在最复杂的分布式训练环境中恢复模型所需的一切。

在 Lightning 检查点中,您会找到:

  • 16 位精度训练的缩放因子(如果使用 16 位精度训练)
  • 当前的 epoch
  • 全局步数
  • LightningModule 的 state_dict
  • 所有优化器的状态
  • 所有学习率调度器的状态
  • 所有回调函数的状态(用于有状态回调函数)
  • 数据模块的状态(用于有状态数据模块)
  • 用于创建模型的超参数(初始参数)
  • 用于创建数据模块的超参数(初始参数)
  • 循环的状态
  1. state_dict 是什么?

nn.Module 的模型权重,具体使用方法如下。

Lightning checkpoints 完全兼容普通的 torch nn.Modules。

python 复制代码
checkpoint = torch.load(CKPT_PATH)
print(checkpoint.keys())

例如,假设像下面这样创建了一个 LightningModule:

python 复制代码
class Encoder(nn.Module):
    ...


class Decoder(nn.Module):
    ...


class Autoencoder(L.LightningModule):
    def __init__(self, encoder, decoder, *args, **kwargs):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder


autoencoder = Autoencoder(Encoder(), Decoder())

一旦autoencoder训练完成,就可以提取出与 torch nn.Module 相关的权重。

python 复制代码
checkpoint = torch.load(CKPT_PATH)
encoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("encoder.")}
decoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("decoder.")}

官方文档:https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html

相关推荐
Q_Q511008285几秒前
python+uniapp基于微信小程序的旅游信息系统
spring boot·python·微信小程序·django·flask·uni-app·node.js
伏小白白白2 分钟前
【论文精度-2】求解车辆路径问题的神经组合优化算法:综合展望(Yubin Xiao,2025)
人工智能·算法·机器学习
鄃鳕3 分钟前
python迭代器解包【python】
开发语言·python
应用市场5 分钟前
OpenCV编程入门:从零开始的计算机视觉之旅
人工智能·opencv·计算机视觉
星域智链23 分钟前
宠物智能用品:当毛孩子遇上 AI,是便利还是过度?
人工智能·科技·学习·宠物
taxunjishu40 分钟前
DeviceNet 转 MODBUS TCP罗克韦尔 ControlLogix PLC 与上位机在汽车零部件涂装生产线漆膜厚度精准控制的通讯配置案例
人工智能·区块链·工业物联网·工业自动化·总线协议
懷淰メ43 分钟前
python3GUI--模仿百度网盘的本地文件管理器 By:PyQt5(详细分享)
开发语言·python·pyqt·文件管理·百度云·百度网盘·ui设计
Q_Q51100828544 分钟前
python基于web的汽车班车车票管理系统/火车票预订系统/高铁预定系统 可在线选座
spring boot·python·django·flask·node.js·汽车·php
新子y1 小时前
【小白笔记】普通二叉树(General Binary Tree)和二叉搜索树的最近公共祖先(LCA)
开发语言·笔记·python
说私域1 小时前
基于多模态AI技术的传统行业智能化升级路径研究——以开源AI大模型、AI智能名片与S2B2C商城小程序为例
人工智能·小程序·开源