Flink ML OneHotEncoder 把类别索引变成稀疏 one-hot 向量

1. OneHotEncoder 做什么?

One-hot 编码把一个类别索引(例如 2)映射成一个向量:

  • 类别集合大小为 N
  • 输出向量长度为 N(或 N-1,取决于 dropLast)
  • 只有对应类别的位置为 1,其余为 0
  • Flink ML 输出一般是 SparseVector(稀疏向量),更节省内存

举个例子:假设类别数 3(0/1/2)

  • 输入 0 → [1,0,0]
  • 输入 1 → [0,1,0]
  • 输入 2 → [0,0,1]

如果 dropLast=true(默认):

  • 只输出前 N-1 个维度,最后一个类别被"隐式表示"(全 0)
  • 输入 2(最后类)→ [0,0](长度为 2 的全 0 向量)

这叫"去掉冗余维度",常用于线性模型防止共线性(dummy variable trap)。

2. 输入列与输出列

输入列(Input Columns)

参数名 类型 默认值 说明
inputCols Integer null 类别索引列(可多列)

输出列(Output Columns)

参数名 类型 默认值 说明
outputCols Vector null one-hot 后的向量列(可多列)

3. 参数(Parameters)详解

Key 默认值 必填 说明
inputCols null 输入列名数组
outputCols null 输出列名数组
handleInvalid ERROR_INVALID 遇到非法值如何处理:ERROR_INVALID/SKIP_INVALID
dropLast true 是否丢弃最后一类(输出维度减少 1)

3.1 dropLast 该怎么选?

  • dropLast=true(默认):
    更适合 LR 这类线性模型(避免多重共线性),也省一维
  • dropLast=false
    适合你希望保留完整可解释 one-hot(比如下游做统计、可视化、或者某些非线性模型也无所谓)

3.2 handleInvalid 的工程意义

  • ERROR_INVALID:脏数据直接失败(更严格)
  • SKIP_INVALID:跳过包含非法类别索引的行(更稳,线上更常见)

4. Java 示例代码解读(fit + transform)

你贴的示例非常标准:先用训练数据拟合出类别空间(决定 one-hot 维度),再对预测数据输出 one-hot 向量。

4.1 训练数据(决定类别数)

java 复制代码
DataStream<Row> trainStream =
        env.fromElements(Row.of(0.0), Row.of(1.0), Row.of(2.0), Row.of(0.0));
Table trainTable = tEnv.fromDataStream(trainStream).as("input");

训练集出现了 0/1/2 三个类别,所以类别数 N=3。

若 dropLast=true,输出维度 = 2;若 dropLast=false,输出维度 = 3。

4.2 预测数据(做编码)

java 复制代码
DataStream<Row> predictStream = env.fromElements(Row.of(0.0), Row.of(1.0), Row.of(2.0));
Table predictTable = tEnv.fromDataStream(predictStream).as("input");

4.3 创建 OneHotEncoder 并训练/预测

java 复制代码
OneHotEncoder oneHotEncoder =
        new OneHotEncoder().setInputCols("input").setOutputCols("output");

OneHotEncoderModel model = oneHotEncoder.fit(trainTable);
Table outputTable = model.transform(predictTable)[0];

4.4 输出读取(SparseVector)

java 复制代码
Double inputValue = (Double) row.getField(oneHotEncoder.getInputCols()[0]);
SparseVector outputValue = (SparseVector) row.getField(oneHotEncoder.getOutputCols()[0]);
System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue);

输出是 SparseVector,典型打印可能类似:

  • input=0 → (size=2, indices=[0], values=[1.0])
  • input=1 → (size=2, indices=[1], values=[1.0])
  • input=2 → (size=2, indices=[], values=[]) (当 dropLast=true 时)

5. 实战用法:OneHotEncoder 常见组合链路

组合 1:Bucketizer → OneHotEncoder → LogisticRegression

连续特征先分桶再 one-hot,是 CTR/风控等场景的经典套路:

  • Bucketizer 把连续数值映射到桶索引(离散)
  • OneHotEncoder 把桶索引变成稀疏向量
  • LR 用稀疏向量训练分类模型

组合 2:离散类别索引(LabelEncoder)→ OneHotEncoder

如果原始是字符串类别(比如 "Beijing"/"Shanghai"),通常流程是:

  • 先用编码器把字符串映射成 index(0/1/2/...)
  • 再 OneHotEncoder 变成向量

6. 注意事项(很容易踩坑)

1)训练集类别空间决定输出维度

预测时出现训练集中没见过的类别,如果 handleInvalid=ERROR_INVALID 可能会直接失败。线上更建议:

  • 提前统一类别字典
  • 或使用 SKIP_INVALID,并监控跳过比例

2)输入列类型要统一

文档说 inputCols 是 Integer,但示例用 Double(0.0/1.0/2.0)。

工程里建议你用 Integer/Long 更清晰,减少类型转换风险。

3)dropLast 的影响要明确

dropLast=true 时,最后一类会被编码成全 0 向量,这在调试/可解释性上容易误解(看起来像"缺失"),所以如果你更强调解释,可考虑 dropLast=false。

7. 小结

OneHotEncoder 是 Flink ML 特征工程里最常用的离散特征转换器之一:

  • 输入:一个或多个类别索引列
  • 输出:每列对应一个 one-hot 稀疏向量列
  • dropLast 控制是否减少一个维度(默认 true)
  • handleInvalid 控制脏数据处理策略(线上建议更稳的策略)
相关推荐
我想吃烤肉肉2 小时前
关于Python的垃圾回收
python
xcLeigh2 小时前
金融数据实时行情API使用教程:如何跨市场查询多品种的实时行情数据
python·websocket·金融·股票·api·期货·外汇
让学习成为一种生活方式2 小时前
如何根据过滤的pep序列进一步过滤gff3文件--python015
开发语言·人工智能·python
阿正的梦工坊2 小时前
WebArena:一个真实的网页环境,用于构建更强大的自主智能体
人工智能·深度学习·机器学习·大模型·llm
薛不痒2 小时前
机器学习算法之SVM
算法·机器学习·支持向量机
qijiabao41132 小时前
深度学习|可变形卷积DCNv3编译安装
人工智能·python·深度学习·机器学习·cuda
小鸡吃米…2 小时前
机器学习所需技能
人工智能·机器学习
m5655bj2 小时前
通过 Python 提取 PDF 表格数据
服务器·python·pdf
玄同7652 小时前
面向对象编程 vs 其他编程范式:LLM 开发该选哪种?
大数据·开发语言·前端·人工智能·python·自然语言处理·知识图谱