学习笔记:在PySpark中使用UDF

最近经常使用PySpark进行数据处理,在面对复杂逻辑的时候需要编写自定义函数(UDF:User-defined Functions)。经过学习后总结如下:

在pyspark中使用自定义函数有三种方式

  • 传统的udf函数
  • Pandas UDF
  • Pandas Function API

传统的UDF函数 和在其他计算引擎中类似,使用函数进行逐行处理

Pandas UDF 利用Apache Arrow和pandas实现的高性能向量化函数可以做到比传统的UDF函数性能提升

Pandas Function API 和Pandas UDF类似,可以高性能进行整个 DataFrame或分组的数据转换

此外还有UDTF(User-defined Table Functions)来把单行转化成多行,但利用上面的Pandas Function API或者Pandas UDF就可以做到,所以并没有去学习

以上作为非专业的数据分析师的初步学习,仅供参考。下面详细介绍下

传统 Python UDF(逐行处理)

创建方式

方式一:使用 udf()函数(最常用)

python 复制代码
from pyspark.sql import functions as F
from pyspark.sql.types import StringType, IntegerType

# 定义普通Python函数
def reverse_string(s):
    return s[::-1] if s else None

# 创建UDF并指定返回类型
reverse_udf = F.udf(reverse_string, StringType())

方式二:使用 @udf装饰器(Pythonic风格)

python 复制代码
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType

@udf(IntegerType())
def squared(x):
    return x * x if x is not None else None

方式三:使用 Lambda 表达式(简单操作)

python 复制代码
from pyspark.sql.types import BooleanType

# 直接使用lambda创建UDF
is_adult_udf = F.udf(lambda age: age >= 18 if age is not None else None, BooleanType())

使用方式

在 DataFrame API 中使用

python 复制代码
# 直接应用于DataFrame转换
df = df.withColumn("new_column", your_udf(F.col("existing_column")))

# 在select表达式中使用
df.select("col1", "col2", your_udf("col3").alias("transformed_col"))

# 在filter条件中使用
df.filter(your_udf("col1") > some_value)

在Spark SQL中使用

python 复制代码
# 注册UDF
spark.udf.register("sql_udf_name", your_python_function, returnType)

# 在SQL查询中使用
result = spark.sql("SELECT sql_udf_name(column) FROM my_table")

Arrow Python UDF

传统的UDF也可以引入Apache Arrow 的内存格式来优化数据在 JVM 和 Python 进程间的传输。这个功能从 Spark 2.3 版本开始提供,并在后续版本中持续优化。这可以带来以下优势:

  1. 性能提升:Arrow的列式内存格式和零拷贝特性可以显著减少序列化和反序列化的开销,尤其对于数值型数据或字符串数据,性能提升明显。
  2. 批处理:Arrow以批次(batch)的形式传输数据,而不是逐行传输,这减少了函数调用的次数,提高了处理效率。
  3. 内存效率:Arrow使用共享内存和固定内存布局,减少了内存占用和复制操作。

需要注意有部分类型不支持

python 复制代码
@udf(returnType='int')  # A default, pickled Python UDF
def slen(s):  # type: ignore[no-untyped-def]
    return len(s)

@udf(returnType='int', useArrow=True)  # An Arrow Python UDF
def arrow_slen(s):  # type: ignore[no-untyped-def]
    return len(s)

Pandas UDF(向量化处理)

Pandas UDF(用户定义函数)是 PySpark 中基于 Apache Arrow 和 pandas 实现的高性能向量化函数,专为优化大数据处理而设计。与传统逐行处理的 Python UDF 相比,其性能可提升 10--100 倍

向量化计算

  • 传统 Python UDF 逐行处理数据,每次调用需序列化/反序列化单行数据,开销巨大
  • Pandas UDF 以批次为单位处理数据 :将 Spark 数据分块转换为 pandas.Seriespandas.DataFrame,在内存中批量执行 pandas 操作(底层由 C 实现),显著减少函数调用次数。

零拷贝数据传输

  • 通过 Apache Arrow 直接在 JVM 内存与 Python 进程间交换数据,避免传统序列化的性能损耗
  • 支持列式存储,仅传递所需列,减少 I/O 开销

在Spark 3.0之前,Pandas UDF 通常使用pyspark.sql.functions.PandasUDFType来定义

