在微信小程序部署AI模型的几种方法

前言

本文只是分享思路,不提供可完整运行的项目代码

onnx部署

以目标检测类模型为例,该类模型会输出类别信息,置信度,包含检测框的4个坐标信息

但不是所有的onnx模型都能在微信小程序部署,有些算子不支持,这种情况需要点特殊操作。

微信小程序提供的接口相当于使用onnxruntime的接口运行onnx模型,我们要做的就是将视频帧数据(包含RGBA的一维像素数组)转换成对应形状的数组(比如3*224*224的一维Float32Array),然后调用接口并将图像输入得到运行的结果(比如一个1*10*6的一维Float32Array,代表着10个预测框的类别,置信度和框的4个坐标),然后将结果处理(比如行人检测,给置信度设置一个阈值0.5,筛选置信度大于阈值的数组的index,然后按照index取出相应的类别和框坐标),最后在wxml中显示类别名或置信度或在canvas绘制框。

代码框架

这里采用的是实时帧数据,按预设频率调用一帧数据并后处理得到结果

初始化session

首先得将onnx上传至云端,获得一个存储路径(比如cloud://cloud1-8gcwcxqrb8722e9e.636c-cloud1-8gcwcxqrb8722e9e-1324077753/rtdetrWorker.onnx)

当用户首次使用该小程序时,手机里没有onnx模型的存储,需要从云端下载;而已经非第一次使用该小程序的用户手机里已经保存了之前下载的onnx模型,就无需下载。所以此处代码逻辑是需要检测用户的存储里是否有该onnx模型,不存在就下载,下载完并保存模型文件后就执行下一步;存在就直接执行下一步。

javascript 复制代码
  InitSession()
  {
    return new Promise(resolve=>{
      const cloudPath = 'cloud://cloud1-8gcwcxqrb8722e9e.636c-cloud1-8gcwcxqrb8722e9e-1324077753/mobilnet.onnx'
      const lastindex=cloudPath.lastIndexOf('/')
      const filename=cloudPath.substring(lastindex+1)
      const modelPath = `${wx.env.USER_DATA_PATH}/`+filename;
      // 判断之前是否已经下载过onnx模型
      wx.getFileSystemManager().access({
      path: modelPath,
      success: (res) =>
      {
        console.log("文件已经存在")
        // 创建session
        this.createInferenceSession(modelPath)
        // 监听帧,频率为1秒1次
        setInterval(this.oneFrame, 1000)
        resolve()
      },
      fail: (res) => {
        // 文件不存在
        console.error(res)
        wx.cloud.init();
        console.log("开始下载模型");
        // 调用自定义函数下载文件
        this.downloadFile(cloudPath, function(r) {
        console.log(`下载进度:${r.progress}%,已下载${r.totalBytesWritten}B,共${r.totalBytesExpectedToWrite}B`)
      }).then(result => {
          // 保存模型到本地
          wx.getFileSystemManager().saveFile({
            tempFilePath:result.tempFilePath,
            filePath: modelPath,
            success: (res) => { // 注册回调函数
              console.log(res)
              const modelPath = res.savedFilePath;
              console.log("保存模型到路径: " + modelPath)
              // 创建session
              this.createInferenceSession(modelPath)
              // 监听帧,频率为1秒1次
              setInterval(this.oneFrame, 1000)
              resolve()
            },
            fail(res) {
              console.error(res)
            }
          })
        });
      }
      })
    })

自定义的下载文件函数

javascript 复制代码
  downloadFile(fileID, onCall = () => {}) {
    return new Promise((resolve, reject) => {
      const task = wx.cloud.downloadFile({
        fileID,
        success: res => resolve(res),
      })
      task.onProgressUpdate((res) => {
        if (onCall(res) == false) {
          task.abort()
        }
      })
    })
  },

自定义创建session的函数

javascript 复制代码
  createInferenceSession(modelPath) {
    return new Promise((resolve, reject) => {
      this.session = wx.createInferenceSession({
        model: modelPath,
        precisionLevel : 4,
        allowNPU : false,
        allowQuantize: false,
      });

      // 监听error事件
      this.session.onError((error) => {
        console.error(error);
        reject(error);
      });
      this.session.onLoad(() => {
        resolve();
      });
    })
  },

自定义处理帧函数

就是上面初始化session步骤里面 创建session后 按预设频率执行的函数

开启相机监听,在回调函数内获取帧数据、处理帧数据、开始推理、关闭监听

javascript 复制代码
  oneFrame(){
    const context=wx.createCameraContext()
    const camCallback=(frame)=>{
      // 处理图片数据
      var dstInput=new Float32Array(this.data.imageChannel*this.data.imageWidth*this.data.imageHeight)
      this.preProcess(frame,dstInput)
      // 推理得到结果
      this.infer(dstInput)
      // 关闭监听
      listener.stop()
    }
    const listener=context.onCameraFrame(camCallback)
    listener.start()
  },

自定义的图像处理函数

该函数接收帧数据(RGBA一维数组)和在外面初始化的Float32Array数组,执行归一化、去除透明度通道。

javascript 复制代码
  preProcess(frame, dstInput) {
    return new Promise((resolve, reject) =>
    {
      const origData = new Uint8Array(frame.data);
      const hRatio = frame.height / this.data.imageHeight;
      const wRatio = frame.width / this.data.imageWidth;
      const origHStride = frame.width * 4;
      const origWStride = 4;
      const mean = [0.485, 0.456, 0.406]
      // Reverse of std = [0.229, 0.224, 0.225]
      const reverse_div = [4.367, 4.464, 4.444]
      const ratio = 1 / 255.0
      const normalized_div = [ratio / reverse_div[0], ratio * reverse_div[1], ratio * reverse_div[2]];
      const normalized_mean = [mean[0] * reverse_div[0], mean[1] * reverse_div[1], mean[2] * reverse_div[2]];
      var idx = 0;
      for (var c = 0; c < this.data.imageChannel; ++c)
      {
        for (var h = 0; h < this.data.imageHeight; ++h)
        {
          const origH = Math.round(h * hRatio);
          const origHOffset = origH * origHStride;
          for (var w = 0; w < this.data.imageWidth; ++w)
          {
            const origW = Math.round(w * wRatio);
            const origIndex = origHOffset + origW * origWStride + c;
            const val = origData[origIndex] * (normalized_div[c]) - normalized_mean[c];
            dstInput[idx] = val;
            idx++;
          }
        }
      } 
      resolve();
    });
  },

自定义的推理函数

推理接口接收数个键值对input,具体需要参照自己的onnx模型,在Netron查看相应的模型信息

我这里只有1个输入,对应的名字为"images",接收(1,3,300,300)性质的图像数组

我这里有2个输出,对应的名字是"794"和"output",分别对应相应类别的置信度(1*10*2)&框的坐标信息(1*10*4),这里的10对应10个预测框,2代表有2个类别

接着就是获取某一类别(比如前景)最大置信度的索引并取出其框的信息

然后绘制在canvas上

当然也可以设置阈值比如0.5,前景类别置信度大于0.5的就保留,然后根据得到的index取出框的信息,绘制到canvas上,或者只取类别和对应的置信度,根据自己的需求处理

javascript 复制代码
  infer(imgData){
    this.session.run({
      "images":{
        shape: [1, this.data.imageChannel, this.data.imageHeight, this.data.imageWidth],
        data: imgData.buffer,
        type: 'float32',
      }
    }).then((res)=>{
      let box = new Float32Array(res.output.data)
      let score = new Float32Array(res[794].data)
      // console.log(box)
      let num = new Float32Array(score)
      var maxVar = num[0];
      var index = 0;
      for (var i = 0; i < num.length; i+=2)
      {
        if (maxVar < num[i])
        {
            maxVar = num[i]   
            index = i/2   
        }
      }
      this.setData({
        xmin:box[index*4],
        xmax:box[index*4+2],
        ymin:box[index*4+1],
        ymax:box[index*4+3]
      })
      this.drawRectangle()
    })
  },

自定义的绘制框函数

这里用的是微信新的canvas接口

javascript 复制代码
  drawRectangle(){
    wx.createSelectorQuery().select('#myCanvas')
      .fields({node:true,size:true})
          .exec((res)=>{
            const canvas=res[0].node
            const ctx=canvas.getContext('2d')
            const dpr = wx.getSystemInfoSync().pixelRatio
            canvas.width = res[0].width * dpr
            canvas.height = res[0].height * dpr
            ctx.scale(dpr, dpr)
            ctx.strokeStyle='red'
            ctx.lineWidth=2
            console.log(this.data.xmin, this.data.ymin, this.data.xmax, this.data.ymax)
            ctx.strokeRect(this.data.xmin, this.data.ymin, this.data.xmax, this.data.ymax,canvas.width,canvas.height)
          })
  }

代码总览

index.js

javascript 复制代码
Page({
  session:null,
  data: {
    src : '',
    windowWidth:0,
    imageWidth : 300,
    imageHeight : 300,
    imageChannel : 3,
    xmin:0,
    ymin:0,
    xmax:0,
    ymax:0
  },
  onLoad(){
    this.setData({
      windowWidth:wx.getSystemInfoSync().windowWidth*0.9
    })
    this.InitSession()
  },
  oneFrame(){
    const context=wx.createCameraContext()
    const camCallback=(frame)=>{
      // 处理图片数据
      var dstInput=new Float32Array(this.data.imageChannel*this.data.imageWidth*this.data.imageHeight)
      this.preProcess(frame,dstInput)
      // 推理得到结果
      this.infer(dstInput)
      // 关闭监听
      listener.stop()
    }
    const listener=context.onCameraFrame(camCallback)
    listener.start()
  },
  downloadFile(fileID, onCall = () => {}) {
    return new Promise((resolve, reject) => {
      const task = wx.cloud.downloadFile({
        fileID,
        success: res => resolve(res),
      })
      task.onProgressUpdate((res) => {
        if (onCall(res) == false) {
          task.abort()
        }
      })
    })
  },
  preProcess(frame, dstInput) {
    return new Promise((resolve, reject) =>
    {
      const origData = new Uint8Array(frame.data);
      const hRatio = frame.height / this.data.imageHeight;
      const wRatio = frame.width / this.data.imageWidth;
      const origHStride = frame.width * 4;
      const origWStride = 4;
      const mean = [0.485, 0.456, 0.406]
      // Reverse of std = [0.229, 0.224, 0.225]
      const reverse_div = [4.367, 4.464, 4.444]
      const ratio = 1 / 255.0
      const normalized_div = [ratio / reverse_div[0], ratio * reverse_div[1], ratio * reverse_div[2]];
      const normalized_mean = [mean[0] * reverse_div[0], mean[1] * reverse_div[1], mean[2] * reverse_div[2]];
      var idx = 0;
      for (var c = 0; c < this.data.imageChannel; ++c)
      {
        for (var h = 0; h < this.data.imageHeight; ++h)
        {
          const origH = Math.round(h * hRatio);
          const origHOffset = origH * origHStride;
          for (var w = 0; w < this.data.imageWidth; ++w)
          {
            const origW = Math.round(w * wRatio);
            const origIndex = origHOffset + origW * origWStride + c;
            const val = origData[origIndex] * (normalized_div[c]) - normalized_mean[c];
            dstInput[idx] = val;
            idx++;
          }
        }
      } 
      resolve();
    });
  },
  infer(imgData){
    this.session.run({
      "images":{
        shape: [1, this.data.imageChannel, this.data.imageHeight, this.data.imageWidth],
        data: imgData.buffer,
        type: 'float32',
      }
    }).then((res)=>{
      let box = new Float32Array(res.output.data)
      let score = new Float32Array(res[794].data)
      // console.log(box)
      let num = new Float32Array(score)
      var maxVar = num[0];
      var index = 0;
      for (var i = 0; i < num.length; i+=2)
      {
        if (maxVar < num[i])
        {
            maxVar = num[i]   
            index = i/2   
        }
      }
      this.setData({
        xmin:box[index*4],
        xmax:box[index*4+2],
        ymin:box[index*4+1],
        ymax:box[index*4+3]
      })
      this.drawRectangle()
    })
  },
  InitSession()
  {
    return new Promise(resolve=>{
      const cloudPath = 'cloud://cloud1-8gcwcxqrb8722e9e.636c-cloud1-8gcwcxqrb8722e9e-1324077753/mobilnet.onnx'
      const lastindex=cloudPath.lastIndexOf('/')
      const filename=cloudPath.substring(lastindex+1)
      const modelPath = `${wx.env.USER_DATA_PATH}/`+filename;
      // 判断之前是否已经下载过onnx模型
      wx.getFileSystemManager().access({
      path: modelPath,
      success: (res) =>
      {
        console.log("file already exist at: " + modelPath)
        this.createInferenceSession(modelPath)
        setInterval(this.oneFrame, 1000)
        resolve()
      },
      fail: (res) => {
        console.error(res)
        wx.cloud.init();
        console.log("begin download model");
        this.downloadFile(cloudPath, function(r) {
        console.log(`下载进度:${r.progress}%,已下载${r.totalBytesWritten}B,共${r.totalBytesExpectedToWrite}B`)
      }).then(result => {
          wx.getFileSystemManager().saveFile({
            tempFilePath:result.tempFilePath,
            filePath: modelPath,
            success: (res) => { // 注册回调函数
              console.log(res)
              const modelPath = res.savedFilePath;
              console.log("save onnx model at path: " + modelPath)
              this.createInferenceSession(modelPath)
              setInterval(this.oneFrame, 1000)
              resolve()
            },
            fail(res) {
              console.error(res)
            }
          })
        });
      }
      })
    })
  },
  createInferenceSession(modelPath) {
    return new Promise((resolve, reject) => {
      this.session = wx.createInferenceSession({
        model: modelPath,
        precisionLevel : 4,
        allowNPU : false,
        allowQuantize: false,
      });

      // 监听error事件
      this.session.onError((error) => {
        console.error(error);
        reject(error);
      });
      this.session.onLoad(() => {
        resolve();
      });
    })
  },
  drawRectangle(){
    wx.createSelectorQuery().select('#myCanvas')
      .fields({node:true,size:true})
          .exec((res)=>{
            const canvas=res[0].node
            const ctx=canvas.getContext('2d')
            const dpr = wx.getSystemInfoSync().pixelRatio
            canvas.width = res[0].width * dpr
            canvas.height = res[0].height * dpr
            ctx.scale(dpr, dpr)
            ctx.strokeStyle='red'
            ctx.lineWidth=2
            console.log(this.data.xmin, this.data.ymin, this.data.xmax, this.data.ymax)
            ctx.strokeRect(this.data.xmin, this.data.ymin, this.data.xmax, this.data.ymax,canvas.width,canvas.height)
          })
  }
})

index.wxss

javascript 复制代码
.c1{
  width: 100%;
  align-items: center;
  text-align: center;
  display: flex;
  flex-direction: column;
}
.camera{
  width: 100%;
}
#myCanvas{
  width: 100%;
  height: 100%;
}

index.wxml

javascript 复制代码
<view class="c1">
<camera class="camera" binderror="error" mode="normal" style="width: 90%; height: {{windowWidth}}px;">
  <canvas id="myCanvas" type="2d"></canvas>
</camera>
</view> 

flask部署

微信小程序负责把图像数据或帧数据传到服务器,在服务器用falsk搭建相关模型运行环境,将接收到的图像数据或帧数据预处理后输入模型里,在将结果返回给微信小程序,微信小程序再显示结果。

我这里给的例子是传送帧数据的,也就是实时检测。

前端

在前端,获得帧数据后,因为帧数据的格式是一维RGBA数组,为了将其转成png,方便服务器处理,把帧数据绘制到画布上,再导出为png送入服务器。接收到服务器的结果后,将检测框绘制到相机的界面,需要在<camera>标签里加上<canvas>标签,然后画上矩形框,并在下方显示分类结果。

主体代码框架

javascript 复制代码
Page({
  data: {
    windowWidth:wx.getSystemInfoSync().windowWidth*1.33,
    boxNum:'',
  },
  // 自定义实时检测的频率,这里是800ms检测一次
  // http://t.csdnimg.cn/rLLLw 具体见此地址
  onLoad(){
    setInterval(this.oneProcessFrame, 800);
  },
})
javascript 复制代码
  oneProcessFrame(){
    const context = wx.createCameraContext();
    const data={"pngData":null}
    const CamFramCall = (frame)=>{
      // 调整显示页面的相机画面,为了使显示页面的横宽比等于frame数据的横宽比
      // 在画框的时候,模型跑出来的检测框坐标是相对于输入的图像的大小
      // 如果显示画面和输入框的比例不匹配,就会出现检测框不完整或者检测框有部分跑到画面外的情况
      // 微信小程序的frame,我没有找到官方提供的可以修改尺寸的API,所以用了这个办法
      //当然还有一种思路,将frame进行裁剪,使frame包含的图片信息正好对应显示画面的信息(像素一一对应)
      this.setData({
        windowWidth:frame.height/frame.width*wx.getSystemInfoSync().windowWidth*0.9
      })
      // 调用自定义函数将frame转png,然后把png数据绑定到传送给服务器的data
      // 再将data传给服务器
      // 这里用了异步编程,只有帧数据顺利转成png才发送给服务器,确保模型接收正确数据
      this.base64ToPNG(frame).then((pngData)=>{
        data["pngData"]=pngData
        this.interWithServer(data)
      })
      // 这里已经处理完一帧的数据,如果不关闭监听相机,那么微信小程序会持续触发相机帧数据回调函数,导致小程序卡顿,资源浪费
      console.log('完成一次帧循环')
      listener.stop()
    }
    // 定义相机帧回调函数
    const listener = context.onCameraFrame(CamFramCall);
    开启监听
    listener.start()
  },

自定义帧数据转base64的函数

参考http://t.csdnimg.cn/2hc7k

这里增加了异步编程的语句,更合理

javascript 复制代码
  base64ToPNG(frame){
    return new Promise(resolve=>{
      const query = wx.createSelectorQuery()
      query.select('#canvas')
        .fields({node:true,size:true})
        .exec((res)=>{
          const canvas=res[0].node
          const ctx=canvas.getContext('2d')
          canvas.width=frame.width
          canvas.height=frame.height
          var imageData=ctx.createImageData(canvas.width,canvas.height)
          var ImgU8Array = new Uint8ClampedArray(frame.data);
          for(var i=0;i<ImgU8Array.length;i+=4){
            imageData.data[0+i]=ImgU8Array[i+0]
            imageData.data[1+i]=ImgU8Array[i+1]
            imageData.data[2+i]=ImgU8Array[i+2]
            imageData.data[3+i]=ImgU8Array[i+3]
          }
          ctx.putImageData(imageData,0,0,0,0,canvas.width,canvas.height)
          resolve(canvas.toDataURL())
        })
    })
  },

自定义传数据到服务器函数

javascript 复制代码
  interWithServer(data){
    const header = {
      'content-type': 'application/x-www-form-urlencoded'
    };
    wx.request({
      // 填上自己的服务器地址(下面这个是我的服务器内网地址,仅供展示)
      url: 'http://172.16.3.186:5000/predict',
      method: 'POST',
      header: header,
      data: data,
      success: (res) => {
        console.log(res.data['xmin'],res.data['ymin'],res.data['xmax'],res.data['ymax'])
        // 调用自定义的画框函数
       this.drawRect(res.data['xmin'],res.data['ymin'],res.data['xmax'],res.data['ymax'])
      },
      fail: () => {
        wx.showToast({
          title: 'Failed to process frame!',
          icon: 'none',
        });
        // 如果与服务器交互失败,清空画布
        ctx.clearRect(0,0,canvas.width,canvas.height)
      }
    });
  },

自定义的画检测框函数

javascript 复制代码
  drawRect(x1,y1,x2,y2){
    wx.createSelectorQuery().select('#myCanvas')
    .fields({node:true,size:true})
        .exec((res)=>{
          const canvas=res[0].node
          const ctx=canvas.getContext('2d')
          canvas.width=wx.getSystemInfoSync().windowWidth*0.9
          canvas.height=this.data.windowWidth
          ctx.clearRect(0,0,canvas.width,canvas.height)
          ctx.strokeStyle='red'
          ctx.lineWidth=2
          ctx.strokeRect(x1,y1,x2,y2)
        })
  },

index.js

javascript 复制代码
Page({
  data: {
    windowWidth:wx.getSystemInfoSync().windowWidth*1.33,
    boxNum:'',
  },
  onLoad(){
    setInterval(this.oneProcessFrame, 800);
  },
  oneProcessFrame(){
    const context = wx.createCameraContext();
    const data={"pngData":null}
    const CamFramCall = (frame)=>{
      this.setData({
        windowWidth:frame.height/frame.width*wx.getSystemInfoSync().windowWidth*0.9
      })
      this.base64ToPNG(frame).then((pngData)=>{
        data["pngData"]=pngData
        this.interWithServer(data)
      })
      console.log('完成一次帧循环')
      listener.stop()
    }
    const listener = context.onCameraFrame(CamFramCall);
    listener.start()
  },
  base64ToPNG(frame){
    return new Promise(resolve=>{
      const query = wx.createSelectorQuery()
      query.select('#canvas')
        .fields({node:true,size:true})
        .exec((res)=>{
          const canvas=res[0].node
          const ctx=canvas.getContext('2d')
          canvas.width=frame.width
          canvas.height=frame.height
          var imageData=ctx.createImageData(canvas.width,canvas.height)
          var ImgU8Array = new Uint8ClampedArray(frame.data);
          for(var i=0;i<ImgU8Array.length;i+=4){
            imageData.data[0+i]=ImgU8Array[i+0]
            imageData.data[1+i]=ImgU8Array[i+1]
            imageData.data[2+i]=ImgU8Array[i+2]
            imageData.data[3+i]=ImgU8Array[i+3]
          }
          ctx.putImageData(imageData,0,0,0,0,canvas.width,canvas.height)
          resolve(canvas.toDataURL())
        })
    })
  },
  drawRect(x1,y1,x2,y2){
    wx.createSelectorQuery().select('#myCanvas')
    .fields({node:true,size:true})
        .exec((res)=>{
          const canvas=res[0].node
          const ctx=canvas.getContext('2d')
          canvas.width=wx.getSystemInfoSync().windowWidth*0.9
          canvas.height=this.data.windowWidth
          ctx.clearRect(0,0,canvas.width,canvas.height)
          ctx.strokeStyle='red'
          ctx.lineWidth=2
          ctx.strokeRect(x1,y1,x2,y2)
        })
  },
  interWithServer(data){
    const header = {
      'content-type': 'application/x-www-form-urlencoded'
    };
    wx.request({
      url: 'http://172.16.3.186:5000/predict',
      method: 'POST',
      header: header,
      data: data,
      success: (res) => {
        console.log(res.data['xmin'],res.data['ymin'],res.data['xmax'],res.data['ymax'])
        this.drawRect(res.data['xmin'],res.data['ymin'],res.data['xmax'],res.data['ymax'])
      },
      fail: () => {
        wx.showToast({
          title: 'Failed to process frame!',
          icon: 'none',
        });
        ctx.clearRect(0,0,canvas.width,canvas.height)
      }
    });
  },
  onUnload(){
  }
})

index.wxml

html 复制代码
<view class="c1">
  <camera class="camera" binderror="error" mode="normal" style="width: 90%; height: {{windowWidth}}px;">
    <canvas id="myCanvas" type="2d"></canvas>
  </camera>
  <view class="cla">类别:{{className}}</view>
  <view class="num">数量:{{boxNum}}</view>
  <canvas id="canvas" hidden="true" type="2d"></canvas>
</view> 

index.wxss

javascript 复制代码
.c1{
  width: 100%;
  align-items: center;
  text-align: center;
  display: flex;
  flex-direction: column;
}
.camera{
  width: 100%;
}
#myCanvas{
  width: 100%;
  height: 100%;
}
#canvas{
  width: 100%;
}

