果蔬识别系统性能优化之路

目录

一级目录

二级目录

三级目录

前情提要

超详细前端AI蔬菜水果生鲜识别应用优化之路

当前问题

  1. indexddb在webview中确实性能有限,存储量上来后每次读取数据会有明显卡顿
  2. 目前的余弦相邻算法是基于所有特征向量数据进行的计算,一旦数据量大了后计算量也是一个十分消耗性能的点
  3. 由于机器性能问题,本地化加载模型在每次进入页面后需要消耗很长一段时间进行模型的加载,体验相当不好

优化方案

1. 内存化

方式:为了解决indexddb读取速度的问题,最直接的方式就是把数据放内存对象中,提前将indexddb的数据取出来,然后在数据变化时进行同步

结果:确实省去了读取这一步会快很多,但是仍有隐患,webview分配的内存是否这种占用内存的方式

2. 原生化

本地化模型加载可以在原生层面进行模型的加载,识别,学习,相当于把整套方案在原生端实现一次,通过bridge调用原生相关方法完成识别,但目前暂无学习原生的意向,所以搁置

3. 接口化

方式:将数据存储在mysql,识别单独起一个python服务,通过接口调用,利用服务器的性能优势使识别速度提升,保证网络消耗在合理范围即可

行动

最终选择了接口化,在保证网络的情况下,识别速度在300ms内可被接受

  1. nestjs搭建服务端,分feature和img两个模块
  2. mysql搭建数据库,建立feature和img两个表,通过imgId和feature表关联,同时feature表包含storeCode字段,用来管理门店
  3. 搭建redis,代替内存化方案,实现快速读取
  4. 搭建python服务端,与nestjs服务端进行通信,进行识别结果的传输
  5. 通过IVF方式提升计算速度,保证大量特征值情况下仍然可以快速算出相似结果

实现

  1. python端,flask搭建http服务,当然后续可以改成其他和服务端通信方式,提升速度
  2. 实现识别接口恶化同步接口,用于识别图片特征值和同步数据库存储的特征向量
    app.py
python 复制代码
from flask import Flask, request, jsonify
from flask_cors import CORS
from detect import MainDetect
from tensorflow.keras import layers, models

app = Flask(__name__)
CORS(app)  # 允许所有路由上的跨域请求
detector = MainDetect()


@app.route('/')
def home():
    return "Welcome to the Vegetable Recognize App!"


@app.route('/predict', methods=['POST'])
def predict():
    if 'file' not in request.files:
        return 'No file part', 400

    file = request.files['file']
    store_code = request.form["storeCode"]
    if file.filename == '':
        return 'No selected file', 400

    try:
        image_data = file.read()
        data = detector.classify_image(image_data, store_code)
        outputs = data["outputs"]
        time = data["time"]
        index = data["index"]
        return jsonify({"predictTime": time, "features": outputs, "index": index})

    except Exception as e:
        return jsonify({'error': str(e)}), 500


@app.route('/sync', methods=['POST'])
def sync():
    data = request.get_json()
    arr = data.get('data')
    store_code = data.get('storeCode')
    detector.sync(store_code, arr)
    return jsonify({"message": 'ok'})

detect.py

python 复制代码
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNet
from tensorflow.keras.applications.mobilenet import preprocess_input, decode_predictions
from PIL import Image
import numpy as np
import cv2
import time
import io
import gc
from ivf import IVFPQ

model = MobileNet(input_shape=(224, 224, 3), weights='imagenet', include_top=False, pooling='avg')


class MainDetect:
    # 初始化
    def __init__(self):
        super().__init__()
        # 模型初始化
        self.image_id = None
        self.image_features = None
        self.model = tf.keras.models.load_model("models/custom/my-model.h5")
        self.ivfObj = {}

    def classify_image(self, image_data, store_code):
        # Load and preprocess image
        img = tf.image.decode_image(image_data, channels=3)
        img = tf.image.resize(img, [224, 224])
        img = tf.expand_dims(img, axis=0)  # Add batch dimension

        # Run model prediction
        start_time = time.time()
        outputs = model.predict(img)
        # outputs = self.model.predict(outputs)
        # prediction = tf.divide(outputs, tf.norm(outputs))
        i = []
        if store_code + '-featureDatabase' in self.ivfObj:
            i = self.ivfObj[store_code + '-featureDatabase'].search(outputs)
            i = i.flatten().tolist()

        end_time = time.time()

        # Calculate elapsed time
        elapsed_time = end_time - start_time

        # Flatten the outputs and return them
        # output_data = prediction.numpy().flatten().tolist()
        output_data = outputs.flatten().tolist()

        # Force garbage collection to free up memory
        del img, outputs, end_time, start_time  # Ensure variables are deleted
        gc.collect()

        return {"outputs": output_data, "time": f"{elapsed_time * 1000:.2f}ms", "index": i}

    def sync(self, store_code, data):
        if store_code + '-featureDatabase' in self.ivfObj:
            del self.ivfObj[store_code + '-featureDatabase']
        self.ivfObj[store_code + '-featureDatabase'] = IVFPQ()
        for item in data:
            feature = item['features']
            self.ivfObj[store_code + '-featureDatabase'].add(np.array([feature], dtype=np.float32))
        return 'ok'

