deeplearning4j使用vgg19图片向量比对springboot+es环境

文章目录


一、桌面创建两个目录读图

二、POM

xml 复制代码
<dependency>
    <groupId>org.springframework.data</groupId>
    <artifactId>spring-data-elasticsearch</artifactId>
</dependency>

<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-core</artifactId>
    <version>1.0.0-beta7</version>
</dependency>
<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-zoo</artifactId>
    <version>1.0.0-beta7</version>
</dependency>
<dependency>
    <groupId>org.elasticsearch</groupId>
    <artifactId>elasticsearch</artifactId>
</dependency>
<dependency>
    <groupId>org.elasticsearch.client</groupId>
    <artifactId>transport</artifactId>
</dependency>
<dependency>
    <groupId>org.elasticsearch.client</groupId>
    <artifactId>elasticsearch-rest-client</artifactId>
</dependency>
<dependency>
    <groupId>org.elasticsearch.plugin</groupId>
    <artifactId>transport-netty4-client</artifactId>
</dependency>

三、code

java 复制代码
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.zoo.model.VGG19;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;

import javax.annotation.PostConstruct;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

@Service("vgg19Service")
public class Vgg19ServiceImpl implements Vgg19Service {

    private static ComputationGraph vgg19Model;

    @PostConstruct
    public void init() throws IOException {
        VGG19 vgg19 = VGG19.builder().build();
        vgg19Model = (ComputationGraph) vgg19.initPretrained();
    }

    @Autowired
    private INDArrayPojoRepository indArrayPojoRepository;

    @Override
    public String find(MultipartFile file) throws IOException {

//        VGG19 vgg19 = VGG19.builder().build();
//         vgg19Model = (ComputationGraph) vgg19.initPretrained();


        String templateImagePath = "C:\\Users\\Administrator\\Desktop\\template\\1.png";

        // 图像文件夹路径
        String imageFolder = "C:\\Users\\Administrator\\Desktop\\target";

        // 加载模板图像
        NativeImageLoader imageLoader = new NativeImageLoader(224, 224, 3);
        INDArray templateImage = imageLoader.asMatrix(new File(templateImagePath));

        // 提取模板图像的特征向量
        INDArray templateFeatures = vgg19Model.outputSingle(templateImage);

        // 存储图像相似度的映射
        Map<String, Double> similarityMap = new HashMap<>();

        // 遍历图像文件夹
        File folder = new File(imageFolder);
        File[] imageFiles = folder.listFiles();
        long i = 1L;
        indArrayPojoRepository.deleteAll();
        if (imageFiles != null) {
            for (File imageFile : imageFiles) {
                // 加载当前图像
//                INDArray currentImage = imageLoader.asMatrix(imageFile);
//                // 提取当前图像的特征向量
//                INDArray currentFeatures = vgg19Model.outputSingle(currentImage);
//                long[] longVector = currentFeatures.toLongVector();
//                System.out.println(longVector);
//                double[] doubleVector = currentFeatures.toDoubleVector();
//                System.out.println(new ImagesArrayPojo(i,doubleVector));
                indArrayPojoRepository.save( new ImagesArrayPojo(i,new double[]{1,11.11,1}));
//                indArrayPojoRepository.findBySimilarity(templateFeatures.toDoubleVector(), PageRequest.of(1, 20));
//                System.out.println(currentFeatures);
//                // 计算余弦相似度
//                double similarityScore = Transforms.cosineSim(templateFeatures, currentFeatures);
//
//                // 将图像名称和相似度存储到映射中
//                similarityMap.put(imageFile.getName(), similarityScore);

                i ++;
            }
        }

        // 打印相似度最高的三张图像名称
//        similarityMap.entrySet().stream()
//                .sorted(Map.Entry.<String, Double>comparingByValue().reversed())
//                .limit(3)
//                .forEach(entry -> System.out.println("Image: " + entry.getKey() + ", Similarity: " + entry.getValue()));
return null;
    }
}

java实体类

java 复制代码
@Data
@AllArgsConstructor
@NoArgsConstructor
@Document(indexName = "images_double")
public class ImagesArrayPojo {