在Spark 3.x之后,你可以使用Python类型提示(Type Hints)来定义Pandas UDF,使代码更简洁、更易理解,并减少错误

Spark 2.x 写法

python 复制代码
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import DoubleType

# 必须显式声明 UDF 类型
@pandas_udf(DoubleType(), PandasUDFType.SCALAR)
def celsius_to_fahrenheit(temp_series):
return (temp_series * 9/5) + 32

# 使用示例
df = df.withColumn("temp_f", celsius_to_fahrenheit("temp_c"))

Spark 3.x 写法(利用类型提示)

python 复制代码
from pyspark.sql.functions import pandas_udf
import pandas as pd

# 无需额外类型声明,类型提示自动确定类型
@pandas_udf("double")
def celsius_to_fahrenheit(temp_series: pd.Series) -> pd.Series:
return (temp_series * 9/5) + 32

# 使用方式相同
df = df.withColumn("temp_f", celsius_to_fahrenheit("temp_c"))

标量 Pandas UDF (Series → Series)

udf函数的输入和输出的长度必须一样

python 复制代码
from pyspark.sql.functions import pandas_udf
import pandas as pd

# 方式一:使用装饰器
@pandas_udf('double')
def celsius_to_fahrenheit(temp_c: pd.Series) -> pd.Series:
    return (temp_c * 9/5) + 32

# 使用
df = df.withColumn("temp_f", celsius_to_fahrenheit("temp_c"))
python 复制代码
# 方式二:函数式创建
def fahrenheit_to_celsius(temp_f: pd.Series) -> pd.Series:
    return (temp_f - 32) * 5/9

ftoc_udf = pandas_udf(fahrenheit_to_celsius, 'double')

# 使用
df = df.withColumn("temp_f", ftoc_udf("temp_c"))

分组映射 Pandas UDF (DataFrame → DataFrame)

从Spark 3.0开始,分组映射 Pandas UDF被分类为一个单独的Pandas Function API,DataFrame.groupby().applyInPandas()。推荐使用applyInPandas()

python 复制代码
from pyspark.sql.types import StructType, StructField, StringType, DoubleType

# 定义输出schema
output_schema = StructType([
    StructField("category", StringType(), True),
    StructField("avg_value", DoubleType(), True),
    StructField("max_value", DoubleType(), True)
])

@pandas_udf(output_schema)
def calculate_stats(df: pd.DataFrame) -> pd.DataFrame:
    result = pd.DataFrame({
        'category': [df['category'].iloc[0]],
        'avg_value': [df['value'].mean()],
        'max_value': [df['value'].max()]
    })
    return result

# 使用
result_df = df.groupby("category").apply(calculate_stats)

分组聚合 Pandas UDF (Series → Scalar)

python 复制代码
@pandas_udf('double')
def mean_udf(v: pd.Series) -> float:
    return v.mean()

# 使用
df.groupBy("department").agg(mean_udf("salary").alias("avg_salary"))

迭代器 Pandas UDF (Iterator[Series] → Iterator[Series])

python 复制代码
@pandas_udf('string')
def process_batch(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
    # 初始化昂贵资源(如模型)
    expensive_model = load_ml_model()
    
    for series in iterator:
        # 对每个批次应用模型
        yield expensive_model.predict(series)
        
# 使用
df.withColumn("prediction", process_batch("features"))

Pandas Function API

直接操作 整个 DataFrame 或分组,可自由转换数据结构

Pandas Function API 可以通过Pandas实例直接对整个DataFrame应用Python本机函数。在内部,它与Pandas UDF类似,使用Arrow传输数据,Pandas处理数据,这允许向量化操作。然而,Pandas Function API 在PySpark DataFrame而不是Column下表现得像一个常规的API,Pandass函数API中的Python类型提示是可选的,不会影响它目前的内部工作方式,尽管将来可能需要它们。

从Spark 3.0开始,分组映射 Pandas UDF被分类为一个单独的Pandas Function API,DataFrame.groupby().applyInPandas()

它仍然可以与pyspark.sql.functions.PandasUDFTypeDataFrame.groupby().apply() 保持原样;但是,最好直接使用DataFrame.groupby().applyInPandas()pyspark.sql.functions.PandasUDFType将在未来被弃用。

Grouped Map (DataFrame → DataFrame)

此API实现"split-apply-combine"模式,该模式由三个步骤组成:

  • 使用DataFrame.groupBy()将数据拆分为组
  • 对每个组应用函数。函数的输入和输出都是pandas.DataFrame。输入数据包含每个组的所有行和列。
  • 将结果合并到一个新的PySpark DataFrame中

要使用DataFrame.groupBy().applyInPandas(),用户需要定义以下内容:

  • 一个Python函数,定义每个组的计算。
  • 定义输出PySpark DataFrameshema的StructType或字符串。
python 复制代码
def process_group(group_df: pd.DataFrame) -> pd.DataFrame:
    # 组内计算:如添加排名
    group_df["rank"] = group_df["value"].rank(method="dense", ascending=False)
    return group_df

# 输出Schema需匹配返回的DataFrame结构
output_schema = df.schema.add("rank", IntegerType())

# 按group_key分组处理
result_df = df.groupBy("group_key").applyInPandas(process_group, schema=output_schema)

Map (Iterator[DataFrame] → Iterator[DataFrame])

类似Series to Series Pandas UDF,函数接收DataFrame的迭代器,并返回任意长度的DataFrame的迭代器

python 复制代码
df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))

