java(kotlin) ai框架djl

DJL(Deep Java Library)是一个开源的深度学习框架,由AWS推出,DJL支持多种深度学习后端,包括但不限于:

MXNet:由Apache软件基金会支持的开源深度学习框架。

PyTorch:广泛使用的开源机器学习库,由Facebook的AI研究团队开发。

TensorFlow:由Google开发的另一个流行的开源机器学习框架。

DJL与Java生态系统紧密集成,可以与Spring Boot、Quarkus等Java框架协同工作。

maven

 <!--        djl-->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.28.0</version>
        </dependency>
  <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>0.28.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-model-zoo</artifactId>
            <version>0.28.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
            <version>0.28.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
            <version>0.28.0</version>
        </dependency>
        <!--        /djl-->

Java DJL 架构图

plaintext 复制代码
┌──────────────────────────────┐
│          ModelZoo            │
├──────────────────────────────┤
│            Model             │
└───────────────┬──────────────┘
                │
      ┌─────────▼─────────┐
      │       Engine      │
      └───────┬─┬─────────┘
              │ │
      ┌───────▼─▼─────────┐
      │     NDManager     │
      └───────┬─┬─────────┘
              │ │
    ┌─────────▼─▼───────────┐
    │    Dataset 
    └─────────┬─────────────┘
              │
    ┌─────────▼─────────────┐
    │  Trainer / Predictor  │
    └───────────────────────┘

主要组件详细描述

