逻辑回归实战(二):Java + DL4J 实现模型,评估指标(准确率/召回率)计算

逻辑回归实战(二):Java + DL4J 实现模型,评估指标(准确率/召回率)计算

------别只看"准确率",你的业务要的是"抓得准"还是"抓得全"?

大家好,我是那个总在模型上线前被风控总监问"能不能把高危用户一个不漏地揪出来?"、又在 KES 表里核对每一条误判样本的老架构。上一期我们把用户流失数据干净利落地存进了电科金仓 KingbaseES(KES),标签定义清晰、时间锚点明确、特征快照隔离。

但数据就位只是开始。
真正的考验是:你能否从这张表里,训练出一个不仅"能跑",而且"业务可用"的分类模型?

很多人一上来就调 model.output(),看到 accuracy=0.92 就收工。

但现实是:在不平衡数据中(比如流失率仅 5%),accuracy 是最大的谎言

而真相是:评估指标的选择,本质上是业务目标的翻译

今天我们就用 Java + DL4J(DeepLearning4J),从 KES 中加载用户行为特征,训练逻辑回归模型,并计算准确率、精确率、召回率、F1、AUC 等核心指标------全程基于国产数据库,只为回答那个灵魂问题:

"这个模型,到底能不能用?"


一、为什么不用手写?因为工程需要可扩展性

上一篇我们手写了逻辑回归。那为什么这次用 DL4J?

不是因为懒,而是因为:

  • 自动处理数值稳定性(如 Sigmoid 溢出);
  • 内置优化器(SGD、Adam);
  • 标准化评估工具链
  • 未来可无缝升级为 MLP

在企业级 AI 落地中,可维护性 > 代码行数。DL4J 作为 JVM 原生深度学习框架,完美契合国产化技术栈。

💡 注意:DL4J 的 LogisticRegression 层本质仍是线性 + Sigmoid,和我们手写数学等价。


二、从 KES 加载特征与标签

复用上期设计的视图 ai_datasets.churn_train_v1

java 复制代码
public DataSet loadChurnDataset(Connection conn, String viewName) throws SQLException {
    String sql = "SELECT login_count_7d, payment_count_30d, total_amount_30d, " +
                 "last_login_days, complaint_flag::INT, (plan_type='premium')::INT AS is_premium, " +
                 "label FROM " + viewName;

    List<INDArray> featuresList = new ArrayList<>();
    List<INDArray> labelsList = new ArrayList<>();

    try (PreparedStatement ps = conn.prepareStatement(sql);
         ResultSet rs = ps.executeQuery()) {

        while (rs.next()) {
            // 特征向量(6维)
            double[] x = {
                rs.getDouble("login_count_7d"),
                rs.getDouble("payment_count_30d"),
                rs.getDouble("total_amount_30d"),
                rs.getDouble("last_login_days"),
                rs.getInt("complaint_flag"),
                rs.getInt("is_premium")
            };
            featuresList.add(Nd4j.create(x));

            // 标签:0 或 1
            int y = rs.getInt("label");
            labelsList.add(Nd4j.scalar(y));
        }
    }

    // 合并为 DataSet
    INDArray features = Nd4j.vstack(featuresList.toArray(new INDArray[0]));
    INDArray labels = Nd4j.vstack(labelsList.toArray(new INDArray[0]));

    return new DataSet(features, labels);
}

🔗 确保使用 KES JDBC 驱动 支持布尔转整型(::INT)。


三、用 DL4J 构建逻辑回归模型

java 复制代码
public MultiLayerNetwork buildLogisticModel(int numFeatures) {
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
        .seed(123)
        .updater(new Sgd(0.01))          // 学习率
        .l2(1e-4)                        // L2 正则防过拟合
        .list()
        .layer(new OutputLayer.Builder()
            .nIn(numFeatures)
            .nOut(1)
            .activation(Activation.SIGMOID)  // 关键:Sigmoid 输出概率
            .lossFunction(LossFunctions.LossFunction.XENT) // 交叉熵
            .build())
        .build();

    return new MultiLayerNetwork(conf);
}

✅ 这就是 DL4J 中的"逻辑回归":单输出 + Sigmoid + 交叉熵。


四、训练与预测

java 复制代码
// 加载数据
DataSet trainData = loadChurnDataset(conn, "ai_datasets.churn_train_v1");
DataSet testData = loadChurnDataset(conn, "ai_datasets.churn_test_v1");

// 构建模型
MultiLayerNetwork model = buildLogisticModel(6);

// 训练
for (int epoch = 0; epoch < 50; epoch++) {
    model.fit(trainData);
    if (epoch % 10 == 0) {
        double loss = model.score();
        System.out.printf("Epoch %d, Loss: %.4f%n", epoch, loss);
    }
}

// 预测概率
INDArray testProbs = model.output(testData.getFeatures()); // shape: [N, 1]
INDArray testLabels = testData.getLabels();

五、评估指标:不止一个数字

5.1 手动计算核心指标(理解本质)

java 复制代码
public static EvaluationMetrics computeMetrics(INDArray probs, INDArray labels, double threshold) {
    int tp = 0, fp = 0, tn = 0, fn = 0;
    int n = probs.rows();

    for (int i = 0; i < n; i++) {
        double prob = probs.getDouble(i, 0);
        int pred = prob >= threshold ? 1 : 0;
        int actual = (int) labels.getDouble(i, 0);

        if (pred == 1 && actual == 1) tp++;
        if (pred == 1 && actual == 0) fp++;
        if (pred == 0 && actual == 0) tn++;
        if (pred == 0 && actual == 1) fn++;
    }

    double accuracy = (tp + tn) / (double) n;
    double precision = tp / (double) Math.max(tp + fp, 1);
    double recall = tp / (double) Math.max(tp + fn, 1);
    double f1 = 2 * precision * recall / Math.max(precision + recall, 1e-8);

    return new EvaluationMetrics(accuracy, precision, recall, f1, tp, fp, tn, fn);
}

