1、pom文件
html
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.example</groupId>
<artifactId>learndjl</artifactId>
<version>1.0</version>
<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
<djl.version>0.35.1</djl.version>
</properties>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>bom</artifactId>
<version>${djl.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>commons-cli</groupId>
<artifactId>commons-cli</artifactId>
<version>1.9.0</version>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.17.0</version>
</dependency>
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-slf4j2-impl</artifactId>
<version>2.24.1</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
</dependency>
<dependency>
<groupId>ai.djl.timeseries</groupId>
<artifactId>timeseries</artifactId>
</dependency>
<dependency>
<groupId>ai.djl.huggingface</groupId>
<artifactId>tokenizers</artifactId>
</dependency>
<dependency>
<groupId>ai.djl.audio</groupId>
<artifactId>audio</artifactId>
</dependency>
<!-- MXNet -->
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-model-zoo</artifactId>
<scope>runtime</scope>
</dependency>
<!-- Pytorch -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-model-zoo</artifactId>
<scope>runtime</scope>
</dependency>
<!-- TensorFlow -->
<dependency>
<groupId>ai.djl.tensorflow</groupId>
<artifactId>tensorflow-model-zoo</artifactId>
<scope>runtime</scope>
</dependency>
<!-- ONNXRuntime -->
<dependency>
<groupId>ai.djl.onnxruntime</groupId>
<artifactId>onnxruntime-engine</artifactId>
<scope>runtime</scope>
<exclusions>
<exclusion>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.20.0</version>
</dependency>
<dependency>
<groupId>org.testng</groupId>
<artifactId>testng</artifactId>
<version>7.10.2</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
2、java代码
java
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.examples.inference.cv;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.translator.YoloV8TranslatorFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
/** An example of inference using an yolov8 model. */
public final class YoloDetection {
private static final Logger logger = LoggerFactory.getLogger(YoloDetection.class);
private YoloDetection() {}
public static void main(String[] args) throws IOException, ModelException, TranslateException {
DetectedObjects detection = predict();
logger.info("{}", detection);
}
public static DetectedObjects predict() throws IOException, ModelException, TranslateException {
Path imgPath = Paths.get("src/test/resources/yolov8_test.jpg");
Image img = ImageFactory.getInstance().fromFile(imgPath);
// Use DJL OnnxRuntime model zoo model, model can be found:
// https://mlrepo.djl.ai/model/cv/object_detection/ai/djl/onnxruntime/yolo11n/0.0.1/yolo11n.zip
Criteria<Path, DetectedObjects> criteria =
Criteria.builder()
.setTypes(Path.class, DetectedObjects.class)
// .optModelUrls("djl://ai.djl.pytorch/yolo11n")
.optModelUrls("https://mlrepo.djl.ai/model/cv/object_detection/ai/djl/onnxruntime/yolo11n/0.0.1/yolo11n.zip")
.optEngine("OnnxRuntime")
.optArgument("width", 640)
.optArgument("height", 640)
.optArgument("resize", true)
.optArgument("toTensor", true)
.optArgument("applyRatio", true)
.optArgument("threshold", 0.6f)
// for performance optimization maxBox parameter can reduce number of
// considered boxes from 8400
.optArgument("maxBox", 1000)
.optTranslatorFactory(new YoloV8TranslatorFactory())
.optProgress(new ProgressBar())
.build();
try (ZooModel<Path, DetectedObjects> model = criteria.loadModel();
Predictor<Path, DetectedObjects> predictor = model.newPredictor()) {
Path outputPath = Paths.get("build/output");
Files.createDirectories(outputPath);
DetectedObjects detection = predictor.predict(imgPath);
if (detection.getNumberOfObjects() > 0) {
img.drawBoundingBoxes(detection);
Path output = outputPath.resolve("yolo_detected.png");
try (OutputStream os = Files.newOutputStream(output)) {
img.save(os, "png");
}
logger.info("Detected object saved in: {}", output);
}
return detection;
}
}
}