果蔬识别系统性能优化之路(五)

目录

前情提要

果蔬识别系统性能优化之路(四)

剩下问题

  1. 新建store_feature表,关联storeCode和featureId表,对数据库进行规范化,创建一个新的表来映射storeCode与feature的关系,从而可以使用简单的WHERE条件来充分利用索引
  2. 实现对特征向量ivf的增删改查

解决方案

新建storeFeature表

  1. 新建store表,storeFeature表
typescript 复制代码
import { Entity, PrimaryGeneratedColumn, Column, OneToMany } from 'typeorm';
import { StoreFeature } from '../../feature/entities/store-feature.entity';

@Entity()
export class Store {
  @PrimaryGeneratedColumn()
  id: number;

  @Column({ unique: true })
  storeCode: string;

  @Column({ nullable: true })
  storeName: string;

  @OneToMany(() => StoreFeature, (storeFeature) => storeFeature.store)
  storeFeatures: StoreFeature[];
}
typescript 复制代码
import { Entity, ManyToOne, JoinColumn, PrimaryGeneratedColumn } from 'typeorm';
import { Store } from '../../store/entities/store.entity';
import { Feature } from './feature.entity';

@Entity()
export class StoreFeature {
  @PrimaryGeneratedColumn()
  id: number;

  @ManyToOne(() => Store, { onDelete: 'CASCADE' })
  @JoinColumn({ name: 'storeCode', referencedColumnName: 'storeCode' })
  store: Store;

  @ManyToOne(() => Feature, { onDelete: 'CASCADE' })
  @JoinColumn({ name: 'featureId', referencedColumnName: 'id' })
  feature: Feature;
}

storeFeature表关联store表和feature表

  1. feature.service大改造
typescript 复制代码
import { Injectable } from '@nestjs/common';
import { CreateFeatureDto } from './dto/create-feature.dto';
import { Feature } from './entities/feature.entity';
import { InjectRepository } from '@nestjs/typeorm';
import { Repository, In } from 'typeorm';
import { RedisService } from '../redis/redis.service';
import { HttpService } from '@nestjs/axios';
import { firstValueFrom } from 'rxjs';
import * as FormData from 'form-data';
import { Img } from '../img/entities/img.entity';
import { Store } from '../store/entities/store.entity';
import { StoreFeature } from './entities/store-feature.entity';

@Injectable()
export class FeatureService {
  constructor(
    @InjectRepository(Feature)
    private readonly featureRepository: Repository<Feature>,
    @InjectRepository(Img)
    private readonly imgRepository: Repository<Img>,
    @InjectRepository(Store)
    private readonly storeRepository: Repository<Store>,
    @InjectRepository(StoreFeature)
    private readonly storeFeatureRepository: Repository<StoreFeature>,
    private readonly httpService: HttpService,
    private readonly redisService: RedisService,
  ) {
  }

  /**
   * 创建
   * @param file
   * @param createFeatureDto
   * @param needSync //是否需要同步redis,默认为true
   */
  async create(file: Express.Multer.File, createFeatureDto: CreateFeatureDto, needSync: boolean = true): Promise<Feature> {
    const img = this.imgRepository.create({
      img: file.buffer,
    });
    await this.imgRepository.save(img);
    const [feature, store] = await Promise.all([
      new Promise(async (resolve) => {
        const feature: Feature = this.featureRepository.create({
          ...createFeatureDto,
          imgId: img.id,
        });
        await this.featureRepository.save(feature);
        resolve(feature);
      }),
      new Promise(async (resolve) => {
        let store = await this.storeRepository.findOne({ where: { storeCode: createFeatureDto.storeCode } });
        if (!store) {
          store = this.storeRepository.create({
            storeCode: createFeatureDto.storeCode,
            storeName: createFeatureDto.storeName,
          });
          await this.storeRepository.save(store);
        }
        resolve(store);
      }),
    ]);
    const storeFeature = this.storeFeatureRepository.create({
      feature,
      store,
    });
    await this.storeFeatureRepository.save(storeFeature);
    needSync && await this.syncRedis(createFeatureDto.storeCode);
    return feature as Feature;
  }

