Flink ML LinearRegression 用 Table API 训练线性回归并输出预测值

1. LinearRegression 做什么?

它学习一个线性函数(简化表达):

y\^=w⋅x+b\]\[ \\hat{y} = w \\cdot x + b \]\[y\^=w⋅x+b

其中:

  • (x)(x)(x)是特征向量(Vector)
  • (w)(w)(w) 是权重向量(模型参数)
  • (b)(b)(b) 是偏置
  • (y^)(\hat{y})(y^) 是预测值(prediction)

2. 输入列与输出列

输入列(Input Columns)

参数名 类型 默认值 说明
featuresCol Vector "features" 特征向量
labelCol Integer(示例实际用 Double) "label" 要预测的连续标签
weightCol Double "weight" 样本权重(可选)

你贴的示例里 labelweight 都是 Double,这更符合回归场景。

输出列(Output Columns)

参数名 类型 默认值 说明
predictionCol Integer(实际为 Double) "prediction" 预测值

3. 参数详解(Parameters)

LinearRegressionModel(模型)常用参数:

Key 默认值 说明
featuresCol "features" 特征列名
predictionCol "prediction" 预测列名

LinearRegression(训练器)额外参数:

Key 默认值 说明
labelCol "label" 标签列名
weightCol null 权重列名(可选)
maxIter 20 最大迭代次数
reg 0.0 正则化系数(L2/L1 混合由 elasticNet 控制)
elasticNet 0.0 ElasticNet 参数:0=纯L2,1=纯L1
learningRate 0.1 学习率
globalBatchSize 32 全局 batch 大小
tol 1e-6 收敛阈值(迭代停止条件之一)

工程建议(很实用):

  • 特征尺度差异大:先 StandardScaler 再 LR,收敛更稳
  • 数据量大/噪声大:适当加 reg,防过拟合、也更稳定
  • 收敛慢:调大 maxIter 或调学习率(learningRate)但要谨慎

4. Java 示例逐段解读(fit + transform)

你给的示例是"训练后再对同一份数据做预测",方便验证效果。

4.1 构造输入表:features + label + weight

java 复制代码
DataStream<Row> inputStream =
        env.fromElements(
                Row.of(Vectors.dense(2, 1), 4.0, 1.0),
                Row.of(Vectors.dense(3, 2), 7.0, 1.0),
                Row.of(Vectors.dense(4, 3), 10.0, 1.0),
                Row.of(Vectors.dense(2, 4), 10.0, 1.0),
                Row.of(Vectors.dense(2, 2), 6.0, 1.0),
                Row.of(Vectors.dense(4, 3), 10.0, 1.0),
                Row.of(Vectors.dense(1, 2), 5.0, 1.0),
                Row.of(Vectors.dense(5, 3), 11.0, 1.0));
Table inputTable = tEnv.fromDataStream(inputStream).as("features", "label", "weight");
  • features 是二维向量
  • label 是回归目标(Double)
  • weight 全是 1.0(表示每个样本权重一样)

4.2 创建训练器并指定权重列

java 复制代码
LinearRegression lr = new LinearRegression().setWeightCol("weight");

如果你不需要样本权重,完全可以不设 weightCol

4.3 训练模型 + 预测

java 复制代码
LinearRegressionModel lrModel = lr.fit(inputTable);
Table outputTable = lrModel.transform(inputTable)[0];
  • fit():在 Table 上训练,得到模型
  • transform():输出表会多一个 prediction

4.4 读取输出:对比 label 与 prediction

java 复制代码
double expectedResult = (Double) row.getField(lr.getLabelCol());
double predictionResult = (Double) row.getField(lr.getPredictionCol());
System.out.printf("... Expected Result: %s \tPrediction Result: %s\n", expectedResult, predictionResult);

5. 实战用法:最常见的两条链路

链路 A:数值特征 → StandardScaler → LinearRegression

如果你的特征量纲差别大(金额、次数、时长混在一起),强烈推荐:

  • VectorAssembler(拼特征)
  • StandardScaler(标准化)
  • LinearRegression(训练)

链路 B:类别特征(StringIndexer + OneHot)+ 数值特征 → VectorAssembler → LinearRegression

当你的输入既有类别又有数值时:

  • StringIndexer:字符串类别 → index
  • OneHotEncoder:index → 稀疏向量
  • VectorAssembler:数值列 + 类别向量 拼成 features
  • LinearRegression:训练回归

6. 小结

Flink ML 的 LinearRegression 使用非常标准化:

  • 输入:features(Vector) + label(Double) + 可选 weight(Double)
  • 训练:lr.fit(table)
  • 预测:model.transform(table) 输出 prediction
相关推荐
财迅通Ai1 天前
6000万吨产能承压 卫星化学迎来战略窗口期
大数据·人工智能·物联网·卫星化学
武子康1 天前
大数据-263 实时数仓-Canal 增量订阅与消费原理:MySQL Binlog 数据同步实践
大数据·hadoop·后端
LJ97951111 天前
媒体发布新武器:Infoseek融媒体平台使用指南
大数据·人工智能
科技小花1 天前
AI重塑数据治理:2026年核心方案评估与场景适配
大数据·人工智能·云原生·ai原生
方向研究1 天前
存储芯片生产
大数据
代码青铜1 天前
如何用 Zion 实现 AI 图片分析与电商文案自动生成流程
大数据·人工智能
gaoshengdainzi1 天前
GB/T23448-2019卫生洁具软管专用检测设备全套解决方案
大数据·卫生洁具软管检测设备·软管试验机
茶靡花开04151 天前
什么是DMS经销商管理系统?经销商管理系统哪个好?
大数据·人工智能
Gofarlic_OMS1 天前
HyperWorks用户仿真行为分析与许可证资源分点配置
java·大数据·运维·服务器·人工智能
fire-flyer1 天前
ClickHouse系列(二):MergeTree 家族详解
大数据·数据库·clickhouse