背景描述
比如两列 from城市 to城市
我们的需求是两侧同一个城市必须labelEncoder后编码相同.
代码
python
"""
需求说明
- 两列城市字段(origin_city、dest_city)表达同一语义,需要共享一套 Label 编码映射。
- 使用 PySpark 框架实现,且编码器可复用,并可整合进 Spark ML Pipeline。
"""
from __future__ import annotations
from typing import Dict, Iterable, List, Optional, Tuple
from pyspark.sql import SparkSession, functions as F, types as T
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.param.shared import Param, Params, TypeConverters
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
class SharedStringIndexer(Estimator, DefaultParamsReadable, DefaultParamsWritable):
"""一个 Estimator:基于多列拟合一次 StringIndexer 的 labels,并产出可对多列统一编码的 Model。
Params
- inputCols: List[str] 要共享映射的输入列
- outputCols: List[str] 对应输出列名(与 inputCols 等长)
- handleInvalid: keep/skip/error(含义同 StringIndexer)
"""
inputCols = Param(Params._dummy(), "inputCols", "input columns", typeConverter=TypeConverters.toListString)
outputCols = Param(Params._dummy(), "outputCols", "output columns", typeConverter=TypeConverters.toListString)
handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid labels", typeConverter=TypeConverters.toString)
def __init__(self, inputCols: List[str], outputCols: List[str], handleInvalid: str = "keep"):
super().__init__()
if len(inputCols) != len(outputCols):
raise ValueError("inputCols 与 outputCols 长度需一致")
self._set(inputCols=inputCols, outputCols=outputCols, handleInvalid=handleInvalid)
def _fit(self, dataset):
# 将多列堆叠为单列 value 后,用 StringIndexer 拟合一次,得到统一 labels
stacked = None
for c in self.getOrDefault(self.inputCols):
col_df = dataset.select(F.col(c).cast(T.StringType()).alias("value")).na.fill({"value": ""})
stacked = col_df if stacked is None else stacked.unionByName(col_df)
indexer = StringIndexer(inputCol="value", outputCol="value_idx", handleInvalid="keep")
model = indexer.fit(stacked)
labels = list(model.labels)
return SharedStringIndexerModel().setParams(
inputCols=self.getOrDefault(self.inputCols),
outputCols=self.getOrDefault(self.outputCols),
handleInvalid=self.getOrDefault(self.handleInvalid),
labels=labels,
)
class SharedStringIndexerModel(Model, DefaultParamsReadable, DefaultParamsWritable):
"""Transformer:将拟合得到的 labels 作为共享映射,对多列输出统一索引。
为了能够被 PipelineModel.save/load 序列化,labels 作为一个 Param 保存。
"""
inputCols = Param(Params._dummy(), "inputCols", "input columns", typeConverter=TypeConverters.toListString)
outputCols = Param(Params._dummy(), "outputCols", "output columns", typeConverter=TypeConverters.toListString)
handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid labels", typeConverter=TypeConverters.toString)
labels = Param(Params._dummy(), "labels", "shared label list", typeConverter=TypeConverters.toListString)
def __init__(self):
# 必须无参构造以支持反序列化
super().__init__()
def setParams(self, **kwargs):
self._set(**kwargs)
return self
def _transform(self, dataset):
label_list = self.getOrDefault(self.labels) or []
mapping = {v: i for i, v in enumerate(label_list)}
unknown_index = len(label_list)
handle_invalid = self.getOrDefault(self.handleInvalid)
bmap = dataset.sparkSession.sparkContext.broadcast(mapping)
def map_value(v: Optional[str]) -> Optional[int]:
if v is None:
return None if handle_invalid == "skip" else unknown_index if handle_invalid == "keep" else None
idx = bmap.value.get(v)
if idx is not None:
return idx
if handle_invalid == "keep":
return unknown_index
if handle_invalid == "skip":
return None
raise ValueError(f"未知标签: {v}")
enc = F.udf(map_value, T.IntegerType())
out = dataset
for src, dst in zip(self.getOrDefault(self.inputCols), self.getOrDefault(self.outputCols)):
out = out.withColumn(dst, enc(F.col(src).cast(T.StringType())))
return out
def main():
spark = SparkSession.builder.appName("shared_string_indexer_pipeline").getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
data = [
(1, "北京", "上海", 1),
(2, "上海", "北京", 0),
(3, "广州", "深圳", 1),
(4, "深圳", "广州", 0),
(5, "北京", "广州", 1),
(6, "上海", "深圳", 0),
]
columns = ["id", "origin_city", "dest_city", "label"]
df = spark.createDataFrame(data, schema=columns)
df_test = spark.createDataFrame([
(1, "北京111", "上海", 1)], schema=columns)
shared_indexer = SharedStringIndexer(
inputCols=["origin_city", "dest_city"],
outputCols=["origin_city_idx", "dest_city_idx"],
handleInvalid="keep",
)
assembler = VectorAssembler(
inputCols=["origin_city_idx", "dest_city_idx"],
outputCol="features",
)
pipeline = Pipeline(stages=[shared_indexer, assembler])
model = pipeline.fit(df)
out = model.transform(df)
out.select("id", "origin_city", "dest_city", "origin_city_idx", "dest_city_idx", "features").show(truncate=False)
# 复用:保存/加载整个 PipelineModel(包含共享映射)
# 也可以仅保存 shared_indexer 的 model(通过 pipeline.stages[0] 的写接口)
model.write().overwrite().save("./shared_indexer_pipeline_model")
print('新数据转换:') # handleInvalid="keep" 所以这里新枚举值不报错
model.transform(df_test).show(truncate=False)
print("加载导出后的模型 新数据转换")
loaded_model = PipelineModel.load("./shared_indexer_pipeline_model")
loaded_model.transform(df_test).show(truncate=False)
# spark.stop()
main()
输出
+---+-----------+---------+---------------+-------------+---------+
|id |origin_city|dest_city|origin_city_idx|dest_city_idx|features |
+---+-----------+---------+---------------+-------------+---------+
|1 |北京 |上海 |1 |0 |[1.0,0.0]|
|2 |上海 |北京 |0 |1 |[0.0,1.0]|
|3 |广州 |深圳 |2 |3 |[2.0,3.0]|
|4 |深圳 |广州 |3 |2 |[3.0,2.0]|
|5 |北京 |广州 |1 |2 |[1.0,2.0]|
|6 |上海 |深圳 |0 |3 |[0.0,3.0]|
+---+-----------+---------+---------------+-------------+---------+
新数据转换:
+---+-----------+---------+-----+---------------+-------------+---------+
|id |origin_city|dest_city|label|origin_city_idx|dest_city_idx|features |
+---+-----------+---------+-----+---------------+-------------+---------+
|1 |北京111 |上海 |1 |4 |0 |[4.0,0.0]|
+---+-----------+---------+-----+---------------+-------------+---------+
加载导出后的模型 新数据转换
+---+-----------+---------+-----+---------------+-------------+---------+
|id |origin_city|dest_city|label|origin_city_idx|dest_city_idx|features |
+---+-----------+---------+-----+---------------+-------------+---------+
|1 |北京111 |上海 |1 |4 |0 |[4.0,0.0]|
+---+-----------+---------+-----+---------------+-------------+---------+