Java 在机器学习中的应用:基于 DL4J 与 Weka 的完整实战案例

1. 引言

说到机器学习(Machine Learning),大部分人的第一反应是:

Python 才是最常用的语言。

但其实 Java 在机器学习领域一直有自己的生态,特别适合:

  • 企业级场景

  • 高性能服务器

  • 需要与 Java Web/Spring 结合的业务

  • 需要跨平台部署

  • 大数据处理(Hadoop/Spark)

Java 有两个老牌且强大的 ML 框架:

✔ Weka ------ 传统机器学习算法的"瑞士军刀"

✔ DL4J ------ Java 世界的深度学习框架(可训练 CNN/RNN)

这篇文章给你带来 两个完整的 ML 项目示例

  1. Weka 鸢尾花分类(决策树模型)

  2. DL4J 深度学习:MNIST 手写数字识别(CNN 模型)


2. Java 做机器学习是否现实?与 Python 的差异

比较项 Java Python
易用性 稍复杂 简单、库多
深度学习框架 DL4J TensorFlow/PyTorch
性能 JVM 高性能,强并发 C 底层加速
部署 方便与大型系统结合 需要环境管理
社区生态 较小 超级庞大

结论:

Python 适合研究,Java 适合 工业级部署

比如你开发 Spring Boot 系统,需要一个 AI 模型在线预测,Java 直接接入更容易。


3. Java 生态中的两大机器学习框架


3.1 📌 Weka:经典机器学习工具包

  • 日本、澳大利亚大学团队维护

  • 已有 20+ 年历史

  • 提供大量算法:决策树、朴素贝叶斯、SVM、KNN、聚类、特征选择

  • 支持 GUI + Java API

适合:

  • 教学

  • 快速验证算法

  • 嵌入 Java 程序部署


3.2 📌 Deeplearning4j(DL4J)

DL4J 是 JVM 平台最强的深度学习框架:

  • 支持 CNN、RNN、LSTM、GAN

  • 支持 GPU(用 CUDA)

  • 可与 Spark 分布式训练

  • 与 Java/Scala 生态深度集成

适合:

  • 需要部署在 Java 服务中的深度学习模型

  • 大型服务端系统

  • 高性能场景(多线程)


🔥 4. Weka 实战:Java 构建鸢尾花分类器

我们将基于 iris.arff 数据集,使用 J48 决策树算法做分类。


4.1 项目依赖(Maven)

复制代码
<dependency>
    <groupId>nz.ac.waikato.cms.weka</groupId>
    <artifactId>weka-stable</artifactId>
    <version>3.8.6</version>
</dependency>

4.2 加载数据集

Weka 的数据集格式为 .arff,也支持 CSV。

复制代码
DataSource source = new DataSource("iris.arff");
Instances data = source.getDataSet();

// 设置 label 列
data.setClassIndex(data.numAttributes() - 1);

4.3 数据预处理

例如:归一化、标准化等。

复制代码
Normalize normalize = new Normalize();
normalize.setInputFormat(data);
Instances newData = Filter.useFilter(data, normalize);

4.4 构建分类模型(J48 决策树)

复制代码
Classifier model = new J48();
model.buildClassifier(newData);

4.5 模型评估

使用交叉验证:

复制代码
Evaluation eval = new Evaluation(newData);
eval.crossValidateModel(model, newData, 10, new Random(1));

System.out.println(eval.toSummaryString());
System.out.println(eval.toMatrixString());

⭐ 4.6 Weka 完整可运行代码

复制代码
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.Evaluation;
import weka.classifiers.trees.J48;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Normalize;

import java.util.Random;

public class IrisClassifier {

    public static void main(String[] args) throws Exception {

        DataSource source = new DataSource("iris.arff");
        Instances data = source.getDataSet();

        // 最后一列是分类标签
        data.setClassIndex(data.numAttributes() - 1);

        // 数据预处理(归一化)
        Normalize normalize = new Normalize();
        normalize.setInputFormat(data);
        Instances newData = Filter.useFilter(data, normalize);

        // 构建决策树模型
        Classifier model = new J48();
        model.buildClassifier(newData);

        // 10 折交叉验证
        Evaluation eval = new Evaluation(newData);
        eval.crossValidateModel(model, newData, 10, new Random(1));

        // 输出结果
        System.out.println(eval.toSummaryString("\n===== 模型评估 =====\n", true));
        System.out.println(eval.toMatrixString());
        System.out.println(model);
    }
}

运行即可直接看到准确率 94%~98%。


🔥 5. DL4J 深度学习:手写数字识别(MNIST)

本案例构建一个 卷积神经网络 CNN,识别 0~9 手写数字。


5.1 Maven 依赖

复制代码
<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-core</artifactId>
    <version>1.0.0-M2.1</version>
