基于ONNX的Web端YOLOv8模型部署与推理

之前的一些实验,主要集中在模型的搭建、训练和调优上,没有涉及部署的环节,所以这次尝试将模型部署到端侧设备(浏览器),来熟悉一下部署的流程。

整个过程大致包含如下步骤:

  • YOLOv8的自定义数据集训练
  • YOLOv8模型导出onnx格式
  • 搭配onnxruntime的数据预处理
  • 搭配onnxruntime的数据后处理
  • 基于onnxruntime的模型量化
  • 针对CPU (wasm) execution provider的性能优化措施

最终效果: 让模型在PC端浏览器中运行,通过调用端侧摄像头,实现了目标检测的功能。

项目完整代码参考: 这个仓库

demo地址: Paper, Rock, Scissors Webcam Detection

arduino 复制代码
建议:
1.使用PC端浏览器开启demo,手机端算力资源有限,推理速度较慢
2.需要给予浏览器摄像头权限
3.由于部署在vercel海外节点,对于大文件(模型)的传输,速度不太稳定,如超过2min仍在loading,可以刷新再试

一、前期准备

1.数据集

考虑到最终效果需要直观、清晰、易理解,这次选择了roboflow上的石头剪刀布目标检测数据集:

roboflow官方已经对该数据集做了几个版本的迭代,并且添加了预处理和数据增强(如下图)

我下载的是v14版本,包含3个类别(Rock、Paper、Scissors),7335张图片(训练集6445、验证集576、测试集304),图片尺寸统一reshape成了640x640,虽然图片数量很多,但每张只有几十kb,所以整个数据集也就238M。

2.端侧设备与预训练模型选择

端侧设备:这次选择了浏览器端来部署,一是因为自己对js比较熟悉(老本行嘛),二是浏览器端有其天然优势:

  • 更快:直接把模型放在浏览器里运行,省去了请求后端的时间
  • 更安全:因为模型下载到本地,即使离线也可运行,保证了一定的数据隐私
  • 更便宜:直接使用客户端资源做推理,省去了租赁云端算力的开销

预训练模型选择:考虑到浏览器加载模型的耗时,以及任务本身的精度要求,选择YOLOv8中最小的模型YOLOv8n比较合适

二、模型训练

训练方法可以参考我之前的基于YOLOv8的菜品检测实验,只是数据集换成了石头剪刀布,去掉了超参数调优的环节。

使用yolov8n模型,跑100轮,通过wandb记录训练过程,核心notebook代码如下:

python 复制代码
model = YOLO('yolov8n.pt')
results = model.train(project="yolov8_rps", data='./datasets/data.yaml', epochs=100, imgsz=640)

最终训练结果如下,除了mAP50-95低了点,其他都在90以上:

使用训练好的模型,跑一下视频推理(如下),效果还可以

三、模型导出onnx格式

1.什么是ONNX和ONNX Runtime

ONNX是 Open Neural Network Exchange 的缩写,是一种用于表示深度学习模型的开放格式。它使得不同的深度学习框架(如 TensorFlow、PyTorch、Caffe2 等)可以相互转换和共享模型。

ONNX Runtime 是由微软开发的一款推理框架,用于在多种运行后端(如 CPU、GPU、TensorRT、DML 等)上运行 ONNX 格式的模型 。ONNX Runtime支持跨平台、高性能的深度学习推理,可以加载和运行 ONNX 格式的模型 。它可以与各种深度学习框架无缝集成,提供了简单易用的 API,并支持多种编程语言和平台

2.导出onnx格式

导出可以直接使用YOLOv8官方封装好的export方法,支持多种格式导出,其中就包括ONNX

具体参数选项可参考官方文档,运行如下代码,导出为ONNX格式:

python 复制代码
from ultralytics import YOLO

model = YOLO('./rps_best.pt') # 加载之前训练好的模型
model.export(format='onnx')

控制台输出如下:

shell 复制代码
Ultralytics YOLOv8.1.45 🚀 Python-3.10.12 torch-2.2.1+cu121 CPU (Intel Xeon 2.20GHz)
Model summary (fused): 168 layers, 3006233 parameters, 0 gradients, 8.1 GFLOPs

PyTorch: starting from 'rps_best.pt' with input shape (1, 3, 640, 640) BCHW and output shape(s) (1, 7, 8400) (6.0 MB)
requirements: Ultralytics requirement ['onnx>=1.12.0'] not found, attempting AutoUpdate...
Collecting onnx>=1.12.0
  Downloading onnx-1.16.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 15.9/15.9 MB 41.0 MB/s eta 0:00:00