    @Id
    private Long id;

    @Field(type = FieldType.Dense_Vector,dims = 1000)
    private double[] ndDoubleArray;
}

搭配

xml 复制代码
<dependency>
     <groupId>org.springframework.boot</groupId>
     <artifactId>spring-boot-starter-data-elasticsearch</artifactId>
 </dependency>

四、es查询脚本

这里注意查看官方文档,不同的es脚本写法稍有不同,这里使用的是7.4.2

java 复制代码
docker run -d -e ES_JAVA_OPTS="-Xms128m -Xmx128m" -e "discovery.type=single-node" -e "script.disable_dynamic: false" -p 9200:9200 -p 9300:9300 -e ES_MIN_MEM=128m -e ES_MAX_MEM=4096m --name es elasticsearch:7.4.2 
powershell 复制代码
{
  "query": {
    "script_score": {
      "query": {
        "match_all": {}
      },
      "script": {
        "source": "cosineSimilarity(params.query_vector,doc['ndDoubleArray']) + 1.0",
        "params": {
          "query_vector": [维度数组]
        }
      }
    }
  }
}

五、没测试的代码

java 复制代码
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.elasticsearch.annotations.Query;
import org.springframework.data.elasticsearch.repository.ElasticsearchRepository;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;

@Repository
public interface INDArrayPojoRepository extends ElasticsearchRepository<ImagesArrayPojo,Long> {

    @Query("{\n" +
            "  \"size\": 10,\n" +
            "  \"from\": 0,\n" +
            "  \"query\": {\n" +
            "    \"script_score\": {\n" +
            "      \"query\": {\n" +
            "        \"match_all\": {}\n" +
            "      },\n" +
            "      \"script\": {\n" +
            "        \"source\": \"cosineSimilarity(params.query_vector,doc['ndDoubleArray']) + 1.0\",\n" +
            "        \"params\": {\n" +
            "          \"query_vector\": [?1]\n" +
            "        }\n" +
            "      }\n" +
            "    }\n" +
            "  }\n" +
            "}")
    Page<ImagesArrayPojo> findBySimilarity(@Param("queryVector") double[] queryVector, Pageable pageable);
}

总结

思路:首先使用deeplearning4j加载vgg19采集图片的向量值,然后将向量值存储到es中,然后后续搜索使用es的余弦脚本查询

相关推荐
IT学长编程10 分钟前
计算机毕业设计 Java酷听音乐系统的设计与实现 Java实战项目 附源码+文档+视频讲解
java·spring boot·毕业设计·课程设计·毕业论文·音乐系统·计算机毕业设计选题
IT学长编程27 分钟前
计算机毕业设计 基于协同过滤算法的个性化音乐推荐系统的设计与实现 Java实战项目 附源码+文档+视频讲解
java·spring boot·毕业设计·毕业论文·协同过滤算法·计算机毕业设计选题·个性化音乐推荐系统
华农第一蒟蒻42 分钟前
Java中JWT(JSON Web Token)的运用
java·前端·spring boot·json·token
计算机学姐1 小时前
基于SpringBoot+Vue的高校运动会管理系统
java·vue.js·spring boot·后端·mysql·intellij-idea·mybatis
老华带你飞1 小时前
公寓管理系统|SprinBoot+vue夕阳红公寓管理系统(源码+数据库+文档)
java·前端·javascript·数据库·vue.js·spring boot·课程设计
程序员陆通2 小时前
Spring Boot RESTful API开发教程
spring boot·后端·restful
Evand J2 小时前
深度学习的应用综述
深度学习
sp_fyf_20242 小时前
[大语言模型-论文精读] 更大且更可指导的语言模型变得不那么可靠
人工智能·深度学习·神经网络·搜索引擎·语言模型·自然语言处理
我是浮夸3 小时前
MyBatisPlus——学习笔记
java·spring boot·mybatis
杨荧3 小时前
【JAVA开源】基于Vue和SpringBoot的水果购物网站
java·开发语言·vue.js·spring boot·spring cloud·开源