Flink ML 二分类评估器 BinaryClassificationEvaluator AUC、PR-AUC、KS 一次搞懂

1. 支持的输入形式

BinaryClassificationEvaluator 的输入表需要包含:

  • labelCol:真实标签(Number,默认列名 "label"
  • rawPredictionCol:原始预测(Vector 或 Number,默认列名 "rawPrediction"
  • weightCol:样本权重(Number,可选,默认 null)

rawPrediction 可以是两种类型

1)double

可以是:

  • 二分类硬预测(0/1)
  • 或者 label=1 的概率(更推荐用概率)

2)vector(长度为 2)

表示两个类别的 raw predictions / scores / probabilities,例如:

  • [score0, score1]
  • [p(label=0), p(label=1)]

工程建议:如果你的模型能输出概率向量(比如 [p0, p1]),尽量用 vector 输入,这样更通用也更符合评估器预期。

2. 输出指标说明

评估器输出表可能包含以下列(取决于你配置了哪些 metricsNames):

  • areaUnderROC:ROC 曲线下面积(ROC-AUC)
  • areaUnderPR:PR 曲线下面积(PR-AUC)
  • areaUnderLorenz:与 Lorenz 曲线相关的度量(文档描述里与 KS 有关联)
  • ks:与 KS / Lorenz 相关的度量(文档描述里与 areaUnderLorenz 有对应关系)

小提示:你贴的描述里 areaUnderLorenzks 的解释看起来有点"对调"的味道(一个写 KS,一个写 Lorenz 的面积)。在工程实践里你可以把它们都算出来,再结合你对 KS/Lorenz 的理解确认哪个列对应哪个含义,避免拿错指标去做阈值或准入判断。

指标怎么选

  • 业务更关注整体排序能力、对阈值不敏感:优先看 ROC-AUC
  • 正负样本极度不均衡、关注正类识别质量:优先看 PR-AUC
  • 风控、评分卡、强分离能力诉求:常用 KS(越大通常分离越强)

3. 关键参数(Parameters)

Key 默认值 含义
labelCol "label" 标签列名
weightCol null 权重列名
rawPredictionCol "rawPrediction" 原始预测列名
metricsNames [AREA_UNDER_ROC, AREA_UNDER_PR] 需要输出的指标列表

你可以通过 setMetricsNames(...) 指定要算哪些指标。

4. Java 示例:计算 PR-AUC、ROC-AUC、KS

你贴的示例是最典型的用法:输入 label + rawPrediction(vector),输出 1 行指标结果。

java 复制代码
import org.apache.flink.ml.evaluation.binaryclassification.BinaryClassificationEvaluator;
import org.apache.flink.ml.evaluation.binaryclassification.BinaryClassificationEvaluatorParams;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;

public class BinaryClassificationEvaluatorExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        DataStream<Row> inputStream =
                env.fromElements(
                        Row.of(1.0, Vectors.dense(0.1, 0.9)),
                        Row.of(1.0, Vectors.dense(0.2, 0.8)),
                        Row.of(1.0, Vectors.dense(0.3, 0.7)),
                        Row.of(0.0, Vectors.dense(0.25, 0.75)),
                        Row.of(0.0, Vectors.dense(0.4, 0.6)),
                        Row.of(1.0, Vectors.dense(0.35, 0.65)),
                        Row.of(1.0, Vectors.dense(0.45, 0.55)),
                        Row.of(0.0, Vectors.dense(0.6, 0.4)),
                        Row.of(0.0, Vectors.dense(0.7, 0.3)),
                        Row.of(1.0, Vectors.dense(0.65, 0.35)),
                        Row.of(0.0, Vectors.dense(0.8, 0.2)),
                        Row.of(1.0, Vectors.dense(0.9, 0.1)));
        Table inputTable = tEnv.fromDataStream(inputStream).as("label", "rawPrediction");

        BinaryClassificationEvaluator evaluator =
                new BinaryClassificationEvaluator()
                        .setMetricsNames(
                                BinaryClassificationEvaluatorParams.AREA_UNDER_PR,
                                BinaryClassificationEvaluatorParams.KS,
                                BinaryClassificationEvaluatorParams.AREA_UNDER_ROC);

        Table outputTable = evaluator.transform(inputTable)[0];

        Row evaluationResult = outputTable.execute().collect().next();
        System.out.printf(
                "Area under PR: %s\n",
                evaluationResult.getField(BinaryClassificationEvaluatorParams.AREA_UNDER_PR));
        System.out.printf(
                "Area under ROC: %s\n",
                evaluationResult.getField(BinaryClassificationEvaluatorParams.AREA_UNDER_ROC));
        System.out.printf(
                "KS: %s\n",
                evaluationResult.getField(BinaryClassificationEvaluatorParams.KS));
    }
}

