Spring Boot集成DJL实现图片分类

1.什么是DJL?

DJL 是一个很新的项目,在2019年12月初的AWS re: invest大会上才正式的发布出来。。简单来说,DJL是一个使用Java API简化模型训练、测试、部署和使用深度学习模型进行推理的开源库深度学习工具包,开源的许可协议是Apache-2.0。对于Java开发者而言,可以在Java中开发及应用原生的机器学习和深度学习模型,同时简化了深度学习开发的难度。通过DJL提供的直观的、高级的API,Java开发人员可以训练自己的模型,或者利用数据科学家用Python预先训练好的模型来进行推理。如果您恰好是对学习深度学习感兴趣的Java开发者,那么DJL无疑将是开始深度学习应用的一个最好的起点。

2.数据准备

下载训练集

wget https://vision.cs.utexas.edu/projects/finegrained/utzap50k/ut-zap50k-images-square.zip

解压,方便后面训练模型使用

unzip ut-zap50k-images-square.zip  

3.代码工程

实验目的

基于djl实现图片分类

pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>3.2.1</version>
    </parent>

    <modelVersion>4.0.0</modelVersion>

    <artifactId>djl</artifactId>

    <properties>
        <maven.compiler.source>17</maven.compiler.source>
        <maven.compiler.target>17</maven.compiler.target>
    </properties>
    <dependencies>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <optional>true</optional>
        </dependency>

        <!-- DJL -->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
        </dependency>

        <!-- pytorch-engine-->
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <scope>runtime</scope>
        </dependency>



    </dependencies>

    <profiles>
        <profile>
            <id>windows</id>
            <activation>
                <activeByDefault>true</activeByDefault>
            </activation>
            <dependencies>
                <!-- Windows CPU -->
                <dependency>
                    <groupId>ai.djl.pytorch</groupId>
                    <artifactId>pytorch-native-cpu</artifactId>
                    <classifier>win-x86_64</classifier>
                    <scope>runtime</scope>
                    <version>2.0.1</version>
                </dependency>
                <dependency>
                    <groupId>ai.djl.pytorch</groupId>
                    <artifactId>pytorch-jni</artifactId>
                    <version>2.0.1-0.23.0</version>
                    <scope>runtime</scope>
                </dependency>
            </dependencies>
        </profile>
        <profile>
            <id>centos7</id>
            <activation>
                <activeByDefault>false</activeByDefault>
            </activation>
            <dependencies>
                <!-- For Pre-CXX11 build (CentOS7)-->
                <dependency>
                    <groupId>ai.djl.pytorch</groupId>
                    <artifactId>pytorch-native-cpu-precxx11</artifactId>
                    <classifier>linux-x86_64</classifier>
                    <version>2.0.1</version>
                    <scope>runtime</scope>
                </dependency>
                <dependency>
                    <groupId>ai.djl.pytorch</groupId>
                    <artifactId>pytorch-jni</artifactId>
                    <version>2.0.1-0.23.0</version>
                    <scope>runtime</scope>
                </dependency>
            </dependencies>
        </profile>
        <profile>
            <id>linux</id>
            <activation>
                <activeByDefault>false</activeByDefault>
            </activation>
            <dependencies>
                <!-- Linux CPU -->
                <dependency>
                    <groupId>ai.djl.pytorch</groupId>
                    <artifactId>pytorch-native-cpu</artifactId>
                    <classifier>linux-x86_64</classifier>
                    <scope>runtime</scope>
                    <version>2.0.1</version>
                </dependency>
                <dependency>
                    <groupId>ai.djl.pytorch</groupId>
                    <artifactId>pytorch-jni</artifactId>
                    <version>2.0.1-0.23.0</version>
                    <scope>runtime</scope>
                </dependency>
            </dependencies>
        </profile>
        <profile>
            <id>aarch64</id>
            <activation>
                <activeByDefault>false</activeByDefault>
            </activation>
            <dependencies>
                <!-- For aarch64 build-->
                <dependency>
                    <groupId>ai.djl.pytorch</groupId>
                    <artifactId>pytorch-native-cpu-precxx11</artifactId>
                    <classifier>linux-aarch64</classifier>
                    <scope>runtime</scope>
                    <version>2.0.1</version>
                </dependency>
                <dependency>
                    <groupId>ai.djl.pytorch</groupId>
                    <artifactId>pytorch-jni</artifactId>
                    <version>2.0.1-0.23.0</version>
                    <scope>runtime</scope>
                </dependency>
            </dependencies>
        </profile>
    </profiles>

    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>ai.djl</groupId>
                <artifactId>bom</artifactId>
                <version>0.23.0</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>



</project>

conotroller

package com.et.controller;

import ai.djl.MalformedModelException;
import ai.djl.translate.TranslateException;
import com.et.service.ImageClassificationService;

import lombok.RequiredArgsConstructor;
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.stream.Stream;


@RestController
@RequiredArgsConstructor
public class ImageClassificationController {

    private final ImageClassificationService imageClassificationService;

