Flink ML LinearRegression 用 Table API 训练线性回归并输出预测值

1. LinearRegression 做什么?

它学习一个线性函数(简化表达):

y\^=w⋅x+b\]\[ \\hat{y} = w \\cdot x + b \]\[y\^=w⋅x+b

其中:

  • (x)(x)(x)是特征向量(Vector)
  • (w)(w)(w) 是权重向量(模型参数)
  • (b)(b)(b) 是偏置
  • (y^)(\hat{y})(y^) 是预测值(prediction)

2. 输入列与输出列

输入列(Input Columns)

参数名 类型 默认值 说明
featuresCol Vector "features" 特征向量
labelCol Integer(示例实际用 Double) "label" 要预测的连续标签
weightCol Double "weight" 样本权重(可选)

你贴的示例里 labelweight 都是 Double,这更符合回归场景。

输出列(Output Columns)

参数名 类型 默认值 说明
predictionCol Integer(实际为 Double) "prediction" 预测值

3. 参数详解(Parameters)

LinearRegressionModel(模型)常用参数:

Key 默认值 说明
featuresCol "features" 特征列名
predictionCol "prediction" 预测列名

LinearRegression(训练器)额外参数:

Key 默认值 说明
labelCol "label" 标签列名
weightCol null 权重列名(可选)
maxIter 20 最大迭代次数
reg 0.0 正则化系数(L2/L1 混合由 elasticNet 控制)
elasticNet 0.0 ElasticNet 参数:0=纯L2,1=纯L1
learningRate 0.1 学习率
globalBatchSize 32 全局 batch 大小
tol 1e-6 收敛阈值(迭代停止条件之一)

工程建议(很实用):

  • 特征尺度差异大:先 StandardScaler 再 LR,收敛更稳
  • 数据量大/噪声大:适当加 reg,防过拟合、也更稳定
  • 收敛慢:调大 maxIter 或调学习率(learningRate)但要谨慎

4. Java 示例逐段解读(fit + transform)

你给的示例是"训练后再对同一份数据做预测",方便验证效果。

4.1 构造输入表:features + label + weight

java 复制代码
DataStream<Row> inputStream =
        env.fromElements(
                Row.of(Vectors.dense(2, 1), 4.0, 1.0),
                Row.of(Vectors.dense(3, 2), 7.0, 1.0),
                Row.of(Vectors.dense(4, 3), 10.0, 1.0),
                Row.of(Vectors.dense(2, 4), 10.0, 1.0),
                Row.of(Vectors.dense(2, 2), 6.0, 1.0),
                Row.of(Vectors.dense(4, 3), 10.0, 1.0),
                Row.of(Vectors.dense(1, 2), 5.0, 1.0),
                Row.of(Vectors.dense(5, 3), 11.0, 1.0));
Table inputTable = tEnv.fromDataStream(inputStream).as("features", "label", "weight");
  • features 是二维向量
  • label 是回归目标(Double)
  • weight 全是 1.0(表示每个样本权重一样)

4.2 创建训练器并指定权重列

java 复制代码
LinearRegression lr = new LinearRegression().setWeightCol("weight");

如果你不需要样本权重,完全可以不设 weightCol

4.3 训练模型 + 预测

java 复制代码
LinearRegressionModel lrModel = lr.fit(inputTable);
Table outputTable = lrModel.transform(inputTable)[0];
  • fit():在 Table 上训练,得到模型
  • transform():输出表会多一个 prediction

4.4 读取输出:对比 label 与 prediction

java 复制代码
double expectedResult = (Double) row.getField(lr.getLabelCol());
double predictionResult = (Double) row.getField(lr.getPredictionCol());
System.out.printf("... Expected Result: %s \tPrediction Result: %s\n", expectedResult, predictionResult);

5. 实战用法:最常见的两条链路

链路 A:数值特征 → StandardScaler → LinearRegression

如果你的特征量纲差别大(金额、次数、时长混在一起),强烈推荐:

  • VectorAssembler(拼特征)
  • StandardScaler(标准化)
  • LinearRegression(训练)

链路 B:类别特征(StringIndexer + OneHot)+ 数值特征 → VectorAssembler → LinearRegression

当你的输入既有类别又有数值时:

  • StringIndexer:字符串类别 → index
  • OneHotEncoder:index → 稀疏向量
  • VectorAssembler:数值列 + 类别向量 拼成 features
  • LinearRegression:训练回归

6. 小结

Flink ML 的 LinearRegression 使用非常标准化:

  • 输入:features(Vector) + label(Double) + 可选 weight(Double)
  • 训练:lr.fit(table)
  • 预测:model.transform(table) 输出 prediction
相关推荐
菩提祖师_2 小时前
基于大数据背景下智能手机营销对策研究
大数据·智能手机·软件工程
武子康2 小时前
Java-218 RocketMQ Java API 实战:同步/异步 Producer 与 Pull/Push Consumer
java·大数据·分布式·消息队列·rocketmq·java-rocketmq·mq
Hello.Reader2 小时前
Flink ML StandardScaler 标准化(去均值 + 除以标准差)让特征“同量纲”更好学
机器学习·均值算法·flink
艾莉丝努力练剑2 小时前
艾莉丝努力练剑的2025年度总结
java·大数据·linux·开发语言·c++·人工智能·python
雨大王5122 小时前
智能体模型如何革新汽车制造?解析应用场景与典型案例
大数据·人工智能
拓端研究室4 小时前
2026年医药行业展望报告:创新、出海、AI医疗与商业化|附220+份报告PDF、数据、可视化模板汇总下载
大数据·人工智能
virtual_k1smet11 小时前
梧桐·鸿鹄- 大数据assistant-level
大数据·笔记
ggabb11 小时前
海南封关:锚定中国制造2025,破解产业转移生死局
大数据·人工智能
aigcapi15 小时前
[深度观察] RAG 架构重塑流量分发:2025 年 GEO 优化技术路径与头部服务商选型指南
大数据·人工智能·架构