线性回归原理(二):梯度下降算法,Java实现单变量/多变量拟合

线性回归原理(二):梯度下降算法 ------ Java 实现单变量/多变量拟合

------当闭式解不够用时,我们如何"一步步逼近最优"

大家好,我是那个总在训练日志里看 loss 曲线、又在国产服务器上调学习率的老架构。上一期我们手算了最小二乘法,一步到位求出最优参数。

但现实是:当数据量大到百万级,或者模型复杂到非线性,闭式解要么算不动,要么根本不存在

这时候,我们就需要另一种武器:梯度下降(Gradient Descent)

很多人以为梯度下降只是"调参技巧",但真相是:它是现代深度学习的基石,也是我们在资源受限的国产环境中实现可扩展 AI 的关键路径

今天我们就从数学直觉出发,用纯 Java 实现单变量和多变量线性回归的梯度下降,并让数据源完全来自电科金仓 KingbaseES(KES)------不依赖任何 ML 框架,只用 JDK + KES JDBC + 基础线性代数


一、为什么需要梯度下降?

回顾上一期的闭式解:

β = (XᵀX)⁻¹ Xᵀ y

这个公式看起来优雅,但它有两个致命问题:

  1. 计算复杂度高:矩阵求逆是 O(n³),当特征数 > 10,000 时,内存和时间都扛不住;
  2. 无法扩展:一旦换成逻辑回归、神经网络,就不再有闭式解。

而梯度下降的核心思想极其朴素:

沿着损失函数下降最快的方向,一小步一小步走,直到谷底

它不求"一步到位",但求"可扩展、可迭代、可分布式"------这正是工程落地的精髓。


二、数学直觉:梯度是什么?

损失函数(MSE):

J(w, b) = (1/2m) Σ(yᵢ − (w·xᵢ + b))²

对 w 和 b 求偏导(即梯度):

  • ∂J/∂w = -(1/m) Σ xᵢ (yᵢ − ŷᵢ)
  • ∂J/∂b = -(1/m) Σ (yᵢ − ŷᵢ)

梯度告诉我们:当前点上,往哪个方向走能让 loss 降得最快

于是更新规则:

w := w − α · ∂J/∂w
b := b − α · ∂J/∂b

其中 α 是学习率(learning rate)------步子迈多大。

✅ 关键洞察:

梯度下降不是魔法,而是带反馈的试错机制


三、Java 实现:单变量线性回归(从 KES 读取)

假设我们仍用 house_price(area, price) 表。

步骤 1:加载数据

java 复制代码
public static class RegressionData {
    public final double[] x; // 特征(可能多维)
    public final double y;   // 标签
    public RegressionData(double[] x, double y) {
        this.x = x; this.y = y;
    }
}

public List<RegressionData> loadFromKES(Connection conn, String sql) throws SQLException {
    List<RegressionData> data = new ArrayList<>();
    try (PreparedStatement ps = conn.prepareStatement(sql);
         ResultSet rs = ps.executeQuery()) {
        while (rs.next()) {
            // 单变量:x = [area]
            double[] x = { rs.getDouble("area") };
            double y = rs.getDouble("price");
            data.add(new RegressionData(x, y));
        }
    }
    return data;
}

🔗 驱动下载:https://www.kingbase.com.cn/download.html#drive


步骤 2:实现梯度下降

java 复制代码
public static double[] fitWithGradientDescent(
    List<RegressionData> data,
    double learningRate,
    int epochs
) {
    int n = data.size();
    int features = data.get(0).x.length;
    
    // 初始化参数:w (features), b (bias)
    double[] w = new double[features];
    double b = 0.0;

    for (int epoch = 0; epoch < epochs; epoch++) {
        double dwSum = 0.0, dbSum = 0.0;

        // 计算梯度(全量 batch)
        for (RegressionData sample : data) {
            double prediction = 0.0;
            for (int i = 0; i < features; i++) {
                prediction += w[i] * sample.x[i];
            }
            prediction += b;

            double error = prediction - sample.y;
            for (int i = 0; i < features; i++) {
                dwSum += error * sample.x[i];
            }
            dbSum += error;
        }

        // 平均梯度
        double dw = dwSum / n;
        double db = dbSum / n;

        // 更新参数
        for (int i = 0; i < features; i++) {
            w[i] -= learningRate * dw;
        }
        b -= learningRate * db;

        // 每 100 轮打印 loss
        if (epoch % 100 == 0) {
            double loss = computeLoss(data, w, b);
            System.out.printf("Epoch %d, Loss: %.6f%n", epoch, loss);
        }
    }

    // 合并 w 和 b 返回 [w0, w1, ..., b]
    double[] params = new double[features + 1];
    System.arraycopy(w, 0, params, 0, features);
    params[features] = b;
    return params;
}

private static double computeLoss(List<RegressionData> data, double[] w, double b) {
    double total = 0.0;
    for (RegressionData s : data) {
        double pred = b;
        for (int i = 0; i < w.length; i++) {
            pred += w[i] * s.x[i];
        }
        total += Math.pow(pred - s.y, 2);
    }
    return total / (2 * data.size());
}

步骤 3:运行训练