1. ModelZoo 和 Model
  • ModelZoo:提供多种预训练模型

    ModelZoo 的功能
    1. 模型发现与下载
      • ModelZoo 提供了一种机制,可以从多种来源(例如模型提供商、在线仓库等)发现和下载预训练模型。
      • 例如,可以从 AWS S3、Hugging Face、TensorFlow Hub 等平台下载模型。
    2. 模型加载
      • ModelZoo 提供了方便的方法来加载模型,用户可以根据需求加载不同类型的模型(例如图像分类模型、对象检测模型、自然语言处理模型等)。
      • 加载模型时,可以指定模型的名称、版本、以及模型的参数配置。
    3. 模型管理
      • ModelZoo 帮助用户管理已下载和加载的模型,可以方便地查看、更新和删除模型。
      • 通过这种方式,可以有效地管理本地的模型资源,避免重复下载和浪费存储空间。

    示例

    kotlin 复制代码
    import ai.djl.Application
    import ai.djl.Model
    import ai.djl.ModelException
    import ai.djl.modality.Classifications
    import ai.djl.modality.cv.Image
    import ai.djl.repository.zoo.Criteria
    import ai.djl.repository.zoo.ModelZoo
    import ai.djl.translate.TranslateException
    
    
    object ModelZooExample {
        @Throws(ModelException::class, TranslateException::class)
        @JvmStatic
        fun main(args: Array<String>) {
            // 定义模型的标准
            val criteria: Criteria<Image, Classifications> = Criteria.builder()
                .optApplication(Application.CV.IMAGE_CLASSIFICATION) // 应用场景:图像分类
                .setTypes(Image::class.java, Classifications::class.java) // 输入输出类型
                .optFilter("backbone", "resnet50") // 模型过滤条件
                .build()
    
            // 从 ModelZoo 加载模型
            val model: Model = ModelZoo.loadModel(criteria)
    
            // 使用模型进行推理
            // ...
        }
    }

    ModelZoo 的类与接口

    • ModelZoo:核心类,提供模型的下载和加载功能。
    • Criteria:定义模型加载的标准和过滤条件,用于指定所需模型的应用场景、输入输出类型等。
    • ModelLoader:用于实际执行模型的下载和加载操作。
  • Model:表示一个深度学习模型的接口,包含模型的加载、保存和运行等操作。

    ai.djl.ModelZoo

    Key Methods:
    • Model loadModel(Criteria<?, ?> criteria): Loads a model based on the provided criteria.
    • ModelInfo getModel(ModelId modelId): Retrieves information about a specific model using its ModelId.
    • Set<ModelId> listModels(ZooModel<?, ?> model): Lists all models in the zoo that match the given model.

    ai.djl.ModelInfo Interface

    ModelInfo provides metadata about a model, including its name, description, and input/output information.

    Key Methods:
    • String getName(): Returns the name of the model.
    • String getDescription(): Provides a description of the model.
    • Shape getInputShape(): Returns the shape of the input tensor.
    • Shape getOutputShape(): Returns the shape of the output tensor.

    ai.djl.ModelId Class

    ModelId uniquely identifies a model in the model zoo. It includes information about the model's group, name, and version.

    Key Fields:
    • String getGroup(): Gets the group name of the model.
    • String getName(): Gets the name of the model.
    • String getVersion(): Gets the version of the model.

    ai.djl.Application Enum

    Application enumerates different types of applications supported by the model zoo, such as IMAGE_CLASSIFICATION, OBJECT_DETECTION, etc.

    Key Values:
    • CV.IMAGE_CLASSIFICATION
    • CV.OBJECT_DETECTION
    • NLP.TEXT_CLASSIFICATION

    ai.djl.Criteria Class

    Criteria is a builder for creating criteria objects used to filter and load models.

    Key Methods:
    • static Builder<?, ?> builder(): Creates a new builder instance.
    • Criteria<I, O> optApplication(Application application): Sets the application type.
    • Criteria<I, O> optEngine(String engine): Specifies the engine to use (e.g., MXNet, PyTorch)
    example
      import ai.djl.Model
      import ai.djl.ModelException
      import ai.djl.modality.Classifications
      import ai.djl.modality.cv.Image
      import ai.djl.modality.cv.ImageFactory
      import ai.djl.ndarray.NDList
      import ai.djl.translate.TranslateException
      import ai.djl.translate.Translator
      import ai.djl.translate.TranslatorContext
      import java.io.IOException
      import java.nio.file.Paths
    
    
      object DjlExample {
          @JvmStatic
          fun main(args: Array<String>) {
              // 模型路径
              val modelDir = Paths.get("models")
              val modelName = "resnet18"
              try {
                  Model.newInstance(modelName).use { model ->
                      // 加载模型
                      model.load(modelDir)
    
                      // 加载输入图像
                      val img =
                          ImageFactory.getInstance().fromFile(Paths.get("path/to/image.jpg"))
    
                      // 获取预测器
                      val predictor =
                          model.newPredictor(MyTranslator())
    
                      // 执行推理
                      val result = predictor.predict(img)
                      println(result)
                  }
              } catch (e: IOException) {
                  e.printStackTrace()
              } catch (e: ModelException) {
                  e.printStackTrace()
              } catch (e: TranslateException) {
                  e.printStackTrace()
              }
          }
    
          // 自定义 Translator
          private class MyTranslator : Translator<Image?, Classifications?> {
              override fun processInput(ctx: TranslatorContext?, input: Image?): NDList {
                  return NDList(input!!.toNDArray(ctx!!.ndManager))
              }
    
              override fun processOutput(ctx: TranslatorContext, list: NDList): Classifications {
                  val probabilitiesNDArray = list.singletonOrThrow().softmax(1)
                  val labels: List<String> = List(100) { "name$it" }
                  return Classifications(labels, probabilitiesNDArray)
              }
          }
      }
    
