1. LinearRegression 做什么?
它学习一个线性函数(简化表达):
y\^=w⋅x+b\]\[ \\hat{y} = w \\cdot x + b \]\[y\^=w⋅x+b
其中:
- (x)(x)(x)是特征向量(Vector)
- (w)(w)(w) 是权重向量(模型参数)
- (b)(b)(b) 是偏置
- (y^)(\hat{y})(y^) 是预测值(prediction)
2. 输入列与输出列
输入列(Input Columns)
| 参数名 | 类型 | 默认值 | 说明 |
|---|---|---|---|
featuresCol |
Vector | "features" |
特征向量 |
labelCol |
Integer(示例实际用 Double) | "label" |
要预测的连续标签 |
weightCol |
Double | "weight" |
样本权重(可选) |
你贴的示例里
label和weight都是 Double,这更符合回归场景。
输出列(Output Columns)
| 参数名 | 类型 | 默认值 | 说明 |
|---|---|---|---|
predictionCol |
Integer(实际为 Double) | "prediction" |
预测值 |
3. 参数详解(Parameters)
LinearRegressionModel(模型)常用参数:
| Key | 默认值 | 说明 |
|---|---|---|
featuresCol |
"features" |
特征列名 |
predictionCol |
"prediction" |
预测列名 |
LinearRegression(训练器)额外参数:
| Key | 默认值 | 说明 |
|---|---|---|
labelCol |
"label" |
标签列名 |
weightCol |
null | 权重列名(可选) |
maxIter |
20 | 最大迭代次数 |
reg |
0.0 | 正则化系数(L2/L1 混合由 elasticNet 控制) |
elasticNet |
0.0 | ElasticNet 参数:0=纯L2,1=纯L1 |
learningRate |
0.1 | 学习率 |
globalBatchSize |
32 | 全局 batch 大小 |
tol |
1e-6 | 收敛阈值(迭代停止条件之一) |
工程建议(很实用):
- 特征尺度差异大:先
StandardScaler再 LR,收敛更稳 - 数据量大/噪声大:适当加
reg,防过拟合、也更稳定 - 收敛慢:调大
maxIter或调学习率(learningRate)但要谨慎
4. Java 示例逐段解读(fit + transform)
你给的示例是"训练后再对同一份数据做预测",方便验证效果。
4.1 构造输入表:features + label + weight
java
DataStream<Row> inputStream =
env.fromElements(
Row.of(Vectors.dense(2, 1), 4.0, 1.0),
Row.of(Vectors.dense(3, 2), 7.0, 1.0),
Row.of(Vectors.dense(4, 3), 10.0, 1.0),
Row.of(Vectors.dense(2, 4), 10.0, 1.0),
Row.of(Vectors.dense(2, 2), 6.0, 1.0),
Row.of(Vectors.dense(4, 3), 10.0, 1.0),
Row.of(Vectors.dense(1, 2), 5.0, 1.0),
Row.of(Vectors.dense(5, 3), 11.0, 1.0));
Table inputTable = tEnv.fromDataStream(inputStream).as("features", "label", "weight");
features是二维向量label是回归目标(Double)weight全是 1.0(表示每个样本权重一样)
4.2 创建训练器并指定权重列
java
LinearRegression lr = new LinearRegression().setWeightCol("weight");
如果你不需要样本权重,完全可以不设 weightCol。
4.3 训练模型 + 预测
java
LinearRegressionModel lrModel = lr.fit(inputTable);
Table outputTable = lrModel.transform(inputTable)[0];
fit():在 Table 上训练,得到模型transform():输出表会多一个prediction列
4.4 读取输出:对比 label 与 prediction
java
double expectedResult = (Double) row.getField(lr.getLabelCol());
double predictionResult = (Double) row.getField(lr.getPredictionCol());
System.out.printf("... Expected Result: %s \tPrediction Result: %s\n", expectedResult, predictionResult);
5. 实战用法:最常见的两条链路
链路 A:数值特征 → StandardScaler → LinearRegression
如果你的特征量纲差别大(金额、次数、时长混在一起),强烈推荐:
- VectorAssembler(拼特征)
- StandardScaler(标准化)
- LinearRegression(训练)
链路 B:类别特征(StringIndexer + OneHot)+ 数值特征 → VectorAssembler → LinearRegression
当你的输入既有类别又有数值时:
- StringIndexer:字符串类别 → index
- OneHotEncoder:index → 稀疏向量
- VectorAssembler:数值列 + 类别向量 拼成 features
- LinearRegression:训练回归
6. 小结
Flink ML 的 LinearRegression 使用非常标准化:
- 输入:
features(Vector)+label(Double)+ 可选weight(Double) - 训练:
lr.fit(table) - 预测:
model.transform(table)输出prediction