sparkml 多列共享labelEncoder pipeline方案

背景描述

比如两列 from城市 to城市

我们的需求是两侧同一个城市必须labelEncoder后编码相同.

代码

python 复制代码
"""
需求说明
- 两列城市字段(origin_city、dest_city)表达同一语义,需要共享一套 Label 编码映射。
- 使用 PySpark 框架实现,且编码器可复用,并可整合进 Spark ML Pipeline。
"""

from __future__ import annotations

from typing import Dict, Iterable, List, Optional, Tuple

from pyspark.sql import SparkSession, functions as F, types as T
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.param.shared import Param, Params, TypeConverters
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable


class SharedStringIndexer(Estimator, DefaultParamsReadable, DefaultParamsWritable):
    """一个 Estimator:基于多列拟合一次 StringIndexer 的 labels,并产出可对多列统一编码的 Model。

    Params
    - inputCols: List[str] 要共享映射的输入列
    - outputCols: List[str] 对应输出列名(与 inputCols 等长)
    - handleInvalid: keep/skip/error(含义同 StringIndexer)
    """

    inputCols = Param(Params._dummy(), "inputCols", "input columns", typeConverter=TypeConverters.toListString)
    outputCols = Param(Params._dummy(), "outputCols", "output columns", typeConverter=TypeConverters.toListString)
    handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid labels", typeConverter=TypeConverters.toString)

    def __init__(self, inputCols: List[str], outputCols: List[str], handleInvalid: str = "keep"):
        super().__init__()
        if len(inputCols) != len(outputCols):
            raise ValueError("inputCols 与 outputCols 长度需一致")
        self._set(inputCols=inputCols, outputCols=outputCols, handleInvalid=handleInvalid)

    def _fit(self, dataset):
        # 将多列堆叠为单列 value 后,用 StringIndexer 拟合一次,得到统一 labels
        stacked = None
        for c in self.getOrDefault(self.inputCols):
            col_df = dataset.select(F.col(c).cast(T.StringType()).alias("value")).na.fill({"value": ""})
            stacked = col_df if stacked is None else stacked.unionByName(col_df)
        indexer = StringIndexer(inputCol="value", outputCol="value_idx", handleInvalid="keep")
        model = indexer.fit(stacked)
        labels = list(model.labels)
        return SharedStringIndexerModel().setParams(
            inputCols=self.getOrDefault(self.inputCols),
            outputCols=self.getOrDefault(self.outputCols),
            handleInvalid=self.getOrDefault(self.handleInvalid),
            labels=labels,
        )


class SharedStringIndexerModel(Model, DefaultParamsReadable, DefaultParamsWritable):
    """Transformer:将拟合得到的 labels 作为共享映射,对多列输出统一索引。

    为了能够被 PipelineModel.save/load 序列化,labels 作为一个 Param 保存。
    """

    inputCols = Param(Params._dummy(), "inputCols", "input columns", typeConverter=TypeConverters.toListString)
    outputCols = Param(Params._dummy(), "outputCols", "output columns", typeConverter=TypeConverters.toListString)
    handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid labels", typeConverter=TypeConverters.toString)
    labels = Param(Params._dummy(), "labels", "shared label list", typeConverter=TypeConverters.toListString)

    def __init__(self):
        # 必须无参构造以支持反序列化
        super().__init__()

    def setParams(self, **kwargs):
        self._set(**kwargs)
        return self

    def _transform(self, dataset):
        label_list = self.getOrDefault(self.labels) or []
        mapping = {v: i for i, v in enumerate(label_list)}
        unknown_index = len(label_list)
        handle_invalid = self.getOrDefault(self.handleInvalid)
        bmap = dataset.sparkSession.sparkContext.broadcast(mapping)

        def map_value(v: Optional[str]) -> Optional[int]:
            if v is None:
                return None if handle_invalid == "skip" else unknown_index if handle_invalid == "keep" else None
            idx = bmap.value.get(v)
            if idx is not None:
                return idx
            if handle_invalid == "keep":
                return unknown_index
            if handle_invalid == "skip":
                return None
            raise ValueError(f"未知标签: {v}")

        enc = F.udf(map_value, T.IntegerType())
        out = dataset
        for src, dst in zip(self.getOrDefault(self.inputCols), self.getOrDefault(self.outputCols)):
            out = out.withColumn(dst, enc(F.col(src).cast(T.StringType())))
        return out


