Spring Boot集成DJL实现图片分类

1.什么是DJL?

DJL 是一个很新的项目,在2019年12月初的AWS re: invest大会上才正式的发布出来。。简单来说,DJL是一个使用Java API简化模型训练、测试、部署和使用深度学习模型进行推理的开源库深度学习工具包,开源的许可协议是Apache-2.0。对于Java开发者而言,可以在Java中开发及应用原生的机器学习和深度学习模型,同时简化了深度学习开发的难度。通过DJL提供的直观的、高级的API,Java开发人员可以训练自己的模型,或者利用数据科学家用Python预先训练好的模型来进行推理。如果您恰好是对学习深度学习感兴趣的Java开发者,那么DJL无疑将是开始深度学习应用的一个最好的起点。

2.数据准备

下载训练集

复制代码
wget https://vision.cs.utexas.edu/projects/finegrained/utzap50k/ut-zap50k-images-square.zip

解压,方便后面训练模型使用

复制代码
unzip ut-zap50k-images-square.zip  

3.代码工程

实验目的

基于djl实现图片分类

pom.xml

复制代码
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>3.2.1</version>
    </parent>

    <modelVersion>4.0.0</modelVersion>

    <artifactId>djl</artifactId>

    <properties>
        <maven.compiler.source>17</maven.compiler.source>
        <maven.compiler.target>17</maven.compiler.target>
    </properties>
    <dependencies>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <optional>true</optional>
        </dependency>

        <!-- DJL -->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
        </dependency>

        <!-- pytorch-engine-->
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <scope>runtime</scope>
        </dependency>



    </dependencies>

    <profiles>
        <profile>
            <id>windows</id>
            <activation>
                <activeByDefault>true</activeByDefault>
            </activation>
            <dependencies>
                <!-- Windows CPU -->
                <dependency>
                    <groupId>ai.djl.pytorch</groupId>
                    <artifactId>pytorch-native-cpu</artifactId>
                    <classifier>win-x86_64</classifier>
                    <scope>runtime</scope>
                    <version>2.0.1</version>
                </dependency>
                <dependency>
                    <groupId>ai.djl.pytorch</groupId>
                    <artifactId>pytorch-jni</artifactId>
                    <version>2.0.1-0.23.0</version>
                    <scope>runtime</scope>
                </dependency>
            </dependencies>
        </profile>
        <profile>
            <id>centos7</id>
            <activation>
                <activeByDefault>false</activeByDefault>
            </activation>
            <dependencies>
                <!-- For Pre-CXX11 build (CentOS7)-->
                <dependency>
                    <groupId>ai.djl.pytorch</groupId>
                    <artifactId>pytorch-native-cpu-precxx11</artifactId>
                    <classifier>linux-x86_64</classifier>
                    <version>2.0.1</version>
                    <scope>runtime</scope>
                </dependency>
                <dependency>
                    <groupId>ai.djl.pytorch</groupId>
                    <artifactId>pytorch-jni</artifactId>
                    <version>2.0.1-0.23.0</version>
                    <scope>runtime</scope>
                </dependency>
            </dependencies>
        </profile>
        <profile>
            <id>linux</id>
            <activation>
                <activeByDefault>false</activeByDefault>
            </activation>
            <dependencies>
                <!-- Linux CPU -->
                <dependency>
                    <groupId>ai.djl.pytorch</groupId>
                    <artifactId>pytorch-native-cpu</artifactId>
                    <classifier>linux-x86_64</classifier>
                    <scope>runtime</scope>
                    <version>2.0.1</version>
                </dependency>
                <dependency>
                    <groupId>ai.djl.pytorch</groupId>
                    <artifactId>pytorch-jni</artifactId>
                    <version>2.0.1-0.23.0</version>
                    <scope>runtime</scope>
                </dependency>
            </dependencies>
        </profile>
        <profile>
            <id>aarch64</id>
            <activation>
                <activeByDefault>false</activeByDefault>
            </activation>
            <dependencies>
                <!-- For aarch64 build-->
                <dependency>
                    <groupId>ai.djl.pytorch</groupId>
                    <artifactId>pytorch-native-cpu-precxx11</artifactId>
                    <classifier>linux-aarch64</classifier>
                    <scope>runtime</scope>
                    <version>2.0.1</version>
                </dependency>
                <dependency>
                    <groupId>ai.djl.pytorch</groupId>
                    <artifactId>pytorch-jni</artifactId>
                    <version>2.0.1-0.23.0</version>
                    <scope>runtime</scope>
                </dependency>
            </dependencies>
        </profile>
    </profiles>

    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>ai.djl</groupId>
                <artifactId>bom</artifactId>
                <version>0.23.0</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>



</project>

conotroller

复制代码
package com.et.controller;

import ai.djl.MalformedModelException;
import ai.djl.translate.TranslateException;
import com.et.service.ImageClassificationService;

import lombok.RequiredArgsConstructor;
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.stream.Stream;


@RestController
@RequiredArgsConstructor
public class ImageClassificationController {

    private final ImageClassificationService imageClassificationService;

