1. 前言
上一篇博客留了一个坑,打算尝试用Rust代替Go来调用链接库试试效果。不过由于这次在家只有一个Mac无法使用TensorRT来加速推理,因此这次的推理框架选择的是MNN。本来想用TNN来着,不过居然编译一直有问题,所以果断转MNN。大致内容和流程和上一篇差不多,首先撸一套MNN的推理代码,然后编译成动态库;Rust调用这个动态库并写个简单的api测试一下。
2. 开始
2.1 C++部分
这次就直接跳过模型讲解(Yolov8基本都被讲烂了),直接用官方的export demo导出onnx。然后编译MNN,最好把转换工具和cv也一起编译了,有的时候MNN自带的cv还挺好用的。
将onnx转为mnn文件,MNNConvert -f ONNX --modelFile yolov8n.onnx --MNNModel yolov8n.mnn --fp16 --optimizePrefer 2
,这里用fp16存储并且弄个加速优化,具体参数可以看官网介绍,MNN的文档写的还是挺详细的。
然后根据文档教程,用Session模式进行推理。需要注意的是,MNN的模型构建除了加载模型之外,还需要设置线程数,推理后端设置以及精度设置,这些在构造函数的时候设置一下就行,我这里直接利用参数配置初始化构造函数。
ini
struct MNNDetConfigs
{
std::string mnnPath;
cv::Size inputSize{640,640};
int classNum = 80 ;
int threads = 4;
int forwardType = MNN_FORWARD_CPU;
float scoreThr = 0.25;
float nmsThr = 0.5;
std::array<float, 3> means = {0.0f, 0.0f, 0.0f};
std::array<float, 3> norms = {1.0f, 1.0f, 1.0f};
};
这里的推理后端有很多种,不过我编译的在Mac上只有CPU、Metal和Auto是可以用的,分别对应0,1,4
前、后处理部分还是letterbox和nms,推理部分其实各个框架差不多,将输入输出设置好之后通过Session得到结果。
scss
cv::Mat prImg;
preProcess(img,prImg);
interpreter_->resizeTensor(inputTensor_, {1, 3, inputH, inputW});
interpreter_->resizeSession(session_);
pretreat_->convert(prImg.data, inputW, inputH, prImg.step[0], inputTensor_);
interpreter_->runSession(session_);
MNN::Tensor* outputTensor = interpreter_->getSessionOutput(session_, NULL);
MNN::Tensor outputTensorHost(outputTensor, outputTensor->getDimensionType());
outputTensor->copyToHostTensor(&outputTensorHost);
大致C++上的推理部分就实现了,来看看效果
推理耗时大概在20ms左右,也没做啥对比实验不知道其他框架的性能如何,以后有空再测测其他的。
2.2 Rust调用
先看一下动态库的头文件定义
arduino
//
// Created by shelgi on 2023/12/30.
//
#include <stdio.h>
#include <stdlib.h>
#ifndef YOLO_MNN_RUSTWRAPPER_H
#define YOLO_MNN_RUSTWRAPPER_H
#ifdef __cplusplus
extern "C"{
#endif
extern const char* detect(void* model,char* base64Img);
extern void* YOLO(const char* mnnPath,int classNums,float scoreThresh,float nmsThresh,int forwardType,int numThreads);
extern void release(void* model);
#ifdef __cplusplus
};
#endif
#endif //YOLO_MNN_RUSTWRAPPER_H
和之前Go的那个基本一致,不过这里我把结果图片也转为base64作为结果传出来了。在C++动态库编译那里也可以隐藏一些不想被看到的函数,只传出这三个函数接口就行。
然后可以通过bindgen
去根据头文件生成bindings,也可以自己写。不过第一次不太熟悉Rust对于C的类型定义,还是用工具生成吧。
bindgen rustwrapper.h -o bindings.rs
这样就生成了一个bindings.rs的文件,里面包含很多定义,直接滑到最下面找到这三个函数的定义,其他的可以删掉。
rust
extern "C" {
pub fn detect(
model: *mut ::std::os::raw::c_void,
base64Img: *mut ::std::os::raw::c_char,
) -> *mut *const ::std::os::raw::c_char;
}
extern "C" {
pub fn YOLO(
mnnPath: *const ::std::os::raw::c_char,
classNums: ::std::os::raw::c_int,
scoreThresh: f32,
nmsThresh: f32,
forwardType: ::std::os::raw::c_int,
numThreads: ::std::os::raw::c_int,
) -> *mut ::std::os::raw::c_void;
}
extern "C" {
pub fn release(model: *mut ::std::os::raw::c_void);
}
然后根据这个就可以开始写相关的功能函数了,比如图片转base64,json解析等;不过后续发现json解析没啥用,可以把结果字符串在前端直接展示,无非就是丑点😂。
这次Rust的后端部分,一开始想用axum不过一想既然都不了解就随便选吧。恰好看到有个评论推荐了salvo,那就看看文档直接拿来用(其实挺香的,国人开源必须支持 )。根据文件上传的案例直接魔改,用Tera渲染一下检测界面的模板。不得不说,Rust真的挺有意思的,几个陌生的框架上手会在编译期解决大部分问题,运行时确实很少出问题(昨晚一个C++导出char**在Rust调用的时候运行时二次释放除外)。经过一顿瞎操作,得到了一个勉强能用的玩意儿,效果如下
平均检测时间也就20ms左右,后期再去研究一下Rust相关的生态,看看有啥好用的框架。
结尾
这一篇虽说是填坑,不过确实内容很水(工作量不小,Rust里面就c_char转换成String第一次上手的时候都有点懵,但是感觉没啥好讲的)。本来服务方面想试试leptos,不过看了一下文档没看太明白(不过沙盒实时演示很好),加上这几天大部分时间都在医院,也就放弃这个想法,弄了个拿来即用的。等后续空闲一定会整理一下,弄个像样点的分享一下。不知不觉一年又过去马上快元旦了,祝大家元旦快乐!!!