Flink ML KNN 入门基于 Table API 的近邻分类

1. 输入列与输出列

输入列(Input Columns)

参数名 类型 默认值 含义
featuresCol Vector "features" 特征向量列
labelCol Integer "label" 标签列(训练/评估用,也可用于对比)

输出列(Output Columns)

参数名 类型 默认值 含义
predictionCol Integer "prediction" 预测标签列

说明:Flink ML 的 Vector 通常是 DenseVectorSparseVector(特征列必须是向量类型)。

2. 参数(Parameters)

KnnModel 需要的参数

Key 默认值 类型 必填 含义
k 5 Integer 选择最近邻的数量
featuresCol "features" String 特征列名
predictionCol "prediction" String 输出预测列名

Knn 额外需要的参数

Key 默认值 类型 必填 含义
labelCol "label" String 标签列名

3. Java 示例代码(原理 + 流程)

下面是你贴的示例逻辑做一个"工程化解读":

  1. 构造训练集 trainTable(features, label)
  2. 构造待预测集 predictTable(features, label)(这里 label 是"期望值/对照值",不是必须列,但方便打印对比)
  3. knn.fit(trainTable) 生成 knnModel
  4. knnModel.transform(predictTable) 输出结果表,新增 prediction
  5. collect 输出并打印 features、expected、prediction

需要注意的一个坑:label 类型

文档说 labelColInteger ,但你贴的代码训练数据里是 1.0/2.0/3.0 这种 Double。示例里又用:

java 复制代码
double expectedResult = (Double) row.getField(knn.getLabelCol());
double predictionResult = (Double) row.getField(knn.getPredictionCol());

这会让人误以为 prediction 也是 Double。为了更"规范且不踩坑",建议你在自己项目里统一 label 为 Integer(或至少保持 train/predict/输出一致)。

下面我给一个"更规范版本"的示例(仅改了 label 为 Integer,并显式设置列名,逻辑不变)。

更规范的示例(建议用这个)

java 复制代码
import org.apache.flink.ml.classification.knn.Knn;
import org.apache.flink.ml.classification.knn.KnnModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

public class KnnExample {
    public static void main(String[] args) throws Exception {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // 训练数据:features(Vector) + label(Integer)
        DataStream<Row> trainStream =
                env.fromElements(
                        Row.of(Vectors.dense(2.0, 3.0), 1),
                        Row.of(Vectors.dense(2.1, 3.1), 1),
                        Row.of(Vectors.dense(200.1, 300.1), 2),
                        Row.of(Vectors.dense(200.2, 300.2), 2),
                        Row.of(Vectors.dense(200.3, 300.3), 2),
                        Row.of(Vectors.dense(2.8, 3.2), 3),
                        Row.of(Vectors.dense(300.0, 3.2), 4),
                        Row.of(Vectors.dense(2.4, 3.2), 5),
                        Row.of(Vectors.dense(2.5, 3.2), 5)
                );
        Table trainTable = tEnv.fromDataStream(trainStream).as("features", "label");

        // 待预测数据:这里保留 label 作为期望值/对照值(可选)
        DataStream<Row> predictStream =
                env.fromElements(
                        Row.of(Vectors.dense(4.0, 4.1), 5),
                        Row.of(Vectors.dense(300.0, 42.0), 2)
                );
        Table predictTable = tEnv.fromDataStream(predictStream).as("features", "label");

        // 创建 Knn,并设置关键参数
        Knn knn = new Knn()
                .setK(4)
                .setFeaturesCol("features")
                .setLabelCol("label")
                .setPredictionCol("prediction");

        // 训练模型
        KnnModel knnModel = knn.fit(trainTable);

        // 预测
        Table outputTable = knnModel.transform(predictTable)[0];

        // 打印结果:features + expected(label) + prediction
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            DenseVector features = (DenseVector) row.getField(knn.getFeaturesCol());
            Integer expected = (Integer) row.getField(knn.getLabelCol());
            Integer prediction = (Integer) row.getField(knn.getPredictionCol());

            System.out.printf(
                    "Features: %-15s \tExpected: %s \tPrediction: %s\n",
                    features, expected, prediction);
        }
    }
}

4. 输出结果怎么看?

输出表一般会包含:

  • 原始输入列:features、(可选)label
  • 新增输出列:prediction

你打印时就能看到:

  • Features:待预测样本的向量
  • Expected:样本原本的标签(如果你在 predictTable 里带了 label)
  • Prediction:KNN 预测出来的标签

5. 实战建议(很关键)

1)特征缩放非常重要

KNN 完全依赖"距离"。如果你的特征尺度差异很大(比如一个特征是 0~1,另一个是 0~10000),距离会被大尺度特征主导,结果很容易失真。常见做法是:

  • 先用 StandardScaler 标准化
  • 或做 MinMax 归一化

2)k 的选择是个"偏差-方差"平衡

  • k 小:更敏感,容易受噪声影响(方差大)
  • k 大:更平滑,但可能把边界抹平(偏差大)

工程上建议从 3/5/7/9 这类奇数开始试(减少投票平局的概率)。

3)训练数据量大时的性能与资源

KNN 的预测开销通常和训练集规模相关(需要找近邻)。训练集很大时:

  • 可能需要索引/近似近邻(ANN)思路(具体要看 Flink ML 当前实现能力)
  • 或对数据做分桶、预聚类、抽样
  • 或考虑换更适合大规模在线推理的模型(线性模型、树模型等)
相关推荐
小饼干超人4 小时前
详解向量数据库中的PQ算法(Product Quantization)
人工智能·算法·机器学习
少林码僧5 小时前
2.30 传统行业预测神器:为什么GBDT系列算法在企业中最受欢迎
开发语言·人工智能·算法·机器学习·ai·数据分析
zm-v-159304339865 小时前
最新AI-Python自然科学领域机器学习与深度学习技术
人工智能·python·机器学习
郝学胜-神的一滴5 小时前
何友院士《人工智能发展前沿》全景解读:从理论基石到产业变革
人工智能·python·深度学习·算法·机器学习
BHXDML5 小时前
第五章:支持向量机
算法·机器学习·支持向量机
wfeqhfxz25887826 小时前
基于YOLO12-A2C2f-DFFN-DYT-Mona的铁件部件状态识别与分类系统_1
人工智能·分类·数据挖掘
2501_941507946 小时前
脊柱结构异常检测与分类:基于Cascade-RCNN和HRNetV2p-W32模型的改进方案
人工智能·分类·数据挖掘
数据与后端架构提升之路6 小时前
实战:手搓一个“BEV 级”自动驾驶训练加速平台 —— 当 RTX 4090 遇上多模态数据
人工智能·机器学习·自动驾驶
HyperAI超神经6 小时前
【vLLM 学习】Rlhf Utils
人工智能·深度学习·学习·机器学习·ai编程·vllm
2501_941837267 小时前
基于YOLOv8的19种鱼类目标检测与分类系统——鱼类市场物种识别研究
yolo·目标检测·分类