后端

接收数据,预处理图像,送入模型,得到初始结果,转化初始结果得到最终结果,返回数据到前端

这里仅作演示,不提供完整项目运行代码和依赖项

python 复制代码
from deploy.infer import Detector
from PIL import Image
import cv2
import numpy as np
import io
from gevent import monkey
import base64
from flask import Flask, jsonify, request
from gevent.pywsgi import WSGIServer
monkey.patch_all()
app = Flask(__name__)

model_dir = "inferer2 fewshot\infer" # 模型路径
save_path = "output"  # 推理结果保存路径

# 推理参数设置
detector = Detector(
    model_dir,
    device='CPU',
    run_mode='paddle',
    trt_min_shape=1,
    trt_max_shape=1280,
    trt_opt_shape=640,
    trt_calib_mode=False,
    cpu_threads=1,
    enable_mkldnn=False,
    enable_mkldnn_bfloat16=False,
    output_dir=save_path,
    threshold=0.1)

// 推理函数,接收预处理后的数据,返回最终结果
def infer_start(img, threshold=0.2):
    results = detector.predict_image([img[:, :, ::-1]], visual=False)
    np_boxes=results['boxes']
    expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
    np_boxes = np_boxes[expect_boxes, :]
    if len(np_boxes)>0:
        for dt in np_boxes:
            clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
            xmin, ymin, xmax, ymax = bbox
            print('class_id:{:d}, confidence:{:.4f}, left_top:[{:.2f},{:.2f}],'
                'right_bottom:[{:.2f},{:.2f}]'.format(
                    int(clsid), score, xmin, ymin, xmax, ymax))

            return jsonify({"class_name":"行人","prob":float(score),"xmin":int(xmin),"ymin":int(ymin),"xmax":int(xmax),"ymax":int(ymax)})
    else:
        return jsonify({"class_name":"未检测到红火蚁","prob":0,"xmin":0,"ymin":0,"xmax":0,"ymax":0})


    
