deeplearning4j使用vgg19图片向量比对springboot+es环境

文章目录


一、桌面创建两个目录读图

二、POM

xml 复制代码
<dependency>
    <groupId>org.springframework.data</groupId>
    <artifactId>spring-data-elasticsearch</artifactId>
</dependency>

<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-core</artifactId>
    <version>1.0.0-beta7</version>
</dependency>
<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-zoo</artifactId>
    <version>1.0.0-beta7</version>
</dependency>
<dependency>
    <groupId>org.elasticsearch</groupId>
    <artifactId>elasticsearch</artifactId>
</dependency>
<dependency>
    <groupId>org.elasticsearch.client</groupId>
    <artifactId>transport</artifactId>
</dependency>
<dependency>
    <groupId>org.elasticsearch.client</groupId>
    <artifactId>elasticsearch-rest-client</artifactId>
</dependency>
<dependency>
    <groupId>org.elasticsearch.plugin</groupId>
    <artifactId>transport-netty4-client</artifactId>
</dependency>

三、code

java 复制代码
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.zoo.model.VGG19;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;

import javax.annotation.PostConstruct;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

@Service("vgg19Service")
public class Vgg19ServiceImpl implements Vgg19Service {

    private static ComputationGraph vgg19Model;

    @PostConstruct
    public void init() throws IOException {
        VGG19 vgg19 = VGG19.builder().build();
        vgg19Model = (ComputationGraph) vgg19.initPretrained();
    }

    @Autowired
    private INDArrayPojoRepository indArrayPojoRepository;

    @Override
    public String find(MultipartFile file) throws IOException {

//        VGG19 vgg19 = VGG19.builder().build();
//         vgg19Model = (ComputationGraph) vgg19.initPretrained();


        String templateImagePath = "C:\\Users\\Administrator\\Desktop\\template\\1.png";

        // 图像文件夹路径
        String imageFolder = "C:\\Users\\Administrator\\Desktop\\target";

        // 加载模板图像
        NativeImageLoader imageLoader = new NativeImageLoader(224, 224, 3);
        INDArray templateImage = imageLoader.asMatrix(new File(templateImagePath));

        // 提取模板图像的特征向量
        INDArray templateFeatures = vgg19Model.outputSingle(templateImage);

        // 存储图像相似度的映射
        Map<String, Double> similarityMap = new HashMap<>();

        // 遍历图像文件夹
        File folder = new File(imageFolder);
        File[] imageFiles = folder.listFiles();
        long i = 1L;
        indArrayPojoRepository.deleteAll();
        if (imageFiles != null) {
            for (File imageFile : imageFiles) {
                // 加载当前图像
//                INDArray currentImage = imageLoader.asMatrix(imageFile);
//                // 提取当前图像的特征向量
//                INDArray currentFeatures = vgg19Model.outputSingle(currentImage);
//                long[] longVector = currentFeatures.toLongVector();
//                System.out.println(longVector);
//                double[] doubleVector = currentFeatures.toDoubleVector();
//                System.out.println(new ImagesArrayPojo(i,doubleVector));
                indArrayPojoRepository.save( new ImagesArrayPojo(i,new double[]{1,11.11,1}));
//                indArrayPojoRepository.findBySimilarity(templateFeatures.toDoubleVector(), PageRequest.of(1, 20));
//                System.out.println(currentFeatures);
//                // 计算余弦相似度
//                double similarityScore = Transforms.cosineSim(templateFeatures, currentFeatures);
//
//                // 将图像名称和相似度存储到映射中
//                similarityMap.put(imageFile.getName(), similarityScore);

                i ++;
            }
        }

        // 打印相似度最高的三张图像名称
//        similarityMap.entrySet().stream()
//                .sorted(Map.Entry.<String, Double>comparingByValue().reversed())
//                .limit(3)
//                .forEach(entry -> System.out.println("Image: " + entry.getKey() + ", Similarity: " + entry.getValue()));
return null;
    }
}

