yolov5模型迁移笔记

这篇文章记录笔者因工作需要将基于pytorch==1.9torchvision==0.10训练的yolov5模型迁移到新版本pytorch(本文撰写时使用pytorch=2.9torchvision==0.24)来推理的方法和注意事项。

1. 模型权重提取

pytorch保存模型时,有两种方式,一种是将权重与模型结构一起保存:

python3 复制代码
import torch
torch.save(model, 'model_full.pt')

另一种是仅存权重:

python 复制代码
import torch
torch.save(model.state_dice(), 'model_weights.pt')

前者使用 Python 的 pickle 序列化整个 nn.Module 对象。这会把模型类的路径、代码结构、依赖模块(如 torchvision 中的 ConvBNActivation) 全部固化到文件中。一旦你在不同PyTorch/TorchVision 版本中加载,只要内部模块结构有变动(比如类被移动、重命名、删除),就会出现:

typescript 复制代码
AttributeError: Can't get attribute 'XXX' on <module 'YYY'>

所以从兼容性的角度来看,强烈推荐第二种方式。

保存权重的步骤如下:

python 复制代码
# 1. 加载旧 checkpoint
ckpt = torch.load('your_old_model.pth', map_location='cpu')

# 2. 提取模型(根据你的结构)
model = (ckpt.get('ema') or ckpt['model']).float()

# 3. 【可选】fuse 和 eval(如果你希望保存的是 fused 后的权重)
model = model.fuse().eval()  # 或者不 fuse,看需求
model.float()

# 4. 保存纯 state_dict
torch.save(model.state_dict(), 'model_weights_only_fused.pth')

加载模型的方法:

python 复制代码
# 加载权重文件
state_dict = torch.load('model_weights.pt')
# 构建模型结构(这一步不能少!)
model = YourModelClass(...)  # ← 替换为你的实际模型定义,如 YOLOv5()
# 加载权重
model.load_state_dict(state_dict, strict=True)  # 或 strict=False 如果有不匹配
# 设置推理模式
model.float().eval()

2 提取yaml配置

有的模型(例如yolo的DetectionModel)在构建时需要传入cfg(yaml配置类型),可以在模型中进行提取:

python 复制代码
import torch

# 尝试加载(可能需要 weights_only=False)
ckpt = torch.load('Models/YOLOv5/pts/model_full.pt', weights_only=False)

# 检查内容
print("Keys in checkpoint:", ckpt.keys())
if 'yaml' in ckpt:
    print("Model config (yaml):", ckpt['yaml'])
    # 保存为 YAML 文件(可选)
    import yaml
    with open('extracted_model.yaml', 'w') as f:
        yaml.dump(ckpt['yaml'], f)
elif hasattr(ckpt, 'yaml'):  # 如果 ckpt 是完整模型对象
    print("Model config:", ckpt.yaml)
else:
    print("No yaml found! Model structure may be hard-coded.")

检查模型中是否包含yaml:

python 复制代码
import torch

path = 'Models/YOLOv5/pts/model_full.pt'
try:
    ckpt = torch.load(path, weights_only=False)
    if isinstance(ckpt, dict):
        print("Checkpoint keys:", list(ckpt.keys()))
        if 'yaml' in ckpt:
            print("✅ Found yaml config!")
            print(ckpt['yaml'])
        else:
            print("❌ No 'yaml' key. Check if ckpt['model'] has .yaml attribute.")
            if 'model' in ckpt and hasattr(ckpt['model'], 'yaml'):
                print("✅ Found yaml in model object:", ckpt['model'].yaml)
    else:
        # ckpt is a full model
        print("Full model loaded. Checking .yaml attribute...")
        if hasattr(ckpt, 'yaml'):
            print("✅ Found yaml:", ckpt.yaml)
        else:
            print("❌ No yaml found anywhere.")
except Exception as e:
    print("Load failed:", e)

模型加载示例:

python 复制代码
def init_model(self):
    # 1. 加载 state_dict 权重
    state_dict = torch.load(self.weights, map_location=self.device, weights_only=True)

    # 2. 从 YAML 读取模型结构
    with open(self.cfg_path, 'r', encoding='utf-8') as f:
        cfg = yaml.safe_load(f)

    # 3. 构建模型
    self.model = DetectionModel(cfg).to(self.device)

    # 4. 加载权重
    # 如果是在模型fuse后再保存权重的,要先fuse
    self.model.fuse()
    self.model.load_state_dict(state_dict)

    # 5
    self.model = self.model.eval().float()

其它问题记录

non_max_suppression原地修改问题

typescript 复制代码
 in non_max_suppression
    x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
    ~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Output 0 of SelectBackward0 is a view and is being modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.

把这行替换为:

python 复制代码
mask = ((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1)
if mask.any():
    x = x.clone()
    x[mask, 4] = 0

common.py的MobileNetv3相关

原版代码:

python 复制代码
class MobileNetv3(nn.Module):
    """
    使用mobilenet_v3_small, 并且使用通用的预训练模型
    """
    def __init__(self, slice):
        super(MobileNetv3, self).__init__()
        self.model = None
        if slice == 1:
            self.model = models.mobilenet_v3_small(pretrained=True).features[:4]
        elif slice == 2:
            self.model = models.mobilenet_v3_small(pretrained=True).features[4:9]
        elif slice == 3:
            self.model = models.mobilenet_v3_small(pretrained=True).features[9:]
    def forward(self, x):
        return self.model(x)

该代码会有警告:

typescript 复制代码
UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.

改成这样即可:

python 复制代码
class MobileNetv3(nn.Module):
    """
    使用mobilenet_v3_small, 并且使用通用的预训练模型
    """
    def __init__(self, slice):
        super(MobileNetv3, self).__init__()
        self.model = None
        if slice == 1:
            self.model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT).features[:4]
        elif slice == 2:
            self.model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT).features[4:9]
        elif slice == 3:
            self.model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT).features[9:]
    def forward(self, x):
        return self.model(x)
相关推荐
Techblog of HaoWANG19 分钟前
目标检测与跟踪 (7)- YOLOv8 ONNX量化模型部署指南
yolo·目标检测·onnx·量化部署
CYTElena29 分钟前
关于JAVA异常的笔记
java·开发语言·笔记·语言基础
代码游侠31 分钟前
学习笔记——HTML网页开发基础
运维·服务器·开发语言·笔记·学习·html
FL162386312931 分钟前
电力场景输电线路电缆线异常连接处缺陷金属部件腐蚀检测数据集VOC+YOLO格式3429张5类别
人工智能·yolo·机器学习
代码游侠35 分钟前
应用——基于C语言实现的简易Web服务器开发
运维·服务器·c语言·开发语言·笔记·测试工具
学习3人组1 小时前
YOLOv8模型TensorRT量化实操步骤手册
yolo
week_泽1 小时前
OCR学习笔记,调用免费百度api
笔记·学习·ocr
week_泽1 小时前
离线OCR笔记及代码
笔记·ocr
Aliex_git1 小时前
内存堆栈分析笔记
开发语言·javascript·笔记
航Hang*2 小时前
第三章:网络系统建设与运维(中级)——交换技术
运维·笔记·计算机网络·华为·ensp·交换机