// 交互主函数
@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        // 得到png数据,进行预处理
        img_base64 = request.form.get('frameData')
        if img_base64!='':
            img_base64 = img_base64.replace("data:image/png;base64,", "")
            img_base64 = base64.b64decode(img_base64)
            img = Image.open(io.BytesIO(img_base64))
            img=img.convert('RGB')
            img=np.array(img)
            // 调用推理函数并将结果返回
            return infer_start(img)

        else:
            return "数据为空"
        
if __name__ == '__main__':
    server = WSGIServer(('0.0.0.0', 5000), app)
    server.serve_forever()
相关推荐
985小水博一枚呀5 分钟前
【深度学习|可视化】如何以图形化的方式展示神经网络的结构、训练过程、模型的中间状态或模型决策的结果??
人工智能·python·深度学习·神经网络·机器学习·计算机视觉·cnn
guanpinkeji1 小时前
卡牌抽卡机小程序:市场发展下的创新
小程序·团队开发·小程序开发·抽卡机·抽卡机小程序·卡牌·卡牌小程序
LluckyYH2 小时前
代码随想录Day 46|动态规划完结,leetcode题目:647. 回文子串、516.最长回文子序列
数据结构·人工智能·算法·leetcode·动态规划
I592O9297832 小时前
二二复制模式小程序商城开发
小程序
古猫先生2 小时前
YMTC Xtacking 4.0(Gen5)技术深度分析
服务器·人工智能·科技·云计算
一水鉴天2 小时前
智能工厂的软件设计 “程序program”表达式,即 接口模型的代理模式表达式
开发语言·人工智能·中间件·代理模式
Hiweir ·3 小时前
机器翻译之创建Seq2Seq的编码器、解码器
人工智能·pytorch·python·rnn·深度学习·算法·lstm
Element_南笙3 小时前
数据结构_1、基本概念
数据结构·人工智能
FutureUniant3 小时前
GitHub每日最火火火项目(9.21)
人工智能·计算机视觉·ai·github·音视频
菜♕卷3 小时前
深度学习-03 Pytorch
人工智能·pytorch·深度学习