    @PostMapping(path = "/analyze")
    public String predict(@RequestPart("image") MultipartFile image,
                          @RequestParam(defaultValue = "/home/djl-test/models") String modePath)
            throws TranslateException,
            MalformedModelException,
            IOException {
        return imageClassificationService.predict(image, modePath);
    }

    @PostMapping(path = "/training")
    public String training(@RequestParam(defaultValue = "/home/djl-test/images-test")
                           String datasetRoot,
                           @RequestParam(defaultValue = "/home/djl-test/models") String modePath) throws TranslateException, IOException {
        return imageClassificationService.training(datasetRoot, modePath);
    }

    @GetMapping("/download")
    public ResponseEntity<Resource> downloadFile(@RequestParam(defaultValue = "/home/djl-test/images-test") String directoryPath) {
        List<String> imgPathList = new ArrayList<>();
        try (Stream<Path> paths = Files.walk(Paths.get(directoryPath))) {
            // Filter only regular files (excluding directories)
            paths.filter(Files::isRegularFile)
                    .forEach(c-> imgPathList.add(c.toString()));
        } catch (IOException e) {
            return ResponseEntity.status(500).build();
        }
        Random random = new Random();
        String filePath = imgPathList.get(random.nextInt(imgPathList.size()));
        Path file = Paths.get(filePath);
        Resource resource = new FileSystemResource(file.toFile());

        if (!resource.exists()) {
            return ResponseEntity.notFound().build();
        }
        HttpHeaders headers = new HttpHeaders();
        headers.add(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=" + file.getFileName().toString());
        headers.add(HttpHeaders.CONTENT_TYPE, MediaType.IMAGE_JPEG_VALUE);

        try {
            return ResponseEntity.ok()
                    .headers(headers)
                    .contentLength(resource.contentLength())
                    .body(resource);
        } catch (IOException e) {
            return ResponseEntity.status(500).build();
        }
    }

}

service

接口

package com.et.service;

import ai.djl.MalformedModelException;
import ai.djl.translate.TranslateException;
import org.springframework.web.multipart.MultipartFile;

import java.io.IOException;

public interface ImageClassificationService {
    public String predict(MultipartFile image, String modePath) throws IOException, MalformedModelException, TranslateException;
    public String training(String datasetRoot, String modePath) throws TranslateException, IOException;
}

实现类

package com.et.service;

import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.ImageFolder;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.*;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import com.et.Models;
import lombok.Cleanup;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;

import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Path;
import java.nio.file.Paths;


@Slf4j
@Service
public class ImageClassificationServiceImpl implements ImageClassificationService {

    // represents number of training samples processed before the model is updated
    private static final int BATCH_SIZE = 32;

    // the number of passes over the complete dataset
    private static final int EPOCHS = 2;

    //the number of classification labels: boots, sandals, shoes, slippers
    @Value("${djl.num-of-output:4}")
    public int numOfOutput;

    @Override
    public String predict(MultipartFile image, String modePath) throws IOException, MalformedModelException, TranslateException {
        @Cleanup
        InputStream is = image.getInputStream();
        Path modelDir = Paths.get(modePath);
        BufferedImage bi = ImageIO.read(is);
        Image img = ImageFactory.getInstance().fromImage(bi);
        // empty model instance
        try (Model model = Models.getModel(numOfOutput)) {
            // load the model
            model.load(modelDir, Models.MODEL_NAME);
            // define a translator for pre and post processing
            // out of the box this translator converts images to ResNet friendly ResNet 18 shape
            Translator<Image, Classifications> translator =
                    ImageClassificationTranslator.builder()
                            .addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT))
                            .addTransform(new ToTensor())
                            .optApplySoftmax(true)
                            .build();
            // run the inference using a Predictor
            try (Predictor<Image, Classifications> predictor = model.newPredictor(translator)) {
                // holds the probability score per label
                Classifications predictResult = predictor.predict(img);
                log.info("reusult={}",predictResult.toJson());
                return predictResult.toJson();
            }
        }
    }

    @Override
    public String training(String datasetRoot, String modePath) throws TranslateException, IOException {
        log.info("Image dataset training started...Image dataset address path:{}",datasetRoot);
        // the location to save the model
        Path modelDir = Paths.get(modePath);

        // create ImageFolder dataset from directory
        ImageFolder dataset = initDataset(datasetRoot);
        // Split the dataset set into training dataset and validate dataset
        RandomAccessDataset[] datasets = dataset.randomSplit(8, 2);

        // set loss function, which seeks to minimize errors
        // loss function evaluates model's predictions against the correct answer (during training)
        // higher numbers are bad - means model performed poorly; indicates more errors; want to
        // minimize errors (loss)
        Loss loss = Loss.softmaxCrossEntropyLoss();

        // setting training parameters (ie hyperparameters)
        TrainingConfig config = setupTrainingConfig(loss);

        try (Model model = Models.getModel(numOfOutput); // empty model instance to hold patterns
             Trainer trainer = model.newTrainer(config)) {
            // metrics collect and report key performance indicators, like accuracy
            trainer.setMetrics(new Metrics());

            Shape inputShape = new Shape(1, 3, Models.IMAGE_HEIGHT, Models.IMAGE_HEIGHT);

            // initialize trainer with proper input shape
            trainer.initialize(inputShape);

            // find the patterns in data
            EasyTrain.fit(trainer, EPOCHS, datasets[0], datasets[1]);

            // set model properties
            TrainingResult result = trainer.getTrainingResult();
            model.setProperty("Epoch", String.valueOf(EPOCHS));
            model.setProperty(
                    "Accuracy", String.format("%.5f", result.getValidateEvaluation("Accuracy")));
            model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));

            // save the model after done training for inference later
            // model saved as shoeclassifier-0000.params
            model.save(modelDir, Models.MODEL_NAME);
            // save labels into model directory
            Models.saveSynset(modelDir, dataset.getSynset());
            log.info("Image dataset training completed......");
            return String.join("\n", dataset.getSynset());
        }
    }

    private ImageFolder initDataset(String datasetRoot)
            throws IOException, TranslateException {
        ImageFolder dataset =
                ImageFolder.builder()
                        // retrieve the data
                        .setRepositoryPath(Paths.get(datasetRoot))
                        .optMaxDepth(10)
                        .addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT))
                        .addTransform(new ToTensor())
                        // random sampling; don't process the data in order
                        .setSampling(BATCH_SIZE, true)
                        .build();
        dataset.prepare();
        return dataset;
    }

    private TrainingConfig setupTrainingConfig(Loss loss) {
        return new DefaultTrainingConfig(loss)
                .addEvaluator(new Accuracy())
                .addTrainingListeners(TrainingListener.Defaults.logging());
    }

}

