在 PyTorch 中,保存和加载模型是训练流程中的关键环节。主要涉及以下几种文件类型和概念:
1. .pt
/ .pth
文件 (最常见)
-
本质: 这些是 PyTorch 使用 Python 的
pickle
模块序列化 Python 对象后保存的文件。扩展名.pt
和.pth
是约定俗成的,PyTorch 本身没有强制要求,但强烈推荐使用它们。 -
保存内容:
-
模型的状态字典 (
state_dict
): 这是最推荐 的保存方式。state_dict
是一个 Python 字典对象,它将模型的每一层映射到其可学习参数(权重weight
、偏置bias
等)的 Tensor。它不包含模型的结构定义(类代码)。torch.save(model.state_dict(), 'model_state_dict.pt')
-
整个模型: 可以直接保存整个模型对象(包括结构和参数)。这种方式不推荐,因为它依赖于特定的 Python 环境、类定义和文件路径,导致代码难以移植且可能在不同环境或 PyTorch 版本中出错。
torch.save(model, 'entire_model.pt') # 不推荐
-
-
加载方式:
-
加载
state_dict
: 需要先实例化模型结构(类),然后将state_dict
加载到该实例中。model = MyModelClass(*args, **kwargs) # 1. 创建相同结构的模型实例 model.load_state_dict(torch.load('model_state_dict.pt')) # 2. 加载参数 model.eval() # 3. 设置为评估模式(影响 dropout, batchnorm 等层)
-
加载整个模型: 直接加载即可,但存在上述限制。
model = torch.load('entire_model.pt') # 不推荐 model.eval()
-
-
优点 (
state_dict
方式):-
文件较小(只保存参数)。
-
代码更灵活、可移植。模型类定义可以独立修改(只要结构匹配),方便在不同项目或脚本间共享参数。
-
是保存和加载模型的标准做法。
-
-
缺点 (整个模型方式):
-
文件较大(包含结构信息)。
-
严重依赖保存时的具体环境(类定义、导入路径等),难以复用。
-
在不同 PyTorch 版本间可能不兼容。
-
-
警告:
pickle
模块可能存在安全风险。只加载你信任的来源的.pt
/.pth
文件!
2. .zip
文件 (TorchScript)
-
本质: 当使用
torch.jit.save()
保存 TorchScript 模型时,默认生成一个.zip
文件(虽然也可以指定.pt
或.pth
,但.zip
是标准输出)。 -
保存内容: TorchScript 是一种 PyTorch 模型的表示形式,它可以在脱离 Python 环境 的情况下被高性能 C++ 运行时(
torch::jit::load
)或 Python 运行时(torch.jit.load
)加载和执行。它包含了模型的结构(计算图) 和参数。 -
生成方式:
-
追踪 (
torch.jit.trace
): 用一个示例输入"运行"模型,记录执行的操作。适用于没有控制流(如 if/for)的模型。example_input = torch.rand(1, 3, 224, 224) traced_model = torch.jit.trace(model, example_input) torch.jit.save(traced_model, 'traced_model.zip')
-
脚本化 (
torch.jit.script
): 直接解析模型代码(或部分代码)生成 TorchScript。适用于包含控制流的模型。scripted_model = torch.jit.script(model) torch.jit.save(scripted_model, 'scripted_model.zip')
-
-
加载方式:
# Python 中加载 loaded_model = torch.jit.load('traced_model.zip') loaded_model.eval() output = loaded_model(torch.rand(1, 3, 224, 224)) // C++ 中加载 (示例) #include <torch/script.h> torch::jit::script::Module module; module = torch::jit::load("traced_model.zip"); module.eval(); std::vector<torch::jit::IValue> inputs; inputs.push_back(torch::ones({1, 3, 224, 224})); at::Tensor output = module.forward(inputs).toTensor();
-
优点:
-
跨平台/语言: 可以在 Python 和 C++ 中加载运行。
-
独立于 Python: 不需要原始的 Python 模型类定义即可运行(在 C++ 中尤其重要)。
-
序列化优化: 针对部署进行了优化。
-
模型保护: 原始 Python 代码不是必需的(虽然可以反编译,但增加了难度)。
-
-
缺点:
-
生成过程可能复杂(尤其对于动态模型)。
-
可能不完全支持所有 Python 特性(需要调整模型代码)。
-
调试 TorchScript 有时比调试纯 Python 模型困难。
-
3. .onnx
文件 (ONNX 格式)
-
本质: Open Neural Network Exchange (ONNX) 是一种开放标准的格式,用于表示深度学习模型。它定义了一个通用的计算图表示。
-
保存内容: 模型的结构(计算图) 和参数。
-
生成方式: 使用 PyTorch 的
torch.onnx.export
函数将 PyTorch 模型转换为 ONNX 格式。torch.onnx.export(model, # 要转换的模型 torch.rand(1, 3, 224, 224), # 示例输入 "model.onnx", # 输出文件名 input_names=["input"], # 输入节点名称 output_names=["output"], # 输出节点名称 opset_version=11) # ONNX 算子集版本
-
加载方式: ONNX 模型本身不能直接在 PyTorch 中运行(除非使用 ONNX Runtime 的 PyTorch 绑定)。它主要用于:
-
导入到其他支持 ONNX 的深度学习框架(如 TensorFlow, MXNet, Caffe2)。
-
使用专门的 ONNX 运行时进行推理(如 ONNX Runtime, TensorRT),这些运行时通常针对不同硬件做了高度优化。
-
使用工具进行模型可视化、优化或格式转换。
-
-
优点:
-
框架互操作性: 实现不同深度学习框架之间模型的转换和共享。
-
硬件供应商支持: 许多硬件加速器(如 NVIDIA TensorRT, Intel OpenVINO)优先支持或优化 ONNX 模型。
-
标准化: 统一的模型表示格式。
-
-
缺点:
-
转换过程可能存在精度损失或算子不支持的问题(需要检查转换日志)。
-
ONNX 标准本身在不断发展,不同版本间可能有兼容性问题。
-
在 PyTorch 中不能直接加载运行 ONNX 模型进行训练或微调(主要用于推理或迁移到其他框架)。
-
总结与推荐
-
日常训练/研究 (PyTorch 环境内):
-
保存: 使用
torch.save(model.state_dict(), 'model.pt')
。这是最标准、最灵活的方式。 -
加载: 实例化模型结构 +
model.load_state_dict(torch.load('model.pt'))
+model.eval()
。
-
-
生产部署 (脱离 Python 或 C++ 环境):
-
首选: TorchScript (
.zip
或.pt
)。它是 PyTorch 官方的部署方案,支持 Python 和 C++,优化良好。 -
备选/互操作: ONNX (
.onnx
)。当目标平台(如特定硬件加速器)或框架(如 TensorFlow Serving)对 ONNX 有更好支持时使用。需要额外的运行时(ONNX Runtime, TensorRT 等)。
-
-
避免: 保存整个模型对象 (
torch.save(model, ...)
),除非有非常特殊且理解其风险的原因。
选择哪种格式取决于你的具体需求:在 PyTorch 内部继续工作就用 state_dict
;需要部署到非 Python 环境或 C++ 就用 TorchScript;需要与其他框架或特定硬件加速器交互就用 ONNX。