前言
中秋即将来临,预祝大家中秋快乐!!!中秋是中国的传统节日,是一个与家人团聚、赏月、品尝美食的重要时刻。在这个特殊的节日中,有很多元素都与中秋息息相关,像玉兔,月亮,蟾蜍,吃螃蟹,月饼等,为了能够准确地找出语义相关的图片,本文教大家搭建一个简易的文搜图模型。
该模型地主要思路是将图像及其文本内容的表示投影到相同的嵌入空间中,使得文本嵌入在所描述的图像的嵌入附近,最后通过计算向量相似度返回 topk
个图片即可。
demo很简陋。。。
基础
- java
- milvus
- 能够向量化的工具
干货
1、文件上传向量化,简单粗暴直接存本地,由于使用openAi来进行向量化,而openAi提供的embeddings模型只能向量化文本,所以采取给图片增加描述的方式(可以采用能向量化图片、视频的模型,思路是通用的)
kotlin
@PostMapping(value = "/local", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
public Dict local(@RequestParam("file") MultipartFile file) {
if (file.isEmpty()) {
return Dict.create().set("code", 400).set("message", "文件内容为空");
}
String fileName = file.getOriginalFilename();
//文件名
String rawFileName = StrUtil.subBefore(fileName, ".", true);
String fileType = StrUtil.subAfter(fileName, ".", true);
//文件全路径
String localFilePath = StrUtil.appendIfMissing(fileTempPath, "/") + rawFileName + "-" + DateUtil.current(false) + "." + fileType;
try {
file.transferTo(new File(localFilePath));
List<Float> embeddingList = getFloats(rawFileName);
Map<String, List<?>> value = new HashMap<>();
value.put("id", Arrays.asList(UUID.randomUUID().toString()));
value.put("local_file_path", Arrays.asList(fileName));
value.put("text_feature", Arrays.asList(embeddingList));
milvusOperateService.insert("image_test", value);
} catch (IOException e) {
log.error("【文件上传至本地】失败,绝对路径:{}", localFilePath);
return Dict.create().set("code", 500).set("message", "文件上传失败");
}catch (Exception e) {
log.error("【向量化失败】失败,绝对路径:{}", localFilePath);
return Dict.create().set("code", 500).set("message", "向量化失败");
}
log.info("【文件上传至本地】绝对路径:{}", localFilePath);
return Dict.create().set("code", 200).set("message", "上传成功").set("data", Dict.create().set("fileName", fileName).set("filePath", localFilePath));
}
2、通过输入的文本,查询出含义最相近的图片
dart
@GetMapping(value = "/findFeature")
public Dict findFeature(@RequestParam("fileDesc") String fileDesc) {
List<String> images =new ArrayList<>();
try {
List<Float> embeddingList = getFloats(fileDesc);
Map<String, List> searchByFeature = milvusOperateService.searchByFeature("image_test", 2, "text_feature",
"{"ef":10}","", Arrays.asList("local_file_path"), Arrays.asList(embeddingList));
for(Object o:searchByFeature.get("local_file_path")){
images.add((String) o);
}
} catch (Exception e) {
log.error("【向量搜索】失败,text:{}", fileDesc);
return Dict.create().set("code", 500).set("message", "向量搜索失败");
}
return Dict.create().set("code", 200).set("message", "success").set("data",images );
}
效果展示
输入一些和中秋元素有关的文本描述(需要事先上传相关的图片以及描述,这里我偷懒了,直接把描述写在文件名上了),前端结构很简单,比较粗暴。。。
向量工具类
scss
/**
* milvus操作,不考虑分区
* milvus删除操作后的数据会被标记、自动压缩,超过Time Travel保存的时间被清除
*/
@Component
@Slf4j
public class MilvusOperateService {
// 管理链接对象的池子
private GenericObjectPool<MilvusServiceClient> milvusServiceClientGenericObjectPool;
private MilvusOperateService() {
// 私有构造方法创建一个对象池工厂
MilvusPoolFactory milvusPoolFactory = new MilvusPoolFactory();
// 对象池配置 (暂时使用默认的就行了)
GenericObjectPoolConfig objectPoolConfig = new GenericObjectPoolConfig();
// int cpu = Runtime.getRuntime().availableProcessors();
// int minIdle = cpu * 6;
// objectPoolConfig.setMinIdle(minIdle);
// objectPoolConfig.setMaxIdle(minIdle * 8);
// objectPoolConfig.setMaxTotal(minIdle * 16);
//删除废弃对象的配置设置
AbandonedConfig abandonedConfig = new AbandonedConfig();
//在Maintenance的时候检查是否有泄漏
abandonedConfig.setRemoveAbandonedOnMaintenance(true);
//borrow 的时候检查泄漏
abandonedConfig.setRemoveAbandonedOnBorrow(true);
//如果一个对象borrow之后20秒还没有返还给pool,认为是泄漏的对象
abandonedConfig.setRemoveAbandonedTimeout(20);
// 对象池
milvusServiceClientGenericObjectPool = new GenericObjectPool(milvusPoolFactory, objectPoolConfig);
milvusServiceClientGenericObjectPool.setAbandonedConfig(abandonedConfig);
milvusServiceClientGenericObjectPool.setTimeBetweenEvictionRunsMillis(5000); //5秒运行一次维护任务
log.info("MilvusOperateService-对象池创建成功");
}
/**
* 创建一个Collection 类似于创建关系型数据库中的一张表
*
* @param collection 集合名称
* @param collectionDesc 集合描述
* @param fieldTypes 建表字段
* @return
*/
public Boolean createCollection(String collection, String collectionDesc, List<FieldType> fieldTypes) throws Exception {
MilvusServiceClient milvusServiceClient = null;
try {
// 通过对象池管理对象
milvusServiceClient = milvusServiceClientGenericObjectPool.borrowObject();
CreateCollectionParam.Builder builder = CreateCollectionParam.newBuilder()
.withCollectionName(collection)
.withDescription(collectionDesc);
for (FieldType fieldType : fieldTypes) {
builder.addFieldType(fieldType);
}
CreateCollectionParam createCollectionReq = builder.build();
R<RpcStatus> result = milvusServiceClient.createCollection(createCollectionReq);
log.info("MilvusOperateService-创建集合结果" + result.getStatus() + " [0为成功]");
if (result.getStatus().intValue() == 0) {
return true;
}
return false;
} catch (Exception e) {
e.printStackTrace();
log.info("MilvusOperateService-创建集合结果 失败,err:{}", e.getMessage());
throw e;
} finally {
// 回收对象到对象池
if (milvusServiceClient != null) {
milvusServiceClientGenericObjectPool.returnObject(milvusServiceClient);
}
}
}
/**
* 把集合加载到内存中(milvus查询前必须把数据加载到内存中)
*
* @param collection
*/
public void loadingLocation(String collection) throws Exception {
MilvusServiceClient milvusServiceClient = null;
try {
// 通过对象池管理对象
milvusServiceClient = milvusServiceClientGenericObjectPool.borrowObject();
R<RpcStatus> rpcStatusR = milvusServiceClient.loadCollection(
LoadCollectionParam.newBuilder()
.withCollectionName(collection)
.build());
log.info("MilvusOperateService-加载集合结果" + rpcStatusR + " [0为成功]");
} catch (Exception e) {
log.info("MilvusOperateService-加载集合结果 失败,Exception:{}", e.getMessage());
e.printStackTrace();
throw e;
} finally {
// 回收对象到对象池
if (milvusServiceClient != null) {
milvusServiceClientGenericObjectPool.returnObject(milvusServiceClient);
}
}
}
/**
* 查看集合详情
*
* @param collection
*/
public GetCollectionStatisticsResponse getCollectionInfo(String collection) throws Exception {
MilvusServiceClient milvusServiceClient = null;
try {
// 通过对象池管理对象
milvusServiceClient = milvusServiceClientGenericObjectPool.borrowObject();
R<GetCollectionStatisticsResponse> collectionStatistics = milvusServiceClient.getCollectionStatistics(
GetCollectionStatisticsParam.newBuilder()
.withCollectionName(collection)
.build());
GetCollStatResponseWrapper wrapperCollectionStatistics = new GetCollStatResponseWrapper(collectionStatistics.getData());
log.info("Collection row count: " + wrapperCollectionStatistics.getRowCount());
return collectionStatistics.getData();
} catch (Exception e) {
log.info("MilvusOperateService-查看集合详情 失败,Exception:{}", e.getMessage());
e.printStackTrace();
throw e;
} finally {
// 回收对象到对象池
if (milvusServiceClient != null) {
milvusServiceClientGenericObjectPool.returnObject(milvusServiceClient);
}
}
}
/**
* 列出所有集合
*/
public ShowCollectionsResponse getCollectionList() throws Exception {
MilvusServiceClient milvusServiceClient = null;
try {
// 通过对象池管理对象
milvusServiceClient = milvusServiceClientGenericObjectPool.borrowObject();
R<ShowCollectionsResponse> collections = milvusServiceClient.showCollections(ShowCollectionsParam.newBuilder().build());
return collections.getData();
} catch (Exception e) {
log.info("MilvusOperateService-列出所有集合 失败,Exception:{}", e.getMessage());
e.printStackTrace();
throw e;
} finally {
// 回收对象到对象池
if (milvusServiceClient != null) {
milvusServiceClientGenericObjectPool.returnObject(milvusServiceClient);
}
}
}
/**
* 检查集合是否存在
*
* @param collection
*/
public Boolean hasCollection(String collection) throws Exception {
MilvusServiceClient milvusServiceClient = null;
try {
// 通过对象池管理对象
milvusServiceClient = milvusServiceClientGenericObjectPool.borrowObject();
R<Boolean> hasCollection = milvusServiceClient.hasCollection(
HasCollectionParam.newBuilder()
.withCollectionName(collection)
.build());
return hasCollection.getData();
} catch (Exception e) {
log.info("MilvusOperateService-检查集合是否存在 失败,Exception:{}", e.getMessage());
e.printStackTrace();
throw e;
} finally {
// 回收对象到对象池
if (milvusServiceClient != null) {
milvusServiceClientGenericObjectPool.returnObject(milvusServiceClient);
}
}
}
/**
* 在搜索或查询后从内存中释放集合以减少内存使用
*
* @param collection
*/
public void freedLoaction(String collection) throws Exception {
MilvusServiceClient milvusServiceClient = null;
try {
// 通过对象池管理对象
milvusServiceClient = milvusServiceClientGenericObjectPool.borrowObject();
R<RpcStatus> rpcStatusR = milvusServiceClient.releaseCollection(
ReleaseCollectionParam.newBuilder()
.withCollectionName(collection)
.build());
log.info("MilvusOperateService-释放集合结果" + rpcStatusR + " [0为成功]");
} catch (Exception e) {
log.info("MilvusOperateService-释放集合结果 失败,Exception:{}", e.getMessage());
e.printStackTrace();
throw e;
} finally {
// 回收对象到对象池
if (milvusServiceClient != null) {
milvusServiceClientGenericObjectPool.returnObject(milvusServiceClient);
}
}
}
/**
* 删除一个集合(标记删除)
*
*
* @param collection
*/
private void delCollection(String collection) throws Exception {
MilvusServiceClient milvusServiceClient = null;
try {
// 通过对象池管理对象
milvusServiceClient = milvusServiceClientGenericObjectPool.borrowObject();
R<RpcStatus> rpcStatusR = milvusServiceClient.dropCollection(
DropCollectionParam.newBuilder()
.withCollectionName(collection)
.build());
log.info("MilvusOperateService-删除集合结果" + rpcStatusR.getStatus() + " [0为成功]");
} catch (Exception e) {
log.info("MilvusOperateService-删除集合结果 失败,Exception:{}", e.getMessage());
e.printStackTrace();
throw e;
} finally {
// 回收对象到对象池
if (milvusServiceClient != null) {
milvusServiceClientGenericObjectPool.returnObject(milvusServiceClient);
}
}
}
public String insert(String collectionName, Map<String, List<?>> values) throws Exception {
MilvusServiceClient milvusServiceClient = null;
try {
// 通过对象池管理对象
milvusServiceClient = milvusServiceClientGenericObjectPool.borrowObject();
List<InsertParam.Field> fields = new ArrayList<>();
for (String code : values.keySet()) {
fields.add(new InsertParam.Field(code, values.get(code)));
}
InsertParam insertParam = InsertParam.newBuilder()
.withCollectionName(collectionName)
.withFields(fields)
.build();
R<MutationResult> insertResult = milvusServiceClient.insert(insertParam);
if (insertResult.getStatus() == 0) {
return insertResult.getData().getIDs().getStrId().getData(0);
} else {
log.info("MilvusOperateService-插入数据结果 失败,err:{}", insertResult.getMessage());
throw new RuntimeException(insertResult.getMessage());
}
// milvusServiceClient.flush()
} catch (Exception e) {
log.info("MilvusOperateService-插入数据结果 失败,Exception:{}", e.getMessage());
e.printStackTrace();
throw e;
} finally {
// 回收对象到对象池
if (milvusServiceClient != null) {
milvusServiceClientGenericObjectPool.returnObject(milvusServiceClient);
}
}
}
/**
* 刷新数据
*
* @param collectionNames
* @return
*/
public void flush(List<String> collectionNames) throws Exception {
MilvusServiceClient milvusServiceClient = null;
try {
// 通过对象池管理对象
milvusServiceClient = milvusServiceClientGenericObjectPool.borrowObject();
FlushParam flushParam = FlushParam.newBuilder()
.withCollectionNames(collectionNames)
.build();
R<FlushResponse> responseR = milvusServiceClient.flush(flushParam);
if (responseR.getStatus() != 0) {
log.info("MilvusOperateService-flush结果 失败,err:{}", responseR.getMessage());
throw new RuntimeException(responseR.getMessage());
}
} catch (Exception e) {
log.info("MilvusOperateService-flush结果 失败,Exception:{}", e.getMessage());
e.printStackTrace();
throw e;
} finally {
// 回收对象到对象池
if (milvusServiceClient != null) {
milvusServiceClientGenericObjectPool.returnObject(milvusServiceClient);
}
}
}
/**
* 删除数据
*
* @param collectionName 集合名
* @param deleteExpr 布尔表达式
* @return
*/
public void delete(String collectionName, String deleteExpr) throws Exception {
MilvusServiceClient milvusServiceClient = null;
try {
// 通过对象池管理对象
milvusServiceClient = milvusServiceClientGenericObjectPool.borrowObject();
DeleteParam deleteParam = DeleteParam.newBuilder()
.withCollectionName(collectionName)
.withExpr(deleteExpr)
.build();
R<MutationResult> deleteResult = milvusServiceClient.delete(deleteParam);
if (deleteResult.getStatus() != 0) {
log.info("MilvusOperateService-删除数据结果 失败,err:{}", deleteResult.getMessage());
throw new RuntimeException(deleteResult.getMessage());
}
} catch (Exception e) {
log.info("MilvusOperateService-删除数据结果 失败,Exception:{}", e.getMessage());
e.printStackTrace();
throw e;
} finally {
// 回收对象到对象池
if (milvusServiceClient != null) {
milvusServiceClientGenericObjectPool.returnObject(milvusServiceClient);
}
}
}
/**
* 根据向量搜索数据
*
* @param collection 集合名称
* @param topK 查询多少条相似结果
* @param VectorFieldName 查询的字段
* @param params 每种索引参数不同
* @param searchOutputFields 返回的字段
* @param searchVectors 用于搜索的向量
* @return
*/
public Map<String, List> searchByFeature(String collection, int topK, String VectorFieldName, String params,String expr,
List<String> searchOutputFields, List<?> searchVectors) throws Exception {
MilvusServiceClient milvusServiceClient = null;
try {
// 通过对象池管理对象
milvusServiceClient = milvusServiceClientGenericObjectPool.borrowObject();
// List<String> searchOutputFields = Arrays.asList("user_code", "user_name", "user_code");
SearchParam.Builder builder = SearchParam.newBuilder();
builder.withCollectionName(collection)
.withMetricType(MetricType.L2)
.withOutFields(searchOutputFields)
.withTopK(topK)
.withVectors(searchVectors)
.withVectorFieldName(VectorFieldName)
// .withParams("{"nprobe":10}")
.withParams(params);
if(!StringUtils.isNotBlank(expr)){
builder.withExpr(expr);
}
SearchParam searchParam = builder.build();
R<SearchResults> respSearch = milvusServiceClient.search(searchParam);
if (respSearch.getStatus() == 0) {
SearchResultsWrapper wrapperSearch = new SearchResultsWrapper(respSearch.getData().getResults());
Map<String, List> map = new HashMap();
for (String name : searchOutputFields) {
map.put(name, wrapperSearch.getFieldData(name, 0));
}
return map;
} else {
log.info("MilvusOperateService-根据向量搜索数据 失败,err:{}", respSearch.getMessage());
throw new RuntimeException(respSearch.getMessage());
}
} catch (Exception e) {
e.printStackTrace();
log.info("MilvusOperateService-根据向量搜索数据 失败,Exception:{}", e.getMessage());
throw e;
} finally {
// 回收对象到对象池
if (milvusServiceClient != null) {
milvusServiceClientGenericObjectPool.returnObject(milvusServiceClient);
}
}
}
public static void main(String[] args) throws Exception {
Random ran = new Random();
List<Long> book_id_array = new ArrayList<>();
List<Long> word_count_array = new ArrayList<>();
List<List<Float>> book_intro_array = new ArrayList<>();
for (long i = 0L; i < 2; ++i) {
book_id_array.add(i);
word_count_array.add(i + 10000);
List<Float> vector = new ArrayList<>();
for (int k = 0; k < 1536; ++k) {
vector.add(ran.nextFloat());
}
book_intro_array.add(vector);
}
System.out.println(book_intro_array);
}
}
scala
public class MilvusPoolFactory extends BasePooledObjectFactory<MilvusServiceClient> {
@Override
public MilvusServiceClient create() throws Exception {
ConnectParam connectParam = ConnectParam.newBuilder()
.withHost("192.168.1.68")
.withPort(19530)
.build();
return new MilvusServiceClient(connectParam);
}
@Override
public PooledObject<MilvusServiceClient> wrap(MilvusServiceClient milvusServiceClient) {
return new DefaultPooledObject<>(milvusServiceClient);
}
}
xml
<dependency>
<groupId>io.milvus</groupId>
<artifactId>milvus-sdk-java</artifactId>
<exclusions>
<exclusion>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-slf4j-impl</artifactId>
</exclusion>
</exclusions>
<version>2.2.2</version>
</dependency>