application.yaml

server:
  port: 8888
spring:
  application:
    name: djl-image-classification-demo
  servlet:
    multipart:
      max-file-size: 100MB
      max-request-size: 100MB
  mvc:
    pathmatch:
      matching-strategy: ant_path_matcher

启动类

package com.et;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
public class DemoApplication {

   public static void main(String[] args) {
      SpringApplication.run(DemoApplication.class, args);
   }
}

以上只是一些关键代码,所有代码请参见下面代码仓库

代码仓库

4.测试

启动Spring Boot应用

训练模型

使用之前下载的数据集

控制台输出日志,如果没有gpu的话,训练有点慢,估计要等一会

2024-10-11T21:00:05.407+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] c.e.s.ImageClassificationServiceImpl : Image dataset training started...Image dataset address path:/Users/liuhaihua/ai/ut-zap50k-images-square
2024-10-11T21:00:08.455+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.util.Platform : Ignore mismatching platform from: jar:file:/Users/liuhaihua/.m2/repository/ai/djl/pytorch/pytorch-native-cpu/2.0.1/pytorch-native-cpu-2.0.1-win-x86_64.jar!/native/lib/pytorch.properties
2024-10-11T21:00:09.240+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.pytorch.engine.PtEngine : PyTorch graph executor optimizer is enabled, this may impact your inference latency and throughput. See: https://docs.djl.ai/docs/development/inference_performance_optimization.html#graph-executor-optimization
2024-10-11T21:00:09.241+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.pytorch.engine.PtEngine : Number of inter-op threads is 4
2024-10-11T21:00:09.241+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.pytorch.engine.PtEngine : Number of intra-op threads is 4
2024-10-11T21:00:09.287+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Training on: cpu().
2024-10-11T21:00:09.290+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Load PyTorch Engine Version 1.13.1 in 0.044 ms.
Training: 100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.38
Validating: 100% |████████████████████████████████████████|
2024-10-11T22:42:48.142+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Epoch 1 finished.
2024-10-11T22:42:48.187+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.38
2024-10-11T22:42:48.189+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Validate: Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.24
Training: 5% |███ | Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.22

预测图片分类

使用上一步训练出来的模型进行预测

根据返回的结果看见鞋子的概率最高,由此可见该图片所属的鞋类为 Shoes

5.引用

相关推荐
专职6 分钟前
spring boot中实现手动分页
java·spring boot·后端
Ciderw15 分钟前
Go中的三种锁
开发语言·c++·后端·golang·互斥锁·
m0_748246351 小时前
SpringBoot返回文件让前端下载的几种方式
前端·spring boot·后端
m0_748230441 小时前
创建一个Spring Boot项目
java·spring boot·后端
卿着飞翔1 小时前
Java面试题2025-Mysql
java·spring boot·后端
C++小厨神1 小时前
C#语言的学习路线
开发语言·后端·golang
计算机-秋大田2 小时前
基于微信小程序的校园失物招领系统设计与实现(LW+源码+讲解)
java·前端·后端·微信小程序·小程序·课程设计
綦枫Maple2 小时前
Spring Boot(6)解决ruoyi框架连续快速发送post请求时,弹出“数据正在处理,请勿重复提交”提醒的问题
java·spring boot·后端
缺的不是资料,是学习的心2 小时前
使用qwen作为基座训练分类大模型
python·机器学习·分类
码至终章2 小时前
kafka常用目录文件解析
java·分布式·后端·kafka·mq