def filter_func(iterator: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]:
    for pdf in iterator:
        yield pdf[pdf.id == 1]

df.mapInPandas(filter_func, schema=df.schema).show()
# +---+---+
# | id|age|
# +---+---+
# |  1| 21|
# +---+---+

协同分组 Co-grouped Map

协同分组允许对不同 DataFrame共享相同键但数据结构不同的关联数据进行联合处理。这超越了传统的 JOIN 操作,提供更灵活的数据组合方式。

它包括以下步骤:

  • 对数据进行Shuffle,使共享key的每个数据帧的组被组合在一起
  • 对每个协同组应用一个函数。该函数的输入是两个pandas.DataFrame(带有一个可选元组,代表key)。函数的输出是pandas.DataFrame
  • 将所有组中的pandas.DataFrame转换为新的PySpark pandas.DataFrame
python 复制代码
import pandas as pd

df1 = spark.createDataFrame(
    [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
    ("time", "id", "v1"))

df2 = spark.createDataFrame(
    [(20000101, 1, "x"), (20000101, 2, "y")],
    ("time", "id", "v2"))

def merge_ordered(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
    return pd.merge_ordered(left, right)

df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(
    merge_ordered, schema="time int, id int, v1 double, v2 string").show()
# +--------+---+---+----+
# |    time| id| v1|  v2|
# +--------+---+---+----+
# |20000101|  1|1.0|   x|
# |20000102|  1|3.0|null|
# |20000101|  2|2.0|   y|
# |20000102|  2|4.0|null|
# +--------+---+---+----+

使用 Pandas Function API 的注意事项

Schema 必须显式声明 不同于 UDF 可自动推断类型,Function API 要求手动指定输出结构

避免超大分组 单个分组数据需能放入单机内存,可通过以下方式控制

优先使用向量化操作 即使在 Pandas 函数内部,也应避免 for 循环


✨ 微信公众号【凉凉的知识库】同步更新,欢迎关注获取最新最有用的知识 ✨

相关推荐
币圈小菜鸟5 小时前
Selenium 自动化测试实战:绕过登录直接获取 Cookie
linux·python·selenium·测试工具·ubuntu·自动化
YangYang9YangYan6 小时前
2025年数字化转型关键证书分析与选择指南
大数据·信息可视化
BD_Marathon6 小时前
【Flink】DataStream API (二)
大数据·flink
BD_Marathon6 小时前
【Flink】DataStream API (一)
大数据·flink
lifallen6 小时前
深入了解Flink核心:Slot资源管理机制
大数据·数据结构·数据库·算法·flink·apache
一百天成为python专家6 小时前
python爬虫之selenium库进阶(小白五分钟从入门到精通)
开发语言·数据库·pytorch·爬虫·python·深度学习·selenium
q_q王6 小时前
linux安装gitlab详细教程,本地管理源代码
git·python·gitlab·代码
zzywxc7877 小时前
苹果WWDC25开发秘鉴:AI、空间计算与Swift 6的融合之道
java·人工智能·python·spring cloud·dubbo·swift·空间计算
小白学大数据8 小时前
模拟登录与Cookie持久化:爬取中国汽车网用户专属榜单数据
开发语言·爬虫·python