ivf.py

python 复制代码
import faiss
import numpy as np


class IVFPQ:
    def __init__(self, d=1024, nlist=1, m=16, n_bits=8):
        # 创建量化器
        quantizer = faiss.IndexFlatL2(d)  # 使用L2距离进行量化
        self.index = faiss.IndexIVFPQ(quantizer, d, nlist, m, n_bits)

        np.random.seed(1234)
        xb = np.random.random((256, d)).astype('float32')  # 模拟数据库中的特征向量
        # 训练索引
        self.index.train(xb)
        self.index.add(xb)  # 将特征向量添加到索引中

    def search(self, xq, k=50):
        d, i = self.index.search(xq, k)
        return i

    def add(self, xb):
        self.index.add(xb)

    def sync(self, features):
        for i in range(len(features)):
            self.add(features[i])
  1. nestjs进行接口的转发和数据处理
    识别的service,调用python服务,并通过返回的索引找到正确的目标label
typescript 复制代码
  /**
   * 预测
   * @param file
   * @param num
   * @param storeCode
   * @param justPredict
   */
  async predict(file: Express.Multer.File, num: string = '5', storeCode: string, justPredict: Boolean = false) {
    const url = 'http://localhost:5000/predict'; // Python 服务的 URL
    const startTime = Date.now();
    try {
      // 返回 Python 服务的响应数据
      const formData = new FormData();
      formData.append('file', file.buffer, file.originalname);
      formData.append('storeCode', storeCode);
      const response = await firstValueFrom(this.httpService.post(url, formData));
      const endTime = Date.now();
      const features = response.data.features;
      const index = response.data.index;
      if (justPredict) {
        return features;
      }
      // const top5 = await this.findTopNSimilar(features, parseInt(num), storeCode);

      const featureDatabaseStr = await this.redisService.get(`${storeCode}-featureDatabase`);
      if (!featureDatabaseStr) {
        return response.data = {
          ...response.data,
          [`top${num}`]: [],
          features,
          totalTime: `${endTime - startTime}ms`,
        };
      }
      const featureDatabase = JSON.parse(featureDatabaseStr);
      const list = [];
      index.forEach((i: number) => {
        const ide = i - 256;
        if (ide >= 0) {
          const item = featureDatabase[ide];
          if (!list.some(l => l.label === item.label)) {
            list.push({ label: item.label });
          }
        }
      });
      return response.data = {
        ...response.data,
        [`top${num}`]: list,
        features,
        totalTime: `${endTime - startTime}ms`,
      };
    } catch (error) {
      // 错误处理
      console.error('Error calling Python service:', error);
      throw error;
    }
  }

同步python端特征向量,在数据库的增删改查时进行调用

typescript 复制代码
  /**
   * 同步redis
   * @param storeCode
   */
  async syncRedis(storeCode: string) {
    const featureDatabase = await this.findAll(storeCode);
    await this.redisService.set(`${storeCode}-featureDatabase`, JSON.stringify(featureDatabase));
    const url = 'http://localhost:5000/sync'; // Python 服务的 URL
    await firstValueFrom(this.httpService.post(url, { data: featureDatabase, storeCode }));
  }

结语

没规划好,其实全部用python实现应该更自然,后续有时间再更新,先上线跑跑

相关推荐
一点媛艺2 小时前
Kotlin函数由易到难
开发语言·python·kotlin
魔道不误砍柴功3 小时前
Java 中如何巧妙应用 Function 让方法复用性更强
java·开发语言·python
_.Switch4 小时前
高级Python自动化运维:容器安全与网络策略的深度解析
运维·网络·python·安全·自动化·devops
测开小菜鸟5 小时前
使用python向钉钉群聊发送消息
java·python·钉钉
萧鼎6 小时前
Python并发编程库:Asyncio的异步编程实战
开发语言·数据库·python·异步
学地理的小胖砸6 小时前
【一些关于Python的信息和帮助】
开发语言·python
疯一样的码农6 小时前
Python 继承、多态、封装、抽象
开发语言·python
Python大数据分析@7 小时前
python操作CSV和excel,如何来做?
开发语言·python·excel
黑叶白树7 小时前
简单的签到程序 python笔记
笔记·python
Shy9604187 小时前
Bert完形填空
python·深度学习·bert