1. 引言
说到机器学习(Machine Learning),大部分人的第一反应是:
Python 才是最常用的语言。
但其实 Java 在机器学习领域一直有自己的生态,特别适合:
-
企业级场景
-
高性能服务器
-
需要与 Java Web/Spring 结合的业务
-
需要跨平台部署
-
大数据处理(Hadoop/Spark)
Java 有两个老牌且强大的 ML 框架:
✔ Weka ------ 传统机器学习算法的"瑞士军刀"
✔ DL4J ------ Java 世界的深度学习框架(可训练 CNN/RNN)
这篇文章给你带来 两个完整的 ML 项目示例:
-
Weka 鸢尾花分类(决策树模型)
-
DL4J 深度学习:MNIST 手写数字识别(CNN 模型)
2. Java 做机器学习是否现实?与 Python 的差异
| 比较项 | Java | Python |
|---|---|---|
| 易用性 | 稍复杂 | 简单、库多 |
| 深度学习框架 | DL4J | TensorFlow/PyTorch |
| 性能 | JVM 高性能,强并发 | C 底层加速 |
| 部署 | 方便与大型系统结合 | 需要环境管理 |
| 社区生态 | 较小 | 超级庞大 |
结论:
Python 适合研究,Java 适合 工业级部署。
比如你开发 Spring Boot 系统,需要一个 AI 模型在线预测,Java 直接接入更容易。
3. Java 生态中的两大机器学习框架
3.1 📌 Weka:经典机器学习工具包
-
日本、澳大利亚大学团队维护
-
已有 20+ 年历史
-
提供大量算法:决策树、朴素贝叶斯、SVM、KNN、聚类、特征选择
-
支持 GUI + Java API
适合:
-
教学
-
快速验证算法
-
嵌入 Java 程序部署
3.2 📌 Deeplearning4j(DL4J)
DL4J 是 JVM 平台最强的深度学习框架:
-
支持 CNN、RNN、LSTM、GAN
-
支持 GPU(用 CUDA)
-
可与 Spark 分布式训练
-
与 Java/Scala 生态深度集成
适合:
-
需要部署在 Java 服务中的深度学习模型
-
大型服务端系统
-
高性能场景(多线程)
🔥 4. Weka 实战:Java 构建鸢尾花分类器
我们将基于 iris.arff 数据集,使用 J48 决策树算法做分类。
4.1 项目依赖(Maven)
<dependency> <groupId>nz.ac.waikato.cms.weka</groupId> <artifactId>weka-stable</artifactId> <version>3.8.6</version> </dependency>
4.2 加载数据集
Weka 的数据集格式为 .arff,也支持 CSV。
DataSource source = new DataSource("iris.arff"); Instances data = source.getDataSet(); // 设置 label 列 data.setClassIndex(data.numAttributes() - 1);
4.3 数据预处理
例如:归一化、标准化等。
Normalize normalize = new Normalize(); normalize.setInputFormat(data); Instances newData = Filter.useFilter(data, normalize);
4.4 构建分类模型(J48 决策树)
Classifier model = new J48(); model.buildClassifier(newData);
4.5 模型评估
使用交叉验证:
Evaluation eval = new Evaluation(newData); eval.crossValidateModel(model, newData, 10, new Random(1)); System.out.println(eval.toSummaryString()); System.out.println(eval.toMatrixString());
⭐ 4.6 Weka 完整可运行代码
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.Evaluation;
import weka.classifiers.trees.J48;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Normalize;
import java.util.Random;
public class IrisClassifier {
public static void main(String[] args) throws Exception {
DataSource source = new DataSource("iris.arff");
Instances data = source.getDataSet();
// 最后一列是分类标签
data.setClassIndex(data.numAttributes() - 1);
// 数据预处理(归一化)
Normalize normalize = new Normalize();
normalize.setInputFormat(data);
Instances newData = Filter.useFilter(data, normalize);
// 构建决策树模型
Classifier model = new J48();
model.buildClassifier(newData);
// 10 折交叉验证
Evaluation eval = new Evaluation(newData);
eval.crossValidateModel(model, newData, 10, new Random(1));
// 输出结果
System.out.println(eval.toSummaryString("\n===== 模型评估 =====\n", true));
System.out.println(eval.toMatrixString());
System.out.println(model);
}
}
运行即可直接看到准确率 94%~98%。
🔥 5. DL4J 深度学习:手写数字识别(MNIST)
本案例构建一个 卷积神经网络 CNN,识别 0~9 手写数字。
5.1 Maven 依赖
<dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-core</artifactId> <version>1.0.0-M2.1</version> </dependency> <dependency> <groupId>org.nd4j</groupId> <artifactId>nd4j-native-platform</artifactId> <version>1.0.0-M2.1</version> </dependency>
5.2 神经网络结构
我们使用典型结构:
-
卷积层
-
池化层
-
全连接层
-
Softmax 输出层
5.3 加载 MNIST 数据
DL4J 自带:
MnistDataSetIterator mnistTrain = new MnistDataSetIterator(64, true, 12345); MnistDataSetIterator mnistTest = new MnistDataSetIterator(64, false, 12345);
5.4 构建 CNN 模型
MultiLayerConfiguration config = new NeuralNetConfiguration.Builder() .updater(new Adam(0.001)) .list() .layer(new ConvolutionLayer.Builder(5, 5) .nIn(1) .stride(1, 1) .nOut(20) .activation(Activation.RELU) .build()) .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) .kernelSize(2, 2) .stride(2, 2) .build()) .layer(new DenseLayer.Builder() .activation(Activation.RELU) .nOut(500) .build()) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(10) .activation(Activation.SOFTMAX) .build()) .build();
5.5 模型训练
MultiLayerNetwork model = new MultiLayerNetwork(config); model.init(); for (int i = 0; i < 5; i++) { model.fit(mnistTrain); System.out.println("Epoch " + i + " 完成"); }
5.6 ⭐ 完整程序(可直接运行)
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class MnistCNN {
public static void main(String[] args) throws Exception {
DataSetIterator mnistTrain = new MnistDataSetIterator(64, true, 12345);
DataSetIterator mnistTest = new MnistDataSetIterator(64, false, 12345);
MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
.updater(new Adam(0.001))
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.list()
.layer(new ConvolutionLayer.Builder(5, 5)
.nIn(1)
.stride(1, 1)
.nOut(20)
.activation(Activation.RELU)
.build())
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(new DenseLayer.Builder()
.nOut(500)
.activation(Activation.RELU)
.build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(10)
.activation(Activation.SOFTMAX)
.build())
.build();
MultiLayerNetwork model = new MultiLayerNetwork(config);
model.init();
System.out.println("开始训练模型...");
for (int i = 0; i < 5; i++) {
model.fit(mnistTrain);
System.out.println("Epoch " + i + " 完成");
}
System.out.println("开始评估...");
Evaluation eval = model.evaluate(mnistTest);
System.out.println(eval.stats());
}
}
训练成功后准确率可达 98%~99%。
6. Java 机器学习项目构建最佳实践
-
使用 Maven/Gradle 进行依赖管理
-
尽量使用 DL4J + ND4J 来处理大规模矩阵计算
-
训练模型尽量在 Python 完成,部署在 Java
-
GPU 加速需安装 CUDA
-
日志使用 SLF4J
-
配置文件使用 YAML 更清晰
7. Java × AI 的未来趋势
Java 在 AI 领域正在转向:
✔ 模型推理,而不是训练
✔ 与大数据系统结合(Spark + DL4J)
✔ 与企业级系统整合(Spring Boot)
Java 更适合:
-
推理在线服务
-
热部署系统
-
金融、电商、医疗系统中的 AI 模块
8. 总结
本文完整展示了:
-
Java 做机器学习的可行性
-
Weka 进行传统机器学习
-
DL4J 训练 CNN 神经网络
-
两个完整 Java 机器学习项目代码