Onnx使用预训练的 ResNet18 模型对输入图像进行分类,并将分类结果显示在图像上

目录

一、整体功能概述

二、函数分析

[2.1 resnet() 函数:](#2.1 resnet() 函数:)

[2.2 pre_process(img_path) 函数:](#2.2 pre_process(img_path) 函数:)

[2.3 loadOnnx(img_path) 函数:](#2.3 loadOnnx(img_path) 函数:)

三、代码执行流程


一、整体功能概述

这段代码实现了一个图像分类系统,使用预训练的 ResNet18 模型对输入图像进行分类,并将分类结果显示在图像上。它包括以下主要步骤:

读取一个包含类别名称和对应编号的文本文件,并将其存储在字典中。

定义了几个函数,包括模型导出函数 resnet()、图像预处理函数 pre_process() 和加载 ONNX 模型进行分类的函数 loadOnnx()。

在主程序中,指定输入图像路径,调用 loadOnnx() 函数对图像进行分类并显示结果。

二、函数分析

2.1 resnet() 函数:

使用 torchvision 中的预训练 ResNet18 模型,并设置为评估模式。

生成一个随机输入张量 x,并将模型导出为 ONNX 格式,保存为 models/resnet18.onnx 文件。

复制代码
def resnet():
    model=models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    model.eval()
    x=torch.randn(1,3,224,224)
    torch.onnx.export(model,x,'models/resnet18.onnx',input_names=['input'],output_names=['output'])

2.2 pre_process(img_path) 函数:

读取输入图像 img_path。

调整图像大小为 224x224。

将图像颜色通道从 BGR 转换为 RGB。

对图像像素值进行归一化处理。

交换图像维度顺序,并增加一个维度。

返回预处理后的图像张量。

复制代码
def pre_process(img_path):
    #h w c--->224,224,3
    #归一化
    #换轴
    #增加维度
    img=cv2.imread(img_path)
    scale_image=cv2.resize(img,dsize=(224,224))
    rgb_img=cv2.cvtColor(scale_image,cv2.COLOR_BGR2RGB)
    rgb_img=rgb_img/255
    rgb_img=np.transpose(rgb_img,(2,0,1))
    rgb_img=np.expand_dims(rgb_img,0).astype(np.float32)
    return rgb_img

2.3 loadOnnx(img_path) 函数:

创建一个 ONNX 推理会话,加载预导出的 ResNet18 ONNX 模型。

调用 pre_process() 函数对输入图像进行预处理。

准备输入数据并进行推理。

获取推理结果中概率最大的类别编号。

根据类别编号从字典中获取对应的类别名称,并进行翻译。

在输入图像上显示分类结果,并展示图像。

复制代码
def loadOnnx(img_path):
    session=ort.InferenceSession(r'models\resnet18.onnx',providers=['CPUExecutionProvider'])
    img=pre_process(img_path)
    img_back=cv2.imread(img_path)
    intput_feed={'input':img}
    session_out=session.run(None,intput_feed)[0]
    out=np.argmax(session_out,axis=1)[0]
    res=str(out)
    # print(dict[res])
    ans=dict[res].split(',')[1].split(']')[0].strip()
    ans = translator.translate(ans)
    cv2.putText(img_back,ans,(100,100),fontFace=1,fontScale=2.0,color=(0,0,255),thickness=3,lineType=cv2.LINE_AA)
    cv2.imshow('win',img_back)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    print(ans)

完整代码如下

复制代码
import cv2
import numpy as np
import torch
from torchvision import models
from torchvision.models import ResNet18_Weights
import onnxruntime as ort
from translate import Translator
translator=Translator(to_lang='Chinese')#翻译成中文
dict={}
with open('类别.txt','r',encoding='utf-8') as f:
    lines=f.readlines()
    for line in lines:
        name=line.split('\t')[0]
        value=line.split('\t')[1]
        dict[name]=value
# print(dict)
def resnet():
    model=models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    model.eval()
    x=torch.randn(1,3,224,224)
    torch.onnx.export(model,x,'models/resnet18.onnx',input_names=['input'],output_names=['output'])
def pre_process(img_path):
    #h w c--->224,224,3
    #归一化
    #换轴
    #增加维度
    img=cv2.imread(img_path)
    scale_image=cv2.resize(img,dsize=(224,224))
    rgb_img=cv2.cvtColor(scale_image,cv2.COLOR_BGR2RGB)
    rgb_img=rgb_img/255
    rgb_img=np.transpose(rgb_img,(2,0,1))
    rgb_img=np.expand_dims(rgb_img,0).astype(np.float32)
    return rgb_img
    #RGB
def loadOnnx(img_path):
    session=ort.InferenceSession(r'models\resnet18.onnx',providers=['CPUExecutionProvider'])
    img=pre_process(img_path)
    img_back=cv2.imread(img_path)
    intput_feed={'input':img}
    session_out=session.run(None,intput_feed)[0]
    out=np.argmax(session_out,axis=1)[0]
    res=str(out)
    # print(dict[res])
    ans=dict[res].split(',')[1].split(']')[0].strip()
    ans = translator.translate(ans)
    cv2.putText(img_back,ans,(100,100),fontFace=1,fontScale=2.0,color=(0,0,255),thickness=3,lineType=cv2.LINE_AA)
    cv2.imshow('win',img_back)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    print(ans)
    pass
if __name__ == '__main__':
    img_path='dog.png'
    # resnet()#导出模型
    loadOnnx(img_path)

三、代码执行流程

在 if name == 'main': 部分:

定义输入图像路径 img_path。

可以选择调用 resnet() 函数导出模型(注释状态,通常只在第一次运行或模型更新时使用)。

调用 loadOnnx(img_path) 函数对输入图像进行分类和显示结果。

相关推荐
ocr_sinosecu13 分钟前
OCR定制识别:解锁文字识别的无限可能
人工智能·机器学习·ocr
奋斗者1号15 分钟前
分类数据处理全解析:从独热编码到高维特征优化
人工智能·机器学习·分类
契合qht53_shine30 分钟前
深度学习 视觉处理(CNN) day_02
人工智能·深度学习·cnn
就叫飞六吧1 小时前
如何判断你的PyTorch是GPU版还是CPU版?
人工智能·pytorch·python
zsffuture1 小时前
opencv 读取3G大图失败,又不想重新编译opencv ,可以如下操作
人工智能·opencv·webpack
AntBlack1 小时前
别说了别说了 ,Trae 已经在不停优化迭代了
前端·人工智能·后端
訾博ZiBo1 小时前
AI日报 - 2025年04月28日
人工智能
annus mirabilis1 小时前
解析Suna:全球首款开源通用AI智能体
人工智能·开源·suna
riveting2 小时前
SD2351核心板:重构AI视觉产业价值链的“超级节点”
大数据·linux·图像处理·人工智能·重构·智能硬件
Lilith的AI学习日记2 小时前
大语言模型中的幻觉现象深度解析:原理、评估与缓解策略
人工智能·语言模型·自然语言处理·aigc·ai编程