Flink ML StandardScaler 标准化(去均值 + 除以标准差)让特征“同量纲”更好学

1. StandardScaler 做什么?

对向量特征的每个维度 (x) 做标准化:

  • 先减去均值:(x - \mu)(可选)
  • 再除以标准差:((x - \mu) / \sigma)(可选)

在 Flink ML 里,通过两个开关控制:

  • withMean:是否减均值(默认 false)
  • withStd:是否除以标准差(默认 true)

2. 输入列与输出列

输入列(Input Columns)

参数名 类型 默认值 说明
inputCol Vector "input" 待标准化的特征向量

输出列(Output Columns)

参数名 类型 默认值 说明
outputCol Vector "output" 标准化后的向量

3. 参数详解(Parameters)

Key 默认值 类型 说明
inputCol "input" String 输入列名
outputCol "output" String 输出列名
withMean false Boolean 是否先减去均值(中心化)
withStd true Boolean 是否按标准差缩放到单位方差

withMean 什么时候开?

  • 你希望特征以 0 为中心、并且数据不是稀疏 one-hot/高维稀疏向量:可以开
  • 如果你的特征是稀疏向量(例如 OneHotEncoder 输出),一般不建议开(中心化会破坏稀疏性、带来不必要开销)

4. Java 示例解读(fit + transform)

标准用法永远是两步:

1)在训练数据上 fit() 学到每个维度的统计量(均值、方差/标准差)

2)用同一个 StandardScalerModel 对训练/预测数据 transform(),保证线上线下一致

你给的示例在同一份 inputTable 上 fit + transform,演示效果更直观。

java 复制代码
import org.apache.flink.ml.feature.standardscaler.StandardScaler;
import org.apache.flink.ml.feature.standardscaler.StandardScalerModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

public class StandardScalerExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        DataStream<Row> inputStream =
                env.fromElements(
                        Row.of(Vectors.dense(-2.5, 9, 1)),
                        Row.of(Vectors.dense(1.4, -5, 1)),
                        Row.of(Vectors.dense(2, -1, -2)));
        Table inputTable = tEnv.fromDataStream(inputStream).as("input");

        StandardScaler standardScaler = new StandardScaler(); // 默认 withMean=false, withStd=true

        StandardScalerModel model = standardScaler.fit(inputTable);

        Table outputTable = model.transform(inputTable)[0];

        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            DenseVector inputValue = (DenseVector) row.getField(standardScaler.getInputCol());
            DenseVector outputValue = (DenseVector) row.getField(standardScaler.getOutputCol());
            System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue);
        }
    }
}

这段代码做了什么:

  • fit():扫描 inputTable,计算每个维度的标准差(以及如果 withMean=true 则计算均值)
  • transform():对每条向量逐维标准化,新增输出列 output

5. 实战建议

1)StandardScaler 比 MinMaxScaler 更抗异常值一点,但也不是"免疫"

StandardScaler用均值/方差,极端值仍会影响统计量,只是通常比 min/max 更稳定。

如果异常值特别多,建议先做截断/清洗,再标准化。

2)不要在预测数据上重新 fit

必须做到:

  • 训练阶段:fit(train)
  • 预测阶段:transform(predict) 用同一个 model
    否则线上每批数据的缩放尺度都变,模型输出会漂。

3)常见组合

  • VectorAssembler → StandardScaler → KMeans/KNN/LinearSVC/LogisticRegression
  • Bucketizer/OneHotEncoder 这种离散稀疏特征链路,一般不需要 withMean
相关推荐
Chef_Chen15 小时前
数据科学每日总结--Day44--机器学习
人工智能·机器学习
Master_oid18 小时前
机器学习29:增强式学习(Deep Reinforcement Learning)④
人工智能·学习·机器学习
ballball~~18 小时前
拉普拉斯金字塔
算法·机器学习
Cemtery11618 小时前
Day26 常见的降维算法
人工智能·python·算法·机器学习
weixin_4469340321 小时前
统计学中“in sample test”与“out of sample”有何区别?
人工智能·python·深度学习·机器学习·计算机视觉
wubba lubba dub dub7501 天前
第三十三周 学习周报
学习·算法·机器学习
猫天意1 天前
【深度学习小课堂】| torch | 升维打击还是原位拼接?深度解码 PyTorch 中 stack 与 cat 的几何奥义
开发语言·人工智能·pytorch·深度学习·神经网络·yolo·机器学习
cyyt1 天前
深度学习周报(1.12~1.18)
人工智能·算法·机器学习
独自破碎E1 天前
【回溯+剪枝】字符串的排列
算法·机器学习·剪枝
OJAC1111 天前
当DeepSeek V4遇见近屿智能:一场AI进化的叙事正在展开
人工智能·深度学习·机器学习