  /**
   * 同步redis
   * @param storeCode
   */
  async syncRedis(storeCode: string) {
    const url = 'http://localhost:5000/sync'; // Python 服务的 URL
    const s = Date.now();
    const response = await firstValueFrom(this.httpService.post(url, { storeCode }));
    const { ids } = response.data;
    await this.redisService.set(`${storeCode}-featureDatabase`, JSON.stringify(ids));
    const e = Date.now();
    console.log(`门店:${storeCode},同步redis耗时:${e - s}ms`);
  }

  /**
   * 查询所有
   * @param storeCode
   * @param selectP
   */
  async findAll(storeCode: string, selectP?: string[]) {
    return await this.featureRepository
      .createQueryBuilder('feature')
      .select(selectP)
      .innerJoin(StoreFeature, 'storeFeature', 'feature.id = storeFeature.featureId')
      .innerJoin(Store, 'store', 'storeFeature.storeCode = store.storeCode')
      .where('store.storeCode = :storeCode', { storeCode })
      .getMany();
  }

  /**
   * 查询特性及其关联的图像
   * @param storeCode
   */
  async findAllWithImage(storeCode: string): Promise<Feature[]> {
    return await this.featureRepository.createQueryBuilder('feature')
      .leftJoinAndSelect('feature.img', 'img')
      .innerJoin(StoreFeature, 'storeFeature', 'feature.id = storeFeature.featureId')
      .innerJoin(Store, 'store', 'storeFeature.storeCode = store.storeCode')
      .where('store.storeCode = :storeCode', { storeCode })
      .getMany();
  }

  /**
   * 删除门店所有数据
   * @param storeCode
   */
  async removeAll(storeCode: string): Promise<void> {
    const store = await this.storeRepository.findOne({ where: { storeCode }, relations: ['storeFeatures'] });
    if (!store) {
      return;
    }
    // 批量删除 storeFeatures 和 store
    if (store.storeFeatures.length > 0) {
      await this.storeFeatureRepository
        .query('DELETE FROM store_feature WHERE id IN (?)', [store.storeFeatures.map(sf => sf.id)]);
    }
    await this.storeRepository.remove(store);  // 删除 store
    const unreferencedFeatures = await this.featureRepository
      .createQueryBuilder('feature')
      .leftJoinAndSelect('feature.img', 'img')
      .leftJoin('feature.storeFeatures', 'storeFeature')
      .where('storeFeature.id IS NULL') // 这里的条件确保我们只选择那些没有其他引用的 feature
      .getMany();
    // 批量删除未引用的 features
    if (unreferencedFeatures.length > 0) {
      for (const feature of unreferencedFeatures) {
        await this.remove(feature);
      }
    }
    await this.redisService.del(`${storeCode}-featureDatabase`);
    await this.syncRedis(storeCode);
  }