Requirement already satisfied: numpy>=1.20 in /usr/local/lib/python3.10/dist-packages (from onnx>=1.12.0) (1.25.2)
Requirement already satisfied: protobuf>=3.20.2 in /usr/local/lib/python3.10/dist-packages (from onnx>=1.12.0) (3.20.3)
Installing collected packages: onnx
Successfully installed onnx-1.16.0

requirements: AutoUpdate success ✅ 18.0s, installed 1 package: ['onnx>=1.12.0']
requirements: ⚠️ Restart runtime or rerun command for updates to take effect


ONNX: starting export with onnx 1.16.0 opset 17...
ONNX: export success ✅ 20.5s, saved as 'rps_best.onnx' (11.7 MB)

Export complete (25.5s)
Results saved to /content
Predict:         yolo predict task=detect model=rps_best.onnx imgsz=640  
Validate:        yolo val task=detect model=rps_best.onnx imgsz=640 data=./datasets/data.yaml  
Visualize:       https://netron.app

'rps_best.onnx'

导出的模型大概有12M,输出信息中有一个opset 17比较重要,表示对模型算子的支持版本,较新的版本支持更多算子,较老的则对平台兼容性更好,详情可参考官方说明

四、搭配onnxruntime的数据预处理

1.YOLOv8和onnxruntime在使用上的区别

对于目标检测模型(比如YOLOv8)的推理过程,通常是这样的:

  • (1)读取图片
  • (2)对图片数据做预处理,以符合模型input层的格式
  • (3)模型加载处理后的图片数据
  • (4)模型输出
  • (5)解析模型输出,将相关信息绘制到原始图片上

比如YOLOv8官方已经将上述步骤都封装在了predict方法里,直接调用就行。但是onnxruntime不同,因为它需要支持多种深度学习框架、多个平台、各种不同的模型,它实现的是一个通用的API,只包含上述过程的(3)和(4),其他需要我们自己来实现。

2.python版的数据预处理

先介绍python的实现,因为python在对数据张量的处理方面,有很多高效的库,能让整个流程看起来更清晰易懂。

安装并导入onnxruntime

python 复制代码
!pip install onnxruntime
import onnxruntime as ort

加载onnx模型并实例化

python 复制代码
ort_model = ort.InferenceSession('./rps_best.onnx')

查看模型input层shape

python 复制代码
for input in ort_model.get_inputs():
    print("input name: ", input.name)
    print("input shape: ", input.shape)
    print("input type: ", input.type)

打印结果如下:

python 复制代码
Name: images
Type: tensor(float)
Shape: [1, 3, 640, 640]

输入为四维浮点数据,包含一张3通道(RGB)的图片,尺寸640x640,像下面这样:

以单张图片为例,使用PIL调整图片尺寸

python 复制代码
!pip install pillow
from PIL import Image

img = Image.open("test.jpg")
img_width, img_height = img.size
img = img.resize((640,640))

删除alpha通道

python 复制代码
img = img.convert("RGB")

将图片数据转换为浮点数矩阵

python 复制代码
import numpy as np

input = np.array(img)

查看当前数据的shape

python 复制代码
input.shape
output:(640, 640, 3)

由于我们需要的是(3, 640, 640),得做一下转换

python 复制代码
input = input.transpose(2,0,1)

input.shape
output:(3, 640, 640)

由于需要的格式是四维的,还要做一下升维

python 复制代码
input = input.reshape(1,3,640,640) #或者用expand_dims

input.shape
output:(1, 3, 640, 640)

shape对了,还要对数据做归一化

python 复制代码
input = input/255.0

input[0,0,0,0]
output: 0.9137254901960784

至此,python版的数据预处理完成

3.javascript版的数据预处理

首先,通过canvas的getImageData获取图片的像素数组

javascript 复制代码
function prepare_input(img) {
  const canvas = document.createElement("canvas");
  canvas.width = 640;
  canvas.height = 640;
  const context = canvas.getContext("2d");
  context.drawImage(img, 0, 0, 640, 640);

  const data = context.getImageData(0, 0, 640, 640).data;
}

此时拿到的data是一个一维数组,需要按下图做转换:

python 复制代码
  const red = [],
  green = [],
  blue = [];
  
  for (let index = 0; index < data.length; index += 4) {
    red.push(data[index] / 255);
    green.push(data[index + 1] / 255);
    blue.push(data[index + 2] / 255);
  }
  return [...red, ...green, ...blue];

javascript版的数据预处理就是这些

