deeplearning4j使用vgg19图片向量比对springboot+es环境

文章目录


一、桌面创建两个目录读图

二、POM

xml 复制代码
<dependency>
    <groupId>org.springframework.data</groupId>
    <artifactId>spring-data-elasticsearch</artifactId>
</dependency>

<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-core</artifactId>
    <version>1.0.0-beta7</version>
</dependency>
<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-zoo</artifactId>
    <version>1.0.0-beta7</version>
</dependency>
<dependency>
    <groupId>org.elasticsearch</groupId>
    <artifactId>elasticsearch</artifactId>
</dependency>
<dependency>
    <groupId>org.elasticsearch.client</groupId>
    <artifactId>transport</artifactId>
</dependency>
<dependency>
    <groupId>org.elasticsearch.client</groupId>
    <artifactId>elasticsearch-rest-client</artifactId>
</dependency>
<dependency>
    <groupId>org.elasticsearch.plugin</groupId>
    <artifactId>transport-netty4-client</artifactId>
</dependency>

三、code

java 复制代码
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.zoo.model.VGG19;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;

import javax.annotation.PostConstruct;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

@Service("vgg19Service")
public class Vgg19ServiceImpl implements Vgg19Service {

    private static ComputationGraph vgg19Model;

    @PostConstruct
    public void init() throws IOException {
        VGG19 vgg19 = VGG19.builder().build();
        vgg19Model = (ComputationGraph) vgg19.initPretrained();
    }

    @Autowired
    private INDArrayPojoRepository indArrayPojoRepository;

    @Override
    public String find(MultipartFile file) throws IOException {

//        VGG19 vgg19 = VGG19.builder().build();
//         vgg19Model = (ComputationGraph) vgg19.initPretrained();


        String templateImagePath = "C:\\Users\\Administrator\\Desktop\\template\\1.png";

        // 图像文件夹路径
        String imageFolder = "C:\\Users\\Administrator\\Desktop\\target";

        // 加载模板图像
        NativeImageLoader imageLoader = new NativeImageLoader(224, 224, 3);
        INDArray templateImage = imageLoader.asMatrix(new File(templateImagePath));

        // 提取模板图像的特征向量
        INDArray templateFeatures = vgg19Model.outputSingle(templateImage);

        // 存储图像相似度的映射
        Map<String, Double> similarityMap = new HashMap<>();

        // 遍历图像文件夹
        File folder = new File(imageFolder);
        File[] imageFiles = folder.listFiles();
        long i = 1L;
        indArrayPojoRepository.deleteAll();
        if (imageFiles != null) {
            for (File imageFile : imageFiles) {
                // 加载当前图像
//                INDArray currentImage = imageLoader.asMatrix(imageFile);
//                // 提取当前图像的特征向量
//                INDArray currentFeatures = vgg19Model.outputSingle(currentImage);
//                long[] longVector = currentFeatures.toLongVector();
//                System.out.println(longVector);
//                double[] doubleVector = currentFeatures.toDoubleVector();
//                System.out.println(new ImagesArrayPojo(i,doubleVector));
                indArrayPojoRepository.save( new ImagesArrayPojo(i,new double[]{1,11.11,1}));
//                indArrayPojoRepository.findBySimilarity(templateFeatures.toDoubleVector(), PageRequest.of(1, 20));
//                System.out.println(currentFeatures);
//                // 计算余弦相似度
//                double similarityScore = Transforms.cosineSim(templateFeatures, currentFeatures);
//
//                // 将图像名称和相似度存储到映射中
//                similarityMap.put(imageFile.getName(), similarityScore);

                i ++;
            }
        }

        // 打印相似度最高的三张图像名称
//        similarityMap.entrySet().stream()
//                .sorted(Map.Entry.<String, Double>comparingByValue().reversed())
//                .limit(3)
//                .forEach(entry -> System.out.println("Image: " + entry.getKey() + ", Similarity: " + entry.getValue()));
return null;
    }
}

java实体类

java 复制代码
@Data
@AllArgsConstructor
@NoArgsConstructor
@Document(indexName = "images_double")
public class ImagesArrayPojo {

    @Id
    private Long id;

    @Field(type = FieldType.Dense_Vector,dims = 1000)
    private double[] ndDoubleArray;
}

搭配

xml 复制代码
<dependency>
     <groupId>org.springframework.boot</groupId>
     <artifactId>spring-boot-starter-data-elasticsearch</artifactId>
 </dependency>

四、es查询脚本

这里注意查看官方文档,不同的es脚本写法稍有不同,这里使用的是7.4.2

java 复制代码
docker run -d -e ES_JAVA_OPTS="-Xms128m -Xmx128m" -e "discovery.type=single-node" -e "script.disable_dynamic: false" -p 9200:9200 -p 9300:9300 -e ES_MIN_MEM=128m -e ES_MAX_MEM=4096m --name es elasticsearch:7.4.2 
powershell 复制代码
{
  "query": {
    "script_score": {
      "query": {
        "match_all": {}
      },
      "script": {
        "source": "cosineSimilarity(params.query_vector,doc['ndDoubleArray']) + 1.0",
        "params": {
          "query_vector": [维度数组]
        }
      }
    }
  }
}

五、没测试的代码

java 复制代码
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.elasticsearch.annotations.Query;
import org.springframework.data.elasticsearch.repository.ElasticsearchRepository;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;

@Repository
public interface INDArrayPojoRepository extends ElasticsearchRepository<ImagesArrayPojo,Long> {

    @Query("{\n" +
            "  \"size\": 10,\n" +
            "  \"from\": 0,\n" +
            "  \"query\": {\n" +
            "    \"script_score\": {\n" +
            "      \"query\": {\n" +
            "        \"match_all\": {}\n" +
            "      },\n" +
            "      \"script\": {\n" +
            "        \"source\": \"cosineSimilarity(params.query_vector,doc['ndDoubleArray']) + 1.0\",\n" +
            "        \"params\": {\n" +
            "          \"query_vector\": [?1]\n" +
            "        }\n" +
            "      }\n" +
            "    }\n" +
            "  }\n" +
            "}")
    Page<ImagesArrayPojo> findBySimilarity(@Param("queryVector") double[] queryVector, Pageable pageable);
}

总结

思路:首先使用deeplearning4j加载vgg19采集图片的向量值,然后将向量值存储到es中,然后后续搜索使用es的余弦脚本查询

相关推荐
红色的山茶花13 分钟前
YOLOv9-0.1部分代码阅读笔记-loss_tal.py
笔记·深度学习·yolo
小蜗牛慢慢爬行28 分钟前
有关异步场景的 10 大 Spring Boot 面试问题
java·开发语言·网络·spring boot·后端·spring·面试
Allen Bright35 分钟前
Spring Boot 整合 RabbitMQ:手动 ACK 与 QoS 配置详解
spring boot·rabbitmq·java-rabbitmq
喝醉酒的小白40 分钟前
Elasticsearch相关知识@1
大数据·elasticsearch·搜索引擎
一位小说男主1 小时前
编码器与解码器:从‘乱码’到‘通话’
人工智能·深度学习
goTsHgo1 小时前
在 Spring Boot 的 MVC 框架中 路径匹配的实现 详解
spring boot·后端·mvc
钱多多_qdd1 小时前
spring cache源码解析(四)——从@EnableCaching开始来阅读源码
java·spring boot·spring
飞的肖1 小时前
前端使用 Element Plus架构vue3.0实现图片拖拉拽,后等比压缩,上传到Spring Boot后端
前端·spring boot·架构
Q_19284999061 小时前
基于Spring Boot的摄影器材租赁回收系统
java·spring boot·后端