目录
前情提要
剩下问题
- 新建store_feature表,关联storeCode和featureId表,对数据库进行规范化,创建一个新的表来映射storeCode与feature的关系,从而可以使用简单的WHERE条件来充分利用索引
- 实现对特征向量ivf的增删改查
解决方案
新建storeFeature表
- 新建store表,storeFeature表
typescript
import { Entity, PrimaryGeneratedColumn, Column, OneToMany } from 'typeorm';
import { StoreFeature } from '../../feature/entities/store-feature.entity';
@Entity()
export class Store {
@PrimaryGeneratedColumn()
id: number;
@Column({ unique: true })
storeCode: string;
@Column({ nullable: true })
storeName: string;
@OneToMany(() => StoreFeature, (storeFeature) => storeFeature.store)
storeFeatures: StoreFeature[];
}
typescript
import { Entity, ManyToOne, JoinColumn, PrimaryGeneratedColumn } from 'typeorm';
import { Store } from '../../store/entities/store.entity';
import { Feature } from './feature.entity';
@Entity()
export class StoreFeature {
@PrimaryGeneratedColumn()
id: number;
@ManyToOne(() => Store, { onDelete: 'CASCADE' })
@JoinColumn({ name: 'storeCode', referencedColumnName: 'storeCode' })
store: Store;
@ManyToOne(() => Feature, { onDelete: 'CASCADE' })
@JoinColumn({ name: 'featureId', referencedColumnName: 'id' })
feature: Feature;
}
storeFeature表关联store表和feature表
- feature.service大改造
typescript
import { Injectable } from '@nestjs/common';
import { CreateFeatureDto } from './dto/create-feature.dto';
import { Feature } from './entities/feature.entity';
import { InjectRepository } from '@nestjs/typeorm';
import { Repository, In } from 'typeorm';
import { RedisService } from '../redis/redis.service';
import { HttpService } from '@nestjs/axios';
import { firstValueFrom } from 'rxjs';
import * as FormData from 'form-data';
import { Img } from '../img/entities/img.entity';
import { Store } from '../store/entities/store.entity';
import { StoreFeature } from './entities/store-feature.entity';
@Injectable()
export class FeatureService {
constructor(
@InjectRepository(Feature)
private readonly featureRepository: Repository<Feature>,
@InjectRepository(Img)
private readonly imgRepository: Repository<Img>,
@InjectRepository(Store)
private readonly storeRepository: Repository<Store>,
@InjectRepository(StoreFeature)
private readonly storeFeatureRepository: Repository<StoreFeature>,
private readonly httpService: HttpService,
private readonly redisService: RedisService,
) {
}
/**
* 创建
* @param file
* @param createFeatureDto
* @param needSync //是否需要同步redis,默认为true
*/
async create(file: Express.Multer.File, createFeatureDto: CreateFeatureDto, needSync: boolean = true): Promise<Feature> {
const img = this.imgRepository.create({
img: file.buffer,
});
await this.imgRepository.save(img);
const [feature, store] = await Promise.all([
new Promise(async (resolve) => {
const feature: Feature = this.featureRepository.create({
...createFeatureDto,
imgId: img.id,
});
await this.featureRepository.save(feature);
resolve(feature);
}),
new Promise(async (resolve) => {
let store = await this.storeRepository.findOne({ where: { storeCode: createFeatureDto.storeCode } });
if (!store) {
store = this.storeRepository.create({
storeCode: createFeatureDto.storeCode,
storeName: createFeatureDto.storeName,
});
await this.storeRepository.save(store);
}
resolve(store);
}),
]);
const storeFeature = this.storeFeatureRepository.create({
feature,
store,
});
await this.storeFeatureRepository.save(storeFeature);
needSync && await this.syncRedis(createFeatureDto.storeCode);
return feature as Feature;
}
/**
* 同步redis
* @param storeCode
*/
async syncRedis(storeCode: string) {
const url = 'http://localhost:5000/sync'; // Python 服务的 URL
const s = Date.now();
const response = await firstValueFrom(this.httpService.post(url, { storeCode }));
const { ids } = response.data;
await this.redisService.set(`${storeCode}-featureDatabase`, JSON.stringify(ids));
const e = Date.now();
console.log(`门店:${storeCode},同步redis耗时:${e - s}ms`);
}
/**
* 查询所有
* @param storeCode
* @param selectP
*/
async findAll(storeCode: string, selectP?: string[]) {
return await this.featureRepository
.createQueryBuilder('feature')
.select(selectP)
.innerJoin(StoreFeature, 'storeFeature', 'feature.id = storeFeature.featureId')
.innerJoin(Store, 'store', 'storeFeature.storeCode = store.storeCode')
.where('store.storeCode = :storeCode', { storeCode })
.getMany();
}
/**
* 查询特性及其关联的图像
* @param storeCode
*/
async findAllWithImage(storeCode: string): Promise<Feature[]> {
return await this.featureRepository.createQueryBuilder('feature')
.leftJoinAndSelect('feature.img', 'img')
.innerJoin(StoreFeature, 'storeFeature', 'feature.id = storeFeature.featureId')
.innerJoin(Store, 'store', 'storeFeature.storeCode = store.storeCode')
.where('store.storeCode = :storeCode', { storeCode })
.getMany();
}
/**
* 删除门店所有数据
* @param storeCode
*/
async removeAll(storeCode: string): Promise<void> {
const store = await this.storeRepository.findOne({ where: { storeCode }, relations: ['storeFeatures'] });
if (!store) {
return;
}
// 批量删除 storeFeatures 和 store
if (store.storeFeatures.length > 0) {
await this.storeFeatureRepository
.query('DELETE FROM store_feature WHERE id IN (?)', [store.storeFeatures.map(sf => sf.id)]);
}
await this.storeRepository.remove(store); // 删除 store
const unreferencedFeatures = await this.featureRepository
.createQueryBuilder('feature')
.leftJoinAndSelect('feature.img', 'img')
.leftJoin('feature.storeFeatures', 'storeFeature')
.where('storeFeature.id IS NULL') // 这里的条件确保我们只选择那些没有其他引用的 feature
.getMany();
// 批量删除未引用的 features
if (unreferencedFeatures.length > 0) {
for (const feature of unreferencedFeatures) {
await this.remove(feature);
}
}
await this.redisService.del(`${storeCode}-featureDatabase`);
await this.syncRedis(storeCode);
}
/**
* 预测
* @param file
* @param num
* @param storeCode
* @param justPredict
* @param needList
*/
async predict(
file: Express.Multer.File,
num: string = '5',
storeCode: string,
justPredict: string = 'false',
needList: boolean = false,
) {
const PYTHON_SERVICE_URL = 'http://localhost:5000/predict'; // Python service URL
const REDIS_KEY_PREFIX = '-featureDatabase';
const startTime = Date.now();
const numInt = parseInt(num);
const isJustPredict = justPredict === 'true';
try {
// Prepare form data
const formData = new FormData();
formData.append('file', file.buffer, file.originalname);
formData.append('storeCode', storeCode);
formData.append('justPredict', justPredict);
// Send request to Python service
const response = await firstValueFrom(this.httpService.post(PYTHON_SERVICE_URL, formData));
const { features, index, predictTime } = response.data;
if (isJustPredict) {
return this.buildResponse([], features, predictTime, startTime, numInt);
}
// Retrieve feature database from Redis
const featureDatabaseStr = await this.redisService.get(`${storeCode}${REDIS_KEY_PREFIX}`);
if (!featureDatabaseStr) {
return this.buildResponse([], features, predictTime, startTime, numInt);
}
// Parse the Redis result and filter the IDs
const featureDatabase = JSON.parse(featureDatabaseStr);
const ids = index
.map((idx: number) => featureDatabase[idx]);
if (!ids.length) {
return this.buildResponse([], features, predictTime, startTime, numInt);
}
// Query for features in the database
const featureList = await this.featureRepository.createQueryBuilder('feature')
.where('feature.id IN (:...ids)', { ids })
.orderBy(`FIELD(feature.id, ${ids.map((id: any) => `'${id}'`).join(', ')})`, 'ASC')
.getMany();
// Filter to ensure unique labels
const uniqueList = this.filterUniqueFeatures(featureList, numInt);
const result = this.buildResponse(uniqueList, features, predictTime, startTime, numInt);
return needList ? { ...result, featureList: featureList.map(({ features, ...rest }) => rest) } : result;
} catch (error) {
throw new Error(`Prediction failed: ${error.message}`);
}
}
private filterUniqueFeatures(featureList: any[], limit: number) {
const uniqueList = [];
for (const feature of featureList) {
if (!uniqueList.some(f => f.label === feature.label)) {
uniqueList.push(feature);
}
if (uniqueList.length === limit) break;
}
return uniqueList;
}
private buildResponse(list: any[], features: any, predictTime: string, startTime: number, num: number) {
const totalTime = `${Date.now() - startTime}ms`;
return {
predictTime,
[`top${num}`]: list.map(({ features, ...rest }) => rest),
features,
totalTime,
};
}
/**
* 计算余弦相似度
* @param vecA
* @param vecB
*/
cosineSimilarity(vecA: number[], vecB: number[]): number {
if (vecA.length !== vecB.length) {
throw new Error('Vectors must be of the same length');
}
const dotProduct = vecA.reduce((sum, value, index) => sum + value * vecB[index], 0);
const magnitudeA = Math.sqrt(vecA.reduce((sum, value) => sum + value * value, 0));
const magnitudeB = Math.sqrt(vecB.reduce((sum, value) => sum + value * value, 0));
return dotProduct / (magnitudeA * magnitudeB);
}
/**
* 查找相似
* @param inputFeatures
* @param num
* @param storeCode
*/
async findTopNSimilar(inputFeatures: number[], num: number, storeCode: string): Promise<{
label: string;
similarity: number
}[]> {
const featureDatabaseStr = await this.redisService.get(`${storeCode}-featureDatabase`);
if (!featureDatabaseStr) {
return [];
}
const featureDatabase = JSON.parse(featureDatabaseStr);
const similarities = featureDatabase.map(({ features, label }) => {
let similarity = 0;
if (features) {
similarity = this.cosineSimilarity(inputFeatures, features);
}
return { label: label as string, similarity: similarity as number };
});
similarities.sort((a: { similarity: number; }, b: { similarity: number; }) => b.similarity - a.similarity);
const uniqueLabels = new Set<string>();
const topNUnique: { label: string; similarity: number; }[] = [];
for (const item of similarities) {
if (!uniqueLabels.has(item.label as string)) {
uniqueLabels.add(item.label);
item.similarity = Math.round(item.similarity * 100) / 100;
topNUnique.push(item);
if (topNUnique.length === num) break;
}
}
return topNUnique;
}
/**
* 根据名称查询
* @param label
* @param storeCode
*/
async getByName(label: string, storeCode: string): Promise<Feature[]> {
return await this.featureRepository
.createQueryBuilder('feature')
.leftJoinAndSelect('feature.img', 'img')
.innerJoin(StoreFeature, 'storeFeature', 'feature.id = storeFeature.featureId')
.innerJoin(Store, 'store', 'storeFeature.storeCode = store.storeCode')
.where('store.storeCode = :storeCode', { storeCode })
.andWhere('feature.label = :label', { label })
.getMany();
}
/**
* 根据名称向量个数查询
* @param label
* @param storeCode
*/
async getCountByLabel(label: string, storeCode: string): Promise<number> {
return await this.featureRepository
.createQueryBuilder('feature')
.leftJoinAndSelect('feature.img', 'img')
.innerJoin(StoreFeature, 'storeFeature', 'feature.id = storeFeature.featureId')
.innerJoin(Store, 'store', 'storeFeature.storeCode = store.storeCode')
.where('store.storeCode = :storeCode', { storeCode })
.andWhere('feature.label = :label', { label })
.getCount();
}
/**
* 批量学习
* @param files
* @param createFeatureDto
*/
async batchStudy(files: Express.Multer.File[], createFeatureDto: CreateFeatureDto) {
const list = [];
for (const file of files) {
try {
const { features: f } = await this.predict(file, '5', createFeatureDto.storeCode, 'true');
const feature = await this.create(file, {
...createFeatureDto,
features: f,
}, false);
// 创建一个副本,不包含 `features` 属性
const { features, ...featureWithoutFeatures } = feature;
// 将不包含 `features` 属性的对象推送到数组中
list.push(featureWithoutFeatures);
} catch (e) {
console.error(e);
}
}
await this.syncRedis(createFeatureDto.storeCode);
return list;
}
/**
* 删除门店的特征值数据
* @param feature
*/
async remove(feature: Feature) {
await this.featureRepository.remove(feature);
await this.imgRepository.remove(feature.img);
}
/**
* 批量删除
* @param ids
* @param storeCode
*/
async batchRemove(ids: string, storeCode: string) {
const list = ids.split(',').map(id => +id);
// 批量查询所有相关的 Feature
const features = await this.featureRepository.find({
where: { id: In(list) },
relations: ['img', 'storeFeatures'],
});
for (const feature of features) {
feature && await this.remove(feature);
await this.storeFeatureRepository.remove(feature.storeFeatures);
}
await this.syncRedis(storeCode);
}
/**
* 导入数据
* @param storeCode
* @param sourceStoreCode
* @param storeName
*/
async importData(storeCode: string, sourceStoreCode?: string, storeName?: string) {
let storeFeatures = [];
// 第一步:查询指定 storeCode 关联的所有 featureId
const storeFeatureIds = await this.storeFeatureRepository
.createQueryBuilder('storeFeature')
.select('storeFeature.featureId')
.where('storeFeature.storeCode = :storeCode', { storeCode })
.getRawMany();
// 提取出 featureId 列表
const featureIdsToExclude = storeFeatureIds.map(row => row.featureId);
let distinctFeatureIds = [];
if (featureIdsToExclude.length === 0) {
distinctFeatureIds = await this.storeFeatureRepository
.createQueryBuilder('storeFeature')
.select('DISTINCT storeFeature.featureId') // 确保 featureId 唯一
.getRawMany();
} else {
// 第二步:排除这些 featureId,并确保 featureId 唯一
distinctFeatureIds = await this.storeFeatureRepository
.createQueryBuilder('storeFeature')
.select('DISTINCT storeFeature.featureId') // 确保 featureId 唯一
.where('storeFeature.featureId NOT IN (:...featureIdsToExclude)', { featureIdsToExclude }) // 排除 featureId
.getRawMany();
}
const featureIds = distinctFeatureIds.map(record => record.featureId);
if (!sourceStoreCode) {
storeFeatures = await this.featureRepository
.createQueryBuilder('feature')
.leftJoinAndSelect('feature.img', 'img')
.whereInIds(featureIds)
.getMany();
} else {
storeFeatures = await this.featureRepository
.createQueryBuilder('feature')
.leftJoinAndSelect('feature.img', 'img')
.innerJoin('feature.storeFeatures', 'storeFeatures')
.whereInIds(featureIds)
.andWhere('storeFeatures.storeCode = :storeCode', { storeCode: sourceStoreCode }) // 使用参数化查询
.getMany();
}
let targetStore = await this.storeRepository.findOne({ where: { storeCode: storeCode } });
if (!targetStore) {
targetStore = this.storeRepository.create({
storeCode: storeCode,
storeName: storeName,
});
await this.storeRepository.save(targetStore);
}
// Create new StoreFeature records for the target storeCode
const newStoreFeatures = storeFeatures.map((feature: Feature) => ({
store: targetStore,
feature, // Reuse the existing feature
}));
// Save new StoreFeature records
const storeFeatureInstances = this.storeFeatureRepository.create(newStoreFeatures);
await this.storeFeatureRepository.save(storeFeatureInstances);
await this.syncRedis(storeCode);
return `同步完成,共导入${storeFeatures.length}条数据`;
}
async init() {
const distinctStoreCodes = await this.storeRepository
.createQueryBuilder('store')
.select('store.storeCode')
.distinct(true)
.getRawMany();
const syncList = [];
for (const row of distinctStoreCodes) {
const storeCode = row.store_storeCode;
syncList.push(this.syncRedis(storeCode));
}
await Promise.all(syncList);
console.log('初始化完成');
}
}
- 结果:并没有提升多少,但好在关系更清晰,为之后的拓展打了基础
实现ivf的动态增删改查
- 结论:ivf无法在不训练只增加的情况下进行新增向量的识别,所以每次新增向量必须重新进行训练和添加
- python端ivf改造
detect.py(识别和同步方法)
python
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
import numpy as np
import time
import gc
from ivf import IVFPQ
from feature import get_feature_by_store_code
import orjson
from concurrent.futures import ThreadPoolExecutor
# 加载预训练的 MobileNetV2 模型,不包含顶部的分类层
model = MobileNetV2(input_shape=(224, 224, 3), weights='imagenet', include_top=False, pooling='avg')
class MainDetect:
# 初始化
def __init__(self):
super().__init__()
# 模型初始化
self.image_id = None
self.image_features = None
self.model = tf.keras.models.load_model("models/custom/my-model.h5")
self.ivfObj = {}
def classify_image(self, image_data, store_code, just_predict):
# Load and preprocess image
img = tf.image.decode_image(image_data, channels=3)
img = tf.image.resize(img, [224, 224])
img = tf.expand_dims(img, axis=0) # Add batch dimension
# Run model prediction
start_time = time.time()
outputs = model.predict(img)
# outputs = self.model.predict(outputs)
# prediction = tf.divide(outputs, tf.norm(outputs))
i = []
if just_predict == "false":
if store_code + '-featureDatabase' in self.ivfObj:
i = self.ivfObj[store_code + '-featureDatabase'].search(outputs)
i = i.flatten().tolist()
end_time = time.time()
# Calculate elapsed time
elapsed_time = end_time - start_time
# Flatten the outputs and return them
# output_data = prediction.numpy().flatten().tolist()
output_data = outputs.flatten().tolist()
# Force garbage collection to free up memory
del img, outputs, end_time, start_time # Ensure variables are deleted
gc.collect()
return {"outputs": output_data, "time": f"{elapsed_time * 1000:.2f}ms", "index": i}
def sync(self, store_code):
if store_code + '-featureDatabase' in self.ivfObj:
del self.ivfObj[store_code + '-featureDatabase']
data = get_feature_by_store_code(store_code)
if len(data) == 0:
return []
else:
def parse_features(item):
return orjson.loads(item['features'])
with ThreadPoolExecutor() as executor:
features_list = list(executor.map(parse_features, data))
# 提取所有特征并转换为 NumPy 数组
features = np.array(features_list, dtype=np.float32)
self.ivfObj[store_code + '-featureDatabase'] = IVFPQ(features)
ids = [item['id'] for item in data]
return ids
ivf.py(ivf构造)
python
import faiss
import numpy as np
num_threads = 8
faiss.omp_set_num_threads(num_threads)
class IVFPQ:
def __init__(self, features, nlist=100, m=16, n_bits=8):
d = features.shape[1]
# 创建量化器
quantizer = faiss.IndexFlatL2(d) # 使用L2距离进行量化
self.index = faiss.IndexIVFFlat(quantizer, d, nlist)
# self.index = faiss.IndexIVFPQ(quantizer, d, nlist, m, n_bits)
# 训练索引
count = 3900
if features.size >= count * d:
self.index.train(features)
if features.size > 1000 * d:
batch_size = 1000 # 每次处理1000个特征
for i in range(0, len(features), batch_size):
self.index.add(features[i:i + batch_size])
else:
self.index.add(features)
else:
points = int(count - features.size / d)
np.random.seed(points)
xb = np.random.random((points, d)).astype('float32') # 模拟数据库中的特征向量
combined_features = np.vstack((features, xb)) # Stack them vertically
# 训练索引
self.index.train(combined_features)
self.index.add(combined_features) # 将特征向量添加到索引中
def search(self, xq, k=100):
d, i = self.index.search(xq, k)
return i
def add(self, xb):
self.index.add(xb)
def train(self, xb):
self.index.train(xb)
def sync(self, features):
for i in range(len(features)):
self.add(features[i])
结语
这个项目优化到这差不多告一段落了,后续还有啥优化点会继续跟进,稍后会把整个架构图和功能点都梳理一遍