  /**
   * 预测
   * @param file
   * @param num
   * @param storeCode
   * @param justPredict
   * @param needList
   */
  async predict(
    file: Express.Multer.File,
    num: string = '5',
    storeCode: string,
    justPredict: string = 'false',
    needList: boolean = false,
  ) {
    const PYTHON_SERVICE_URL = 'http://localhost:5000/predict'; // Python service URL
    const REDIS_KEY_PREFIX = '-featureDatabase';
    const startTime = Date.now();
    const numInt = parseInt(num);
    const isJustPredict = justPredict === 'true';

    try {
      // Prepare form data
      const formData = new FormData();
      formData.append('file', file.buffer, file.originalname);
      formData.append('storeCode', storeCode);
      formData.append('justPredict', justPredict);

      // Send request to Python service
      const response = await firstValueFrom(this.httpService.post(PYTHON_SERVICE_URL, formData));
      const { features, index, predictTime } = response.data;

      if (isJustPredict) {
        return this.buildResponse([], features, predictTime, startTime, numInt);
      }

      // Retrieve feature database from Redis
      const featureDatabaseStr = await this.redisService.get(`${storeCode}${REDIS_KEY_PREFIX}`);
      if (!featureDatabaseStr) {
        return this.buildResponse([], features, predictTime, startTime, numInt);
      }

      // Parse the Redis result and filter the IDs
      const featureDatabase = JSON.parse(featureDatabaseStr);
      const ids = index
        .map((idx: number) => featureDatabase[idx]);

      if (!ids.length) {
        return this.buildResponse([], features, predictTime, startTime, numInt);
      }

      // Query for features in the database
      const featureList = await this.featureRepository.createQueryBuilder('feature')
        .where('feature.id IN (:...ids)', { ids })
        .orderBy(`FIELD(feature.id, ${ids.map((id: any) => `'${id}'`).join(', ')})`, 'ASC')
        .getMany();

      // Filter to ensure unique labels
      const uniqueList = this.filterUniqueFeatures(featureList, numInt);

      const result = this.buildResponse(uniqueList, features, predictTime, startTime, numInt);
      return needList ? { ...result, featureList: featureList.map(({ features, ...rest }) => rest) } : result;
    } catch (error) {
      throw new Error(`Prediction failed: ${error.message}`);
    }
  }

  private filterUniqueFeatures(featureList: any[], limit: number) {
    const uniqueList = [];
    for (const feature of featureList) {
      if (!uniqueList.some(f => f.label === feature.label)) {
        uniqueList.push(feature);
      }
      if (uniqueList.length === limit) break;
    }
    return uniqueList;
  }

  private buildResponse(list: any[], features: any, predictTime: string, startTime: number, num: number) {
    const totalTime = `${Date.now() - startTime}ms`;
    return {
      predictTime,
      [`top${num}`]: list.map(({ features, ...rest }) => rest),
      features,
      totalTime,
    };
  }

  /**
   * 计算余弦相似度
   * @param vecA
   * @param vecB
   */
  cosineSimilarity(vecA: number[], vecB: number[]): number {
    if (vecA.length !== vecB.length) {
      throw new Error('Vectors must be of the same length');
    }
    const dotProduct = vecA.reduce((sum, value, index) => sum + value * vecB[index], 0);
    const magnitudeA = Math.sqrt(vecA.reduce((sum, value) => sum + value * value, 0));
    const magnitudeB = Math.sqrt(vecB.reduce((sum, value) => sum + value * value, 0));
    return dotProduct / (magnitudeA * magnitudeB);
  }

  /**
   * 查找相似
   * @param inputFeatures
   * @param num
   * @param storeCode
   */
  async findTopNSimilar(inputFeatures: number[], num: number, storeCode: string): Promise<{
    label: string;
    similarity: number
  }[]> {
    const featureDatabaseStr = await this.redisService.get(`${storeCode}-featureDatabase`);
    if (!featureDatabaseStr) {
      return [];
    }
    const featureDatabase = JSON.parse(featureDatabaseStr);
    const similarities = featureDatabase.map(({ features, label }) => {
      let similarity = 0;
      if (features) {
        similarity = this.cosineSimilarity(inputFeatures, features);
      }
      return { label: label as string, similarity: similarity as number };
    });

    similarities.sort((a: { similarity: number; }, b: { similarity: number; }) => b.similarity - a.similarity);

    const uniqueLabels = new Set<string>();
    const topNUnique: { label: string; similarity: number; }[] = [];
    for (const item of similarities) {
      if (!uniqueLabels.has(item.label as string)) {
        uniqueLabels.add(item.label);
        item.similarity = Math.round(item.similarity * 100) / 100;
        topNUnique.push(item);
        if (topNUnique.length === num) break;
      }
    }
    return topNUnique;
  }