java实体类

java 复制代码
@Data
@AllArgsConstructor
@NoArgsConstructor
@Document(indexName = "images_double")
public class ImagesArrayPojo {

    @Id
    private Long id;

    @Field(type = FieldType.Dense_Vector,dims = 1000)
    private double[] ndDoubleArray;
}

搭配

xml 复制代码
<dependency>
     <groupId>org.springframework.boot</groupId>
     <artifactId>spring-boot-starter-data-elasticsearch</artifactId>
 </dependency>

四、es查询脚本

这里注意查看官方文档,不同的es脚本写法稍有不同,这里使用的是7.4.2

java 复制代码
docker run -d -e ES_JAVA_OPTS="-Xms128m -Xmx128m" -e "discovery.type=single-node" -e "script.disable_dynamic: false" -p 9200:9200 -p 9300:9300 -e ES_MIN_MEM=128m -e ES_MAX_MEM=4096m --name es elasticsearch:7.4.2 
powershell 复制代码
{
  "query": {
    "script_score": {
      "query": {
        "match_all": {}
      },
      "script": {
        "source": "cosineSimilarity(params.query_vector,doc['ndDoubleArray']) + 1.0",
        "params": {
          "query_vector": [维度数组]
        }
      }
    }
  }
}

五、没测试的代码

java 复制代码
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.elasticsearch.annotations.Query;
import org.springframework.data.elasticsearch.repository.ElasticsearchRepository;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;

@Repository
public interface INDArrayPojoRepository extends ElasticsearchRepository<ImagesArrayPojo,Long> {

    @Query("{\n" +
            "  \"size\": 10,\n" +
            "  \"from\": 0,\n" +
            "  \"query\": {\n" +
            "    \"script_score\": {\n" +
            "      \"query\": {\n" +
            "        \"match_all\": {}\n" +
            "      },\n" +
            "      \"script\": {\n" +
            "        \"source\": \"cosineSimilarity(params.query_vector,doc['ndDoubleArray']) + 1.0\",\n" +
            "        \"params\": {\n" +
            "          \"query_vector\": [?1]\n" +
            "        }\n" +
            "      }\n" +
            "    }\n" +
            "  }\n" +
            "}")
    Page<ImagesArrayPojo> findBySimilarity(@Param("queryVector") double[] queryVector, Pageable pageable);
}

总结

思路:首先使用deeplearning4j加载vgg19采集图片的向量值,然后将向量值存储到es中,然后后续搜索使用es的余弦脚本查询

相关推荐
只有左边一个小酒窝17 分钟前
(六)卷积神经网络:深度学习在计算机视觉中的应用
深度学习·计算机视觉·cnn
weixin_438335401 小时前
Spring Boot实现接口时间戳鉴权
java·spring boot·后端
carpell2 小时前
【语义分割专栏】3:Segnet实战篇(附上完整可运行的代码pytorch)
人工智能·python·深度学习·计算机视觉·语义分割
mengyoufengyu2 小时前
DeepSeek11-Ollama + Open WebUI 搭建本地 RAG 知识库全流程指南
人工智能·深度学习·deepseek
vlln3 小时前
2025年与2030年AI及AI智能体 (Agent) 市场份额分析报告
人工智能·深度学习·神经网络·ai
GiantGo3 小时前
信息最大化(Information Maximization)
深度学习·无监督学习·信息最大化
风象南3 小时前
SpringBoot的4种死信队列处理方式
java·spring boot·后端
coderSong256810 小时前
Java高级 |【实验八】springboot 使用Websocket
java·spring boot·后端·websocket
Blossom.11810 小时前
使用Python和Scikit-Learn实现机器学习模型调优
开发语言·人工智能·python·深度学习·目标检测·机器学习·scikit-learn
Mr_Air_Boy11 小时前
SpringBoot使用dynamic配置多数据源时使用@Transactional事务在非primary的数据源上遇到的问题
java·spring boot·后端