Onnx使用预训练的 ResNet18 模型对输入图像进行分类,并将分类结果显示在图像上

目录

一、整体功能概述

二、函数分析

[2.1 resnet() 函数:](#2.1 resnet() 函数:)

[2.2 pre_process(img_path) 函数:](#2.2 pre_process(img_path) 函数:)

[2.3 loadOnnx(img_path) 函数:](#2.3 loadOnnx(img_path) 函数:)

三、代码执行流程


一、整体功能概述

这段代码实现了一个图像分类系统,使用预训练的 ResNet18 模型对输入图像进行分类,并将分类结果显示在图像上。它包括以下主要步骤:

读取一个包含类别名称和对应编号的文本文件,并将其存储在字典中。

定义了几个函数,包括模型导出函数 resnet()、图像预处理函数 pre_process() 和加载 ONNX 模型进行分类的函数 loadOnnx()。

在主程序中,指定输入图像路径,调用 loadOnnx() 函数对图像进行分类并显示结果。

二、函数分析

2.1 resnet() 函数:

使用 torchvision 中的预训练 ResNet18 模型,并设置为评估模式。

生成一个随机输入张量 x,并将模型导出为 ONNX 格式,保存为 models/resnet18.onnx 文件。

复制代码
def resnet():
    model=models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    model.eval()
    x=torch.randn(1,3,224,224)
    torch.onnx.export(model,x,'models/resnet18.onnx',input_names=['input'],output_names=['output'])

2.2 pre_process(img_path) 函数:

读取输入图像 img_path。

调整图像大小为 224x224。

将图像颜色通道从 BGR 转换为 RGB。

对图像像素值进行归一化处理。

交换图像维度顺序,并增加一个维度。

返回预处理后的图像张量。

复制代码
def pre_process(img_path):
    #h w c--->224,224,3
    #归一化
    #换轴
    #增加维度
    img=cv2.imread(img_path)
    scale_image=cv2.resize(img,dsize=(224,224))
    rgb_img=cv2.cvtColor(scale_image,cv2.COLOR_BGR2RGB)
    rgb_img=rgb_img/255
    rgb_img=np.transpose(rgb_img,(2,0,1))
    rgb_img=np.expand_dims(rgb_img,0).astype(np.float32)
    return rgb_img

2.3 loadOnnx(img_path) 函数:

创建一个 ONNX 推理会话,加载预导出的 ResNet18 ONNX 模型。

调用 pre_process() 函数对输入图像进行预处理。

准备输入数据并进行推理。

获取推理结果中概率最大的类别编号。

根据类别编号从字典中获取对应的类别名称,并进行翻译。

在输入图像上显示分类结果,并展示图像。

复制代码
def loadOnnx(img_path):
    session=ort.InferenceSession(r'models\resnet18.onnx',providers=['CPUExecutionProvider'])
    img=pre_process(img_path)
    img_back=cv2.imread(img_path)
    intput_feed={'input':img}
    session_out=session.run(None,intput_feed)[0]
    out=np.argmax(session_out,axis=1)[0]
    res=str(out)
    # print(dict[res])
    ans=dict[res].split(',')[1].split(']')[0].strip()
    ans = translator.translate(ans)
    cv2.putText(img_back,ans,(100,100),fontFace=1,fontScale=2.0,color=(0,0,255),thickness=3,lineType=cv2.LINE_AA)
    cv2.imshow('win',img_back)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    print(ans)

完整代码如下

复制代码
import cv2
import numpy as np
import torch
from torchvision import models
from torchvision.models import ResNet18_Weights
import onnxruntime as ort
from translate import Translator
translator=Translator(to_lang='Chinese')#翻译成中文
dict={}
with open('类别.txt','r',encoding='utf-8') as f:
    lines=f.readlines()
    for line in lines:
        name=line.split('\t')[0]
        value=line.split('\t')[1]
        dict[name]=value
# print(dict)
def resnet():
    model=models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    model.eval()
    x=torch.randn(1,3,224,224)
    torch.onnx.export(model,x,'models/resnet18.onnx',input_names=['input'],output_names=['output'])
def pre_process(img_path):
    #h w c--->224,224,3
    #归一化
    #换轴
    #增加维度
    img=cv2.imread(img_path)
    scale_image=cv2.resize(img,dsize=(224,224))
    rgb_img=cv2.cvtColor(scale_image,cv2.COLOR_BGR2RGB)
    rgb_img=rgb_img/255
    rgb_img=np.transpose(rgb_img,(2,0,1))
    rgb_img=np.expand_dims(rgb_img,0).astype(np.float32)
    return rgb_img
    #RGB
def loadOnnx(img_path):
    session=ort.InferenceSession(r'models\resnet18.onnx',providers=['CPUExecutionProvider'])
    img=pre_process(img_path)
    img_back=cv2.imread(img_path)
    intput_feed={'input':img}
    session_out=session.run(None,intput_feed)[0]
    out=np.argmax(session_out,axis=1)[0]
    res=str(out)
    # print(dict[res])
    ans=dict[res].split(',')[1].split(']')[0].strip()
    ans = translator.translate(ans)
    cv2.putText(img_back,ans,(100,100),fontFace=1,fontScale=2.0,color=(0,0,255),thickness=3,lineType=cv2.LINE_AA)
    cv2.imshow('win',img_back)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    print(ans)
    pass
if __name__ == '__main__':
    img_path='dog.png'
    # resnet()#导出模型
    loadOnnx(img_path)

三、代码执行流程

在 if name == 'main': 部分:

定义输入图像路径 img_path。

可以选择调用 resnet() 函数导出模型(注释状态,通常只在第一次运行或模型更新时使用)。

调用 loadOnnx(img_path) 函数对输入图像进行分类和显示结果。

相关推荐
JoannaJuanCV8 分钟前
自动驾驶—CARLA仿真(19)automatic_control demo
人工智能·机器学习·自动驾驶
热爱生活的五柒12 分钟前
PolSAR Image Registration——极化合成孔径雷达(PolSAR)图像配准
人工智能·计算机视觉·sar
qq_2337727113 分钟前
**给复杂机器“装上行车记录仪”:一篇量子论文如何照亮AI时代的信任之路**
人工智能
美林数据Tempodata13 分钟前
案例分享|西安财经大学打造全覆盖、全链条人工智能通识教育培养体系
人工智能
O561 6O623O7 安徽正华露16 分钟前
露,生物信号采集处理系统一体机 生物机能实验系统 生物信号采集处理系统 生理机能实验
人工智能
AI营销快线20 分钟前
原圈科技如何引领AI营销内容生产升级:行业进化路线与闭环创新洞察
人工智能
AI营销先锋23 分钟前
2025 AI市场舆情分析行业报告:原圈科技如何帮助企业穿越迷雾,寻找增长北极星
大数据·人工智能
找方案26 分钟前
hello-agents 学习笔记:智能体发展史 —— 从符号逻辑到 AI 协作的进化之旅
人工智能·笔记·学习·智能体·hello-agents
skywalk816329 分钟前
Auto-Coder用Qwen3-Coder-30B-A3B-Instruct模型写一个学习汉字的项目
人工智能·学习·auto-coder
Alluxio42 分钟前
Alluxio正式登陆Oracle云市场,为AI工作负载提供TB级吞吐量与亚毫秒级延迟
人工智能·分布式·机器学习·缓存·ai·oracle