模型持久化(二):从 KingbaseES 加载模型,实现离线预测

模型持久化(二):从 KingbaseES 加载模型,实现离线预测

------别再让模型"睡在库房",它该上生产线了

大家好,我是那个总在凌晨被叫醒、因为线上预测服务挂了,又不得不手动从 KES 里捞出模型临时跑批的老架构。上一期我们把训练好的随机森林序列化存进了电科金仓 KingbaseES 的 BYTEA 字段,解决了"模型去哪儿"的问题。

但存进去只是第一步。
真正的价值,在于"用起来"

今天我们就干一件更硬核的事:从 KES 中加载已持久化的模型,对海量历史数据做离线批量预测,并将结果写回数据库。全程用纯 Java 实现,不依赖 Spark、Flink 或任何外部计算引擎,只为回答那个灵魂拷问:

"你的模型,到底能不能扛住真实业务的吞吐?"


一、为什么需要离线预测?

在国产化项目中,我们常陷入两个极端:

  • 要么只做实时 API 预测(单条请求);
  • 要么把全量数据导出到 Python 脚本跑批。

但现实场景往往是:

  • 每天凌晨对 100 万用户做风险重评
  • 每周对历史交易做反欺诈回溯
  • 每月生成监管报送的模型打分表

这些任务:

  • 数据量大(GB 级);
  • 对延迟不敏感(小时级完成即可);
  • 要求结果可审计、可追溯。

最好的方案,就是直接在 KES 内完成"模型加载 → 批量预测 → 结果写回"闭环


二、KES 表结构准备

假设我们有两张表:

2.1 模型注册表(复用上期)

sql 复制代码
-- ai_models.model_registry (含 model_blob)

2.2 待预测的用户表

sql 复制代码
CREATE TABLE ai_features.user_profiles (
    user_id       BIGINT PRIMARY KEY,
    age           INT,
    income        REAL,
    debt_ratio    REAL,
    credit_score  INT,
    payment_hist  VARCHAR(10),
    -- ... 其他特征
    risk_score    REAL,          -- ← 预测结果(待填充)
    model_version VARCHAR(20)    -- ← 使用的模型版本
);

目标:用最新 active 模型,为所有 risk_score IS NULL 的用户打分


三、Java 实现:高效批量预测

3.1 分页读取特征数据

避免 OOM,我们按批次读取:

java 复制代码
public List<UserFeature> loadBatch(Connection conn, long offset, int batchSize) 
        throws SQLException {
    String sql = """
        SELECT user_id, age, income, debt_ratio, credit_score, payment_hist
        FROM ai_features.user_profiles
        WHERE risk_score IS NULL
        ORDER BY user_id
        LIMIT ? OFFSET ?
        """;
    
    List<UserFeature> batch = new ArrayList<>();
    try (PreparedStatement ps = conn.prepareStatement(sql)) {
        ps.setInt(1, batchSize);
        ps.setLong(2, offset);
        ResultSet rs = ps.executeQuery();
        while (rs.next()) {
            batch.add(new UserFeature(
                rs.getLong("user_id"),
                Map.of(
                    "age", new FeatureValue("age", rs.getInt("age")),
                    "income", new FeatureValue("income", rs.getDouble("income")),
                    "debt_ratio", new FeatureValue("debt_ratio", rs.getDouble("debt_ratio")),
                    "credit_score", new FeatureValue("credit_score", rs.getInt("credit_score")),
                    "payment_hist", new FeatureValue("payment_hist", rs.getString("payment_hist"))
                )
            ));
        }
    }
    return batch;
}

3.2 从 KES 加载模型(复用上期逻辑)

java 复制代码
public RandomForest loadActiveModel(Connection conn, String modelName) {
    // 复用 ModelLoader.loadActiveModelFromKES()
}

3.3 批量预测 + 批量写回

java 复制代码
public void runOfflinePrediction(Connection conn, String modelName, int batchSize) 
        throws SQLException {
    
    // 1. 加载模型
    RandomForest model = loadActiveModel(conn, modelName);
    String modelVersion = getModelVersion(conn, modelName); // 查询 version 字段
    
    // 2. 获取待处理总数
    long total = countPendingUsers(conn);
    System.out.printf("Starting offline prediction for %d users...%n", total);
    
    // 3. 分批处理
    for (long offset = 0; offset < total; offset += batchSize) {
        List<UserFeature> batch = loadBatch(conn, offset, batchSize);
        
        // 批量预测(并行可选)
        List<PredictionResult> results = batch.parallelStream()
            .map(user -> {
                double prob = model.predictProbability(user.features);
                return new PredictionResult(user.userId, prob, modelVersion);
            })
            .collect(Collectors.toList());
        
        // 批量更新
        updateRiskScores(conn, results);
        
        System.out.printf("Processed %d / %d%n", Math.min(offset + batchSize, total), total);
    }
}

