常见的神经网络权重文件格式及其详细说明的表格:
扩展名 | 所属框架/工具 | 如何生成 | 表示内容 | 使用方法 | 注意事项 |
---|---|---|---|---|---|
.pt , .pth |
PyTorch | torch.save(model.state_dict(), "model.pt") |
PyTorch模型的状态字典(权重和参数)或整个模型 | 加载方式:model.load_state_dict(torch.load("model.pt")) |
如果保存整个模型(含结构和权重),可能导致跨设备加载问题。pth 通常用于旧版本。 |
.h5 , .hdf5 |
Keras/TensorFlow | model.save("model.h5") |
完整的Keras模型(含结构、权重和优化器状态) | 加载方式:keras.models.load_model("model.h5") |
HDF5格式依赖h5py 库;部分自定义层可能需要手动定义。 |
.pkl , .pickle |
Python通用(如scikit-learn) | pickle.dump(model, open("model.pkl", "wb")) |
序列化的Python对象(如模型参数或全量模型) | 加载方式:model = pickle.load(open("model.pkl", "rb")) |
存在安全风险(反序列化恶意代码);建议仅在可信来源使用。 |
.ckpt |
TensorFlow | 使用tf.train.Checkpoint 或ModelCheckpoint 回调 |
TensorFlow检查点文件(模型参数、优化器状态等) | 恢复训练:model.load_weights("model.ckpt") |
检查点文件包含多个文件(如.index 、.data-xxx 等),需一起保留。 |
.pb |
TensorFlow(SavedModel) | tf.saved_model.save(model, "model_dir") |
TensorFlow的计算图结构和权重(Protocol Buffers格式) | 加载方式:tf.keras.models.load_model("model_dir") 或 tf.saved_model.load("model_dir") |
跨语言兼容性好,支持C++/Java等语言调用。 |
.onnx |
ONNX(跨框架标准) | torch.onnx.export(model, input, "model.onnx") |
跨框架的标准化模型(含结构和权重) | 加载工具:ONNX Runtime(ort.InferenceSession("model.onnx") )或其他框架的转换工具 |
需验证框架支持的操作;可能需调整算子兼容性。 |
.weights |
Darknet/YOLO | Darknet训练时自动生成(如YOLOv3的yolov3.weights ) |
Darknet模型的权重参数,需配合配置文件(.cfg )使用 |
加载方式:darknet.load_net("yolov3.cfg", "yolov3.weights") |
无模型结构信息,必须与对应的.cfg 文件匹配使用。 |
.bin , .safetensors |
Hugging Face Transformers库 | model.save_pretrained("dir") 会生成pytorch_model.bin |
模型权重文件,通常与配置文件(config.json )配合使用 |
加载方式:model.from_pretrained("dir") |
.safetensors 是更安全的格式(替代.bin ),避免恶意代码注入。 |
.tflite |
TensorFlow Lite | 转换工具:tf.lite.TFLiteConverter.from_saved_model("model_dir").convert() |
轻量化模型,适用于移动端/嵌入式设备 | 移动端推理:使用TensorFlow Lite的Interpreter API加载。 | 模型可能经过量化(精度降低但体积减小)。 |
.params |
MXNet | net.save_parameters("model.params") |
MXNet模型的权重参数 | 加载方式:net.load_parameters("model.params") |
需预先定义网络结构再加载参数。 |
.joblib |
scikit-learn | joblib.dump(model, "model.joblib") |
序列化后的scikit-learn模型(适用于大文件) | 加载方式:model = joblib.load("model.joblib") |
比pickle 更高效,但主要服务于传统机器学习模型,少用于神经网络。 |
使用场景总结:
- 训练/推理用途 :
.pt
(PyTorch)、.h5
(Keras)、.ckpt
(TensorFlow)常用于训练过程中的保存与恢复。 - 跨框架部署 :
.onnx
适合不同框架间的模型转换;.pb
(TensorFlow)适合生产环境部署。 - 嵌入式/移动端 :
.tflite
针对移动设备优化。 - 安全性优先 :优先选择
.safetensors
替代.pkl
或.bin
。 - 协作与共享 :
.weights
+.cfg
(YOLO)、saved_model
(目录)包含完整信息,便于团队协作。