五、搭配onnxruntime的数据后处理

1.python版的数据后处理

先查看模型输出的shape

python 复制代码
outputs = ort_model.get_outputs()
output = outputs[0]
print("Name:",output.name)
print("Type:",output.type)
print("Shape:",output.shape)

打印结果如下:模型返回一个三维浮点数组,代表一张图片上有8400个检测框(8400是YOLOv8可以检测的最大边界框数量,并且无论实际检测到多少个对象,它都会为任何图像返回 8400 行),每个框有4个坐标点信息+3个类别置信度

shell 复制代码
Name: output0
Type: tensor(float)
Shape: [1, 7, 8400]

可以用模型跑一下推理,看一下实际的输出

python 复制代码
input = input.astype(np.float32) # input代表预处理后的数据,这里先转成单精度浮点
outputs = ort_model.run(["output0"], {"images":input})
output = outputs[0]
output.shape # (1, 7, 8400)

取出其中的二维数据,做转置,方便后续处理

python 复制代码
output = output[0]
output.shape # (7, 8400)
output = output.transpose() # 转置
output.shape # (8400, 7)

这里的(8400, 7)代表有8400行7列,可以看下第一行的数据

python 复制代码
row = output[0]
print(row) 
# [     14.623      22.475      29.157      45.249  1.4901e-07  3.2783e-07  5.9605e-08]

前4项代表检测框的xcenter、ycenter、width、height,后面3项代表'Paper', 'Rock', 'Scissors'的概率,这里需要将坐标信息转换为x1y1(矩形框左上角),x2y2(矩形框右下角)形式,方便画图,还要获取最大概率值和对应的类别

python 复制代码
yolo_classes = ['Paper', 'Rock', 'Scissors']

def parse_row(row):
    xc,yc,w,h = row[:4]
    x1 = (xc-w/2)/640*img_width
    y1 = (yc-h/2)/640*img_height
    x2 = (xc+w/2)/640*img_width
    y2 = (yc+h/2)/640*img_height
    prob = row[4:].max()
    class_id = row[4:].argmax()
    label = yolo_classes[class_id]
    return [x1,y1,x2,y2,label,prob]

对于模型返回的概率较低的检测框我们需要过滤掉,这里设置丢弃概率<0.5的框

python 复制代码
boxes = [row for row in [parse_row(row) for row in output] if row[5]>0.5]

len(boxes) # 20

还有20个框,将这20个框绘制到测试图片上,可以发现它们很多都是重叠的

python 复制代码
from PIL import ImageDraw
img = Image.open("test.jpg")
draw = ImageDraw.Draw(img)

for box in boxes:
    x1,y1,x2,y2,class_id,prob = box
    draw.rectangle((x1,y1,x2,y2),None,"#00ff00")

这时候就需要使用NMS算法做处理

2.IoU与NMS的实现

IoU:可以使用IoU(Intersection over Union,交并比)来判断检测框定位的好坏。所谓交并比,是指预测边框与实际边框的交集和并集的比率,取值范围为0~1,越接近1越好 代码实现:

python 复制代码
def intersection(box1,box2):
    box1_x1,box1_y1,box1_x2,box1_y2 = box1[:4]
    box2_x1,box2_y1,box2_x2,box2_y2 = box2[:4]
    x1 = max(box1_x1,box2_x1)
    y1 = max(box1_y1,box2_y1)
    x2 = min(box1_x2,box2_x2)
    y2 = min(box1_y2,box2_y2)
    return (x2-x1)*(y2-y1)
    
def union(box1,box2):
    box1_x1,box1_y1,box1_x2,box1_y2 = box1[:4]
    box2_x1,box2_y1,box2_x2,box2_y2 = box2[:4]
    box1_area = (box1_x2-box1_x1)*(box1_y2-box1_y1)
    box2_area = (box2_x2-box2_x1)*(box2_y2-box2_y1)
    return box1_area + box2_area - intersection(box1,box2)
    
def iou(box1,box2):
    return intersection(box1,box2)/union(box1,box2) 

NMS:预测结果中,可能多个预测结果间存在重叠部分,需要保留交并比(IoU)最大的、去掉非最大的预测结果,这就是非极大值抑制(Non-Maximum Suppression,简写作NMS)

NMS的算法步骤如下:

  • 将所有框放入队列中
  • 先找到置信度最高的框(假设为A)
  • 将A放入结果数组中
  • 依次计算其他框与A的IoU值
  • 如果某个框(假设为B)的IoU大于给定阈值(比如0.7),则认为B和A框定的是同一个物体,删除B
  • 循环上述步骤,直到队列中没有框了