</dependency>

<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-native-platform</artifactId>
    <version>1.0.0-M2.1</version>
</dependency>

5.2 神经网络结构

我们使用典型结构:

  • 卷积层

  • 池化层

  • 全连接层

  • Softmax 输出层


5.3 加载 MNIST 数据

DL4J 自带:

复制代码
MnistDataSetIterator mnistTrain = new MnistDataSetIterator(64, true, 12345);
MnistDataSetIterator mnistTest = new MnistDataSetIterator(64, false, 12345);

5.4 构建 CNN 模型

复制代码
MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
    .updater(new Adam(0.001))
    .list()
    .layer(new ConvolutionLayer.Builder(5, 5)
            .nIn(1)
            .stride(1, 1)
            .nOut(20)
            .activation(Activation.RELU)
            .build())
    .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
            .kernelSize(2, 2)
            .stride(2, 2)
            .build())
    .layer(new DenseLayer.Builder()
            .activation(Activation.RELU)
            .nOut(500)
            .build())
    .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
            .nOut(10)
            .activation(Activation.SOFTMAX)
            .build())
    .build();

5.5 模型训练

复制代码
MultiLayerNetwork model = new MultiLayerNetwork(config);
model.init();

for (int i = 0; i < 5; i++) {
    model.fit(mnistTrain);
    System.out.println("Epoch " + i + " 完成");
}

5.6 ⭐ 完整程序(可直接运行)

复制代码
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public class MnistCNN {

    public static void main(String[] args) throws Exception {

        DataSetIterator mnistTrain = new MnistDataSetIterator(64, true, 12345);
        DataSetIterator mnistTest = new MnistDataSetIterator(64, false, 12345);

        MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
                .updater(new Adam(0.001))
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .list()
                .layer(new ConvolutionLayer.Builder(5, 5)
                        .nIn(1)
                        .stride(1, 1)
                        .nOut(20)
                        .activation(Activation.RELU)
                        .build())
                .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(2, 2)
                        .build())
                .layer(new DenseLayer.Builder()
                        .nOut(500)
                        .activation(Activation.RELU)
                        .build())
                .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nOut(10)
                        .activation(Activation.SOFTMAX)
                        .build())
                .build();

        MultiLayerNetwork model = new MultiLayerNetwork(config);
        model.init();

        System.out.println("开始训练模型...");
        for (int i = 0; i < 5; i++) {
            model.fit(mnistTrain);
            System.out.println("Epoch " + i + " 完成");
        }

        System.out.println("开始评估...");
        Evaluation eval = model.evaluate(mnistTest);
        System.out.println(eval.stats());
    }
}

训练成功后准确率可达 98%~99%


6. Java 机器学习项目构建最佳实践

  • 使用 Maven/Gradle 进行依赖管理

  • 尽量使用 DL4J + ND4J 来处理大规模矩阵计算

  • 训练模型尽量在 Python 完成,部署在 Java

  • GPU 加速需安装 CUDA

  • 日志使用 SLF4J

  • 配置文件使用 YAML 更清晰


7. Java × AI 的未来趋势

Java 在 AI 领域正在转向:

✔ 模型推理,而不是训练

✔ 与大数据系统结合(Spark + DL4J)

✔ 与企业级系统整合(Spring Boot)

Java 更适合:

  • 推理在线服务

  • 热部署系统

  • 金融、电商、医疗系统中的 AI 模块


8. 总结

本文完整展示了:

  • Java 做机器学习的可行性

  • Weka 进行传统机器学习

  • DL4J 训练 CNN 神经网络

  • 两个完整 Java 机器学习项目代码

相关推荐
q***23922 小时前
nginx简单命令启动,关闭等
java·服务器·nginx
拾忆,想起2 小时前
Dubbo负载均衡全解析:五种策略详解与实战指南
java·运维·微服务·架构·负载均衡·dubbo·哈希算法
shayudiandian2 小时前
【Java】关键字 native
java
合作小小程序员小小店2 小时前
桌面开发,在线%幼儿教育考试管理%系统,基于eclipse,java,swing,mysql数据库
java·数据库·sql·mysql·eclipse·jdk
江塘2 小时前
机器学习-决策树多种生成方法讲解及实战代码讲解(C++/Python实现)
c++·python·决策树·机器学习
明洞日记2 小时前
【设计模式手册005】单例模式 - 唯一实例的优雅实现
java·单例模式·设计模式
二川bro2 小时前
第48节:WebAssembly加速与C++物理引擎编译
java·c++·wasm
⑩-3 小时前
苍穹外卖Day(8)(9)
java·spring boot·mybatis
IUGEI3 小时前
Websocket、HTTP/2、HTTP/3原理解析
java·网络·后端·websocket·网络协议·http·https