幽冥大陆(八十四)Python 水果识别PTH 转 ONNX 脚本新 —东方仙盟练气期

的 PTH 模型实际训练的是131 类 水果,但转换脚本中手动设置了NUM_CLASSES = 208,导致分类头的权重维度不匹配(131≠208),这是典型的「模型结构和权重维度不一致」问题。

修正后的完整 PTH 转 ONNX 脚本(适配 131 类)

python

运行

复制代码
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models import MobileNet_V2_Weights

# ==================== 配置项(修正类别数为131) ====================
PTH_MODEL_PATH = "wlzcfruit_mobilenetv2.pth"  # 训练好的pth路径
ONNX_MODEL_PATH = "wlzcfruit_mobilenetv2.onnx"  # 输出onnx路径
NUM_CLASSES = 131  # ✅ 修正为实际训练的131类(不是208)
INPUT_SIZE = (100, 100)  # 和训练脚本的Resize一致
DEVICE = torch.device("cpu")
# ===================================================================

# 1. 重建和训练时完全一致的模型结构
def build_model(num_classes):
    # 重建MobileNetV2(和训练脚本结构完全一致)
    model = models.mobilenet_v2(pretrained=True)  # 保持和训练脚本一致的写法
    # 修改分类头(维度匹配131类)
    model.classifier[1] = nn.Linear(model.last_channel, num_classes)
    return model

# 2. 加载pth权重并设置为推理模式
model = build_model(NUM_CLASSES)
# 加载权重(CPU环境)
model.load_state_dict(torch.load(PTH_MODEL_PATH, map_location=DEVICE))
model.eval()  # 必须设为推理模式

# 3. 构造示例输入(1,3,100,100)
dummy_input = torch.randn(1, 3, INPUT_SIZE[0], INPUT_SIZE[1], device=DEVICE)

# 4. 转换PTH到ONNX
torch.onnx.export(
    model,
    dummy_input,
    ONNX_MODEL_PATH,
    opset_version=12,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size"},
        "output": {0: "batch_size"}
    },
    verbose=False
)

# 5. 验证ONNX模型
try:
    import onnx
    onnx_model = onnx.load(ONNX_MODEL_PATH)
    onnx.checker.check_model(onnx_model)
    print(f"✅ ONNX模型转换成功!路径:{ONNX_MODEL_PATH}")
    print(f"📌 输入维度:1x3x{INPUT_SIZE[0]}x{INPUT_SIZE[1]}(动态批次)")
    print(f"📌 输出维度:1x{NUM_CLASSES}(对应{NUM_CLASSES}类水果)")
except ImportError:
    print("⚠️ 未安装onnx库,跳过验证(执行 pip install onnx 安装)")
    print(f"✅ ONNX模型已生成:{ONNX_MODEL_PATH}")
except Exception as e:
    print(f"❌ 验证失败:{str(e)}")

额外优化:自动匹配类别数(避免手动填错)

如果担心后续类别数变化,可从训练集目录自动读取类别数,彻底避免维度不匹配:

python

运行

复制代码
import os
# 自动读取训练集类别数(和训练脚本逻辑一致)
TRAIN_DIR = "./wlzcfruit/Training"
NUM_CLASSES = len([d for d in os.listdir(TRAIN_DIR) if os.path.isdir(os.path.join(TRAIN_DIR, d))])
print(f"🔍 自动检测到训练集类别数:{NUM_CLASSES}")

执行步骤

  1. 替换脚本中NUM_CLASSES = 208131(或用上面的自动检测代码);
  2. 确保PTH_MODEL_PATH指向正确的wlzcfruit_mobilenetv2.pth
  3. 重新运行脚本:python wlzc_fruit_pth2onnx.py

阿雪技术观

在科技发展浪潮中,我们不妨积极投身技术共享。不满足于做受益者,更要主动担当贡献者。无论是分享代码、撰写技术博客,还是参与开源项目维护改进,每一个微小举动都可能蕴含推动技术进步的巨大能量。东方仙盟是汇聚力量的天地,我们携手在此探索硅基生命,为科技进步添砖加瓦。

Hey folks, in this wild tech - driven world, why not dive headfirst into the whole tech - sharing scene? Don't just be the one reaping all the benefits; step up and be a contributor too. Whether you're tossing out your code snippets, hammering out some tech blogs, or getting your hands dirty with maintaining and sprucing up open - source projects, every little thing you do might just end up being a massive force that pushes tech forward. And guess what? The Eastern FairyAlliance is this awesome place where we all come together. We're gonna team up and explore the whole silicon - based life thing, and in the process, we'll be fueling the growth of technology

相关推荐
逐米时代1 分钟前
企业AI智能体是什么?如何解决制造型企业信息孤岛问题
人工智能·制造
用什么都重名1 分钟前
Python文本匹配利器:FlashText与RapidFuzz深度对比
python·flash text·rapidfuzz
@Ma2 分钟前
Python 实现企业微信外部群主动消息发送及成功接入后如何避坑,避免风控封号
开发语言·python·企业微信
标书畅畅行2 分钟前
深度解析钛投标AI标书工具:全流程企业级AI投标解决方案,重构投标数字化生产力
大数据·数据库·人工智能
DXM05212 分钟前
第10期| 卷积神经网络CNN通俗详解:AI遥感的底层核心
人工智能·python·神经网络·机器学习·arcgis·cnn·文心一言
o561路6o623o73 分钟前
陈,CPP条件位置偏爱系统
深度学习
ShyanZh4 分钟前
【skill】Agent-Browser:AI代理的浏览器自动化实战指南
运维·人工智能·自动化·skill·agent-browser
Hello:CodeWorld4 分钟前
AI Agent:从核心原理、架构框架到工程实战,大模型时代的自主智能革命
大数据·人工智能·python·架构
mowei4 分钟前
MCP 配了 20 分钟,CLI 一句话:我给 Agent 选工具的真实取舍
人工智能
Chengbei116 分钟前
CTF & 红队专用 AI 求解AI 引擎 Cairn 系统,化轻量化部署,红队、CTF、漏洞研究一站式解决方案
java·人工智能·安全·web安全·网络安全·系统安全