2. Dataset
  • 常见的数据集类型:

    1. RandomAccessDataset :
      • RandomAccessDataset 是一种基本的数据集接口,适用于数据可以随机访问的情况,如数组或列表。
      • 它支持批处理(batching)、数据切片(slicing)等操作,适合大多数监督学习任务。
    2. IterableDataset :
      • IterableDataset 适用于数据不能随机访问的情况,如流数据或实时生成的数据。
      • 它通过迭代器(iterator)提供数据,适用于需要动态生成或处理的数据源。
    3. RecordDataset :
      • RecordDataset 是基于记录文件(record file)的数据集格式,常用于大规模数据处理。
      • 它可以高效地加载和处理数据记录,适用于分布式训练和大数据集的处理。

    DJL 的数据集组件提供的功能包括:

    1. 数据加载和预处理 :
      • 支持从多种数据源加载数据,如本地文件、远程服务器、数据库等。
      • 提供数据预处理功能,如归一化、数据增强、特征提取等。
    2. 批处理(Batching) :
      • 支持将数据分成小批次进行处理,适用于大规模数据集的训练。
      • 提供灵活的批处理策略,可根据需要进行自定义。
    3. 数据变换(Transformations) :
      • 提供多种数据变换功能,如图像变换、文本处理、数值处理等。
      • 支持链式调用,将多个变换操作组合在一起,形成数据处理管道。
    4. 数据加载器(DataLoader) :
      • DataLoader 负责将数据集打包成批次,并在训练过程中按需提供数据。
      • 支持多线程数据加载,提高数据处理效率。
  • Dataset:定义数据集的抽象类,用户可以继承该类来实现自定义的数据集。

      import ai.djl.Model;
      import ai.djl.ModelException;
      import ai.djl.inference.Predictor;
      import ai.djl.modality.Classifications;
      import ai.djl.modality.cv.Image;
      import ai.djl.modality.cv.ImageFactory;
      import ai.djl.repository.zoo.Criteria;
      import ai.djl.repository.zoo.ModelZoo;
      import ai.djl.translate.TranslateException;
    
      import java.io.IOException;
      import java.nio.file.Paths;
    
      public class DjlExample {
          public static void main(String[] args) throws IOException, ModelException, TranslateException {
              // 加载模型
              Criteria<Image, Classifications> criteria = Criteria.builder()
                      .optEngine("TensorFlow") // 选择引擎
                      .setTypes(Image.class, Classifications.class)
                      .optModelPath(Paths.get("path/to/model"))
                      .build();
              
              try (Model model = ModelZoo.loadModel(criteria);
                   Predictor<Image, Classifications> predictor = model.newPredictor()) {
                  
                  // 加载图像
                  Image img = ImageFactory.getInstance().fromFile(Paths.get("path/to/image.jpg"));
                  
                  // 进行推理
                  Classifications result = predictor.predict(img);
                  System.out.println(result);
              }
          }
      }
    
      import ai.djl.Application;
      import ai.djl.Model;
      import ai.djl.basicdataset.cv.classification.FashionMnist;
      import ai.djl.engine.Engine;
      import ai.djl.metric.Metrics;
      import ai.djl.ndarray.NDArray;
      import ai.djl.ndarray.NDManager;
      import ai.djl.training.DefaultTrainingConfig;
      import ai.djl.training.EasyTrain;
      import ai.djl.training.Trainer;
      import ai.djl.training.dataset.Batch;
      import ai.djl.training.dataset.Dataset;
      import ai.djl.training.listener.TrainingListener;
      import ai.djl.training.loss.Loss;
      import ai.djl.training.optimizer.Optimizer;
      import ai.djl.training.tracker.Tracker;
      import ai.djl.translate.TranslateException;
      import ai.djl.util.Pair;
    
      import java.io.IOException;
    
      public class DJLDatasetExample {
    
          public static void main(String[] args) throws IOException, TranslateException {
              NDManager manager = NDManager.newBaseManager();
    
              FashionMnist fashionMnist = FashionMnist.builder()
                      .optUsage(Dataset.Usage.TRAIN)
                      .setSampling(32, true) // 32 is the batch size
                      .optLimit(Long.MAX_VALUE) // Use this to limit the number of samples
                      .build();
    
              fashionMnist.prepare();
    
              Model model = Model.newInstance("fashion-mnist-model");
              TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                      .optOptimizer(Optimizer.sgd().setLearningRateTracker(Tracker.fixed(0.1f)).build())
                      .addTrainingListeners(TrainingListener.Defaults.logging());
    
              try (Trainer trainer = model.newTrainer(config)) {
                  trainer.initialize(new long[]{1, 28, 28}); // Example shape for image data
    
                  Metrics metrics = new Metrics();
                  trainer.setMetrics(metrics);
    
                  for (Batch batch : trainer.iterateDataset(fashionMnist)) {
                      EasyTrain.trainBatch(trainer, batch);
                      trainer.step();
                      batch.close();
                  }
    
                  trainer.notifyListeners(listener -> listener.onTrainingEnd(trainer));
              }
          }
      }
    

