Flink ML LinearRegression 用 Table API 训练线性回归并输出预测值

1. LinearRegression 做什么?

它学习一个线性函数(简化表达):

y\^=w⋅x+b\]\[ \\hat{y} = w \\cdot x + b \]\[y\^=w⋅x+b

其中:

  • (x)(x)(x)是特征向量(Vector)
  • (w)(w)(w) 是权重向量(模型参数)
  • (b)(b)(b) 是偏置
  • (y^)(\hat{y})(y^) 是预测值(prediction)

2. 输入列与输出列

输入列(Input Columns)

参数名 类型 默认值 说明
featuresCol Vector "features" 特征向量
labelCol Integer(示例实际用 Double) "label" 要预测的连续标签
weightCol Double "weight" 样本权重(可选)

你贴的示例里 labelweight 都是 Double,这更符合回归场景。

输出列(Output Columns)

参数名 类型 默认值 说明
predictionCol Integer(实际为 Double) "prediction" 预测值

3. 参数详解(Parameters)

LinearRegressionModel(模型)常用参数:

Key 默认值 说明
featuresCol "features" 特征列名
predictionCol "prediction" 预测列名

LinearRegression(训练器)额外参数:

Key 默认值 说明
labelCol "label" 标签列名
weightCol null 权重列名(可选)
maxIter 20 最大迭代次数
reg 0.0 正则化系数(L2/L1 混合由 elasticNet 控制)
elasticNet 0.0 ElasticNet 参数:0=纯L2,1=纯L1
learningRate 0.1 学习率
globalBatchSize 32 全局 batch 大小
tol 1e-6 收敛阈值(迭代停止条件之一)

工程建议(很实用):

  • 特征尺度差异大:先 StandardScaler 再 LR,收敛更稳
  • 数据量大/噪声大:适当加 reg,防过拟合、也更稳定
  • 收敛慢:调大 maxIter 或调学习率(learningRate)但要谨慎

4. Java 示例逐段解读(fit + transform)

你给的示例是"训练后再对同一份数据做预测",方便验证效果。

4.1 构造输入表:features + label + weight

java 复制代码
DataStream<Row> inputStream =
        env.fromElements(
                Row.of(Vectors.dense(2, 1), 4.0, 1.0),
                Row.of(Vectors.dense(3, 2), 7.0, 1.0),
                Row.of(Vectors.dense(4, 3), 10.0, 1.0),
                Row.of(Vectors.dense(2, 4), 10.0, 1.0),
                Row.of(Vectors.dense(2, 2), 6.0, 1.0),
                Row.of(Vectors.dense(4, 3), 10.0, 1.0),
                Row.of(Vectors.dense(1, 2), 5.0, 1.0),
                Row.of(Vectors.dense(5, 3), 11.0, 1.0));
Table inputTable = tEnv.fromDataStream(inputStream).as("features", "label", "weight");
  • features 是二维向量
  • label 是回归目标(Double)
  • weight 全是 1.0(表示每个样本权重一样)

4.2 创建训练器并指定权重列

java 复制代码
LinearRegression lr = new LinearRegression().setWeightCol("weight");

如果你不需要样本权重,完全可以不设 weightCol

4.3 训练模型 + 预测

java 复制代码
LinearRegressionModel lrModel = lr.fit(inputTable);
Table outputTable = lrModel.transform(inputTable)[0];
  • fit():在 Table 上训练,得到模型
  • transform():输出表会多一个 prediction

4.4 读取输出:对比 label 与 prediction

java 复制代码
double expectedResult = (Double) row.getField(lr.getLabelCol());
double predictionResult = (Double) row.getField(lr.getPredictionCol());
System.out.printf("... Expected Result: %s \tPrediction Result: %s\n", expectedResult, predictionResult);

5. 实战用法:最常见的两条链路

链路 A:数值特征 → StandardScaler → LinearRegression

如果你的特征量纲差别大(金额、次数、时长混在一起),强烈推荐:

  • VectorAssembler(拼特征)
  • StandardScaler(标准化)
  • LinearRegression(训练)

链路 B:类别特征(StringIndexer + OneHot)+ 数值特征 → VectorAssembler → LinearRegression

当你的输入既有类别又有数值时:

  • StringIndexer:字符串类别 → index
  • OneHotEncoder:index → 稀疏向量
  • VectorAssembler:数值列 + 类别向量 拼成 features
  • LinearRegression:训练回归

6. 小结

Flink ML 的 LinearRegression 使用非常标准化:

  • 输入:features(Vector) + label(Double) + 可选 weight(Double)
  • 训练:lr.fit(table)
  • 预测:model.transform(table) 输出 prediction
相关推荐
@insist1232 小时前
信息安全工程师考点精讲:身份认证核心原理与分类体系(上篇)
大数据·网络·分类·信息安全工程师·软件水平考试
天辛大师2 小时前
AI助力旅游扩大化,五一旅游公园通游年票普惠研究
大数据·启发式算法·旅游
WordPress学习笔记2 小时前
镌刻中式美学的高端WordPress主题
大数据·人工智能·wordpress
数智化精益手记局3 小时前
拆解物料管理erp系统的核心功能,看物料管理erp系统如何解决库存积压与缺料难题
大数据·网络·人工智能·安全·信息可视化·精益工程
Elastic 中国社区官方博客5 小时前
使用 Observability Migration Platform 将 Datadog 和 Grafana 的仪表板与告警迁移到 Kibana
大数据·elasticsearch·搜索引擎·信息可视化·全文检索·grafana·datalog
jkyy20145 小时前
AI运动数字化:以技术重塑场景,健康有益赋能全域运动健康管理
大数据·人工智能·健康医疗
金融小师妹5 小时前
4月30日多因子共振节点:鲍威尔“收官效应”与权力结构重塑的预期重构
大数据·人工智能·重构·逻辑回归
2601_949925185 小时前
AI Agent如何重构跨境物流的决策?
大数据·人工智能·重构·ai agent·geo优化·物流科技
xiaoduo AI6 小时前
客服机器人问题解决率怎么统计?Agent系统自动判断是否解决,比人工回访准?
大数据·人工智能·机器人
小五兄弟7 小时前
YouTube 肖像检测扩展背后:短剧出海版权保护的技术实现与实战策略
大数据·人工智能