DJL(Deep Java Library)是一个开源的深度学习框架,由AWS推出,DJL支持多种深度学习后端,包括但不限于:
MXNet:由Apache软件基金会支持的开源深度学习框架。
PyTorch:广泛使用的开源机器学习库,由Facebook的AI研究团队开发。
TensorFlow:由Google开发的另一个流行的开源机器学习框架。
DJL与Java生态系统紧密集成,可以与Spring Boot、Quarkus等Java框架协同工作。
maven
<!-- djl-->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.28.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.28.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-model-zoo</artifactId>
<version>0.28.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
<version>0.28.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>0.28.0</version>
</dependency>
<!-- /djl-->
Java DJL 架构图
plaintext
┌──────────────────────────────┐
│ ModelZoo │
├──────────────────────────────┤
│ Model │
└───────────────┬──────────────┘
│
┌─────────▼─────────┐
│ Engine │
└───────┬─┬─────────┘
│ │
┌───────▼─▼─────────┐
│ NDManager │
└───────┬─┬─────────┘
│ │
┌─────────▼─▼───────────┐
│ Dataset
└─────────┬─────────────┘
│
┌─────────▼─────────────┐
│ Trainer / Predictor │
└───────────────────────┘
主要组件详细描述
1. ModelZoo 和 Model
-
ModelZoo:提供多种预训练模型
ModelZoo
的功能- 模型发现与下载 :
ModelZoo
提供了一种机制,可以从多种来源(例如模型提供商、在线仓库等)发现和下载预训练模型。- 例如,可以从 AWS S3、Hugging Face、TensorFlow Hub 等平台下载模型。
- 模型加载 :
ModelZoo
提供了方便的方法来加载模型,用户可以根据需求加载不同类型的模型(例如图像分类模型、对象检测模型、自然语言处理模型等)。- 加载模型时,可以指定模型的名称、版本、以及模型的参数配置。
- 模型管理 :
ModelZoo
帮助用户管理已下载和加载的模型,可以方便地查看、更新和删除模型。- 通过这种方式,可以有效地管理本地的模型资源,避免重复下载和浪费存储空间。
示例
kotlinimport ai.djl.Application import ai.djl.Model import ai.djl.ModelException import ai.djl.modality.Classifications import ai.djl.modality.cv.Image import ai.djl.repository.zoo.Criteria import ai.djl.repository.zoo.ModelZoo import ai.djl.translate.TranslateException object ModelZooExample { @Throws(ModelException::class, TranslateException::class) @JvmStatic fun main(args: Array<String>) { // 定义模型的标准 val criteria: Criteria<Image, Classifications> = Criteria.builder() .optApplication(Application.CV.IMAGE_CLASSIFICATION) // 应用场景:图像分类 .setTypes(Image::class.java, Classifications::class.java) // 输入输出类型 .optFilter("backbone", "resnet50") // 模型过滤条件 .build() // 从 ModelZoo 加载模型 val model: Model = ModelZoo.loadModel(criteria) // 使用模型进行推理 // ... } }
ModelZoo
的类与接口ModelZoo
:核心类,提供模型的下载和加载功能。Criteria
:定义模型加载的标准和过滤条件,用于指定所需模型的应用场景、输入输出类型等。ModelLoader
:用于实际执行模型的下载和加载操作。
- 模型发现与下载 :
-
Model:表示一个深度学习模型的接口,包含模型的加载、保存和运行等操作。
ai.djl.ModelZoo
Key Methods:
Model loadModel(Criteria<?, ?> criteria)
: Loads a model based on the provided criteria.ModelInfo getModel(ModelId modelId)
: Retrieves information about a specific model using itsModelId
.Set<ModelId> listModels(ZooModel<?, ?> model)
: Lists all models in the zoo that match the given model.
ai.djl.ModelInfo
InterfaceModelInfo
provides metadata about a model, including its name, description, and input/output information.Key Methods:
String getName()
: Returns the name of the model.String getDescription()
: Provides a description of the model.Shape getInputShape()
: Returns the shape of the input tensor.Shape getOutputShape()
: Returns the shape of the output tensor.
ai.djl.ModelId
ClassModelId
uniquely identifies a model in the model zoo. It includes information about the model's group, name, and version.Key Fields:
String getGroup()
: Gets the group name of the model.String getName()
: Gets the name of the model.String getVersion()
: Gets the version of the model.
ai.djl.Application
EnumApplication
enumerates different types of applications supported by the model zoo, such as IMAGE_CLASSIFICATION, OBJECT_DETECTION, etc.Key Values:
CV.IMAGE_CLASSIFICATION
CV.OBJECT_DETECTION
NLP.TEXT_CLASSIFICATION
ai.djl.Criteria
ClassCriteria
is a builder for creating criteria objects used to filter and load models.Key Methods:
static Builder<?, ?> builder()
: Creates a new builder instance.Criteria<I, O> optApplication(Application application)
: Sets the application type.Criteria<I, O> optEngine(String engine)
: Specifies the engine to use (e.g., MXNet, PyTorch)
example
import ai.djl.Model import ai.djl.ModelException import ai.djl.modality.Classifications import ai.djl.modality.cv.Image import ai.djl.modality.cv.ImageFactory import ai.djl.ndarray.NDList import ai.djl.translate.TranslateException import ai.djl.translate.Translator import ai.djl.translate.TranslatorContext import java.io.IOException import java.nio.file.Paths object DjlExample { @JvmStatic fun main(args: Array<String>) { // 模型路径 val modelDir = Paths.get("models") val modelName = "resnet18" try { Model.newInstance(modelName).use { model -> // 加载模型 model.load(modelDir) // 加载输入图像 val img = ImageFactory.getInstance().fromFile(Paths.get("path/to/image.jpg")) // 获取预测器 val predictor = model.newPredictor(MyTranslator()) // 执行推理 val result = predictor.predict(img) println(result) } } catch (e: IOException) { e.printStackTrace() } catch (e: ModelException) { e.printStackTrace() } catch (e: TranslateException) { e.printStackTrace() } } // 自定义 Translator private class MyTranslator : Translator<Image?, Classifications?> { override fun processInput(ctx: TranslatorContext?, input: Image?): NDList { return NDList(input!!.toNDArray(ctx!!.ndManager)) } override fun processOutput(ctx: TranslatorContext, list: NDList): Classifications { val probabilitiesNDArray = list.singletonOrThrow().softmax(1) val labels: List<String> = List(100) { "name$it" } return Classifications(labels, probabilitiesNDArray) } } }
2. Dataset
-
常见的数据集类型:
- RandomAccessDataset :
RandomAccessDataset
是一种基本的数据集接口,适用于数据可以随机访问的情况,如数组或列表。- 它支持批处理(batching)、数据切片(slicing)等操作,适合大多数监督学习任务。
- IterableDataset :
IterableDataset
适用于数据不能随机访问的情况,如流数据或实时生成的数据。- 它通过迭代器(iterator)提供数据,适用于需要动态生成或处理的数据源。
- RecordDataset :
RecordDataset
是基于记录文件(record file)的数据集格式,常用于大规模数据处理。- 它可以高效地加载和处理数据记录,适用于分布式训练和大数据集的处理。
DJL 的数据集组件提供的功能包括:
- 数据加载和预处理 :
- 支持从多种数据源加载数据,如本地文件、远程服务器、数据库等。
- 提供数据预处理功能,如归一化、数据增强、特征提取等。
- 批处理(Batching) :
- 支持将数据分成小批次进行处理,适用于大规模数据集的训练。
- 提供灵活的批处理策略,可根据需要进行自定义。
- 数据变换(Transformations) :
- 提供多种数据变换功能,如图像变换、文本处理、数值处理等。
- 支持链式调用,将多个变换操作组合在一起,形成数据处理管道。
- 数据加载器(DataLoader) :
DataLoader
负责将数据集打包成批次,并在训练过程中按需提供数据。- 支持多线程数据加载,提高数据处理效率。
- RandomAccessDataset :
-
Dataset:定义数据集的抽象类,用户可以继承该类来实现自定义的数据集。
import ai.djl.Model; import ai.djl.ModelException; import ai.djl.inference.Predictor; import ai.djl.modality.Classifications; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelZoo; import ai.djl.translate.TranslateException; import java.io.IOException; import java.nio.file.Paths; public class DjlExample { public static void main(String[] args) throws IOException, ModelException, TranslateException { // 加载模型 Criteria<Image, Classifications> criteria = Criteria.builder() .optEngine("TensorFlow") // 选择引擎 .setTypes(Image.class, Classifications.class) .optModelPath(Paths.get("path/to/model")) .build(); try (Model model = ModelZoo.loadModel(criteria); Predictor<Image, Classifications> predictor = model.newPredictor()) { // 加载图像 Image img = ImageFactory.getInstance().fromFile(Paths.get("path/to/image.jpg")); // 进行推理 Classifications result = predictor.predict(img); System.out.println(result); } } }
import ai.djl.Application; import ai.djl.Model; import ai.djl.basicdataset.cv.classification.FashionMnist; import ai.djl.engine.Engine; import ai.djl.metric.Metrics; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.EasyTrain; import ai.djl.training.Trainer; import ai.djl.training.dataset.Batch; import ai.djl.training.dataset.Dataset; import ai.djl.training.listener.TrainingListener; import ai.djl.training.loss.Loss; import ai.djl.training.optimizer.Optimizer; import ai.djl.training.tracker.Tracker; import ai.djl.translate.TranslateException; import ai.djl.util.Pair; import java.io.IOException; public class DJLDatasetExample { public static void main(String[] args) throws IOException, TranslateException { NDManager manager = NDManager.newBaseManager(); FashionMnist fashionMnist = FashionMnist.builder() .optUsage(Dataset.Usage.TRAIN) .setSampling(32, true) // 32 is the batch size .optLimit(Long.MAX_VALUE) // Use this to limit the number of samples .build(); fashionMnist.prepare(); Model model = Model.newInstance("fashion-mnist-model"); TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .optOptimizer(Optimizer.sgd().setLearningRateTracker(Tracker.fixed(0.1f)).build()) .addTrainingListeners(TrainingListener.Defaults.logging()); try (Trainer trainer = model.newTrainer(config)) { trainer.initialize(new long[]{1, 28, 28}); // Example shape for image data Metrics metrics = new Metrics(); trainer.setMetrics(metrics); for (Batch batch : trainer.iterateDataset(fashionMnist)) { EasyTrain.trainBatch(trainer, batch); trainer.step(); batch.close(); } trainer.notifyListeners(listener -> listener.onTrainingEnd(trainer)); } } }
3. Engine 和 NDManager
-
Engine:DJL支持多个深度学习引擎,如MXNet、PyTorch、ONNX、TensorFlow,Engine接口提供统一的抽象,方便切换底层引擎。
-
NDManager:管理NDArray,用于处理多维数组,封装了底层的数组操作。
Using DJL Engine
javaimport ai.djl.Model import ai.djl.ModelException import ai.djl.ndarray.NDArray import ai.djl.ndarray.NDList import ai.djl.ndarray.types.Shape import ai.djl.translate.Batchifier import ai.djl.translate.TranslateException import ai.djl.translate.Translator import ai.djl.translate.TranslatorContext import java.io.IOException import java.nio.file.Paths object DJLEngineExample { @Throws(ModelException::class, TranslateException::class, IOException::class) @JvmStatic fun main(args: Array<String>) { // Initialize the model val model = Model.newInstance("model-name", "ai.djl.pytorch") // Assuming "model-name" is valid and using PyTorch engine // Load a pre-trained model model.load(Paths.get("path/to/your/model")) // Ensure the path is correct // Define a translator for data preprocessing and postprocessing val translator: Translator<Array<Float>, Float> = object : Translator<Array<Float>, Float> { override fun processInput(ctx: TranslatorContext, input: Array<Float>): NDList { val manager = ctx.ndManager val array: NDArray = manager.create(input.toFloatArray()).reshape(Shape(1, input.size.toLong())) // Reshape might be necessary return NDList(array) } override fun processOutput(ctx: TranslatorContext, list: NDList): Float { // Assuming the output is a single scalar value return list[0].getFloat() // Use getFloat() to get the scalar value } override fun getBatchifier(): Batchifier? { return null // Or implement batching if needed } } model.newPredictor(translator).use { predictor -> val input = arrayOf(1.0f, 2.0f, 3.0f) // Input should match the model's expected input shape val output = predictor.predict(input) println("Prediction: $output") } } }
Overview of NDManager
Key Features of NDManager:
- Memory Management: Automates the process of memory allocation and deallocation for NDArrays.
- Resource Scope: NDArrays created by an NDManager are tied to the lifecycle of that manager. When the manager is closed, all associated NDArrays are also released.
- Hierarchical Structure: NDManagers can create child managers, which can further manage their own NDArrays. This is useful for managing resources in complex workflows.
Using NDManager
javaimport ai.djl.ndarray.NDManager object NDManagerExample { @JvmStatic fun main(args: Array<String>) { NDManager.newBaseManager().use { manager -> val array = manager.create(floatArrayOf(1.0f, 2.0f, 3.0f)) println("Array: $array") // Perform operations val result = array.add(2.0f) println("Result: $result") } // No need to explicitly free the memory, it's handled by the NDManager } }
4. Trainer 和 Predictor
Trainer 类
提供训练模型的接口,包含优化器、损失函数和训练循环等功能。用于训练深度学习模型。它封装了训练过程中的一些常见操作,如前向传播、反向传播和参数更新。
主要功能包括:
- 模型的训练和验证
- 管理优化器和损失函数
- 提供易于使用的训练循环
代码演示
以下是使用 DJL 的 Trainer 类训练一个简单神经网络的示例代码:
java
import ai.djl.Model
import ai.djl.basicdataset.cv.classification.FashionMnist
import ai.djl.basicmodelzoo.basic.Mlp
import ai.djl.ndarray.types.Shape
import ai.djl.training.DefaultTrainingConfig
import ai.djl.training.TrainingConfig
import ai.djl.training.dataset.Dataset
import ai.djl.training.dataset.RandomAccessDataset
import ai.djl.training.listener.LoggingTrainingListener
import ai.djl.training.listener.TrainingListener
import ai.djl.training.loss.Loss
import ai.djl.training.optimizer.Optimizer
import ai.djl.training.tracker.FixedPerVarTracker
import ai.djl.training.util.ProgressBar
import ai.djl.translate.TranslateException
import java.io.IOException
import java.nio.file.Paths
object DjlTrainerDemo {
@Throws(IOException::class, TranslateException::class)
@JvmStatic
fun main(args: Array<String>) {
// Load dataset
val trainDataset: RandomAccessDataset =
FashionMnist.builder().optUsage(Dataset.Usage.TRAIN).setSampling(32, true).build()
trainDataset.prepare(ProgressBar())
// Define model
val model = Model.newInstance("mlp")
model.block = Mlp(28 * 28, 10, intArrayOf(128, 64))
// Define training configuration
val config: TrainingConfig = DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.optOptimizer(
Optimizer.sgd()
.setLearningRateTracker(
FixedPerVarTracker.builder()
.setDefaultValue(0.01f)
.build()
).build()
)
.addTrainingListeners(LoggingTrainingListener())
model.newTrainer(config).use { trainer ->
trainer.initialize(Shape(1, (28 * 28).toLong()))
for (epoch in 0..9) {
for (batch in trainer.iterateDataset(trainDataset)) {
trainer.step()
batch.close()
}
trainer.notifyListeners { listener: TrainingListener ->
listener.onEpoch(trainer)
}
}
model.save(Paths.get("model"), "mlp")
}
}
}
Predictor 类
用于模型推理,接收输入数据并返回预测结果。用于对训练好的模型进行推理。它提供了一个简单的接口,用于将输入数据传递给模型并获取预测结果。
主要功能包括:
- 加载模型进行推理
- 处理输入和输出数据的转换
代码演示
java
import ai.djl.Model
import ai.djl.modality.Classifications
import ai.djl.ndarray.NDArray
import ai.djl.ndarray.NDList
import ai.djl.ndarray.NDManager
import ai.djl.ndarray.types.Shape
import ai.djl.translate.Batchifier
import ai.djl.translate.TranslateException
import ai.djl.translate.Translator
import ai.djl.translate.TranslatorContext
import java.io.IOException
import java.nio.file.Paths
object DjlPredictorDemo {
@Throws(IOException::class, TranslateException::class)
@JvmStatic
fun main(args: Array<String>) {
// Load model
val model = Model.newInstance("mlp")
model.load(Paths.get("model"), "mlp")
// Define Translator
val translator: Translator<NDArray, Classifications> = object : Translator<NDArray, Classifications> {
override fun processInput(ctx: TranslatorContext, input: NDArray): NDList {
return NDList(input.reshape(Shape(1, (28 * 28).toLong())))
}
override fun processOutput(ctx: TranslatorContext, list: NDList): Classifications {
// Assuming the output NDArray is the first element in NDList
val probabilities = list.singletonOrThrow()
return Classifications(listOf("Label1", "Label2"), probabilities) // Example labels
}
override fun getBatchifier(): Batchifier {
return Batchifier.STACK
}
}
model.newPredictor(translator).use { predictor ->
val manager = NDManager.newBaseManager()
val array = manager.ones(Shape(1, (28 * 28).toLong()))
val classifications = predictor.predict(array)
println(classifications)
}
}
}