3. Engine 和 NDManager
  • Engine:DJL支持多个深度学习引擎,如MXNet、PyTorch、ONNX、TensorFlow,Engine接口提供统一的抽象,方便切换底层引擎。

  • NDManager:管理NDArray,用于处理多维数组,封装了底层的数组操作。

    Using DJL Engine
    java 复制代码
    import ai.djl.Model
    import ai.djl.ModelException
    import ai.djl.ndarray.NDArray
    import ai.djl.ndarray.NDList
    import ai.djl.ndarray.types.Shape
    import ai.djl.translate.Batchifier
    import ai.djl.translate.TranslateException
    import ai.djl.translate.Translator
    import ai.djl.translate.TranslatorContext
    import java.io.IOException
    import java.nio.file.Paths
    
    object DJLEngineExample {
        @Throws(ModelException::class, TranslateException::class, IOException::class)
        @JvmStatic
        fun main(args: Array<String>) {
            // Initialize the model
            val model = Model.newInstance("model-name", "ai.djl.pytorch") // Assuming "model-name" is valid and using PyTorch engine
    
            // Load a pre-trained model
            model.load(Paths.get("path/to/your/model")) // Ensure the path is correct
    
            // Define a translator for data preprocessing and postprocessing
            val translator: Translator<Array<Float>, Float> = object : Translator<Array<Float>, Float> {
                override fun processInput(ctx: TranslatorContext, input: Array<Float>): NDList {
                    val manager = ctx.ndManager
                    val array: NDArray = manager.create(input.toFloatArray()).reshape(Shape(1, input.size.toLong())) // Reshape might be necessary
                    return NDList(array)
                }
    
                override fun processOutput(ctx: TranslatorContext, list: NDList): Float {
                    // Assuming the output is a single scalar value
                    return list[0].getFloat() // Use getFloat() to get the scalar value
                }
    
                override fun getBatchifier(): Batchifier? {
                    return null // Or implement batching if needed
                }
            }
    
            model.newPredictor(translator).use { predictor ->
                val input = arrayOf(1.0f, 2.0f, 3.0f) // Input should match the model's expected input shape
                val output = predictor.predict(input)
                println("Prediction: $output")
            }
        }
    }
    Overview of NDManager
    Key Features of NDManager:
    1. Memory Management: Automates the process of memory allocation and deallocation for NDArrays.
    2. Resource Scope: NDArrays created by an NDManager are tied to the lifecycle of that manager. When the manager is closed, all associated NDArrays are also released.
    3. Hierarchical Structure: NDManagers can create child managers, which can further manage their own NDArrays. This is useful for managing resources in complex workflows.
    Using NDManager
    java 复制代码
    import ai.djl.ndarray.NDManager
    
    
    object NDManagerExample {
        @JvmStatic
        fun main(args: Array<String>) {
            NDManager.newBaseManager().use { manager ->
                val array = manager.create(floatArrayOf(1.0f, 2.0f, 3.0f))
                println("Array: $array")
    
                // Perform operations
                val result = array.add(2.0f)
                println("Result: $result")
            }
            // No need to explicitly free the memory, it's handled by the NDManager
        }
    }
4. Trainer 和 Predictor
Trainer 类

提供训练模型的接口,包含优化器、损失函数和训练循环等功能。用于训练深度学习模型。它封装了训练过程中的一些常见操作,如前向传播、反向传播和参数更新。

主要功能包括:

  • 模型的训练和验证
  • 管理优化器和损失函数
  • 提供易于使用的训练循环
代码演示

以下是使用 DJL 的 Trainer 类训练一个简单神经网络的示例代码:

java 复制代码
import ai.djl.Model
import ai.djl.basicdataset.cv.classification.FashionMnist
import ai.djl.basicmodelzoo.basic.Mlp
import ai.djl.ndarray.types.Shape
import ai.djl.training.DefaultTrainingConfig
import ai.djl.training.TrainingConfig
import ai.djl.training.dataset.Dataset
import ai.djl.training.dataset.RandomAccessDataset
import ai.djl.training.listener.LoggingTrainingListener
import ai.djl.training.listener.TrainingListener
import ai.djl.training.loss.Loss
import ai.djl.training.optimizer.Optimizer
import ai.djl.training.tracker.FixedPerVarTracker
import ai.djl.training.util.ProgressBar
import ai.djl.translate.TranslateException
import java.io.IOException
import java.nio.file.Paths

