Flink ML KNN 入门基于 Table API 的近邻分类

1. 输入列与输出列

输入列(Input Columns)

参数名 类型 默认值 含义
featuresCol Vector "features" 特征向量列
labelCol Integer "label" 标签列(训练/评估用,也可用于对比)

输出列(Output Columns)

参数名 类型 默认值 含义
predictionCol Integer "prediction" 预测标签列

说明:Flink ML 的 Vector 通常是 DenseVectorSparseVector(特征列必须是向量类型)。

2. 参数(Parameters)

KnnModel 需要的参数

Key 默认值 类型 必填 含义
k 5 Integer 选择最近邻的数量
featuresCol "features" String 特征列名
predictionCol "prediction" String 输出预测列名

Knn 额外需要的参数

Key 默认值 类型 必填 含义
labelCol "label" String 标签列名

3. Java 示例代码(原理 + 流程)

下面是你贴的示例逻辑做一个"工程化解读":

  1. 构造训练集 trainTable(features, label)
  2. 构造待预测集 predictTable(features, label)(这里 label 是"期望值/对照值",不是必须列,但方便打印对比)
  3. knn.fit(trainTable) 生成 knnModel
  4. knnModel.transform(predictTable) 输出结果表,新增 prediction
  5. collect 输出并打印 features、expected、prediction

需要注意的一个坑:label 类型

文档说 labelColInteger ,但你贴的代码训练数据里是 1.0/2.0/3.0 这种 Double。示例里又用:

java 复制代码
double expectedResult = (Double) row.getField(knn.getLabelCol());
double predictionResult = (Double) row.getField(knn.getPredictionCol());

这会让人误以为 prediction 也是 Double。为了更"规范且不踩坑",建议你在自己项目里统一 label 为 Integer(或至少保持 train/predict/输出一致)。

下面我给一个"更规范版本"的示例(仅改了 label 为 Integer,并显式设置列名,逻辑不变)。

更规范的示例(建议用这个)

java 复制代码
import org.apache.flink.ml.classification.knn.Knn;
import org.apache.flink.ml.classification.knn.KnnModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

public class KnnExample {
    public static void main(String[] args) throws Exception {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // 训练数据:features(Vector) + label(Integer)
        DataStream<Row> trainStream =
                env.fromElements(
                        Row.of(Vectors.dense(2.0, 3.0), 1),
                        Row.of(Vectors.dense(2.1, 3.1), 1),
                        Row.of(Vectors.dense(200.1, 300.1), 2),
                        Row.of(Vectors.dense(200.2, 300.2), 2),
                        Row.of(Vectors.dense(200.3, 300.3), 2),
                        Row.of(Vectors.dense(2.8, 3.2), 3),
                        Row.of(Vectors.dense(300.0, 3.2), 4),
                        Row.of(Vectors.dense(2.4, 3.2), 5),
                        Row.of(Vectors.dense(2.5, 3.2), 5)
                );
        Table trainTable = tEnv.fromDataStream(trainStream).as("features", "label");

        // 待预测数据:这里保留 label 作为期望值/对照值(可选)
        DataStream<Row> predictStream =
                env.fromElements(
                        Row.of(Vectors.dense(4.0, 4.1), 5),
                        Row.of(Vectors.dense(300.0, 42.0), 2)
                );
        Table predictTable = tEnv.fromDataStream(predictStream).as("features", "label");

        // 创建 Knn,并设置关键参数
        Knn knn = new Knn()
                .setK(4)
                .setFeaturesCol("features")
                .setLabelCol("label")
                .setPredictionCol("prediction");

        // 训练模型
        KnnModel knnModel = knn.fit(trainTable);

        // 预测
        Table outputTable = knnModel.transform(predictTable)[0];

        // 打印结果:features + expected(label) + prediction
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            DenseVector features = (DenseVector) row.getField(knn.getFeaturesCol());
            Integer expected = (Integer) row.getField(knn.getLabelCol());
            Integer prediction = (Integer) row.getField(knn.getPredictionCol());

            System.out.printf(
                    "Features: %-15s \tExpected: %s \tPrediction: %s\n",
                    features, expected, prediction);
        }
    }
}

4. 输出结果怎么看?

输出表一般会包含:

  • 原始输入列:features、(可选)label
  • 新增输出列:prediction

你打印时就能看到:

  • Features:待预测样本的向量
  • Expected:样本原本的标签(如果你在 predictTable 里带了 label)
  • Prediction:KNN 预测出来的标签

5. 实战建议(很关键)

1)特征缩放非常重要

KNN 完全依赖"距离"。如果你的特征尺度差异很大(比如一个特征是 0~1,另一个是 0~10000),距离会被大尺度特征主导,结果很容易失真。常见做法是:

  • 先用 StandardScaler 标准化
  • 或做 MinMax 归一化

2)k 的选择是个"偏差-方差"平衡

  • k 小:更敏感,容易受噪声影响(方差大)
  • k 大:更平滑,但可能把边界抹平(偏差大)

工程上建议从 3/5/7/9 这类奇数开始试(减少投票平局的概率)。

3)训练数据量大时的性能与资源

KNN 的预测开销通常和训练集规模相关(需要找近邻)。训练集很大时:

  • 可能需要索引/近似近邻(ANN)思路(具体要看 Flink ML 当前实现能力)
  • 或对数据做分桶、预聚类、抽样
  • 或考虑换更适合大规模在线推理的模型(线性模型、树模型等)
相关推荐
FL16238631293 小时前
轴承表面缺陷检测数据集VOC+YOLO格式2064张8类别
人工智能·yolo·机器学习
Hello.Reader3 小时前
Flink ML 线性 SVM(Linear SVC)入门输入输出列、训练参数与 Java 示例解读
java·支持向量机·flink
小鸡吃米…4 小时前
机器学习——生态系统
人工智能·机器学习
行走的bug...5 小时前
利用计算机辅助数学运算
人工智能·算法·机器学习
晨光32115 小时前
Day43 训练和测试的规范写法
python·深度学习·机器学习
智算菩萨5 小时前
【Python机器学习】K-Means 聚类:数据分组与用户画像的完整技术指南
人工智能·python·机器学习
cyyt6 小时前
深度学习周报(12.22~12.28)
人工智能·算法·机器学习
智算菩萨6 小时前
【Python机器学习】回归模型评估指标深度解析:MAE、MSE、RMSE与R²的理论与实践
python·机器学习·回归
Cherry的跨界思维7 小时前
【AI测试全栈:认知升级】2、AI核心概念与全栈技术栈全景
人工智能·深度学习·机器学习·语言模型·ai测试·ai全栈·测试全栈