背景描述
比如两列 from城市 to城市
我们的需求是两侧同一个城市必须labelEncoder后编码相同.
代码
python
from __future__ import annotations
from typing import Dict, Iterable, List, Optional
from pyspark.sql import SparkSession, functions as F, types as T
from pyspark.ml.feature import StringIndexer
class SharedLabelEncoder:
"""
共享标签编码器:对多列使用同一套 label->index 映射。
- handle_invalid: "keep"(未知值编码为未知索引)、"skip"(返回 None)、"error"(抛错)
- unknown 索引默认等于 len(labels),仅在 handle_invalid="keep" 时使用。
"""
def __init__(self, labels: Optional[List[str]] = None, handle_invalid: str = "keep"):
self.labels: List[str] = labels or []
self.label_to_index: Dict[str, int] = {v: i for i, v in enumerate(self.labels)}
self.handle_invalid = handle_invalid
def fit(self, df, cols: Iterable[str]) -> "SharedLabelEncoder":
# 将多列堆叠为单列 value 后,用 StringIndexer 拟合一次,得到统一 labels
stacked = None
for c in cols:
col_df = df.select(F.col(c).cast(T.StringType()).alias("value")).na.fill({"value": ""})
stacked = col_df if stacked is None else stacked.unionByName(col_df)
indexer = StringIndexer(inputCol="value", outputCol="value_idx", handleInvalid="keep")
model = indexer.fit(stacked)
self.labels = list(model.labels)
self.label_to_index = {v: i for i, v in enumerate(self.labels)}
return self
def _build_udf(self, spark: SparkSession):
m_b = spark.sparkContext.broadcast(self.label_to_index)
unknown_index = len(self.labels)
def map_value(v: Optional[str]) -> Optional[int]:
if v is None:
return None if self.handle_invalid == "skip" else unknown_index if self.handle_invalid == "keep" else None
idx = m_b.value.get(v)
if idx is not None:
return idx
if self.handle_invalid == "keep":
return unknown_index
if self.handle_invalid == "skip":
return None
raise ValueError(f"未知标签: {v}")
return F.udf(map_value, T.IntegerType())
def transform(self, df, input_cols: Iterable[str], suffix: str = "_idx"):
udf_map = self._build_udf(df.sparkSession)
out = df
for c in input_cols:
out = out.withColumn(c + suffix, udf_map(F.col(c).cast(T.StringType())))
return out
def save(self, path: str):
import json
obj = {"labels": self.labels, "handle_invalid": self.handle_invalid}
with open(path, "w", encoding="utf-8") as f:
json.dump(obj, f, ensure_ascii=False)
@staticmethod
def load(path: str) -> "SharedLabelEncoder":
import json
with open(path, "r", encoding="utf-8") as f:
obj = json.load(f)
return SharedLabelEncoder(labels=obj.get("labels", []), handle_invalid=obj.get("handle_invalid", "keep"))
def main():
spark = SparkSession.builder.appName("shared_label_encoder").getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
data = [
(1, "北京", "上海", 1),
(2, "上海", "北京", 0),
(3, "广州", "深圳", 1),
(4, "深圳", "广州", 0),
(5, "北京", "广州", 1),
(6, "上海", "深圳", 0),
]
columns = ["id", "origin_city", "dest_city", "label"]
df = spark.createDataFrame(data, schema=columns)
# 拟合共享编码器(基于两列)
encoder = SharedLabelEncoder(handle_invalid="keep").fit(df, ["origin_city", "dest_city"])
# 变换两列到相同索引空间
out_df = encoder.transform(df, ["origin_city", "dest_city"])
print("编码结果:")
out_df.show(truncate=False)
# 保存/加载并复用
path = "./shared_label_encoder_city.json"
encoder.save(path)
encoder2 = SharedLabelEncoder.load(path)
new_df = spark.createDataFrame([(7, "北京", "杭州", 1)], schema=columns) # 杭州为新值
out_new = encoder2.transform(new_df, ["origin_city", "dest_city"])
print("加载导出后的encoder并复用:")
out_new.show(truncate=False)
main()
输出
编码结果:
+---+-----------+---------+-----+---------------+-------------+
|id |origin_city|dest_city|label|origin_city_idx|dest_city_idx|
+---+-----------+---------+-----+---------------+-------------+
|1 |北京 |上海 |1 |1 |0 |
|2 |上海 |北京 |0 |0 |1 |
|3 |广州 |深圳 |1 |2 |3 |
|4 |深圳 |广州 |0 |3 |2 |
|5 |北京 |广州 |1 |1 |2 |
|6 |上海 |深圳 |0 |0 |3 |
+---+-----------+---------+-----+---------------+-------------+
加载导出后的encoder并复用:
+---+-----------+---------+-----+---------------+-------------+
|id |origin_city|dest_city|label|origin_city_idx|dest_city_idx|
+---+-----------+---------+-----+---------------+-------------+
|7 |北京 |杭州 |1 |1 |4 |
+---+-----------+---------+-----+---------------+-------------+