object DjlTrainerDemo {
    @Throws(IOException::class, TranslateException::class)
    @JvmStatic
    fun main(args: Array<String>) {
        // Load dataset
        val trainDataset: RandomAccessDataset =
            FashionMnist.builder().optUsage(Dataset.Usage.TRAIN).setSampling(32, true).build()
        trainDataset.prepare(ProgressBar())

        // Define model
        val model = Model.newInstance("mlp")
        model.block = Mlp(28 * 28, 10, intArrayOf(128, 64))

        // Define training configuration
        val config: TrainingConfig = DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
            .optOptimizer(
                Optimizer.sgd()
                    .setLearningRateTracker(
                        FixedPerVarTracker.builder()
                            .setDefaultValue(0.01f)
                            .build()
                    ).build()
            )
            .addTrainingListeners(LoggingTrainingListener())

        model.newTrainer(config).use { trainer ->
            trainer.initialize(Shape(1, (28 * 28).toLong()))
            for (epoch in 0..9) {
                for (batch in trainer.iterateDataset(trainDataset)) {
                    trainer.step()
                    batch.close()
                }
                trainer.notifyListeners { listener: TrainingListener ->
                    listener.onEpoch(trainer)
                }
            }
            model.save(Paths.get("model"), "mlp")
        }
    }
}
Predictor 类

用于模型推理,接收输入数据并返回预测结果。用于对训练好的模型进行推理。它提供了一个简单的接口,用于将输入数据传递给模型并获取预测结果。

主要功能包括:

  • 加载模型进行推理
  • 处理输入和输出数据的转换
代码演示
java 复制代码
import ai.djl.Model
import ai.djl.modality.Classifications
import ai.djl.ndarray.NDArray
import ai.djl.ndarray.NDList
import ai.djl.ndarray.NDManager
import ai.djl.ndarray.types.Shape
import ai.djl.translate.Batchifier
import ai.djl.translate.TranslateException
import ai.djl.translate.Translator
import ai.djl.translate.TranslatorContext
import java.io.IOException
import java.nio.file.Paths

object DjlPredictorDemo {
    @Throws(IOException::class, TranslateException::class)
    @JvmStatic
    fun main(args: Array<String>) {
        // Load model
        val model = Model.newInstance("mlp")
        model.load(Paths.get("model"), "mlp")

        // Define Translator
        val translator: Translator<NDArray, Classifications> = object : Translator<NDArray, Classifications> {
            override fun processInput(ctx: TranslatorContext, input: NDArray): NDList {
                return NDList(input.reshape(Shape(1, (28 * 28).toLong())))
            }

            override fun processOutput(ctx: TranslatorContext, list: NDList): Classifications {
                // Assuming the output NDArray is the first element in NDList
                val probabilities = list.singletonOrThrow()
                return Classifications(listOf("Label1", "Label2"), probabilities) // Example labels
            }

            override fun getBatchifier(): Batchifier {
                return Batchifier.STACK
            }
        }

        model.newPredictor(translator).use { predictor ->
            val manager = NDManager.newBaseManager()
            val array = manager.ones(Shape(1, (28 * 28).toLong()))
            val classifications = predictor.predict(array)
            println(classifications)
        }
    }
}
相关推荐
使者大牙3 分钟前
【LLM学习笔记】第四篇:模型压缩方法——量化、剪枝、蒸馏、分解
人工智能·深度学习·算法·机器学习
Matlab程序猿小助手4 分钟前
【MATLAB源码-第222期】基于matlab的改进蚁群算法三维栅格地图路径规划,加入精英蚁群策略。包括起点终点,障碍物,着火点,楼梯。
开发语言·人工智能·算法·matlab·机器人·无人机
岛屿旅人6 分钟前
2025-2026财年美国CISA国际战略规划(下)
网络·人工智能·安全·web安全·网络安全
卧式纯绿10 分钟前
自动驾驶3D目标检测综述(三)
人工智能·python·深度学习·目标检测·3d·cnn·自动驾驶
尘浮生15 分钟前
Java项目实战II基于SpringBoot的客户关系管理系统(开发文档+数据库+源码)
java·开发语言·数据库·spring boot·后端·微信小程序·小程序
2401_8576100317 分钟前
企业OA系统:Spring Boot技术实现与管理
java·spring boot·后端
飞滕人生TYF20 分钟前
java 集合 菱形内定义封装类 而非基本数据类型 原因解释 详解
java
繁依Fanyi22 分钟前
在 Spring Boot 中实现多种方式登录(用户名、手机号、邮箱等)的不正经指南
java·spring boot·后端
夕阳产业——JAVA,入坑=失业25 分钟前
泛型擦除是什么?
java·开发语言
Gungnirss1 小时前
前后端分离,后端拦截器无法获得前端请求的token
java·前端·token