java 复制代码
String sql = "SELECT area, price FROM ai_features.house_price";
try (Connection conn = KESDataSource.getConnection()) {
    List<RegressionData> data = loadFromKES(conn, sql);
    
    // 归一化(重要!)
    normalizeFeatures(data); // 见下文

    double[] params = fitWithGradientDescent(data, 0.01, 1000);
    System.out.printf("Final: w=%.4f, b=%.4f%n", params[0], params[1]);
}

💡 必须做特征归一化!否则梯度下降会震荡甚至发散。

java 复制代码
public static void normalizeFeatures(List<RegressionData> data) {
    int n = data.size();
    int dim = data.get(0).x.length;
    double[] mean = new double[dim];
    double[] std = new double[dim];

    // 计算均值 & 标准差
    for (RegressionData d : data) {
        for (int i = 0; i < dim; i++) {
            mean[i] += d.x[i];
        }
    }
    for (int i = 0; i < dim; i++) mean[i] /= n;

    for (RegressionData d : data) {
        for (int i = 0; i < dim; i++) {
            std[i] += Math.pow(d.x[i] - mean[i], 2);
        }
    }
    for (int i = 0; i < dim; i++) std[i] = Math.sqrt(std[i] / n);

    // 标准化
    for (RegressionData d : data) {
        for (int i = 0; i < dim; i++) {
            d.x[i] = (d.x[i] - mean[i]) / (std[i] == 0 ? 1 : std[i]);
        }
    }
}

四、扩展到多变量:无缝支持

现在,假设 KES 表多了两个字段:

sql 复制代码
CREATE TABLE ai_features.house_price_v2 (
    area        REAL,
    bedrooms    INT,
    age         INT,
    price       REAL
);

只需修改 SQL 和特征提取:

java 复制代码
// 在 loadFromKES 中:
double[] x = {
    rs.getDouble("area"),
    rs.getDouble("bedrooms"),
    rs.getDouble("age")
};

其余代码完全不用改fitWithGradientDescent 自动处理任意维度。

典型输出:

复制代码
Epoch 0, Loss: 12500.000000
Epoch 100, Loss: 842.345678
Epoch 200, Loss: 210.123456
...
Final: w=[3.21, 5.67, -0.89], b=12.34

五、与 KES 协同:构建可复现的训练流水线

我们将整个流程封装为可调度任务:

java 复制代码
public void trainHousePriceModel(String version) {
    try (Connection conn = KESDataSource.getConnection()) {
        // 1. 从指定版本表读取
        String table = "ai_features.house_price_" + version;
        List<RegressionData> data = loadFromKES(conn, "SELECT * FROM " + table);
        
        // 2. 归一化 + 训练
        normalizeFeatures(data);
        double[] params = fitWithGradientDescent(data, 0.01, 2000);
        
        // 3. 写回模型注册表
        saveModelToKES(conn, "house_price_linear_gd", version, params);
    }
}

这样,每次数据更新,只需跑一次任务,新模型自动生效


六、工程调优建议

  1. 学习率选择

    • 太大 → loss 震荡不收敛;
    • 太小 → 收敛太慢;
    • 建议从 0.01 开始,观察 loss 曲线调整。
  2. 批量 vs 全量

    当前实现是全批量梯度下降(BGD) ,适合中小数据集;

    若数据超大,可改用随机梯度下降(SGD)小批量(Mini-batch)

  3. 早停机制

    当 loss 连续 10 轮下降 < 1e-6,可提前终止。

  4. KES 并行读取

    对大表启用 SET enable_parallel_query = on,加速数据加载。


结语:迭代,是工程智能的起点

梯度下降教会我们一个深刻的工程哲学:不必一步完美,但要持续逼近

在国产化 AI 落地中,我们往往没有 GPU 集群,也没有 TB 级标注数据。

但我们有电科金仓的 KES 提供稳定数据底座,有 Java 提供可控执行环境,有梯度下降提供可扩展的优化路径。

当你能在飞腾服务器上,用几百行 Java 代码,从 KES 中完成一次多变量线性回归训练,并将结果用于生产预测------你就真正拥有了自主、可控、可演进的 AI 能力

------ 一位相信"智能,始于每一次微小的迭代"的架构师

相关推荐
宸津-代码粉碎机2 小时前
用MySQL玩转数据可视化
数据库·mysql·信息可视化
步步为营DotNet2 小时前
深度探索.NET 中ILogger:构建稳健日志系统的核心组件
数据库·.net
loading小马2 小时前
Mybatis-Plus超级实用的多种功能用法
java·spring boot·后端·maven·mybatis
licheng99672 小时前
工具、测试与部署
jvm·数据库·python
春日见2 小时前
Docker如何基于脚本拉取镜像,配置环境,尝试编译
运维·驱动开发·算法·docker·容器
毕设源码-邱学长2 小时前
【开题答辩全过程】以 南工计算机等级网站为例,包含答辩的问题和答案
java
红队it2 小时前
【数据分析+机器学习】基于机器学习的招聘数据分析可视化预测推荐系统(完整系统源码+数据库+开发笔记+详细部署教程)✅
数据库·机器学习·数据分析
NE_STOP2 小时前
spring boot3--自动配置与手动配置
java
小北方城市网2 小时前
Spring Cloud Gateway 生产级微内核架构设计与可插拔过滤器开发
java·大数据·linux·运维·spring boot·redis·分布式