逻辑回归实战(二):Java + DL4J 实现模型,评估指标(准确率/召回率)计算
------别只看"准确率",你的业务要的是"抓得准"还是"抓得全"?
大家好,我是那个总在模型上线前被风控总监问"能不能把高危用户一个不漏地揪出来?"、又在 KES 表里核对每一条误判样本的老架构。上一期我们把用户流失数据干净利落地存进了电科金仓 KingbaseES(KES),标签定义清晰、时间锚点明确、特征快照隔离。
但数据就位只是开始。
真正的考验是:你能否从这张表里,训练出一个不仅"能跑",而且"业务可用"的分类模型?
很多人一上来就调 model.output(),看到 accuracy=0.92 就收工。
但现实是:在不平衡数据中(比如流失率仅 5%),accuracy 是最大的谎言。
而真相是:评估指标的选择,本质上是业务目标的翻译。
今天我们就用 Java + DL4J(DeepLearning4J),从 KES 中加载用户行为特征,训练逻辑回归模型,并计算准确率、精确率、召回率、F1、AUC 等核心指标------全程基于国产数据库,只为回答那个灵魂问题:
"这个模型,到底能不能用?"
一、为什么不用手写?因为工程需要可扩展性
上一篇我们手写了逻辑回归。那为什么这次用 DL4J?
不是因为懒,而是因为:
- ✅ 自动处理数值稳定性(如 Sigmoid 溢出);
- ✅ 内置优化器(SGD、Adam);
- ✅ 标准化评估工具链;
- ✅ 未来可无缝升级为 MLP。
在企业级 AI 落地中,可维护性 > 代码行数。DL4J 作为 JVM 原生深度学习框架,完美契合国产化技术栈。
💡 注意:DL4J 的
LogisticRegression层本质仍是线性 + Sigmoid,和我们手写数学等价。
二、从 KES 加载特征与标签
复用上期设计的视图 ai_datasets.churn_train_v1:
java
public DataSet loadChurnDataset(Connection conn, String viewName) throws SQLException {
String sql = "SELECT login_count_7d, payment_count_30d, total_amount_30d, " +
"last_login_days, complaint_flag::INT, (plan_type='premium')::INT AS is_premium, " +
"label FROM " + viewName;
List<INDArray> featuresList = new ArrayList<>();
List<INDArray> labelsList = new ArrayList<>();
try (PreparedStatement ps = conn.prepareStatement(sql);
ResultSet rs = ps.executeQuery()) {
while (rs.next()) {
// 特征向量(6维)
double[] x = {
rs.getDouble("login_count_7d"),
rs.getDouble("payment_count_30d"),
rs.getDouble("total_amount_30d"),
rs.getDouble("last_login_days"),
rs.getInt("complaint_flag"),
rs.getInt("is_premium")
};
featuresList.add(Nd4j.create(x));
// 标签:0 或 1
int y = rs.getInt("label");
labelsList.add(Nd4j.scalar(y));
}
}
// 合并为 DataSet
INDArray features = Nd4j.vstack(featuresList.toArray(new INDArray[0]));
INDArray labels = Nd4j.vstack(labelsList.toArray(new INDArray[0]));
return new DataSet(features, labels);
}
🔗 确保使用 KES JDBC 驱动 支持布尔转整型(
::INT)。
三、用 DL4J 构建逻辑回归模型
java
public MultiLayerNetwork buildLogisticModel(int numFeatures) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.updater(new Sgd(0.01)) // 学习率
.l2(1e-4) // L2 正则防过拟合
.list()
.layer(new OutputLayer.Builder()
.nIn(numFeatures)
.nOut(1)
.activation(Activation.SIGMOID) // 关键:Sigmoid 输出概率
.lossFunction(LossFunctions.LossFunction.XENT) // 交叉熵
.build())
.build();
return new MultiLayerNetwork(conf);
}
✅ 这就是 DL4J 中的"逻辑回归":单输出 + Sigmoid + 交叉熵。
四、训练与预测
java
// 加载数据
DataSet trainData = loadChurnDataset(conn, "ai_datasets.churn_train_v1");
DataSet testData = loadChurnDataset(conn, "ai_datasets.churn_test_v1");
// 构建模型
MultiLayerNetwork model = buildLogisticModel(6);
// 训练
for (int epoch = 0; epoch < 50; epoch++) {
model.fit(trainData);
if (epoch % 10 == 0) {
double loss = model.score();
System.out.printf("Epoch %d, Loss: %.4f%n", epoch, loss);
}
}
// 预测概率
INDArray testProbs = model.output(testData.getFeatures()); // shape: [N, 1]
INDArray testLabels = testData.getLabels();
五、评估指标:不止一个数字
5.1 手动计算核心指标(理解本质)
java
public static EvaluationMetrics computeMetrics(INDArray probs, INDArray labels, double threshold) {
int tp = 0, fp = 0, tn = 0, fn = 0;
int n = probs.rows();
for (int i = 0; i < n; i++) {
double prob = probs.getDouble(i, 0);
int pred = prob >= threshold ? 1 : 0;
int actual = (int) labels.getDouble(i, 0);
if (pred == 1 && actual == 1) tp++;
if (pred == 1 && actual == 0) fp++;
if (pred == 0 && actual == 0) tn++;
if (pred == 0 && actual == 1) fn++;
}
double accuracy = (tp + tn) / (double) n;
double precision = tp / (double) Math.max(tp + fp, 1);
double recall = tp / (double) Math.max(tp + fn, 1);
double f1 = 2 * precision * recall / Math.max(precision + recall, 1e-8);
return new EvaluationMetrics(accuracy, precision, recall, f1, tp, fp, tn, fn);
}
典型输出(假设流失率 8%):
Threshold=0.5 → Acc=0.92, Prec=0.68, Rec=0.45
Threshold=0.3 → Acc=0.85, Prec=0.52, Rec=**0.78**
💡 解读:
- 业务要"抓得全"(如反欺诈)→ 选低阈值,提升 Recall;
- 业务要"抓得准"(如人工外呼)→ 选高阈值,提升 Precision。
5.2 用 DL4J 内置工具计算 AUC
java
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.curves.RocCurve;
// 构造 one-hot 标签(DL4J 要求)
INDArray labelsOneHot = Nd4j.zeros(testLabels.rows(), 2);
for (int i = 0; i < testLabels.rows(); i++) {
int label = (int) testLabels.getDouble(i, 0);
labelsOneHot.putScalar(i, label, 1.0);
}
// 拼接概率为 [P(0), P(1)]
INDArray probsFull = Nd4j.zeros(testProbs.rows(), 2);
for (int i = 0; i < testProbs.rows(); i++) {
double p1 = testProbs.getDouble(i, 0);
probsFull.putScalar(i, 0, 1 - p1);
probsFull.putScalar(i, 1, p1);
}
Evaluation eval = new Evaluation();
eval.eval(labelsOneHot, probsFull);
System.out.println("AUC: " + eval.auc());
RocCurve roc = eval.roc();
double bestThreshold = roc.getThresholdAtMaxYoudenJ(); // 最优阈值
System.out.println("Optimal threshold: " + bestThreshold);
✅ AUC 不依赖阈值,衡量模型整体排序能力。
在不平衡数据中,AUC > 0.85 才算可用。
六、与 KES 协同:将评估结果写回数据库
java
public void saveEvaluationToKES(Connection conn, String modelName, EvaluationMetrics metrics, double auc) throws SQLException {
String sql = """
INSERT INTO ai_eval.model_performance (
model_name, dataset, accuracy, precision, recall, f1_score, auc, created_at
) VALUES (?, 'churn_test_v1', ?, ?, ?, ?, ?, NOW())
""";
try (PreparedStatement ps = conn.prepareStatement(sql)) {
ps.setString(1, modelName);
ps.setDouble(2, metrics.accuracy);
ps.setDouble(3, metrics.precision);
ps.setDouble(4, metrics.recall);
ps.setDouble(5, metrics.f1);
ps.setDouble(6, auc);
ps.executeUpdate();
}
}
这样,每次实验都有迹可循,可对比、可复现。
七、工程建议:如何选择阈值?
- 固定业务成本法 :
- 外呼成本 10 元,挽回收益 200 元 → 只要 Precision > 5% 就值得打;
- 最大化 F1 :
- 当 Precision 和 Recall 同等重要时;
- Youden's J 统计量 :
J = Sensitivity + Specificity - 1,DL4J 可自动计算。
💡 在国产化项目中,业务方必须参与阈值决策------这是 AI 落地的关键一步。
结语:评估,是模型走向生产的通行证
在 AI 工程中,训练只是开始,评估才是信任的起点。
今天我们用 Java + DL4J + 电科金仓 KES,完成了一次端到端的逻辑回归实战:
- 从版本化视图加载数据;
- 用 DL4J 构建可扩展模型;
- 计算多维评估指标;
- 将结果写回数据库供决策。
这套流程,不依赖 Python,完全运行在国产 JVM 技术栈上,却具备企业级的可审计性与业务对齐能力。
当你能向业务方展示:"在召回率 75% 的前提下,我们的精确率是 52%,预计每月可挽回 120 万收入"------你就不再是"调参侠",而是值得信赖的 AI 架构师。
下一期,我们会讲:逻辑回归实战(三):模型部署------构建基于 KES 的实时流失预警服务 。
敬请期待。
------ 一位相信"可评估的模型,才是可用的模型"的架构师