Onnx使用预训练的 ResNet18 模型对输入图像进行分类,并将分类结果显示在图像上

目录

一、整体功能概述

二、函数分析

[2.1 resnet() 函数:](#2.1 resnet() 函数:)

[2.2 pre_process(img_path) 函数:](#2.2 pre_process(img_path) 函数:)

[2.3 loadOnnx(img_path) 函数:](#2.3 loadOnnx(img_path) 函数:)

三、代码执行流程


一、整体功能概述

这段代码实现了一个图像分类系统,使用预训练的 ResNet18 模型对输入图像进行分类,并将分类结果显示在图像上。它包括以下主要步骤:

读取一个包含类别名称和对应编号的文本文件,并将其存储在字典中。

定义了几个函数,包括模型导出函数 resnet()、图像预处理函数 pre_process() 和加载 ONNX 模型进行分类的函数 loadOnnx()。

在主程序中,指定输入图像路径,调用 loadOnnx() 函数对图像进行分类并显示结果。

二、函数分析

2.1 resnet() 函数:

使用 torchvision 中的预训练 ResNet18 模型,并设置为评估模式。

生成一个随机输入张量 x,并将模型导出为 ONNX 格式,保存为 models/resnet18.onnx 文件。

def resnet():
    model=models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    model.eval()
    x=torch.randn(1,3,224,224)
    torch.onnx.export(model,x,'models/resnet18.onnx',input_names=['input'],output_names=['output'])

2.2 pre_process(img_path) 函数:

读取输入图像 img_path。

调整图像大小为 224x224。

将图像颜色通道从 BGR 转换为 RGB。

对图像像素值进行归一化处理。

交换图像维度顺序,并增加一个维度。

返回预处理后的图像张量。

def pre_process(img_path):
    #h w c--->224,224,3
    #归一化
    #换轴
    #增加维度
    img=cv2.imread(img_path)
    scale_image=cv2.resize(img,dsize=(224,224))
    rgb_img=cv2.cvtColor(scale_image,cv2.COLOR_BGR2RGB)
    rgb_img=rgb_img/255
    rgb_img=np.transpose(rgb_img,(2,0,1))
    rgb_img=np.expand_dims(rgb_img,0).astype(np.float32)
    return rgb_img

2.3 loadOnnx(img_path) 函数:

创建一个 ONNX 推理会话,加载预导出的 ResNet18 ONNX 模型。

调用 pre_process() 函数对输入图像进行预处理。

准备输入数据并进行推理。

获取推理结果中概率最大的类别编号。

根据类别编号从字典中获取对应的类别名称,并进行翻译。

在输入图像上显示分类结果,并展示图像。

def loadOnnx(img_path):
    session=ort.InferenceSession(r'models\resnet18.onnx',providers=['CPUExecutionProvider'])
    img=pre_process(img_path)
    img_back=cv2.imread(img_path)
    intput_feed={'input':img}
    session_out=session.run(None,intput_feed)[0]
    out=np.argmax(session_out,axis=1)[0]
    res=str(out)
    # print(dict[res])
    ans=dict[res].split(',')[1].split(']')[0].strip()
    ans = translator.translate(ans)
    cv2.putText(img_back,ans,(100,100),fontFace=1,fontScale=2.0,color=(0,0,255),thickness=3,lineType=cv2.LINE_AA)
    cv2.imshow('win',img_back)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    print(ans)

完整代码如下

import cv2
import numpy as np
import torch
from torchvision import models
from torchvision.models import ResNet18_Weights
import onnxruntime as ort
from translate import Translator
translator=Translator(to_lang='Chinese')#翻译成中文
dict={}
with open('类别.txt','r',encoding='utf-8') as f:
    lines=f.readlines()
    for line in lines:
        name=line.split('\t')[0]
        value=line.split('\t')[1]
        dict[name]=value
# print(dict)
def resnet():
    model=models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    model.eval()
    x=torch.randn(1,3,224,224)
    torch.onnx.export(model,x,'models/resnet18.onnx',input_names=['input'],output_names=['output'])
def pre_process(img_path):
    #h w c--->224,224,3
    #归一化
    #换轴
    #增加维度
    img=cv2.imread(img_path)
    scale_image=cv2.resize(img,dsize=(224,224))
    rgb_img=cv2.cvtColor(scale_image,cv2.COLOR_BGR2RGB)
    rgb_img=rgb_img/255
    rgb_img=np.transpose(rgb_img,(2,0,1))
    rgb_img=np.expand_dims(rgb_img,0).astype(np.float32)
    return rgb_img
    #RGB
def loadOnnx(img_path):
    session=ort.InferenceSession(r'models\resnet18.onnx',providers=['CPUExecutionProvider'])
    img=pre_process(img_path)
    img_back=cv2.imread(img_path)
    intput_feed={'input':img}
    session_out=session.run(None,intput_feed)[0]
    out=np.argmax(session_out,axis=1)[0]
    res=str(out)
    # print(dict[res])
    ans=dict[res].split(',')[1].split(']')[0].strip()
    ans = translator.translate(ans)
    cv2.putText(img_back,ans,(100,100),fontFace=1,fontScale=2.0,color=(0,0,255),thickness=3,lineType=cv2.LINE_AA)
    cv2.imshow('win',img_back)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    print(ans)
    pass
if __name__ == '__main__':
    img_path='dog.png'
    # resnet()#导出模型
    loadOnnx(img_path)

三、代码执行流程

在 if name == 'main': 部分:

定义输入图像路径 img_path。

可以选择调用 resnet() 函数导出模型(注释状态,通常只在第一次运行或模型更新时使用)。

调用 loadOnnx(img_path) 函数对输入图像进行分类和显示结果。

相关推荐
AI-入门2 分钟前
AI 产品经理:2024 年职场新航标 ——AI 产品经理的未来与契机
人工智能·chatgpt·prompt·产品经理·agi
北极冰雨13 分钟前
Agent、RAG、LangChain的概念及作用
人工智能
极客小张14 分钟前
基于OpenCV与MQTT的停车场车牌识别系统:结合SQLite和Flask的设计流程
arm开发·人工智能·单片机·opencv·物联网·flask·毕业设计
z千鑫33 分钟前
【数据分析】利用Python+AI+工作流实现自动化数据分析-全流程讲解
人工智能·python·ai·数据分析·自动化·ai编程·ai工作流
雪碧有白泡泡40 分钟前
Stable Diffusion AI算法,实现一键式后期处理与图像修复魔法
人工智能·算法·stable diffusion
高登先生44 分钟前
科技之光,照亮未来之路“2024南京国际人工智能展会”
大数据·人工智能·科技·数学建模·能源
.5481 小时前
AI基础 L19 Quantifying Uncertainty and Reasoning with Probabilities I 量化不确定性和概率推理
人工智能
菌菌的快乐生活1 小时前
树莓派安装 OpenCV 教程
人工智能·opencv·计算机视觉
SEVEN-YEARS1 小时前
深度学习入门:探索神经网络、感知器与损失函数
人工智能·深度学习·神经网络
DA树聚1 小时前
ChatGPT的底层逻辑
人工智能·深度学习·语言模型·自然语言处理·chatgpt·数据挖掘