线性回归原理(二):梯度下降算法 ------ Java 实现单变量/多变量拟合
------当闭式解不够用时,我们如何"一步步逼近最优"
大家好,我是那个总在训练日志里看 loss 曲线、又在国产服务器上调学习率的老架构。上一期我们手算了最小二乘法,一步到位求出最优参数。
但现实是:当数据量大到百万级,或者模型复杂到非线性,闭式解要么算不动,要么根本不存在。
这时候,我们就需要另一种武器:梯度下降(Gradient Descent)。
很多人以为梯度下降只是"调参技巧",但真相是:它是现代深度学习的基石,也是我们在资源受限的国产环境中实现可扩展 AI 的关键路径。
今天我们就从数学直觉出发,用纯 Java 实现单变量和多变量线性回归的梯度下降,并让数据源完全来自电科金仓 KingbaseES(KES)------不依赖任何 ML 框架,只用 JDK + KES JDBC + 基础线性代数。
一、为什么需要梯度下降?
回顾上一期的闭式解:
β = (XᵀX)⁻¹ Xᵀ y
这个公式看起来优雅,但它有两个致命问题:
- 计算复杂度高:矩阵求逆是 O(n³),当特征数 > 10,000 时,内存和时间都扛不住;
- 无法扩展:一旦换成逻辑回归、神经网络,就不再有闭式解。
而梯度下降的核心思想极其朴素:
沿着损失函数下降最快的方向,一小步一小步走,直到谷底。
它不求"一步到位",但求"可扩展、可迭代、可分布式"------这正是工程落地的精髓。
二、数学直觉:梯度是什么?
损失函数(MSE):
J(w, b) = (1/2m) Σ(yᵢ − (w·xᵢ + b))²
对 w 和 b 求偏导(即梯度):
- ∂J/∂w = -(1/m) Σ xᵢ (yᵢ − ŷᵢ)
- ∂J/∂b = -(1/m) Σ (yᵢ − ŷᵢ)
梯度告诉我们:当前点上,往哪个方向走能让 loss 降得最快。
于是更新规则:
w := w − α · ∂J/∂w
b := b − α · ∂J/∂b
其中 α 是学习率(learning rate)------步子迈多大。
✅ 关键洞察:
梯度下降不是魔法,而是带反馈的试错机制。
三、Java 实现:单变量线性回归(从 KES 读取)
假设我们仍用 house_price(area, price) 表。
步骤 1:加载数据
java
public static class RegressionData {
public final double[] x; // 特征(可能多维)
public final double y; // 标签
public RegressionData(double[] x, double y) {
this.x = x; this.y = y;
}
}
public List<RegressionData> loadFromKES(Connection conn, String sql) throws SQLException {
List<RegressionData> data = new ArrayList<>();
try (PreparedStatement ps = conn.prepareStatement(sql);
ResultSet rs = ps.executeQuery()) {
while (rs.next()) {
// 单变量:x = [area]
double[] x = { rs.getDouble("area") };
double y = rs.getDouble("price");
data.add(new RegressionData(x, y));
}
}
return data;
}
步骤 2:实现梯度下降
java
public static double[] fitWithGradientDescent(
List<RegressionData> data,
double learningRate,
int epochs
) {
int n = data.size();
int features = data.get(0).x.length;
// 初始化参数:w (features), b (bias)
double[] w = new double[features];
double b = 0.0;
for (int epoch = 0; epoch < epochs; epoch++) {
double dwSum = 0.0, dbSum = 0.0;
// 计算梯度(全量 batch)
for (RegressionData sample : data) {
double prediction = 0.0;
for (int i = 0; i < features; i++) {
prediction += w[i] * sample.x[i];
}
prediction += b;
double error = prediction - sample.y;
for (int i = 0; i < features; i++) {
dwSum += error * sample.x[i];
}
dbSum += error;
}
// 平均梯度
double dw = dwSum / n;
double db = dbSum / n;
// 更新参数
for (int i = 0; i < features; i++) {
w[i] -= learningRate * dw;
}
b -= learningRate * db;
// 每 100 轮打印 loss
if (epoch % 100 == 0) {
double loss = computeLoss(data, w, b);
System.out.printf("Epoch %d, Loss: %.6f%n", epoch, loss);
}
}
// 合并 w 和 b 返回 [w0, w1, ..., b]
double[] params = new double[features + 1];
System.arraycopy(w, 0, params, 0, features);
params[features] = b;
return params;
}
private static double computeLoss(List<RegressionData> data, double[] w, double b) {
double total = 0.0;
for (RegressionData s : data) {
double pred = b;
for (int i = 0; i < w.length; i++) {
pred += w[i] * s.x[i];
}
total += Math.pow(pred - s.y, 2);
}
return total / (2 * data.size());
}
步骤 3:运行训练
java
String sql = "SELECT area, price FROM ai_features.house_price";
try (Connection conn = KESDataSource.getConnection()) {
List<RegressionData> data = loadFromKES(conn, sql);
// 归一化(重要!)
normalizeFeatures(data); // 见下文
double[] params = fitWithGradientDescent(data, 0.01, 1000);
System.out.printf("Final: w=%.4f, b=%.4f%n", params[0], params[1]);
}
💡 必须做特征归一化!否则梯度下降会震荡甚至发散。
java
public static void normalizeFeatures(List<RegressionData> data) {
int n = data.size();
int dim = data.get(0).x.length;
double[] mean = new double[dim];
double[] std = new double[dim];
// 计算均值 & 标准差
for (RegressionData d : data) {
for (int i = 0; i < dim; i++) {
mean[i] += d.x[i];
}
}
for (int i = 0; i < dim; i++) mean[i] /= n;
for (RegressionData d : data) {
for (int i = 0; i < dim; i++) {
std[i] += Math.pow(d.x[i] - mean[i], 2);
}
}
for (int i = 0; i < dim; i++) std[i] = Math.sqrt(std[i] / n);
// 标准化
for (RegressionData d : data) {
for (int i = 0; i < dim; i++) {
d.x[i] = (d.x[i] - mean[i]) / (std[i] == 0 ? 1 : std[i]);
}
}
}
四、扩展到多变量:无缝支持
现在,假设 KES 表多了两个字段:
sql
CREATE TABLE ai_features.house_price_v2 (
area REAL,
bedrooms INT,
age INT,
price REAL
);
只需修改 SQL 和特征提取:
java
// 在 loadFromKES 中:
double[] x = {
rs.getDouble("area"),
rs.getDouble("bedrooms"),
rs.getDouble("age")
};
其余代码完全不用改 !fitWithGradientDescent 自动处理任意维度。
典型输出:
Epoch 0, Loss: 12500.000000
Epoch 100, Loss: 842.345678
Epoch 200, Loss: 210.123456
...
Final: w=[3.21, 5.67, -0.89], b=12.34
五、与 KES 协同:构建可复现的训练流水线
我们将整个流程封装为可调度任务:
java
public void trainHousePriceModel(String version) {
try (Connection conn = KESDataSource.getConnection()) {
// 1. 从指定版本表读取
String table = "ai_features.house_price_" + version;
List<RegressionData> data = loadFromKES(conn, "SELECT * FROM " + table);
// 2. 归一化 + 训练
normalizeFeatures(data);
double[] params = fitWithGradientDescent(data, 0.01, 2000);
// 3. 写回模型注册表
saveModelToKES(conn, "house_price_linear_gd", version, params);
}
}
这样,每次数据更新,只需跑一次任务,新模型自动生效。
六、工程调优建议
-
学习率选择:
- 太大 → loss 震荡不收敛;
- 太小 → 收敛太慢;
- 建议从 0.01 开始,观察 loss 曲线调整。
-
批量 vs 全量 :
当前实现是全批量梯度下降(BGD) ,适合中小数据集;
若数据超大,可改用随机梯度下降(SGD) 或 小批量(Mini-batch)。
-
早停机制 :
当 loss 连续 10 轮下降 < 1e-6,可提前终止。
-
KES 并行读取 :
对大表启用
SET enable_parallel_query = on,加速数据加载。
结语:迭代,是工程智能的起点
梯度下降教会我们一个深刻的工程哲学:不必一步完美,但要持续逼近。
在国产化 AI 落地中,我们往往没有 GPU 集群,也没有 TB 级标注数据。
但我们有电科金仓的 KES 提供稳定数据底座,有 Java 提供可控执行环境,有梯度下降提供可扩展的优化路径。
当你能在飞腾服务器上,用几百行 Java 代码,从 KES 中完成一次多变量线性回归训练,并将结果用于生产预测------你就真正拥有了自主、可控、可演进的 AI 能力。
------ 一位相信"智能,始于每一次微小的迭代"的架构师