这篇文章记录笔者因工作需要将基于pytorch==1.9、torchvision==0.10训练的yolov5模型迁移到新版本pytorch(本文撰写时使用pytorch=2.9、torchvision==0.24)来推理的方法和注意事项。
1. 模型权重提取
pytorch保存模型时,有两种方式,一种是将权重与模型结构一起保存:
python3
import torch
torch.save(model, 'model_full.pt')
另一种是仅存权重:
python
import torch
torch.save(model.state_dice(), 'model_weights.pt')
前者使用 Python 的 pickle 序列化整个 nn.Module 对象。这会把模型类的路径、代码结构、依赖模块(如 torchvision 中的 ConvBNActivation) 全部固化到文件中。一旦你在不同PyTorch/TorchVision 版本中加载,只要内部模块结构有变动(比如类被移动、重命名、删除),就会出现:
typescript
AttributeError: Can't get attribute 'XXX' on <module 'YYY'>
所以从兼容性的角度来看,强烈推荐第二种方式。
保存权重的步骤如下:
python
# 1. 加载旧 checkpoint
ckpt = torch.load('your_old_model.pth', map_location='cpu')
# 2. 提取模型(根据你的结构)
model = (ckpt.get('ema') or ckpt['model']).float()
# 3. 【可选】fuse 和 eval(如果你希望保存的是 fused 后的权重)
model = model.fuse().eval() # 或者不 fuse,看需求
model.float()
# 4. 保存纯 state_dict
torch.save(model.state_dict(), 'model_weights_only_fused.pth')
加载模型的方法:
python
# 加载权重文件
state_dict = torch.load('model_weights.pt')
# 构建模型结构(这一步不能少!)
model = YourModelClass(...) # ← 替换为你的实际模型定义,如 YOLOv5()
# 加载权重
model.load_state_dict(state_dict, strict=True) # 或 strict=False 如果有不匹配
# 设置推理模式
model.float().eval()
2 提取yaml配置
有的模型(例如yolo的DetectionModel)在构建时需要传入cfg(yaml配置类型),可以在模型中进行提取:
python
import torch
# 尝试加载(可能需要 weights_only=False)
ckpt = torch.load('Models/YOLOv5/pts/model_full.pt', weights_only=False)
# 检查内容
print("Keys in checkpoint:", ckpt.keys())
if 'yaml' in ckpt:
print("Model config (yaml):", ckpt['yaml'])
# 保存为 YAML 文件(可选)
import yaml
with open('extracted_model.yaml', 'w') as f:
yaml.dump(ckpt['yaml'], f)
elif hasattr(ckpt, 'yaml'): # 如果 ckpt 是完整模型对象
print("Model config:", ckpt.yaml)
else:
print("No yaml found! Model structure may be hard-coded.")
检查模型中是否包含yaml:
python
import torch
path = 'Models/YOLOv5/pts/model_full.pt'
try:
ckpt = torch.load(path, weights_only=False)
if isinstance(ckpt, dict):
print("Checkpoint keys:", list(ckpt.keys()))
if 'yaml' in ckpt:
print("✅ Found yaml config!")
print(ckpt['yaml'])
else:
print("❌ No 'yaml' key. Check if ckpt['model'] has .yaml attribute.")
if 'model' in ckpt and hasattr(ckpt['model'], 'yaml'):
print("✅ Found yaml in model object:", ckpt['model'].yaml)
else:
# ckpt is a full model
print("Full model loaded. Checking .yaml attribute...")
if hasattr(ckpt, 'yaml'):
print("✅ Found yaml:", ckpt.yaml)
else:
print("❌ No yaml found anywhere.")
except Exception as e:
print("Load failed:", e)
模型加载示例:
python
def init_model(self):
# 1. 加载 state_dict 权重
state_dict = torch.load(self.weights, map_location=self.device, weights_only=True)
# 2. 从 YAML 读取模型结构
with open(self.cfg_path, 'r', encoding='utf-8') as f:
cfg = yaml.safe_load(f)
# 3. 构建模型
self.model = DetectionModel(cfg).to(self.device)
# 4. 加载权重
# 如果是在模型fuse后再保存权重的,要先fuse
self.model.fuse()
self.model.load_state_dict(state_dict)
# 5
self.model = self.model.eval().float()
其它问题记录
non_max_suppression原地修改问题
typescript
in non_max_suppression
x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Output 0 of SelectBackward0 is a view and is being modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.
把这行替换为:
python
mask = ((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1)
if mask.any():
x = x.clone()
x[mask, 4] = 0
common.py的MobileNetv3相关
原版代码:
python
class MobileNetv3(nn.Module):
"""
使用mobilenet_v3_small, 并且使用通用的预训练模型
"""
def __init__(self, slice):
super(MobileNetv3, self).__init__()
self.model = None
if slice == 1:
self.model = models.mobilenet_v3_small(pretrained=True).features[:4]
elif slice == 2:
self.model = models.mobilenet_v3_small(pretrained=True).features[4:9]
elif slice == 3:
self.model = models.mobilenet_v3_small(pretrained=True).features[9:]
def forward(self, x):
return self.model(x)
该代码会有警告:
typescript
UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
改成这样即可:
python
class MobileNetv3(nn.Module):
"""
使用mobilenet_v3_small, 并且使用通用的预训练模型
"""
def __init__(self, slice):
super(MobileNetv3, self).__init__()
self.model = None
if slice == 1:
self.model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT).features[:4]
elif slice == 2:
self.model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT).features[4:9]
elif slice == 3:
self.model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT).features[9:]
def forward(self, x):
return self.model(x)