safetensors ^[1]^ 号称提供一种更安全的存数据方式,支持多种框架,见 [2]。不过在处理玄数据(metadata)时:
- 只支持 Dict[str, str] 的形式,即值必须是字符串,而不能是 int、float 或嵌套 dict,而这些在 PyTorch 原先的 torch.save、torch.load 是支持的。考虑用
json.dumps
将 dict 转写成字符串,读时则用json.loads
恢复回 dict。 - 没有专门从 checkpoint 文件读出 metadata 的方法。考虑采用 [3] 中 Ok_Storage_1799 的回答所讲利用
safetensors.safe_open
的方法读 metadata。
下面是存、取 PyTorch 模型参数、metadata 的简例:
python
import time, json, pprint
import torch
from safetensors import safe_open # to read metadata
from safetensors.torch import save_model, load_model
print("建模型")
model = torch.nn.Linear(2, 3)
# 初始参数值
for pn, p in model.named_parameters():
print(pn, p)
print("存模型、metadata")
# 将模型参数置零 (模拟 training)
for p in model.parameters():
p.data.zero_()
# 存模型
save_model(
model,
"ckpt.safetensors",
# metadata 用 json 转写成 str
{"metadata": json.dumps({
"time": time.asctime(),
"epoch": 57,
"acc": 0.56,
"args": {
"debug": False,
"dataset": "MNIST",
"decay_steps": [10, 20]
}
})}
)
print("读模型")
load_model(model, "ckpt.safetensors")
# 验证更新(置零)后参数值
for pn, p in model.named_parameters():
print(pn, p)
print("读 metadata")
with safe_open("ckpt.safetensors", framework="pt") as f:
print(type(f), dir(f))
print(list(f.keys())) # 模型参数的名字
print(type(f.metadata())) # dict
for k, v in f.metadata().items():
print(k, v)
# 用 json 恢复 metadata 成 dict
if "metadata" == k:
metadata = json.loads(v)
pprint.pprint(metadata)