1. OneHotEncoder 做什么?
One-hot 编码把一个类别索引(例如 2)映射成一个向量:
- 类别集合大小为 N
- 输出向量长度为 N(或 N-1,取决于 dropLast)
- 只有对应类别的位置为 1,其余为 0
- Flink ML 输出一般是 SparseVector(稀疏向量),更节省内存
举个例子:假设类别数 3(0/1/2)
- 输入 0 → [1,0,0]
- 输入 1 → [0,1,0]
- 输入 2 → [0,0,1]
如果 dropLast=true(默认):
- 只输出前 N-1 个维度,最后一个类别被"隐式表示"(全 0)
- 输入 2(最后类)→ [0,0](长度为 2 的全 0 向量)
这叫"去掉冗余维度",常用于线性模型防止共线性(dummy variable trap)。
2. 输入列与输出列
输入列(Input Columns)
| 参数名 | 类型 | 默认值 | 说明 |
|---|---|---|---|
inputCols |
Integer | null | 类别索引列(可多列) |
输出列(Output Columns)
| 参数名 | 类型 | 默认值 | 说明 |
|---|---|---|---|
outputCols |
Vector | null | one-hot 后的向量列(可多列) |
3. 参数(Parameters)详解
| Key | 默认值 | 必填 | 说明 |
|---|---|---|---|
inputCols |
null | ✅ | 输入列名数组 |
outputCols |
null | ✅ | 输出列名数组 |
handleInvalid |
ERROR_INVALID |
否 | 遇到非法值如何处理:ERROR_INVALID/SKIP_INVALID |
dropLast |
true | 否 | 是否丢弃最后一类(输出维度减少 1) |
3.1 dropLast 该怎么选?
dropLast=true(默认):
更适合 LR 这类线性模型(避免多重共线性),也省一维dropLast=false:
适合你希望保留完整可解释 one-hot(比如下游做统计、可视化、或者某些非线性模型也无所谓)
3.2 handleInvalid 的工程意义
ERROR_INVALID:脏数据直接失败(更严格)SKIP_INVALID:跳过包含非法类别索引的行(更稳,线上更常见)
4. Java 示例代码解读(fit + transform)
你贴的示例非常标准:先用训练数据拟合出类别空间(决定 one-hot 维度),再对预测数据输出 one-hot 向量。
4.1 训练数据(决定类别数)
java
DataStream<Row> trainStream =
env.fromElements(Row.of(0.0), Row.of(1.0), Row.of(2.0), Row.of(0.0));
Table trainTable = tEnv.fromDataStream(trainStream).as("input");
训练集出现了 0/1/2 三个类别,所以类别数 N=3。
若 dropLast=true,输出维度 = 2;若 dropLast=false,输出维度 = 3。
4.2 预测数据(做编码)
java
DataStream<Row> predictStream = env.fromElements(Row.of(0.0), Row.of(1.0), Row.of(2.0));
Table predictTable = tEnv.fromDataStream(predictStream).as("input");
4.3 创建 OneHotEncoder 并训练/预测
java
OneHotEncoder oneHotEncoder =
new OneHotEncoder().setInputCols("input").setOutputCols("output");
OneHotEncoderModel model = oneHotEncoder.fit(trainTable);
Table outputTable = model.transform(predictTable)[0];
4.4 输出读取(SparseVector)
java
Double inputValue = (Double) row.getField(oneHotEncoder.getInputCols()[0]);
SparseVector outputValue = (SparseVector) row.getField(oneHotEncoder.getOutputCols()[0]);
System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue);
输出是 SparseVector,典型打印可能类似:
- input=0 → (size=2, indices=[0], values=[1.0])
- input=1 → (size=2, indices=[1], values=[1.0])
- input=2 → (size=2, indices=[], values=[]) (当 dropLast=true 时)
5. 实战用法:OneHotEncoder 常见组合链路
组合 1:Bucketizer → OneHotEncoder → LogisticRegression
连续特征先分桶再 one-hot,是 CTR/风控等场景的经典套路:
- Bucketizer 把连续数值映射到桶索引(离散)
- OneHotEncoder 把桶索引变成稀疏向量
- LR 用稀疏向量训练分类模型
组合 2:离散类别索引(LabelEncoder)→ OneHotEncoder
如果原始是字符串类别(比如 "Beijing"/"Shanghai"),通常流程是:
- 先用编码器把字符串映射成 index(0/1/2/...)
- 再 OneHotEncoder 变成向量
6. 注意事项(很容易踩坑)
1)训练集类别空间决定输出维度
预测时出现训练集中没见过的类别,如果 handleInvalid=ERROR_INVALID 可能会直接失败。线上更建议:
- 提前统一类别字典
- 或使用 SKIP_INVALID,并监控跳过比例
2)输入列类型要统一
文档说 inputCols 是 Integer,但示例用 Double(0.0/1.0/2.0)。
工程里建议你用 Integer/Long 更清晰,减少类型转换风险。
3)dropLast 的影响要明确
dropLast=true 时,最后一类会被编码成全 0 向量,这在调试/可解释性上容易误解(看起来像"缺失"),所以如果你更强调解释,可考虑 dropLast=false。
7. 小结
OneHotEncoder 是 Flink ML 特征工程里最常用的离散特征转换器之一:
- 输入:一个或多个类别索引列
- 输出:每列对应一个 one-hot 稀疏向量列
dropLast控制是否减少一个维度(默认 true)handleInvalid控制脏数据处理策略(线上建议更稳的策略)