随着 PyTorch 2.x 版本的发布,官方对模型序列化和反序列化的逻辑进行了优化。其中一个重要变化是:默认情况下,`torch.load()` 将只加载模型的权重(即等价于设置 `weights_only=True`),而不再支持通过 `weights_only=False` 显式控制是否加载完整模型。
1. 问题现象
在使用如下代码加载模型时:
model = torch.load('model.pth', map_location=device, weights_only=False)
会触发以下警告:
FutureWarning: The argument `weights_only` is deprecated and will be removed in a future release. Use `map_location` to control device placement instead.
这个警告提示开发者,未来版本将移除 `weights_only=False` 参数,并建议调整模型加载方式以适配新行为。
2. 背景知识:PyTorch 模型保存方式
| 保存方式 | 典型用法 | 适用场景 |
|---|---|---|
| 仅保存模型权重 | torch.save(model.state_dict(), 'model_weights.pth') |
轻量、便于迁移、适合部署 |
| 保存整个模型 | torch.save(model, 'full_model.pth') |
保留结构+参数,适合快速恢复训练或推理 |
3. 原因分析:为何弃用 `weights_only=False`?
- PyTorch 团队发现大部分用户都使用的是仅加载权重的方式(state_dict)。
- 为提升安全性,默认只加载权重,避免执行任意 Python 代码的风险。
- 简化 API 接口设计,统一加载逻辑。
4. 解决方案:如何适配新版行为?
根据模型保存方式的不同,应采用不同的加载策略:
4.1 加载仅包含权重的文件(推荐方式)
-
# 先实例化模型结构 -
model = MyModel() -
# 然后加载权重 -
model.load_state_dict(torch.load('model_weights.pth', map_location=device))
4.2 加载完整模型文件(需要显式设置)
-
# 需要确认模型类定义存在 -
model = torch.load('full_model.pth', map_location=device)
5. 最佳实践建议
- 优先使用 `state_dict` 方式保存模型,提高可移植性和安全性。
- 避免使用 `weights_only=False`,提前适配未来版本行为。
- 若必须加载完整模型,请确保模型类定义一致,且不在生产环境中使用。
- 测试阶段应开启严格模式检查,防止潜在兼容性问题。