  /**
   * 根据名称查询
   * @param label
   * @param storeCode
   */
  async getByName(label: string, storeCode: string): Promise<Feature[]> {
    return await this.featureRepository
      .createQueryBuilder('feature')
      .leftJoinAndSelect('feature.img', 'img')
      .innerJoin(StoreFeature, 'storeFeature', 'feature.id = storeFeature.featureId')
      .innerJoin(Store, 'store', 'storeFeature.storeCode = store.storeCode')
      .where('store.storeCode = :storeCode', { storeCode })
      .andWhere('feature.label = :label', { label })
      .getMany();
  }

  /**
   * 根据名称向量个数查询
   * @param label
   * @param storeCode
   */
  async getCountByLabel(label: string, storeCode: string): Promise<number> {
    return await this.featureRepository
      .createQueryBuilder('feature')
      .leftJoinAndSelect('feature.img', 'img')
      .innerJoin(StoreFeature, 'storeFeature', 'feature.id = storeFeature.featureId')
      .innerJoin(Store, 'store', 'storeFeature.storeCode = store.storeCode')
      .where('store.storeCode = :storeCode', { storeCode })
      .andWhere('feature.label = :label', { label })
      .getCount();
  }

  /**
   * 批量学习
   * @param files
   * @param createFeatureDto
   */
  async batchStudy(files: Express.Multer.File[], createFeatureDto: CreateFeatureDto) {
    const list = [];
    for (const file of files) {
      try {
        const { features: f } = await this.predict(file, '5', createFeatureDto.storeCode, 'true');
        const feature = await this.create(file, {
          ...createFeatureDto,
          features: f,
        }, false);
        // 创建一个副本,不包含 `features` 属性
        const { features, ...featureWithoutFeatures } = feature;
        // 将不包含 `features` 属性的对象推送到数组中
        list.push(featureWithoutFeatures);
      } catch (e) {
        console.error(e);
      }
    }
    await this.syncRedis(createFeatureDto.storeCode);
    return list;
  }

  /**
   * 删除门店的特征值数据
   * @param feature
   */
  async remove(feature: Feature) {
    await this.featureRepository.remove(feature);
    await this.imgRepository.remove(feature.img);
  }

  /**
   * 批量删除
   * @param ids
   * @param storeCode
   */
  async batchRemove(ids: string, storeCode: string) {
    const list = ids.split(',').map(id => +id);
    // 批量查询所有相关的 Feature
    const features = await this.featureRepository.find({
      where: { id: In(list) },
      relations: ['img', 'storeFeatures'],
    });
    for (const feature of features) {
      feature && await this.remove(feature);
      await this.storeFeatureRepository.remove(feature.storeFeatures);
    }
    await this.syncRedis(storeCode);
  }