def main():
    spark = SparkSession.builder.appName("shared_string_indexer_pipeline").getOrCreate()
    spark.sparkContext.setLogLevel("ERROR")

    data = [
        (1, "北京", "上海", 1),
        (2, "上海", "北京", 0),
        (3, "广州", "深圳", 1),
        (4, "深圳", "广州", 0),
        (5, "北京", "广州", 1),
        (6, "上海", "深圳", 0),
    ]
    columns = ["id", "origin_city", "dest_city", "label"]
    df = spark.createDataFrame(data, schema=columns)
    df_test = spark.createDataFrame([
        (1, "北京111", "上海", 1)], schema=columns)

    shared_indexer = SharedStringIndexer(
        inputCols=["origin_city", "dest_city"],
        outputCols=["origin_city_idx", "dest_city_idx"],
        handleInvalid="keep",
    )

    assembler = VectorAssembler(
        inputCols=["origin_city_idx", "dest_city_idx"],
        outputCol="features",
    )

    pipeline = Pipeline(stages=[shared_indexer, assembler])
    model = pipeline.fit(df)
    out = model.transform(df)
    out.select("id", "origin_city", "dest_city", "origin_city_idx", "dest_city_idx", "features").show(truncate=False)

    # 复用:保存/加载整个 PipelineModel(包含共享映射)
    # 也可以仅保存 shared_indexer 的 model(通过 pipeline.stages[0] 的写接口)
    model.write().overwrite().save("./shared_indexer_pipeline_model")
    
    print('新数据转换:')   # handleInvalid="keep" 所以这里新枚举值不报错 
    model.transform(df_test).show(truncate=False)
    
    print("加载导出后的模型 新数据转换")
    loaded_model = PipelineModel.load("./shared_indexer_pipeline_model")
    loaded_model.transform(df_test).show(truncate=False)

    # spark.stop()


main()

输出

复制代码
+---+-----------+---------+---------------+-------------+---------+
|id |origin_city|dest_city|origin_city_idx|dest_city_idx|features |
+---+-----------+---------+---------------+-------------+---------+
|1  |北京       |上海     |1              |0            |[1.0,0.0]|
|2  |上海       |北京     |0              |1            |[0.0,1.0]|
|3  |广州       |深圳     |2              |3            |[2.0,3.0]|
|4  |深圳       |广州     |3              |2            |[3.0,2.0]|
|5  |北京       |广州     |1              |2            |[1.0,2.0]|
|6  |上海       |深圳     |0              |3            |[0.0,3.0]|
+---+-----------+---------+---------------+-------------+---------+

新数据转换:
+---+-----------+---------+-----+---------------+-------------+---------+
|id |origin_city|dest_city|label|origin_city_idx|dest_city_idx|features |
+---+-----------+---------+-----+---------------+-------------+---------+
|1  |北京111    |上海     |1    |4              |0            |[4.0,0.0]|
+---+-----------+---------+-----+---------------+-------------+---------+

加载导出后的模型 新数据转换
+---+-----------+---------+-----+---------------+-------------+---------+
|id |origin_city|dest_city|label|origin_city_idx|dest_city_idx|features |
+---+-----------+---------+-----+---------------+-------------+---------+
|1  |北京111    |上海     |1    |4              |0            |[4.0,0.0]|
+---+-----------+---------+-----+---------------+-------------+---------+
相关推荐
Thomas21431 天前
sparkml 多列共享labelEncoder
javascript·ajax·spark-ml
悟乙己12 天前
在 PySpark ML 中LightGBM比XGBoost更好(二)
spark-ml
Lenskit21 天前
使用pyspark对上百亿行的hive表生成稀疏向量
python·spark-ml·spark
程序猿阿伟3 个月前
《深度探秘:Java构建Spark MLlib与TensorFlow Serving混合推理流水线》
java·spark-ml·tensorflow
武子康4 个月前
大数据-276 Spark MLib - 基础介绍 机器学习算法 Bagging和Boosting区别 GBDT梯度提升树
大数据·人工智能·算法·机器学习·语言模型·spark-ml·boosting
武子康4 个月前
大数据-277 Spark MLib - 基础介绍 机器学习算法 Gradient Boosting GBDT算法原理 高效实现
大数据·人工智能·算法·机器学习·ai·spark-ml·boosting
武子康4 个月前
大数据-275 Spark MLib - 基础介绍 机器学习算法 集成学习 随机森林 Bagging Boosting
大数据·算法·机器学习·ai·语言模型·spark-ml·集成学习
武子康4 个月前
大数据-274 Spark MLib - 基础介绍 机器学习算法 剪枝 后剪枝 ID3 C4.5 CART
大数据·人工智能·算法·机器学习·语言模型·spark-ml·剪枝
武子康4 个月前
大数据-273 Spark MLib - 基础介绍 机器学习算法 决策树 分类原则 分类原理 基尼系数 熵
大数据·人工智能·算法·决策树·机器学习·spark-ml