deeplearning4j使用vgg19图片向量比对springboot+es环境

文章目录


一、桌面创建两个目录读图

二、POM

xml 复制代码
<dependency>
    <groupId>org.springframework.data</groupId>
    <artifactId>spring-data-elasticsearch</artifactId>
</dependency>

<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-core</artifactId>
    <version>1.0.0-beta7</version>
</dependency>
<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-zoo</artifactId>
    <version>1.0.0-beta7</version>
</dependency>
<dependency>
    <groupId>org.elasticsearch</groupId>
    <artifactId>elasticsearch</artifactId>
</dependency>
<dependency>
    <groupId>org.elasticsearch.client</groupId>
    <artifactId>transport</artifactId>
</dependency>
<dependency>
    <groupId>org.elasticsearch.client</groupId>
    <artifactId>elasticsearch-rest-client</artifactId>
</dependency>
<dependency>
    <groupId>org.elasticsearch.plugin</groupId>
    <artifactId>transport-netty4-client</artifactId>
</dependency>

三、code

java 复制代码
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.zoo.model.VGG19;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;

import javax.annotation.PostConstruct;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

@Service("vgg19Service")
public class Vgg19ServiceImpl implements Vgg19Service {

    private static ComputationGraph vgg19Model;

    @PostConstruct
    public void init() throws IOException {
        VGG19 vgg19 = VGG19.builder().build();
        vgg19Model = (ComputationGraph) vgg19.initPretrained();
    }

    @Autowired
    private INDArrayPojoRepository indArrayPojoRepository;

    @Override
    public String find(MultipartFile file) throws IOException {

//        VGG19 vgg19 = VGG19.builder().build();
//         vgg19Model = (ComputationGraph) vgg19.initPretrained();


        String templateImagePath = "C:\\Users\\Administrator\\Desktop\\template\\1.png";

        // 图像文件夹路径
        String imageFolder = "C:\\Users\\Administrator\\Desktop\\target";

        // 加载模板图像
        NativeImageLoader imageLoader = new NativeImageLoader(224, 224, 3);
        INDArray templateImage = imageLoader.asMatrix(new File(templateImagePath));

        // 提取模板图像的特征向量
        INDArray templateFeatures = vgg19Model.outputSingle(templateImage);

        // 存储图像相似度的映射
        Map<String, Double> similarityMap = new HashMap<>();

        // 遍历图像文件夹
        File folder = new File(imageFolder);
        File[] imageFiles = folder.listFiles();
        long i = 1L;
        indArrayPojoRepository.deleteAll();
        if (imageFiles != null) {
            for (File imageFile : imageFiles) {
                // 加载当前图像
//                INDArray currentImage = imageLoader.asMatrix(imageFile);
//                // 提取当前图像的特征向量
//                INDArray currentFeatures = vgg19Model.outputSingle(currentImage);
//                long[] longVector = currentFeatures.toLongVector();
//                System.out.println(longVector);
//                double[] doubleVector = currentFeatures.toDoubleVector();
//                System.out.println(new ImagesArrayPojo(i,doubleVector));
                indArrayPojoRepository.save( new ImagesArrayPojo(i,new double[]{1,11.11,1}));
//                indArrayPojoRepository.findBySimilarity(templateFeatures.toDoubleVector(), PageRequest.of(1, 20));
//                System.out.println(currentFeatures);
//                // 计算余弦相似度
//                double similarityScore = Transforms.cosineSim(templateFeatures, currentFeatures);
//
//                // 将图像名称和相似度存储到映射中
//                similarityMap.put(imageFile.getName(), similarityScore);

                i ++;
            }
        }

        // 打印相似度最高的三张图像名称
//        similarityMap.entrySet().stream()
//                .sorted(Map.Entry.<String, Double>comparingByValue().reversed())
//                .limit(3)
//                .forEach(entry -> System.out.println("Image: " + entry.getKey() + ", Similarity: " + entry.getValue()));
return null;
    }
}

java实体类

