yolov5模型迁移笔记

这篇文章记录笔者因工作需要将基于pytorch==1.9torchvision==0.10训练的yolov5模型迁移到新版本pytorch(本文撰写时使用pytorch=2.9torchvision==0.24)来推理的方法和注意事项。

1. 模型权重提取

pytorch保存模型时,有两种方式,一种是将权重与模型结构一起保存:

python3 复制代码
import torch
torch.save(model, 'model_full.pt')

另一种是仅存权重:

python 复制代码
import torch
torch.save(model.state_dice(), 'model_weights.pt')

前者使用 Python 的 pickle 序列化整个 nn.Module 对象。这会把模型类的路径、代码结构、依赖模块(如 torchvision 中的 ConvBNActivation) 全部固化到文件中。一旦你在不同PyTorch/TorchVision 版本中加载,只要内部模块结构有变动(比如类被移动、重命名、删除),就会出现:

typescript 复制代码
AttributeError: Can't get attribute 'XXX' on <module 'YYY'>

所以从兼容性的角度来看,强烈推荐第二种方式。

保存权重的步骤如下:

python 复制代码
# 1. 加载旧 checkpoint
ckpt = torch.load('your_old_model.pth', map_location='cpu')

# 2. 提取模型(根据你的结构)
model = (ckpt.get('ema') or ckpt['model']).float()

# 3. 【可选】fuse 和 eval(如果你希望保存的是 fused 后的权重)
model = model.fuse().eval()  # 或者不 fuse,看需求
model.float()

# 4. 保存纯 state_dict
torch.save(model.state_dict(), 'model_weights_only_fused.pth')

加载模型的方法:

python 复制代码
# 加载权重文件
state_dict = torch.load('model_weights.pt')
# 构建模型结构(这一步不能少!)
model = YourModelClass(...)  # ← 替换为你的实际模型定义,如 YOLOv5()
# 加载权重
model.load_state_dict(state_dict, strict=True)  # 或 strict=False 如果有不匹配
# 设置推理模式
model.float().eval()

2 提取yaml配置

有的模型(例如yolo的DetectionModel)在构建时需要传入cfg(yaml配置类型),可以在模型中进行提取:

python 复制代码
import torch

# 尝试加载(可能需要 weights_only=False)
ckpt = torch.load('Models/YOLOv5/pts/model_full.pt', weights_only=False)

# 检查内容
print("Keys in checkpoint:", ckpt.keys())
if 'yaml' in ckpt:
    print("Model config (yaml):", ckpt['yaml'])
    # 保存为 YAML 文件(可选)
    import yaml
    with open('extracted_model.yaml', 'w') as f:
        yaml.dump(ckpt['yaml'], f)
elif hasattr(ckpt, 'yaml'):  # 如果 ckpt 是完整模型对象
    print("Model config:", ckpt.yaml)
else:
    print("No yaml found! Model structure may be hard-coded.")

检查模型中是否包含yaml:

python 复制代码
import torch

path = 'Models/YOLOv5/pts/model_full.pt'
try:
    ckpt = torch.load(path, weights_only=False)
    if isinstance(ckpt, dict):
        print("Checkpoint keys:", list(ckpt.keys()))
        if 'yaml' in ckpt:
            print("✅ Found yaml config!")
            print(ckpt['yaml'])
        else:
            print("❌ No 'yaml' key. Check if ckpt['model'] has .yaml attribute.")
            if 'model' in ckpt and hasattr(ckpt['model'], 'yaml'):
                print("✅ Found yaml in model object:", ckpt['model'].yaml)
    else:
        # ckpt is a full model
        print("Full model loaded. Checking .yaml attribute...")
        if hasattr(ckpt, 'yaml'):
            print("✅ Found yaml:", ckpt.yaml)
        else:
            print("❌ No yaml found anywhere.")
except Exception as e:
    print("Load failed:", e)

模型加载示例:

python 复制代码
def init_model(self):
    # 1. 加载 state_dict 权重
    state_dict = torch.load(self.weights, map_location=self.device, weights_only=True)

    # 2. 从 YAML 读取模型结构
    with open(self.cfg_path, 'r', encoding='utf-8') as f:
        cfg = yaml.safe_load(f)

    # 3. 构建模型
    self.model = DetectionModel(cfg).to(self.device)

    # 4. 加载权重
    # 如果是在模型fuse后再保存权重的,要先fuse
    self.model.fuse()
    self.model.load_state_dict(state_dict)

    # 5
    self.model = self.model.eval().float()

其它问题记录

non_max_suppression原地修改问题

typescript 复制代码
 in non_max_suppression
    x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
    ~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Output 0 of SelectBackward0 is a view and is being modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.

把这行替换为:

python 复制代码
mask = ((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1)
if mask.any():
    x = x.clone()
    x[mask, 4] = 0

common.py的MobileNetv3相关

原版代码:

python 复制代码
class MobileNetv3(nn.Module):
    """
    使用mobilenet_v3_small, 并且使用通用的预训练模型
    """
    def __init__(self, slice):
        super(MobileNetv3, self).__init__()
        self.model = None
        if slice == 1:
            self.model = models.mobilenet_v3_small(pretrained=True).features[:4]
        elif slice == 2:
            self.model = models.mobilenet_v3_small(pretrained=True).features[4:9]
        elif slice == 3:
            self.model = models.mobilenet_v3_small(pretrained=True).features[9:]
    def forward(self, x):
        return self.model(x)

该代码会有警告:

typescript 复制代码
UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.

改成这样即可:

python 复制代码
class MobileNetv3(nn.Module):
    """
    使用mobilenet_v3_small, 并且使用通用的预训练模型
    """
    def __init__(self, slice):
        super(MobileNetv3, self).__init__()
        self.model = None
        if slice == 1:
            self.model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT).features[:4]
        elif slice == 2:
            self.model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT).features[4:9]
        elif slice == 3:
            self.model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT).features[9:]
    def forward(self, x):
        return self.model(x)
相关推荐
Slaughter信仰6 小时前
图解大模型_生成式AI原理与实战学习笔记前四张问答(7题)
人工智能·笔记·学习
2401_834517077 小时前
AD学习笔记-26 Active Routing
笔记·学习
断剑zou天涯7 小时前
【算法笔记】Manacher算法
java·笔记·算法
瑶光守护者8 小时前
【学习笔记】5G RedCap:智能回落5G NR驻留的接入策略
笔记·学习·5g
你想知道什么?8 小时前
Python基础篇(上) 学习笔记
笔记·python·学习
xian_wwq8 小时前
【学习笔记】可信数据空间的工程实现
笔记·学习
浩瀚地学9 小时前
【Arcpy】入门学习笔记(五)-矢量数据
经验分享·笔记·python·arcgis·arcpy
Li.CQ9 小时前
SQL学习笔记
笔记·sql·学习
云霄星乖乖的果冻9 小时前
01引言——李沐《动手学深度学习》个人笔记
人工智能·笔记·深度学习