代码实现:

python 复制代码
# NMS
boxes.sort(key=lambda x: x[5], reverse=True)

result = []

while len(boxes)>0:
    result.append(boxes[0])
    boxes = [box for box in boxes if iou(box,boxes[0])<0.7] #<0.7则不是同一物体,要保留

经过NMS处理后,result中只剩2个框了

python 复制代码
print(result)

[[316.81634521484375,
  242.34803009033203,
  517.7810668945312,
  431.6143112182617,
  'Scissors',
  0.9387282],
 [69.72752380371094,
  226.46387481689453,
  205.6693878173828,
  435.36669158935547,
  'Rock',
  0.9296297]]

3.javascript版的数据后处理

计算IoU的实现

javascript 复制代码
function iou(box1, box2) {
  return intersection(box1, box2) / union(box1, box2);
}

function union(box1, box2) {
  const [box1_x1, box1_y1, box1_x2, box1_y2] = box1;
  const [box2_x1, box2_y1, box2_x2, box2_y2] = box2;
  const box1_area = (box1_x2 - box1_x1) * (box1_y2 - box1_y1);
  const box2_area = (box2_x2 - box2_x1) * (box2_y2 - box2_y1);
  return box1_area + box2_area - intersection(box1, box2);
}

function intersection(box1, box2) {
  const [box1_x1, box1_y1, box1_x2, box1_y2] = box1;
  const [box2_x1, box2_y1, box2_x2, box2_y2] = box2;
  const x1 = Math.max(box1_x1, box2_x1);
  const y1 = Math.max(box1_y1, box2_y1);
  const x2 = Math.min(box1_x2, box2_x2);
  const y2 = Math.min(box1_y2, box2_y2);
  return (x2 - x1) * (y2 - y1);
}

获取x1、y1、x2、y2、最大置信度和对应的类别,这里没有做行列的转置,而是直接使用了绝对索引来定位

javascript 复制代码
for (let index = 0; index < 8400; index++) {
    const [class_id, prob] = [...Array(yolo_classes.length).keys()]
      .map((col) => [col, output[8400 * (col + 4) + index]])
      .reduce((accum, item) => (item[1] > accum[1] ? item : accum), [0, 0]);
    if (prob < 0.5) {
      continue;
    }
    const label = yolo_classes[class_id];
    const xc = output[index];
    const yc = output[8400 + index];
    const w = output[2 * 8400 + index];
    const h = output[3 * 8400 + index];
    const x1 = ((xc - w / 2) / 640) * img_width;
    const y1 = ((yc - h / 2) / 640) * img_height;
    const x2 = ((xc + w / 2) / 640) * img_width;
    const y2 = ((yc + h / 2) / 640) * img_height;
    boxes.push([x1, y1, x2, y2, label, prob]);
}

NMS算法实现:

javascript 复制代码
let boxes = [];
boxes = boxes.sort((box1, box2) => box2[5] - box1[5]);

const result = [];
while (boxes.length > 0) {
    result.push(boxes[0]);
    boxes = boxes.filter((box) => iou(boxes[0], box) < 0.7);
}

六、模型部署与推理

在浏览器上部署,用到的是onnxruntime-web这个library,它可以调用端侧的cpu(wasm)、webgl或webGPU来执行模型推理。我这里使用的是cpu来执行,因为兼容性比较好。

onnx官方推荐把模型推理这部分代码,放到web worker中执行,因为是cpu密集型操作,可以有效防止阻塞主线程,具体参考官方文档

worker.js的代码如下

javascript 复制代码
importScripts(
  "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.1/dist/ort.min.js",
);

let model = null;
ort.InferenceSession.create("./rps_best.onnx", {
  executionProviders: ["wasm"],
  graphOptimizationLevel: "all",
}).then((res) => {
  model = res;
  console.log("model", model);
  postMessage({ type: "modelLoaded" });
});

async function run_model(input) {
  if (!model) {
    model = await model;
  }
  input = new ort.Tensor(Float32Array.from(input), [1, 3, 640, 640]);
  const outputs = await model.run({ images: input });
  return outputs["output0"].data;
}

onmessage = async (event) => {
  const { input, startTime } = event.data;
  const output = await run_model(input);
  postMessage({ type: "modelResult", result: output, startTime });
};

绘制最终检测框的代码如下:

javascript 复制代码
function draw_boxes(canvas, boxes) {
  const ctx = canvas.getContext("2d");
  ctx.strokeStyle = "#00FF00";
  ctx.lineWidth = 3;
  ctx.font = "18px serif";
  boxes.forEach(([x1, y1, x2, y2, label]) => {
    ctx.strokeRect(x1, y1, x2 - x1, y2 - y1);
    ctx.fillStyle = "#00ff00";
    const width = ctx.measureText(label).width;
    ctx.fillRect(x1, y1, width + 10, 25);
    ctx.fillStyle = "#000000";
    ctx.fillText(label, x1, y1 + 18);
  });

  // 绘制 Infer count 和 Average infer time
  ctx.font = "16px Arial";
  ctx.fillStyle = "black";
  ctx.fillText(`Infer count: ${inferCount}`, 10, 20);
  ctx.fillText(
    `Average infer time: ${
      inferCount ? parseInt(totalInferTime / inferCount) : 0
    } ms`,
    10,
    40,
  );
}

刚开始运行,除了会加载模型文件(12.2M),还会加载一个wasm library(2.6M)作为backend来调用cpu资源

在我的 Mac M1 Chrome浏览器上,平均推理一次差不多要600ms(如下图)

后续主要从两个方向着手优化:

  • 减小模型文件尺寸,提升加载速度
  • 挖掘端侧性能,从而降低推理耗时

七、优化

1.模型量化

模型量化是指将模型中的浮点权重和激活值转换为低精度整数类型,例如 int8 或 int16 的过程,这可以显著减小模型的大小和内存占用

我这里使用了动态uint8量化,在量化前还需要做一次preprocess,其中包含了对model的optimize,具体可参考官方文档

python 复制代码
!pip install onnxruntime

# https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification/cpu/ReadMe.md
!python -m onnxruntime.quantization.preprocess --input rps_best.onnx --output rps_best_infer.onnx

from onnxruntime.quantization import quantize_dynamic, quantize_static, QuantType
model_fp32 = './rps_best_infer.onnx'
model_quant = './rps_best_uint8.onnx'
quantized_model = quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8)

量化后的模型,大小从12.2M,减少到3.3M,体积减小了四分之一

2.使用SIMD+多线程加速wasm backend

SIMD 是 Single Instruction, Multiple Data 的缩写,指的是单指令多数据。SIMD 指令可以同时对多个数据进行操作,从而提高数据处理速度

多线程可以有效利用多核CPU资源,从而提高性能

onnxruntime-web的SIMD是默认开启的,而多线程需要服务端返回COOP/COEP响应头,以在浏览器端启用跨域隔离(如下图),详情参考这里

由于我使用的vercel部署,直接在vercel.json中添加配置即可让vercel返回相关http头

json 复制代码
{
  "headers": [
    {
      "source": "/(.*)",
      "headers": [
        {
          "key": "Cross-Origin-Embedder-Policy",
          "value": "require-corp"
        },
        {
          "key": "Cross-Origin-Opener-Policy",
          "value": "same-origin"
        }
      ]
    }
  ]
}

启用wasm多线程后,在PC端Chrome上的平均推理时间,从600ms左右减少到269ms

八、有待改进的地方

1.最新版onnx支持int4量化,可以进一步压缩模型体积,减少通过网络传输的数据量

2.尝试将模型资源部署到国内节点,提升传输速度和稳定性

3.在手机浏览器上测试,平均推理时间需要1600ms~2000ms,不确定是手机计算资源孱弱,还是某些加速功能未开启,有待验证

4.在尝试使用WebGL和WebGPU provider时,遇到了一些问题,还未解决,如果能利用端侧设备的GPU资源,则可以更好得加速模型推理

5.可以尝试onnx以外的解决方案,也许会有更好的性能

版权声明:本文为博主原创文章,转载请注明作者和出处

参考:

使用 SIMD 和多线程增强 TensorFlow.js WebAssembly 后端

关于启用跨域隔离的指南

AI模型部署 | onnxruntime部署RT-DETR目标检测模型

How to create YOLOv8-based object detection web service using Python, Julia, Node.js, JavaScript, Go and Rust

ONNX Web官方文档

How to detect objects in videos in a web browser using YOLOv8 neural network and JavaScript

YOLOv8初体验:检测、跟踪、模型部署

相关推荐
Halo咯咯1 分钟前
无限免费 OpenClaw:接入本地模型后,你的 AI Agent 就可以 24 小时自动干活(Mac Mini 可用)
人工智能
NAGNIP12 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab13 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab13 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP17 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年17 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼17 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS17 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区18 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈19 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能