1. 输入列与输出列
输入列(Input Columns)
| 参数名 | 类型 | 默认值 | 含义 |
|---|---|---|---|
featuresCol |
Vector | "features" |
特征向量列 |
labelCol |
Integer | "label" |
标签列(训练/评估用,也可用于对比) |
输出列(Output Columns)
| 参数名 | 类型 | 默认值 | 含义 |
|---|---|---|---|
predictionCol |
Integer | "prediction" |
预测标签列 |
说明:Flink ML 的 Vector 通常是 DenseVector 或 SparseVector(特征列必须是向量类型)。
2. 参数(Parameters)
KnnModel 需要的参数
| Key | 默认值 | 类型 | 必填 | 含义 |
|---|---|---|---|---|
k |
5 | Integer | 否 | 选择最近邻的数量 |
featuresCol |
"features" |
String | 否 | 特征列名 |
predictionCol |
"prediction" |
String | 否 | 输出预测列名 |
Knn 额外需要的参数
| Key | 默认值 | 类型 | 必填 | 含义 |
|---|---|---|---|---|
labelCol |
"label" |
String | 否 | 标签列名 |
3. Java 示例代码(原理 + 流程)
下面是你贴的示例逻辑做一个"工程化解读":
- 构造训练集
trainTable(features, label) - 构造待预测集
predictTable(features, label)(这里 label 是"期望值/对照值",不是必须列,但方便打印对比) knn.fit(trainTable)生成knnModelknnModel.transform(predictTable)输出结果表,新增prediction列- collect 输出并打印 features、expected、prediction
需要注意的一个坑:label 类型
文档说 labelCol 是 Integer ,但你贴的代码训练数据里是 1.0/2.0/3.0 这种 Double。示例里又用:
java
double expectedResult = (Double) row.getField(knn.getLabelCol());
double predictionResult = (Double) row.getField(knn.getPredictionCol());
这会让人误以为 prediction 也是 Double。为了更"规范且不踩坑",建议你在自己项目里统一 label 为 Integer(或至少保持 train/predict/输出一致)。
下面我给一个"更规范版本"的示例(仅改了 label 为 Integer,并显式设置列名,逻辑不变)。
更规范的示例(建议用这个)
java
import org.apache.flink.ml.classification.knn.Knn;
import org.apache.flink.ml.classification.knn.KnnModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;
public class KnnExample {
public static void main(String[] args) throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
// 训练数据:features(Vector) + label(Integer)
DataStream<Row> trainStream =
env.fromElements(
Row.of(Vectors.dense(2.0, 3.0), 1),
Row.of(Vectors.dense(2.1, 3.1), 1),
Row.of(Vectors.dense(200.1, 300.1), 2),
Row.of(Vectors.dense(200.2, 300.2), 2),
Row.of(Vectors.dense(200.3, 300.3), 2),
Row.of(Vectors.dense(2.8, 3.2), 3),
Row.of(Vectors.dense(300.0, 3.2), 4),
Row.of(Vectors.dense(2.4, 3.2), 5),
Row.of(Vectors.dense(2.5, 3.2), 5)
);
Table trainTable = tEnv.fromDataStream(trainStream).as("features", "label");
// 待预测数据:这里保留 label 作为期望值/对照值(可选)
DataStream<Row> predictStream =
env.fromElements(
Row.of(Vectors.dense(4.0, 4.1), 5),
Row.of(Vectors.dense(300.0, 42.0), 2)
);
Table predictTable = tEnv.fromDataStream(predictStream).as("features", "label");
// 创建 Knn,并设置关键参数
Knn knn = new Knn()
.setK(4)
.setFeaturesCol("features")
.setLabelCol("label")
.setPredictionCol("prediction");
// 训练模型
KnnModel knnModel = knn.fit(trainTable);
// 预测
Table outputTable = knnModel.transform(predictTable)[0];
// 打印结果:features + expected(label) + prediction
for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
Row row = it.next();
DenseVector features = (DenseVector) row.getField(knn.getFeaturesCol());
Integer expected = (Integer) row.getField(knn.getLabelCol());
Integer prediction = (Integer) row.getField(knn.getPredictionCol());
System.out.printf(
"Features: %-15s \tExpected: %s \tPrediction: %s\n",
features, expected, prediction);
}
}
}
4. 输出结果怎么看?
输出表一般会包含:
- 原始输入列:
features、(可选)label - 新增输出列:
prediction
你打印时就能看到:
- Features:待预测样本的向量
- Expected:样本原本的标签(如果你在 predictTable 里带了 label)
- Prediction:KNN 预测出来的标签
5. 实战建议(很关键)
1)特征缩放非常重要
KNN 完全依赖"距离"。如果你的特征尺度差异很大(比如一个特征是 0~1,另一个是 0~10000),距离会被大尺度特征主导,结果很容易失真。常见做法是:
- 先用
StandardScaler标准化 - 或做 MinMax 归一化
2)k 的选择是个"偏差-方差"平衡
- k 小:更敏感,容易受噪声影响(方差大)
- k 大:更平滑,但可能把边界抹平(偏差大)
工程上建议从 3/5/7/9 这类奇数开始试(减少投票平局的概率)。
3)训练数据量大时的性能与资源
KNN 的预测开销通常和训练集规模相关(需要找近邻)。训练集很大时:
- 可能需要索引/近似近邻(ANN)思路(具体要看 Flink ML 当前实现能力)
- 或对数据做分桶、预聚类、抽样
- 或考虑换更适合大规模在线推理的模型(线性模型、树模型等)