spark pipeline 转换n个字段,如何对某个字段反向转换

eg:

f1做onehot f2做labelEncoder f3做归一化. 输入模型推理结果仅仅是f2. 如何对f2做反向转换获取到原始数据.

代码

python 复制代码
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, StringIndexerModel, VectorAssembler, MinMaxScaler, IndexToString
from pyspark.ml.functions import vector_to_array


def main():
    # 1) 启动 Spark(本地示例)
    spark = (
        SparkSession.builder.appName("pyspark_pipeline_example")
        .master("local[*]")
        .getOrCreate()
    )
    spark.sparkContext.setLogLevel("ERROR")

    # 2) 构造示例数据:
    #    - category: 需要做 LabelEncoder(StringIndexer)
    #    - value:    需要做数值归一化(MinMaxScaler)
    data = [
        ("A", 1.0),
        ("B", 2.0),
        ("A", 3.0),
        ("C", 5.0),
        (None, 10.0),  # 含空值,演示 handleInvalid="keep"
    ]
    df = spark.createDataFrame(data, ["category", "value"])
    print('原始数据:')
    df.show(truncate=False)

    # 3) 定义 Pipeline 各阶段
    # StringIndexer 做"标签编码",将字符串类目映射到数值索引
    indexer = StringIndexer(
        inputCol="category",
        outputCol="category_idx",
        handleInvalid="keep",  # 未见/空值统一映射到一个索引
    )

    # 数值特征先装配为向量,再做 Min-Max 归一化到 [0,1]
    assembler = VectorAssembler(inputCols=["value"], outputCol="value_vec")
    scaler = MinMaxScaler(inputCol="value_vec", outputCol="value_scaled_vec")

    pipeline = Pipeline(stages=[indexer, assembler, scaler])

    # 4) 拟合并转换
    model = pipeline.fit(df)
    out = model.transform(df)
    # 将 1 维向量转回标量便于查看
    out = out.withColumn("value_scaled", vector_to_array(F.col("value_scaled_vec"))[0])

    print("编码/归一化后的结果:")
    out.select("category", "category_idx", "value", "value_scaled").show(truncate=False)

    # 5) 仅对一列做"反向转换"(把 category_idx -> 原始字符串)
    #    不依赖 stages 的下标,优先从列的 metadata 读取 labels;若缺失再根据输出列名定位对应的 StringIndexerModel。

    def resolve_labels_from_metadata(dataframe, indexed_col: str):
        md = dataframe.schema[indexed_col].metadata
        # StringIndexer 会在输出列写入 ml_attr.vals
        if isinstance(md, dict):
            ml_attr = md.get("ml_attr") or {}
            vals = ml_attr.get("vals")
            if vals:
                return list(vals)
        # 某些 Spark 版本 metadata 不是纯 dict,也尝试通用访问
        try:
            ml_attr = md["ml_attr"]
            vals = ml_attr["vals"]
            if vals:
                return list(vals)
        except Exception:
            pass
        return None

    labels = resolve_labels_from_metadata(out, "category_idx")
    if labels is None:
        # 退化方案:在 pipeline 内按类型与输出列名查找对应的 StringIndexerModel
        for st in model.stages:
            if isinstance(st, StringIndexerModel) and st.getOutputCol() == "category_idx":
                labels = list(st.labels)
                break
    if labels is None:
        raise RuntimeError("无法解析 category_idx 的 labels(既无 metadata,也未在 pipeline 中找到对应的 StringIndexerModel)")

    idx_to_str = IndexToString(inputCol="category_idx", outputCol="category_inv", labels=labels)
    reversed_df = idx_to_str.transform(out)

    print("仅对 category_idx 做反向转换(一列):")
    reversed_df.select("category_idx", "category_inv").show(truncate=False)

    # spark.stop()


if __name__ == "__main__":
    main()

结果

复制代码
原始数据:
+--------+-----+
|category|value|
+--------+-----+
|A       |1.0  |
|B       |2.0  |
|A       |3.0  |
|C       |5.0  |
|NULL    |10.0 |
+--------+-----+

编码/归一化后的结果:
+--------+------------+-----+------------------+
|category|category_idx|value|value_scaled      |
+--------+------------+-----+------------------+
|A       |0.0         |1.0  |0.0               |
|B       |1.0         |2.0  |0.1111111111111111|
|A       |0.0         |3.0  |0.2222222222222222|
|C       |2.0         |5.0  |0.4444444444444444|
|NULL    |3.0         |10.0 |1.0               |
+--------+------------+-----+------------------+

仅对 category_idx 做反向转换(一列):
+------------+------------+
|category_idx|category_inv|
+------------+------------+
|0.0         |A           |
|1.0         |B           |
|0.0         |A           |
|2.0         |C           |
|3.0         |__unknown   |
+------------+------------+
相关推荐
Jackyzhe9 分钟前
Flink源码阅读:JobManager的HA机制
大数据·flink
鲨莎分不晴11 分钟前
大数据基石深度解析:系统性读懂 Hadoop 与 ZooKeeper
大数据·hadoop·zookeeper
Sylvan Ding7 小时前
度量空间数据管理与分析系统——大数据泛构课程作业-2025~2026学年. 毛睿
大数据·深圳大学·大数据泛构·度量空间数据管理与分析系统·毛睿·北京理工大学珠海校区
面向Google编程9 小时前
Flink源码阅读:JobManager的HA机制
大数据·flink
Tony Bai10 小时前
【分布式系统】03 复制(上):“权威中心”的秩序 —— 主从架构、一致性与权衡
大数据·数据库·分布式·架构
汽车仪器仪表相关领域11 小时前
全自动化精准检测,赋能高效年检——NHD-6108全自动远、近光检测仪项目实战分享
大数据·人工智能·功能测试·算法·安全·自动化·压力测试
大厂技术总监下海12 小时前
根治LLM胡说八道!用 Elasticsearch 构建 RAG,给你一个“有据可查”的AI
大数据·elasticsearch·开源
石像鬼₧魂石13 小时前
22端口(OpenSSH 4.7p1)渗透测试完整复习流程(含实战排错)
大数据·网络·学习·安全·ubuntu
TDengine (老段)14 小时前
TDengine Python 连接器进阶指南
大数据·数据库·python·物联网·时序数据库·tdengine·涛思数据
数据猿16 小时前
【金猿CIO展】如康集团CIO 赵鋆洲:数智重塑“顶牛”——如康集团如何用大数据烹饪万亿肉食产业的未来
大数据