输出表是什么样?

evaluator.transform(inputTable)[0] 返回的 Table 通常只有一行,列名就是你请求的指标名,例如:

  • areaUnderPR
  • areaUnderROC
  • ks
  • areaUnderLorenz(如果你也加上)

你通过 evaluationResult.getField("areaUnderROC") 或常量 key 取值即可。

5. 带权重 weightCol 的用法(更贴近生产)

当你遇到样本不均衡、或者希望某些样本更重要时,可以加 weightCol

  • 输入表:("label", "rawPrediction", "weight")
  • evaluator:.setWeightCol("weight")

示意(结构说明):

java 复制代码
Table inputTable = tEnv.fromDataStream(stream).as("label", "rawPrediction", "weight");

BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
    .setWeightCol("weight")
    .setMetricsNames(
        BinaryClassificationEvaluatorParams.AREA_UNDER_ROC,
        BinaryClassificationEvaluatorParams.AREA_UNDER_PR);

6. 实战注意点

1)rawPrediction 最好用"概率"而不是 0/1

AUC/PR-AUC/KS 本质上依赖排序与阈值扫描。你只给 0/1 会极大损失信息量,指标会变得不稳定或不敏感。

2)label 的取值要明确

通常用 0/1,且要确认"正类"是哪一类(一般 label=1 作为正类)。如果你用 -1/1 或其他编码,建议在进入评估器前统一映射到 0/1。

3)rawPrediction 向量的含义要一致

如果是 [p0, p1],你要确保第二个位置确实对应 label=1 的概率或得分,避免把顺序写反。

4)指标用于"模型比较"比"绝对阈值"更靠谱

尤其是不同数据分布、不同采样策略下,绝对值会波动更大。工程上常见做法是:同分布下对比多模型或多版本的提升/下降趋势。

相关推荐
无人装备硬件开发爱好者2 小时前
AI 辅助程序设计的趋势与范式转移:编码、审核、测试全流程深度解析
大数据·人工智能·架构·核心竞争力重构
Hello.Reader2 小时前
Flink ML K-Means 离线聚类 + 在线增量聚类(mini-batch + decayFactor)
大数据·分类·flink
草莓熊Lotso2 小时前
技术深耕,破局成长:我的2025年度技术创作之路
大数据·开发语言·c++·人工智能·年度总结
Gofarlic_OMS2 小时前
通过MathWorks API实现许可证管理自动化
大数据·数据库·人工智能·adobe·金融·自动化·区块链
星川皆无恙2 小时前
从“盲人摸象“到“全面感知“:多模态学习的进化之路
大数据·人工智能·python·深度学习·学习
艾莉丝努力练剑2 小时前
【Linux进程(六)】程序地址空间深度实证:从内存布局验证到虚拟化理解的基石
大数据·linux·运维·服务器·人工智能·windows·centos
yangmf20402 小时前
INFINI Gateway 助力联想集团 ES 迁移升级
大数据·数据库·elasticsearch·搜索引擎·gateway·全文检索
CLTHREE3 小时前
GitHub Fork到PR全流程操作指南
大数据·elasticsearch·搜索引擎
是阿威啊3 小时前
【maap-analysis】spark离线数仓项目完整的开发流程
大数据·分布式·spark·scala