  /**
   * 导入数据
   * @param storeCode
   * @param sourceStoreCode
   * @param storeName
   */
  async importData(storeCode: string, sourceStoreCode?: string, storeName?: string) {
    let storeFeatures = [];
    // 第一步:查询指定 storeCode 关联的所有 featureId
    const storeFeatureIds = await this.storeFeatureRepository
      .createQueryBuilder('storeFeature')
      .select('storeFeature.featureId')
      .where('storeFeature.storeCode = :storeCode', { storeCode })
      .getRawMany();

    // 提取出 featureId 列表
    const featureIdsToExclude = storeFeatureIds.map(row => row.featureId);
    let distinctFeatureIds = [];
    if (featureIdsToExclude.length === 0) {
      distinctFeatureIds = await this.storeFeatureRepository
        .createQueryBuilder('storeFeature')
        .select('DISTINCT storeFeature.featureId')  // 确保 featureId 唯一
        .getRawMany();
    } else {
      // 第二步:排除这些 featureId,并确保 featureId 唯一
      distinctFeatureIds = await this.storeFeatureRepository
        .createQueryBuilder('storeFeature')
        .select('DISTINCT storeFeature.featureId')  // 确保 featureId 唯一
        .where('storeFeature.featureId NOT IN (:...featureIdsToExclude)', { featureIdsToExclude })  // 排除 featureId
        .getRawMany();
    }
    const featureIds = distinctFeatureIds.map(record => record.featureId);
    if (!sourceStoreCode) {
      storeFeatures = await this.featureRepository
        .createQueryBuilder('feature')
        .leftJoinAndSelect('feature.img', 'img')
        .whereInIds(featureIds)
        .getMany();
    } else {
      storeFeatures = await this.featureRepository
        .createQueryBuilder('feature')
        .leftJoinAndSelect('feature.img', 'img')
        .innerJoin('feature.storeFeatures', 'storeFeatures')
        .whereInIds(featureIds)
        .andWhere('storeFeatures.storeCode = :storeCode', { storeCode: sourceStoreCode })  // 使用参数化查询
        .getMany();
    }
    let targetStore = await this.storeRepository.findOne({ where: { storeCode: storeCode } });
    if (!targetStore) {
      targetStore = this.storeRepository.create({
        storeCode: storeCode,
        storeName: storeName,
      });
      await this.storeRepository.save(targetStore);
    }
    // Create new StoreFeature records for the target storeCode
    const newStoreFeatures = storeFeatures.map((feature: Feature) => ({
      store: targetStore,
      feature, // Reuse the existing feature
    }));
    // Save new StoreFeature records
    const storeFeatureInstances = this.storeFeatureRepository.create(newStoreFeatures);
    await this.storeFeatureRepository.save(storeFeatureInstances);
    await this.syncRedis(storeCode);
    return `同步完成,共导入${storeFeatures.length}条数据`;
  }

  async init() {
    const distinctStoreCodes = await this.storeRepository
      .createQueryBuilder('store')
      .select('store.storeCode')
      .distinct(true)
      .getRawMany();
    const syncList = [];
    for (const row of distinctStoreCodes) {
      const storeCode = row.store_storeCode;
      syncList.push(this.syncRedis(storeCode));
    }
    await Promise.all(syncList);
    console.log('初始化完成');
  }
}
  1. 结果:并没有提升多少,但好在关系更清晰,为之后的拓展打了基础

实现ivf的动态增删改查

  1. 结论:ivf无法在不训练只增加的情况下进行新增向量的识别,所以每次新增向量必须重新进行训练和添加
  2. python端ivf改造
    detect.py(识别和同步方法)
python 复制代码
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
import numpy as np
import time
import gc
from ivf import IVFPQ
from feature import get_feature_by_store_code
import orjson
from concurrent.futures import ThreadPoolExecutor

# 加载预训练的 MobileNetV2 模型,不包含顶部的分类层
model = MobileNetV2(input_shape=(224, 224, 3), weights='imagenet', include_top=False, pooling='avg')


class MainDetect:
    # 初始化
    def __init__(self):
        super().__init__()
        # 模型初始化
        self.image_id = None
        self.image_features = None
        self.model = tf.keras.models.load_model("models/custom/my-model.h5")
        self.ivfObj = {}

    def classify_image(self, image_data, store_code, just_predict):
        # Load and preprocess image
        img = tf.image.decode_image(image_data, channels=3)
        img = tf.image.resize(img, [224, 224])
        img = tf.expand_dims(img, axis=0)  # Add batch dimension

        # Run model prediction
        start_time = time.time()
        outputs = model.predict(img)
        # outputs = self.model.predict(outputs)
        # prediction = tf.divide(outputs, tf.norm(outputs))
        i = []
        if just_predict == "false":
            if store_code + '-featureDatabase' in self.ivfObj:
                i = self.ivfObj[store_code + '-featureDatabase'].search(outputs)
                i = i.flatten().tolist()

        end_time = time.time()

        # Calculate elapsed time
        elapsed_time = end_time - start_time

        # Flatten the outputs and return them
        # output_data = prediction.numpy().flatten().tolist()
        output_data = outputs.flatten().tolist()

        # Force garbage collection to free up memory
        del img, outputs, end_time, start_time  # Ensure variables are deleted
        gc.collect()

        return {"outputs": output_data, "time": f"{elapsed_time * 1000:.2f}ms", "index": i}

    def sync(self, store_code):
        if store_code + '-featureDatabase' in self.ivfObj:
            del self.ivfObj[store_code + '-featureDatabase']
        data = get_feature_by_store_code(store_code)

        if len(data) == 0:
            return []
        else:
            def parse_features(item):
                return orjson.loads(item['features'])

            with ThreadPoolExecutor() as executor:
                features_list = list(executor.map(parse_features, data))
            # 提取所有特征并转换为 NumPy 数组
            features = np.array(features_list, dtype=np.float32)
            self.ivfObj[store_code + '-featureDatabase'] = IVFPQ(features)
            ids = [item['id'] for item in data]
            return ids