    @PostMapping(path = "/analyze")
    public String predict(@RequestPart("image") MultipartFile image,
                          @RequestParam(defaultValue = "/home/djl-test/models") String modePath)
            throws TranslateException,
            MalformedModelException,
            IOException {
        return imageClassificationService.predict(image, modePath);
    }

    @PostMapping(path = "/training")
    public String training(@RequestParam(defaultValue = "/home/djl-test/images-test")
                           String datasetRoot,
                           @RequestParam(defaultValue = "/home/djl-test/models") String modePath) throws TranslateException, IOException {
        return imageClassificationService.training(datasetRoot, modePath);
    }

    @GetMapping("/download")
    public ResponseEntity<Resource> downloadFile(@RequestParam(defaultValue = "/home/djl-test/images-test") String directoryPath) {
        List<String> imgPathList = new ArrayList<>();
        try (Stream<Path> paths = Files.walk(Paths.get(directoryPath))) {
            // Filter only regular files (excluding directories)
            paths.filter(Files::isRegularFile)
                    .forEach(c-> imgPathList.add(c.toString()));
        } catch (IOException e) {
            return ResponseEntity.status(500).build();
        }
        Random random = new Random();
        String filePath = imgPathList.get(random.nextInt(imgPathList.size()));
        Path file = Paths.get(filePath);
        Resource resource = new FileSystemResource(file.toFile());

        if (!resource.exists()) {
            return ResponseEntity.notFound().build();
        }
        HttpHeaders headers = new HttpHeaders();
        headers.add(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=" + file.getFileName().toString());
        headers.add(HttpHeaders.CONTENT_TYPE, MediaType.IMAGE_JPEG_VALUE);

        try {
            return ResponseEntity.ok()
                    .headers(headers)
                    .contentLength(resource.contentLength())
                    .body(resource);
        } catch (IOException e) {
            return ResponseEntity.status(500).build();
        }
    }

}

service

接口

复制代码
package com.et.service;

import ai.djl.MalformedModelException;
import ai.djl.translate.TranslateException;
import org.springframework.web.multipart.MultipartFile;

import java.io.IOException;

public interface ImageClassificationService {
    public String predict(MultipartFile image, String modePath) throws IOException, MalformedModelException, TranslateException;
    public String training(String datasetRoot, String modePath) throws TranslateException, IOException;
}

实现类

复制代码
package com.et.service;

import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.ImageFolder;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.*;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import com.et.Models;
import lombok.Cleanup;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;

import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Path;
import java.nio.file.Paths;


@Slf4j
@Service
public class ImageClassificationServiceImpl implements ImageClassificationService {

    // represents number of training samples processed before the model is updated
    private static final int BATCH_SIZE = 32;

    // the number of passes over the complete dataset
    private static final int EPOCHS = 2;

    //the number of classification labels: boots, sandals, shoes, slippers
    @Value("${djl.num-of-output:4}")
    public int numOfOutput;

    @Override
    public String predict(MultipartFile image, String modePath) throws IOException, MalformedModelException, TranslateException {
        @Cleanup
        InputStream is = image.getInputStream();
        Path modelDir = Paths.get(modePath);
        BufferedImage bi = ImageIO.read(is);
        Image img = ImageFactory.getInstance().fromImage(bi);
        // empty model instance
        try (Model model = Models.getModel(numOfOutput)) {
            // load the model
            model.load(modelDir, Models.MODEL_NAME);
            // define a translator for pre and post processing
            // out of the box this translator converts images to ResNet friendly ResNet 18 shape
            Translator<Image, Classifications> translator =
                    ImageClassificationTranslator.builder()
                            .addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT))
                            .addTransform(new ToTensor())
                            .optApplySoftmax(true)
                            .build();
            // run the inference using a Predictor
            try (Predictor<Image, Classifications> predictor = model.newPredictor(translator)) {
                // holds the probability score per label
                Classifications predictResult = predictor.predict(img);
                log.info("reusult={}",predictResult.toJson());
                return predictResult.toJson();
            }
        }
    }

    @Override
    public String training(String datasetRoot, String modePath) throws TranslateException, IOException {
        log.info("Image dataset training started...Image dataset address path:{}",datasetRoot);
        // the location to save the model
        Path modelDir = Paths.get(modePath);

        // create ImageFolder dataset from directory
        ImageFolder dataset = initDataset(datasetRoot);
        // Split the dataset set into training dataset and validate dataset
        RandomAccessDataset[] datasets = dataset.randomSplit(8, 2);

        // set loss function, which seeks to minimize errors
        // loss function evaluates model's predictions against the correct answer (during training)
        // higher numbers are bad - means model performed poorly; indicates more errors; want to
        // minimize errors (loss)
        Loss loss = Loss.softmaxCrossEntropyLoss();

        // setting training parameters (ie hyperparameters)
        TrainingConfig config = setupTrainingConfig(loss);

        try (Model model = Models.getModel(numOfOutput); // empty model instance to hold patterns
             Trainer trainer = model.newTrainer(config)) {
            // metrics collect and report key performance indicators, like accuracy
            trainer.setMetrics(new Metrics());

            Shape inputShape = new Shape(1, 3, Models.IMAGE_HEIGHT, Models.IMAGE_HEIGHT);

            // initialize trainer with proper input shape
            trainer.initialize(inputShape);

            // find the patterns in data
            EasyTrain.fit(trainer, EPOCHS, datasets[0], datasets[1]);

            // set model properties
            TrainingResult result = trainer.getTrainingResult();
            model.setProperty("Epoch", String.valueOf(EPOCHS));
            model.setProperty(
                    "Accuracy", String.format("%.5f", result.getValidateEvaluation("Accuracy")));
            model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));

            // save the model after done training for inference later
            // model saved as shoeclassifier-0000.params
            model.save(modelDir, Models.MODEL_NAME);
            // save labels into model directory
            Models.saveSynset(modelDir, dataset.getSynset());
            log.info("Image dataset training completed......");
            return String.join("\n", dataset.getSynset());
        }
    }

    private ImageFolder initDataset(String datasetRoot)
            throws IOException, TranslateException {
        ImageFolder dataset =
                ImageFolder.builder()
                        // retrieve the data
                        .setRepositoryPath(Paths.get(datasetRoot))
                        .optMaxDepth(10)
                        .addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT))
                        .addTransform(new ToTensor())
                        // random sampling; don't process the data in order
                        .setSampling(BATCH_SIZE, true)
                        .build();
        dataset.prepare();
        return dataset;
    }

    private TrainingConfig setupTrainingConfig(Loss loss) {
        return new DefaultTrainingConfig(loss)
                .addEvaluator(new Accuracy())
                .addTrainingListeners(TrainingListener.Defaults.logging());
    }

}

