sparkml 多列共享labelEncoder

背景描述

比如两列 from城市 to城市

我们的需求是两侧同一个城市必须labelEncoder后编码相同.

代码

python 复制代码
from __future__ import annotations

from typing import Dict, Iterable, List, Optional

from pyspark.sql import SparkSession, functions as F, types as T
from pyspark.ml.feature import StringIndexer


class SharedLabelEncoder:
    """
    共享标签编码器:对多列使用同一套 label->index 映射。

    - handle_invalid: "keep"(未知值编码为未知索引)、"skip"(返回 None)、"error"(抛错)
    - unknown 索引默认等于 len(labels),仅在 handle_invalid="keep" 时使用。
    """

    def __init__(self, labels: Optional[List[str]] = None, handle_invalid: str = "keep"):
        self.labels: List[str] = labels or []
        self.label_to_index: Dict[str, int] = {v: i for i, v in enumerate(self.labels)}
        self.handle_invalid = handle_invalid

    def fit(self, df, cols: Iterable[str]) -> "SharedLabelEncoder":
        # 将多列堆叠为单列 value 后,用 StringIndexer 拟合一次,得到统一 labels
        stacked = None
        for c in cols:
            col_df = df.select(F.col(c).cast(T.StringType()).alias("value")).na.fill({"value": ""})
            stacked = col_df if stacked is None else stacked.unionByName(col_df)
        indexer = StringIndexer(inputCol="value", outputCol="value_idx", handleInvalid="keep")
        model = indexer.fit(stacked)
        self.labels = list(model.labels)
        self.label_to_index = {v: i for i, v in enumerate(self.labels)}
        return self

    def _build_udf(self, spark: SparkSession):
        m_b = spark.sparkContext.broadcast(self.label_to_index)
        unknown_index = len(self.labels)

        def map_value(v: Optional[str]) -> Optional[int]:
            if v is None:
                return None if self.handle_invalid == "skip" else unknown_index if self.handle_invalid == "keep" else None
            idx = m_b.value.get(v)
            if idx is not None:
                return idx
            if self.handle_invalid == "keep":
                return unknown_index
            if self.handle_invalid == "skip":
                return None
            raise ValueError(f"未知标签: {v}")

        return F.udf(map_value, T.IntegerType())

    def transform(self, df, input_cols: Iterable[str], suffix: str = "_idx"):
        udf_map = self._build_udf(df.sparkSession)
        out = df
        for c in input_cols:
            out = out.withColumn(c + suffix, udf_map(F.col(c).cast(T.StringType())))
        return out

    def save(self, path: str):
        import json
        obj = {"labels": self.labels, "handle_invalid": self.handle_invalid}
        with open(path, "w", encoding="utf-8") as f:
            json.dump(obj, f, ensure_ascii=False)

    @staticmethod
    def load(path: str) -> "SharedLabelEncoder":
        import json
        with open(path, "r", encoding="utf-8") as f:
            obj = json.load(f)
        return SharedLabelEncoder(labels=obj.get("labels", []), handle_invalid=obj.get("handle_invalid", "keep"))


def main():
    spark = SparkSession.builder.appName("shared_label_encoder").getOrCreate()
    spark.sparkContext.setLogLevel("ERROR")

    data = [
        (1, "北京", "上海", 1),
        (2, "上海", "北京", 0),
        (3, "广州", "深圳", 1),
        (4, "深圳", "广州", 0),
        (5, "北京", "广州", 1),
        (6, "上海", "深圳", 0),
    ]
    columns = ["id", "origin_city", "dest_city", "label"]
    df = spark.createDataFrame(data, schema=columns)

    # 拟合共享编码器(基于两列)
    encoder = SharedLabelEncoder(handle_invalid="keep").fit(df, ["origin_city", "dest_city"])

    # 变换两列到相同索引空间
    out_df = encoder.transform(df, ["origin_city", "dest_city"])
    print("编码结果:")
    out_df.show(truncate=False)

    # 保存/加载并复用
    path = "./shared_label_encoder_city.json"
    encoder.save(path)
    encoder2 = SharedLabelEncoder.load(path)

    new_df = spark.createDataFrame([(7, "北京", "杭州", 1)], schema=columns)  # 杭州为新值
    out_new = encoder2.transform(new_df, ["origin_city", "dest_city"])
    print("加载导出后的encoder并复用:")
    out_new.show(truncate=False)

    

main()

输出

复制代码
编码结果:
+---+-----------+---------+-----+---------------+-------------+
|id |origin_city|dest_city|label|origin_city_idx|dest_city_idx|
+---+-----------+---------+-----+---------------+-------------+
|1  |北京       |上海     |1    |1              |0            |
|2  |上海       |北京     |0    |0              |1            |
|3  |广州       |深圳     |1    |2              |3            |
|4  |深圳       |广州     |0    |3              |2            |
|5  |北京       |广州     |1    |1              |2            |
|6  |上海       |深圳     |0    |0              |3            |
+---+-----------+---------+-----+---------------+-------------+

加载导出后的encoder并复用:
+---+-----------+---------+-----+---------------+-------------+
|id |origin_city|dest_city|label|origin_city_idx|dest_city_idx|
+---+-----------+---------+-----+---------------+-------------+
|7  |北京       |杭州     |1    |1              |4            |
+---+-----------+---------+-----+---------------+-------------+
相关推荐
漂流瓶jz5 小时前
Webpack中各种devtool配置的含义与SourceMap生成逻辑
前端·javascript·webpack
这是个栗子5 小时前
【问题解决】用pnpm创建的 Vue3项目找不到 .eslintrc.js文件 及 后续的eslint配置的解决办法
javascript·vue.js·pnpm·eslint
zy happy6 小时前
RuoyiApp 在vuex,state存储nickname vue2
前端·javascript·小程序·uni-app·vue·ruoyi
Nan_Shu_6147 小时前
学习:JavaScript(5)
开发语言·javascript·学习
533_7 小时前
[vue3] h函数,阻止事件冒泡
javascript·vue.js·elementui
通往曙光的路上7 小时前
day22_用户授权 头像上传
javascript·vue.js·ecmascript
meichaoWen7 小时前
【Vue】Vue框架的基础知识强化
前端·javascript·vue.js
西西学代码7 小时前
Flutter---DragTarget(颜色拖拽选择器)
前端·javascript·flutter
阿蓝灬8 小时前
React中的stopPropagation和preventDefault
前端·javascript·react.js
天天向上10248 小时前
vue3 抽取el-dialog子组件
前端·javascript·vue.js