ivf.py(ivf构造)

python 复制代码
import faiss
import numpy as np

num_threads = 8
faiss.omp_set_num_threads(num_threads)


class IVFPQ:
    def __init__(self, features, nlist=100, m=16, n_bits=8):
        d = features.shape[1]
        # 创建量化器
        quantizer = faiss.IndexFlatL2(d)  # 使用L2距离进行量化
        self.index = faiss.IndexIVFFlat(quantizer, d, nlist)
        # self.index = faiss.IndexIVFPQ(quantizer, d, nlist, m, n_bits)
        # 训练索引
        count = 3900
        if features.size >= count * d:
            self.index.train(features)
            if features.size > 1000 * d:
                batch_size = 1000  # 每次处理1000个特征
                for i in range(0, len(features), batch_size):
                    self.index.add(features[i:i + batch_size])
            else:
                self.index.add(features)
        else:
            points = int(count - features.size / d)
            np.random.seed(points)
            xb = np.random.random((points, d)).astype('float32')  # 模拟数据库中的特征向量
            combined_features = np.vstack((features, xb))  # Stack them vertically
            # 训练索引
            self.index.train(combined_features)
            self.index.add(combined_features)  # 将特征向量添加到索引中

    def search(self, xq, k=100):
        d, i = self.index.search(xq, k)
        return i

    def add(self, xb):
        self.index.add(xb)

    def train(self, xb):
        self.index.train(xb)

    def sync(self, features):
        for i in range(len(features)):
            self.add(features[i])

结语

这个项目优化到这差不多告一段落了,后续还有啥优化点会继续跟进,稍后会把整个架构图和功能点都梳理一遍

相关推荐
轻口味23 分钟前
【每日学点鸿蒙知识】AVCodec、SmartPerf工具、web组件加载、监听键盘的显示隐藏、Asset Store Kit
前端·华为·harmonyos
alikami26 分钟前
【若依】用 post 请求传 json 格式的数据下载文件
前端·javascript·json
Kai HVZ35 分钟前
python爬虫----爬取视频实战
爬虫·python·音视频
古希腊掌管学习的神37 分钟前
[LeetCode-Python版]相向双指针——611. 有效三角形的个数
开发语言·python·leetcode
m0_7482448340 分钟前
StarRocks 排查单副本表
大数据·数据库·python
B站计算机毕业设计超人1 小时前
计算机毕业设计PySpark+Hadoop中国城市交通分析与预测 Python交通预测 Python交通可视化 客流量预测 交通大数据 机器学习 深度学习
大数据·人工智能·爬虫·python·机器学习·课程设计·数据可视化
路人甲ing..1 小时前
jupyter切换内核方法配置问题总结
chrome·python·jupyter
学术头条1 小时前
清华、智谱团队:探索 RLHF 的 scaling laws
人工智能·深度学习·算法·机器学习·语言模型·计算语言学
18号房客1 小时前
一个简单的机器学习实战例程,使用Scikit-Learn库来完成一个常见的分类任务——**鸢尾花数据集(Iris Dataset)**的分类
人工智能·深度学习·神经网络·机器学习·语言模型·自然语言处理·sklearn
feifeikon1 小时前
机器学习DAY3 : 线性回归与最小二乘法与sklearn实现 (线性回归完)
人工智能·机器学习·线性回归