application.yaml

复制代码
server:
  port: 8888
spring:
  application:
    name: djl-image-classification-demo
  servlet:
    multipart:
      max-file-size: 100MB
      max-request-size: 100MB
  mvc:
    pathmatch:
      matching-strategy: ant_path_matcher

启动类

复制代码
package com.et;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
public class DemoApplication {

   public static void main(String[] args) {
      SpringApplication.run(DemoApplication.class, args);
   }
}

以上只是一些关键代码,所有代码请参见下面代码仓库

代码仓库

4.测试

启动Spring Boot应用

训练模型

使用之前下载的数据集

控制台输出日志,如果没有gpu的话,训练有点慢,估计要等一会

复制代码
2024-10-11T21:00:05.407+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] c.e.s.ImageClassificationServiceImpl : Image dataset training started...Image dataset address path:/Users/liuhaihua/ai/ut-zap50k-images-square
2024-10-11T21:00:08.455+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.util.Platform : Ignore mismatching platform from: jar:file:/Users/liuhaihua/.m2/repository/ai/djl/pytorch/pytorch-native-cpu/2.0.1/pytorch-native-cpu-2.0.1-win-x86_64.jar!/native/lib/pytorch.properties
2024-10-11T21:00:09.240+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.pytorch.engine.PtEngine : PyTorch graph executor optimizer is enabled, this may impact your inference latency and throughput. See: https://docs.djl.ai/docs/development/inference_performance_optimization.html#graph-executor-optimization
2024-10-11T21:00:09.241+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.pytorch.engine.PtEngine : Number of inter-op threads is 4
2024-10-11T21:00:09.241+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.pytorch.engine.PtEngine : Number of intra-op threads is 4
2024-10-11T21:00:09.287+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Training on: cpu().
2024-10-11T21:00:09.290+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Load PyTorch Engine Version 1.13.1 in 0.044 ms.
Training: 100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.38
Validating: 100% |████████████████████████████████████████|
2024-10-11T22:42:48.142+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Epoch 1 finished.
2024-10-11T22:42:48.187+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.38
2024-10-11T22:42:48.189+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Validate: Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.24
Training: 5% |███ | Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.22

预测图片分类

使用上一步训练出来的模型进行预测

根据返回的结果看见鞋子的概率最高,由此可见该图片所属的鞋类为 Shoes

5.引用

相关推荐
David爱编程11 分钟前
Java 守护线程 vs 用户线程:一文彻底讲透区别与应用
java·后端
小奏技术28 分钟前
国内APP的隐私进步,从一个“营销授权”弹窗说起
后端·产品
小研说技术1 小时前
Spring AI存储向量数据
后端
苏三的开发日记1 小时前
jenkins部署ruoyi后台记录(jenkins与ruoyi后台处于同一台服务器)
后端
苏三的开发日记1 小时前
jenkins部署ruoyi后台记录(jenkins与ruoyi后台不在同一服务器)
后端
陈三一1 小时前
MyBatis OGNL 表达式避坑指南
后端·mybatis
whitepure1 小时前
万字详解JVM
java·jvm·后端
我崽不熬夜1 小时前
Java的条件语句与循环语句:如何高效编写你的程序逻辑?
java·后端·java ee
我崽不熬夜2 小时前
Java中的String、StringBuilder、StringBuffer:究竟该选哪个?
java·后端·java ee
我崽不熬夜2 小时前
Java中的基本数据类型和包装类:你了解它们的区别吗?
java·后端·java ee