Flink ML LinearRegression 用 Table API 训练线性回归并输出预测值

1. LinearRegression 做什么?

它学习一个线性函数(简化表达):

y\^=w⋅x+b\]\[ \\hat{y} = w \\cdot x + b \]\[y\^=w⋅x+b

其中:

  • (x)(x)(x)是特征向量(Vector)
  • (w)(w)(w) 是权重向量(模型参数)
  • (b)(b)(b) 是偏置
  • (y^)(\hat{y})(y^) 是预测值(prediction)

2. 输入列与输出列

输入列(Input Columns)

参数名 类型 默认值 说明
featuresCol Vector "features" 特征向量
labelCol Integer(示例实际用 Double) "label" 要预测的连续标签
weightCol Double "weight" 样本权重(可选)

你贴的示例里 labelweight 都是 Double,这更符合回归场景。

输出列(Output Columns)

参数名 类型 默认值 说明
predictionCol Integer(实际为 Double) "prediction" 预测值

3. 参数详解(Parameters)

LinearRegressionModel(模型)常用参数:

Key 默认值 说明
featuresCol "features" 特征列名
predictionCol "prediction" 预测列名

LinearRegression(训练器)额外参数:

Key 默认值 说明
labelCol "label" 标签列名
weightCol null 权重列名(可选)
maxIter 20 最大迭代次数
reg 0.0 正则化系数(L2/L1 混合由 elasticNet 控制)
elasticNet 0.0 ElasticNet 参数:0=纯L2,1=纯L1
learningRate 0.1 学习率
globalBatchSize 32 全局 batch 大小
tol 1e-6 收敛阈值(迭代停止条件之一)

工程建议(很实用):

  • 特征尺度差异大:先 StandardScaler 再 LR,收敛更稳
  • 数据量大/噪声大:适当加 reg,防过拟合、也更稳定
  • 收敛慢:调大 maxIter 或调学习率(learningRate)但要谨慎

4. Java 示例逐段解读(fit + transform)

你给的示例是"训练后再对同一份数据做预测",方便验证效果。

4.1 构造输入表:features + label + weight

java 复制代码
DataStream<Row> inputStream =
        env.fromElements(
                Row.of(Vectors.dense(2, 1), 4.0, 1.0),
                Row.of(Vectors.dense(3, 2), 7.0, 1.0),
                Row.of(Vectors.dense(4, 3), 10.0, 1.0),
                Row.of(Vectors.dense(2, 4), 10.0, 1.0),
                Row.of(Vectors.dense(2, 2), 6.0, 1.0),
                Row.of(Vectors.dense(4, 3), 10.0, 1.0),
                Row.of(Vectors.dense(1, 2), 5.0, 1.0),
                Row.of(Vectors.dense(5, 3), 11.0, 1.0));
Table inputTable = tEnv.fromDataStream(inputStream).as("features", "label", "weight");
  • features 是二维向量
  • label 是回归目标(Double)
  • weight 全是 1.0(表示每个样本权重一样)

4.2 创建训练器并指定权重列

java 复制代码
LinearRegression lr = new LinearRegression().setWeightCol("weight");

如果你不需要样本权重,完全可以不设 weightCol

4.3 训练模型 + 预测

java 复制代码
LinearRegressionModel lrModel = lr.fit(inputTable);
Table outputTable = lrModel.transform(inputTable)[0];
  • fit():在 Table 上训练,得到模型
  • transform():输出表会多一个 prediction

4.4 读取输出:对比 label 与 prediction

java 复制代码
double expectedResult = (Double) row.getField(lr.getLabelCol());
double predictionResult = (Double) row.getField(lr.getPredictionCol());
System.out.printf("... Expected Result: %s \tPrediction Result: %s\n", expectedResult, predictionResult);

5. 实战用法:最常见的两条链路

链路 A:数值特征 → StandardScaler → LinearRegression

如果你的特征量纲差别大(金额、次数、时长混在一起),强烈推荐:

  • VectorAssembler(拼特征)
  • StandardScaler(标准化)
  • LinearRegression(训练)

链路 B:类别特征(StringIndexer + OneHot)+ 数值特征 → VectorAssembler → LinearRegression

当你的输入既有类别又有数值时:

  • StringIndexer:字符串类别 → index
  • OneHotEncoder:index → 稀疏向量
  • VectorAssembler:数值列 + 类别向量 拼成 features
  • LinearRegression:训练回归

6. 小结

Flink ML 的 LinearRegression 使用非常标准化:

  • 输入:features(Vector) + label(Double) + 可选 weight(Double)
  • 训练:lr.fit(table)
  • 预测:model.transform(table) 输出 prediction
相关推荐
塔能物联运维20 小时前
隧道照明“智能进化”:PLC 通信 + AI 调光守护夜间通行生命线
大数据·人工智能
highly200920 小时前
Gitflow
大数据·elasticsearch·搜索引擎
humors22121 小时前
韩秀云老师谈买黄金
大数据·程序人生
重生之绝世牛码21 小时前
Linux软件安装 —— SSH免密登录
大数据·linux·运维·ssh·软件安装·免密登录
StarChainTech21 小时前
无人机租赁平台:开启智能租赁新时代
大数据·人工智能·微信小程序·小程序·无人机·软件需求
Hello.Reader21 小时前
Flink DynamoDB Connector 用 Streams 做 CDC,用 BatchWriteItem 高吞吐写回
大数据·python·flink
早日退休!!!21 小时前
内存泄露(Memory Leak)核心原理与工程实践报告
大数据·网络
发哥来了1 天前
主流AI视频生成工具商用化能力评测:五大关键维度对比分析
大数据·人工智能·音视频
無森~1 天前
MapReduce
大数据·mapreduce
重生之绝世牛码1 天前
Linux软件安装 —— zookeeper集群安装
大数据·linux·运维·服务器·zookeeper·软件安装