java 复制代码
@Data
@AllArgsConstructor
@NoArgsConstructor
@Document(indexName = "images_double")
public class ImagesArrayPojo {

    @Id
    private Long id;

    @Field(type = FieldType.Dense_Vector,dims = 1000)
    private double[] ndDoubleArray;
}

搭配

xml 复制代码
<dependency>
     <groupId>org.springframework.boot</groupId>
     <artifactId>spring-boot-starter-data-elasticsearch</artifactId>
 </dependency>

四、es查询脚本

这里注意查看官方文档,不同的es脚本写法稍有不同,这里使用的是7.4.2

java 复制代码
docker run -d -e ES_JAVA_OPTS="-Xms128m -Xmx128m" -e "discovery.type=single-node" -e "script.disable_dynamic: false" -p 9200:9200 -p 9300:9300 -e ES_MIN_MEM=128m -e ES_MAX_MEM=4096m --name es elasticsearch:7.4.2 
powershell 复制代码
{
  "query": {
    "script_score": {
      "query": {
        "match_all": {}
      },
      "script": {
        "source": "cosineSimilarity(params.query_vector,doc['ndDoubleArray']) + 1.0",
        "params": {
          "query_vector": [维度数组]
        }
      }
    }
  }
}

五、没测试的代码

java 复制代码
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.elasticsearch.annotations.Query;
import org.springframework.data.elasticsearch.repository.ElasticsearchRepository;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;

@Repository
public interface INDArrayPojoRepository extends ElasticsearchRepository<ImagesArrayPojo,Long> {

    @Query("{\n" +
            "  \"size\": 10,\n" +
            "  \"from\": 0,\n" +
            "  \"query\": {\n" +
            "    \"script_score\": {\n" +
            "      \"query\": {\n" +
            "        \"match_all\": {}\n" +
            "      },\n" +
            "      \"script\": {\n" +
            "        \"source\": \"cosineSimilarity(params.query_vector,doc['ndDoubleArray']) + 1.0\",\n" +
            "        \"params\": {\n" +
            "          \"query_vector\": [?1]\n" +
            "        }\n" +
            "      }\n" +
            "    }\n" +
            "  }\n" +
            "}")
    Page<ImagesArrayPojo> findBySimilarity(@Param("queryVector") double[] queryVector, Pageable pageable);
}

总结

思路:首先使用deeplearning4j加载vgg19采集图片的向量值,然后将向量值存储到es中,然后后续搜索使用es的余弦脚本查询

相关推荐
FIN技术铺4 分钟前
Spring Boot框架Starter组件整理
java·spring boot·后端
余生H24 分钟前
transformer.js(三):底层架构及性能优化指南
javascript·深度学习·架构·transformer
小码的头发丝、30 分钟前
Spring Boot 注解
java·spring boot
午觉千万别睡过32 分钟前
RuoYI分页不准确问题解决
spring boot
2301_811274311 小时前
大数据基于Spring Boot的化妆品推荐系统的设计与实现
大数据·spring boot·后端
罗小罗同学1 小时前
医工交叉入门书籍分享:Transformer模型在机器学习领域的应用|个人观点·24-11-22
深度学习·机器学习·transformer
孤独且没人爱的纸鹤1 小时前
【深度学习】:从人工神经网络的基础原理到循环神经网络的先进技术,跨越智能算法的关键发展阶段及其未来趋势,探索技术进步与应用挑战
人工智能·python·深度学习·机器学习·ai
阿_旭1 小时前
TensorFlow构建CNN卷积神经网络模型的基本步骤:数据处理、模型构建、模型训练
人工智能·深度学习·cnn·tensorflow
羊小猪~~1 小时前
tensorflow案例7--数据增强与测试集, 训练集, 验证集的构建
人工智能·python·深度学习·机器学习·cnn·tensorflow·neo4j
极客代码1 小时前
【Python TensorFlow】进阶指南(续篇三)
开发语言·人工智能·python·深度学习·tensorflow