3.4 批量更新结果(使用 PreparedStatement 批处理)

java 复制代码
private void updateRiskScores(Connection conn, List<PredictionResult> results) 
        throws SQLException {
    String sql = "UPDATE ai_features.user_profiles SET risk_score = ?, model_version = ? WHERE user_id = ?";
    
    try (PreparedStatement ps = conn.prepareStatement(sql)) {
        for (PredictionResult r : results) {
            ps.setDouble(1, r.riskScore);
            ps.setString(2, r.modelVersion);
            ps.setLong(3, r.userId);
            ps.addBatch();
        }
        ps.executeBatch(); // ← 关键:批量提交
    }
}

💡 性能提示

  • batchSize 建议 1000~5000;
  • 开启 JDBC 批处理(KES 驱动默认支持);
  • 在事务中执行,失败可回滚。

四、端到端启动脚本

java 复制代码
public static void main(String[] args) throws Exception {
    // 1. 连接 KES
    Connection conn = DriverManager.getConnection(
        "jdbc:kingbase8://localhost:54321/ai_prod",
        "ai_user", "secure_password"
    );
    
    // 2. 启动离线预测
    OfflinePredictor predictor = new OfflinePredictor();
    predictor.runOfflinePrediction(conn, "loan_risk_rf", 2000);
    
    // 3. 提交事务(或自动提交)
    conn.close();
    
    System.out.println("✅ Offline prediction completed.");
}

运行输出:

复制代码
Starting offline prediction for 1,248,562 users...
Processed 2000 / 1248562
Processed 4000 / 1248562
...
Processed 1248562 / 1248562
✅ Offline prediction completed.

🔗 使用 电科金仓 JDBC 驱动 确保批处理高效执行。


五、工程优化建议

5.1 并行加速

  • 单机多线程:parallelStream() 已足够;
  • 分布式扩展:按 user_id % N 分片,多进程并行读不同 offset。

5.2 内存控制

  • 特征对象轻量化(避免冗余字段);
  • 批次大小动态调整(监控 GC 日志)。

5.3 容错与断点续传

sql 复制代码
-- 增加状态字段
ALTER TABLE ai_features.user_profiles ADD COLUMN pred_status VARCHAR(10) DEFAULT 'pending';
-- 更新时标记 'completed'

下次启动时只处理 pred_status = 'pending' 的记录。


六、为什么这适合国产化场景?

  1. 零外部依赖:不引入 Spark/Flink,降低技术栈复杂度;
  2. 资源可控:JVM 内存、CPU 可精确限制,适配国产服务器;
  3. 安全合规:数据不出库,符合金融/政务数据本地化要求;
  4. 运维简单:一个 JAR 包 + 定时任务(如 crontab)即可调度。

而这套能力,正建立在 电科金仓 KES 提供的高吞吐 OLTP + 轻量批处理能力之上------它不仅是数据库,更是 AI 批处理引擎。


结语:让模型从"资产"变成"生产力"

在很多项目里,模型训练完就束之高阁,成了 PPT 里的数字。

但真正的 AI 工程化,是让模型每天默默处理百万级请求,为业务创造真实价值

当你能用一段 Java 代码,从 KES 中加载模型,对全量用户做风险评分,并将结果原子写回------你就完成了从"AI 实验"到"AI 生产"的关键一跃。

因为你知道:模型的价值,不在训练时的 AUC,而在上线后的每一条预测

------ 一位相信"模型,不该只活在 notebook 里"的架构师

相关推荐
硅谷秋水1 小时前
多智体机器人系统(MARS)挑战的进展与创新
深度学习·机器学习·计算机视觉·语言模型·机器人·人机交互
癫狂的兔子2 小时前
【Python】【机器学习】K-MEANS算法
算法·机器学习·kmeans
Ama_tor2 小时前
Navicat学习01|初步应用实践
数据库·navicat
山岚的运维笔记2 小时前
SQL Server笔记 -- 第65章:迁移 第66章:表值参数
数据库·笔记·sql·microsoft·sqlserver
番茄去哪了3 小时前
苍穹外卖day05----店铺营业状态设置
java·数据库·ide·redis·git·maven·mybatis
算法黑哥3 小时前
Sharpness-Aware Minimization (SAM,锐度感知最小化)是让损失曲面变平坦,还是引导参数至平坦区域
深度学习·神经网络·机器学习
索木木4 小时前
大模型训练CP切分(与TP、SP结合)
人工智能·深度学习·机器学习·大模型·训练·cp·切分
暮色妖娆丶4 小时前
Spring 源码分析 事务管理的实现原理(下)
数据库·spring boot·spring
暮色妖娆丶4 小时前
Spring 源码分析 事务管理的实现原理(上)
数据库·spring boot·spring