Onnx使用预训练的 ResNet18 模型对输入图像进行分类,并将分类结果显示在图像上

目录

一、整体功能概述

二、函数分析

[2.1 resnet() 函数:](#2.1 resnet() 函数:)

[2.2 pre_process(img_path) 函数:](#2.2 pre_process(img_path) 函数:)

[2.3 loadOnnx(img_path) 函数:](#2.3 loadOnnx(img_path) 函数:)

三、代码执行流程


一、整体功能概述

这段代码实现了一个图像分类系统,使用预训练的 ResNet18 模型对输入图像进行分类,并将分类结果显示在图像上。它包括以下主要步骤:

读取一个包含类别名称和对应编号的文本文件,并将其存储在字典中。

定义了几个函数,包括模型导出函数 resnet()、图像预处理函数 pre_process() 和加载 ONNX 模型进行分类的函数 loadOnnx()。

在主程序中,指定输入图像路径,调用 loadOnnx() 函数对图像进行分类并显示结果。

二、函数分析

2.1 resnet() 函数:

使用 torchvision 中的预训练 ResNet18 模型,并设置为评估模式。

生成一个随机输入张量 x,并将模型导出为 ONNX 格式,保存为 models/resnet18.onnx 文件。

复制代码
def resnet():
    model=models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    model.eval()
    x=torch.randn(1,3,224,224)
    torch.onnx.export(model,x,'models/resnet18.onnx',input_names=['input'],output_names=['output'])

2.2 pre_process(img_path) 函数:

读取输入图像 img_path。

调整图像大小为 224x224。

将图像颜色通道从 BGR 转换为 RGB。

对图像像素值进行归一化处理。

交换图像维度顺序,并增加一个维度。

返回预处理后的图像张量。

复制代码
def pre_process(img_path):
    #h w c--->224,224,3
    #归一化
    #换轴
    #增加维度
    img=cv2.imread(img_path)
    scale_image=cv2.resize(img,dsize=(224,224))
    rgb_img=cv2.cvtColor(scale_image,cv2.COLOR_BGR2RGB)
    rgb_img=rgb_img/255
    rgb_img=np.transpose(rgb_img,(2,0,1))
    rgb_img=np.expand_dims(rgb_img,0).astype(np.float32)
    return rgb_img

2.3 loadOnnx(img_path) 函数:

创建一个 ONNX 推理会话,加载预导出的 ResNet18 ONNX 模型。

调用 pre_process() 函数对输入图像进行预处理。

准备输入数据并进行推理。

获取推理结果中概率最大的类别编号。

根据类别编号从字典中获取对应的类别名称,并进行翻译。

在输入图像上显示分类结果,并展示图像。

复制代码
def loadOnnx(img_path):
    session=ort.InferenceSession(r'models\resnet18.onnx',providers=['CPUExecutionProvider'])
    img=pre_process(img_path)
    img_back=cv2.imread(img_path)
    intput_feed={'input':img}
    session_out=session.run(None,intput_feed)[0]
    out=np.argmax(session_out,axis=1)[0]
    res=str(out)
    # print(dict[res])
    ans=dict[res].split(',')[1].split(']')[0].strip()
    ans = translator.translate(ans)
    cv2.putText(img_back,ans,(100,100),fontFace=1,fontScale=2.0,color=(0,0,255),thickness=3,lineType=cv2.LINE_AA)
    cv2.imshow('win',img_back)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    print(ans)

完整代码如下

复制代码
import cv2
import numpy as np
import torch
from torchvision import models
from torchvision.models import ResNet18_Weights
import onnxruntime as ort
from translate import Translator
translator=Translator(to_lang='Chinese')#翻译成中文
dict={}
with open('类别.txt','r',encoding='utf-8') as f:
    lines=f.readlines()
    for line in lines:
        name=line.split('\t')[0]
        value=line.split('\t')[1]
        dict[name]=value
# print(dict)
def resnet():
    model=models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    model.eval()
    x=torch.randn(1,3,224,224)
    torch.onnx.export(model,x,'models/resnet18.onnx',input_names=['input'],output_names=['output'])
def pre_process(img_path):
    #h w c--->224,224,3
    #归一化
    #换轴
    #增加维度
    img=cv2.imread(img_path)
    scale_image=cv2.resize(img,dsize=(224,224))
    rgb_img=cv2.cvtColor(scale_image,cv2.COLOR_BGR2RGB)
    rgb_img=rgb_img/255
    rgb_img=np.transpose(rgb_img,(2,0,1))
    rgb_img=np.expand_dims(rgb_img,0).astype(np.float32)
    return rgb_img
    #RGB
def loadOnnx(img_path):
    session=ort.InferenceSession(r'models\resnet18.onnx',providers=['CPUExecutionProvider'])
    img=pre_process(img_path)
    img_back=cv2.imread(img_path)
    intput_feed={'input':img}
    session_out=session.run(None,intput_feed)[0]
    out=np.argmax(session_out,axis=1)[0]
    res=str(out)
    # print(dict[res])
    ans=dict[res].split(',')[1].split(']')[0].strip()
    ans = translator.translate(ans)
    cv2.putText(img_back,ans,(100,100),fontFace=1,fontScale=2.0,color=(0,0,255),thickness=3,lineType=cv2.LINE_AA)
    cv2.imshow('win',img_back)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    print(ans)
    pass
if __name__ == '__main__':
    img_path='dog.png'
    # resnet()#导出模型
    loadOnnx(img_path)

三、代码执行流程

在 if name == 'main': 部分:

定义输入图像路径 img_path。

可以选择调用 resnet() 函数导出模型(注释状态,通常只在第一次运行或模型更新时使用)。

调用 loadOnnx(img_path) 函数对输入图像进行分类和显示结果。

相关推荐
跳跳糖炒酸奶3 分钟前
第四章、Isaacsim在GUI中构建机器人(3):添加摄像头和传感器
人工智能·python·算法·ubuntu·机器人
求知呀1 小时前
最直观的 Cursor 使用教程
前端·人工智能·llm
飞哥数智坊2 小时前
从“工具人”到“超级个体”:程序员如何在AI协同下实现能力跃迁
人工智能
chenqi2 小时前
WebGPU和WebLLM:在浏览器中解锁端侧大模型的未来
前端·人工智能
罗西的思考2 小时前
[2W字长文] 探秘Transformer系列之(23)--- 长度外推
人工智能·算法
小杨4044 小时前
python入门系列十四(多进程)
人工智能·python·pycharm
阿坡RPA18 小时前
手搓MCP客户端&服务端:从零到实战极速了解MCP是什么?
人工智能·aigc
用户277844910499319 小时前
借助DeepSeek智能生成测试用例:从提示词到Excel表格的全流程实践
人工智能·python
机器之心19 小时前
刚刚,DeepSeek公布推理时Scaling新论文,R2要来了?
人工智能
算AI21 小时前
人工智能+牙科:临床应用中的几个问题
人工智能·算法