典型输出(假设流失率 8%):

复制代码
Threshold=0.5 → Acc=0.92, Prec=0.68, Rec=0.45  
Threshold=0.3 → Acc=0.85, Prec=0.52, Rec=**0.78**

💡 解读:

  • 业务要"抓得全"(如反欺诈)→ 选低阈值,提升 Recall
  • 业务要"抓得准"(如人工外呼)→ 选高阈值,提升 Precision

5.2 用 DL4J 内置工具计算 AUC

java 复制代码
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.curves.RocCurve;

// 构造 one-hot 标签(DL4J 要求)
INDArray labelsOneHot = Nd4j.zeros(testLabels.rows(), 2);
for (int i = 0; i < testLabels.rows(); i++) {
    int label = (int) testLabels.getDouble(i, 0);
    labelsOneHot.putScalar(i, label, 1.0);
}

// 拼接概率为 [P(0), P(1)]
INDArray probsFull = Nd4j.zeros(testProbs.rows(), 2);
for (int i = 0; i < testProbs.rows(); i++) {
    double p1 = testProbs.getDouble(i, 0);
    probsFull.putScalar(i, 0, 1 - p1);
    probsFull.putScalar(i, 1, p1);
}

Evaluation eval = new Evaluation();
eval.eval(labelsOneHot, probsFull);
System.out.println("AUC: " + eval.auc());

RocCurve roc = eval.roc();
double bestThreshold = roc.getThresholdAtMaxYoudenJ(); // 最优阈值
System.out.println("Optimal threshold: " + bestThreshold);

✅ AUC 不依赖阈值,衡量模型整体排序能力。

在不平衡数据中,AUC > 0.85 才算可用


六、与 KES 协同:将评估结果写回数据库

java 复制代码
public void saveEvaluationToKES(Connection conn, String modelName, EvaluationMetrics metrics, double auc) throws SQLException {
    String sql = """
        INSERT INTO ai_eval.model_performance (
            model_name, dataset, accuracy, precision, recall, f1_score, auc, created_at
        ) VALUES (?, 'churn_test_v1', ?, ?, ?, ?, ?, NOW())
        """;
    
    try (PreparedStatement ps = conn.prepareStatement(sql)) {
        ps.setString(1, modelName);
        ps.setDouble(2, metrics.accuracy);
        ps.setDouble(3, metrics.precision);
        ps.setDouble(4, metrics.recall);
        ps.setDouble(5, metrics.f1);
        ps.setDouble(6, auc);
        ps.executeUpdate();
    }
}

这样,每次实验都有迹可循,可对比、可复现


七、工程建议:如何选择阈值?

  1. 固定业务成本法
    • 外呼成本 10 元,挽回收益 200 元 → 只要 Precision > 5% 就值得打;
  2. 最大化 F1
    • 当 Precision 和 Recall 同等重要时;
  3. Youden's J 统计量
    • J = Sensitivity + Specificity - 1,DL4J 可自动计算。

💡 在国产化项目中,业务方必须参与阈值决策------这是 AI 落地的关键一步。


结语:评估,是模型走向生产的通行证

在 AI 工程中,训练只是开始,评估才是信任的起点

今天我们用 Java + DL4J + 电科金仓 KES,完成了一次端到端的逻辑回归实战:

  • 从版本化视图加载数据;
  • 用 DL4J 构建可扩展模型;
  • 计算多维评估指标;
  • 将结果写回数据库供决策。

这套流程,不依赖 Python,完全运行在国产 JVM 技术栈上,却具备企业级的可审计性与业务对齐能力。

当你能向业务方展示:"在召回率 75% 的前提下,我们的精确率是 52%,预计每月可挽回 120 万收入"------你就不再是"调参侠",而是值得信赖的 AI 架构师

下一期,我们会讲:逻辑回归实战(三):模型部署------构建基于 KES 的实时流失预警服务

敬请期待。

------ 一位相信"可评估的模型,才是可用的模型"的架构师

相关推荐
忘梓.13 分钟前
解锁动态规划的奥秘:从零到精通的创新思维解析(10)
c++·算法·动态规划·代理模式
foolish..16 分钟前
动态规划笔记
笔记·算法·动态规划
消失的dk16 分钟前
算法---动态规划
算法·动态规划
羑悻的小杀马特16 分钟前
【动态规划篇】欣赏概率论与镜像法融合下,别出心裁探索解答括号序列问题
c++·算法·蓝桥杯·动态规划·镜像·洛谷·空隙法
绍兴贝贝17 分钟前
代码随想录算法训练营第四十六天|LC647.回文子串|LC516.最长回文子序列|动态规划总结
数据结构·人工智能·python·算法·动态规划·力扣
JaJian.17 分钟前
Java后端服务假死问题排查与解决
java·开发语言
愚润求学18 分钟前
【动态规划】二维的背包问题、似包非包、卡特兰数
c++·算法·leetcode·动态规划
救赎小恶魔19 分钟前
C++算法(5)
java·c++·算法
叫我一声阿雷吧19 分钟前
【信奥赛基础】动态规划:小学生也能懂的必考算法入门
算法·动态规划
lipiaoshuigood37 分钟前
MySQL 数据出海之数据同步方案
数据库·mysql