Apache Arrow 在 PySpark 中的使用提速 Pandas 转换与 UDF 的关键武器

1. 什么是 Apache Arrow

Apache Arrow 是一种内存列式数据格式。在 PySpark 里,它的核心作用是提升 JVM 与 Python 之间的数据传输效率,因此对经常使用 Pandas、NumPy 的 Python 用户尤其有价值。不过 Arrow 并不会自动在所有场景下生效,通常需要额外的配置或特定 API 才能启用。

2. 使用 Arrow 的前提

要在 PySpark 中使用 Arrow,首先要确保安装了推荐版本的 PyArrow。官方说明中提到,如果你通过 pip 安装 PySpark,可以使用下面的方式安装 SQL 相关依赖:

bash 复制代码
pip install pyspark[sql]

如果不是这种安装方式,就需要手动保证 PyArrow 在所有集群节点上都可用。官方还说明,在 pyspark.sql 场景下,最低支持版本是 Pandas 2.2.0 和 PyArrow 11.0.0。

3. Spark 与 PyArrow Table 互转

从 Spark 4.0 开始,Spark DataFrame 与 PyArrow Table 可以直接互转:

  • SparkSession.createDataFrame(pyarrow_table)
  • DataFrame.toArrow()
python 复制代码
import pyarrow as pa
import numpy as np

table = pa.table(
    [pa.array(np.random.rand(100)) for _ in range(3)],
    names=["a", "b", "c"]
)

df = spark.createDataFrame(table)

result_table = df.select("*").toArrow()
print(result_table.schema)

需要注意的是,DataFrame.toArrow() 会把 DataFrame 全部收集到 Driver 端,因此只适合小数据集。另外,并不是所有 Spark 和 Arrow 类型都完全支持,遇到不支持的类型会直接报错。

4. 启用 Arrow 优化 Pandas 转换

Arrow 最常见的使用场景,是优化 Spark DataFrame 与 Pandas DataFrame 的互转:

  • SparkSession.createDataFrame(pandas_df)
  • DataFrame.toPandas()

要启用这类优化,需要先打开配置:

python 复制代码
import numpy as np
import pandas as pd

spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

pdf = pd.DataFrame(np.random.rand(100, 3))
df = spark.createDataFrame(pdf)
result_pdf = df.select("*").toPandas()

print(result_pdf.describe())

默认情况下,这个配置是关闭的。官方还说明,如果在真正执行前发生错误,Spark 可以自动回退到非 Arrow 实现,这由 spark.sql.execution.arrow.pyspark.fallback.enabled 控制。

同样要记住:即使开启 Arrow,toPandas() 仍然会把所有数据收集到 Driver,因此它依然只适合小数据集。

5. Pandas UDF:Arrow 最常见的高性能入口

Pandas UDF 本质上就是借助 Arrow 传输数据,再用 Pandas 执行向量化计算,因此它通常比普通 Python UDF 更适合数值计算和批处理逻辑。官方文档说明,Pandas UDF 可以通过 pandas_udf() 装饰器定义,不需要额外配置。

5.1 Series to Series

最常见的一种形式是输入 pandas.Series,输出 pandas.Series

python 复制代码
import pandas as pd
from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import LongType

def multiply_func(a: pd.Series, b: pd.Series) -> pd.Series:
    return a * b

multiply = pandas_udf(multiply_func, returnType=LongType())

df = spark.createDataFrame(pd.DataFrame([1, 2, 3], columns=["x"]))
df.select(multiply(col("x"), col("x"))).show()

这种形式要求输出长度和输入长度一致。

5.2 Iterator of Series to Iterator of Series

如果你的函数需要一次初始化状态,然后对多个批次复用,可以使用迭代器形式:

python 复制代码
from typing import Iterator
import pandas as pd
from pyspark.sql.functions import pandas_udf

@pandas_udf("long")
def plus_one(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
    for x in iterator:
        yield x + 1

这种方式适合"初始化一次,多批次重复使用"的场景。

5.3 多列输入

如果需要多个输入列,可以写成"多 Series 迭代器"形式:

python 复制代码
from typing import Iterator, Tuple
import pandas as pd
from pyspark.sql.functions import pandas_udf

@pandas_udf("long")
def multiply_two_cols(
    iterator: Iterator[Tuple[pd.Series, pd.Series]]
) -> Iterator[pd.Series]:
    for a, b in iterator:
        yield a * b

这种方式适合多列联动计算。

5.4 Series to Scalar

Pandas UDF 还可以做聚合,输入 Series,输出一个标量:

python 复制代码
import pandas as pd
from pyspark.sql.functions import pandas_udf
from pyspark.sql import Window

df = spark.createDataFrame(
    [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
    ("id", "v")
)

@pandas_udf("double")
def mean_udf(v: pd.Series) -> float:
    return v.mean()

df.groupby("id").agg(mean_udf(df["v"])).show()

官方特别提醒,这类 UDF 不支持部分聚合,分组或窗口中的所有数据都会被加载到内存里,因此大组数据要特别小心内存压力。

6. Pandas Function APIs

除了 Pandas UDF,PySpark 还提供了 Pandas Function APIs。它们内部同样依赖 Arrow 进行数据传输,但对外表现为 DataFrame 级别 API,而不是列级别 API。官方重点提到三类:applyInPandas()mapInPandas()cogroup().applyInPandas()

6.1 applyInPandas():按组处理

groupBy().applyInPandas() 适合对每个分组执行 pandas.DataFrame 级别的逻辑:

python 复制代码
import pandas as pd

df = spark.createDataFrame(
    [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
    ("id", "v")
)

def subtract_mean(pdf: pd.DataFrame) -> pd.DataFrame:
    return pdf.assign(v=pdf.v - pdf.v.mean())

df.groupby("id").applyInPandas(
    subtract_mean,
    schema="id long, v double"
).show()

官方明确说明:每个组的所有数据都会先加载到内存,再交给函数处理,因此如果分组倾斜严重,很容易出现 OOM。

6.2 mapInPandas():批次映射

mapInPandas() 可以把一个 pandas.DataFrame 迭代器映射成另一个迭代器,输出长度可以任意变化:

python 复制代码
from typing import Iterable
import pandas as pd

df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))

def filter_func(iterator: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]:
    for pdf in iterator:
        yield pdf[pdf.id == 1]

df.mapInPandas(filter_func, schema=df.schema).show()

它很适合需要自定义过滤、拆分、重组批次数据的场景。

6.3 cogroup().applyInPandas():双表按组处理

如果需要两个 DataFrame 按同一键分组后再联合处理,可以使用 cogroup 版本:

python 复制代码
import pandas as pd

df1 = spark.createDataFrame(
    [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
    ("time", "id", "v1")
)

df2 = spark.createDataFrame(
    [(20000101, 1, "x"), (20000101, 2, "y")],
    ("time", "id", "v2")
)

def merge_ordered(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
    return pd.merge_ordered(left, right)

df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(
    merge_ordered,
    schema="time int, id int, v1 double, v2 string"
).show()

同样,官方提醒 cogroup 的所有数据也会先加载到内存中,因此要小心大组问题。

7. Arrow Python UDF

除了 Pandas UDF,官方还引入了 Arrow Python UDF。它依然是逐行执行,但使用 Arrow 进行高效批量传输和序列化。定义时需要在 udf() 里设置 useArrow=True,或者全局打开 spark.sql.execution.pythonUDF.arrow.enabled

python 复制代码
from pyspark.sql.functions import udf

@udf(returnType="int")
def slen(s):
    return len(s)

@udf(returnType="int", useArrow=True)
def arrow_slen(s):
    return len(s)

df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
df.select(slen("name"), arrow_slen("name")).show()

官方指出,相比默认的 pickled Python UDF,Arrow Python UDF 在类型转换机制上更一致,也更能减少类型不匹配带来的歧义和数据损失。

8. 使用 Arrow 时要注意的几个点

8.1 并非所有类型都支持

官方说明,目前 Arrow 转换支持大多数 Spark SQL 类型,但 ArrayType(TimestampType) 仍然不支持;而 MapType 和嵌套 StructTypeArrayType 需要 PyArrow 2.0.0 及以上版本。

8.2 批大小会影响内存

Spark 会把数据分区转换为 Arrow record batch。默认情况下,每个 batch 最大是 10000 行,由 spark.sql.execution.arrow.maxRecordsPerBatch 控制。如果列很多,应该适当降低这个值,以避免 JVM 内存压力过大。

8.3 时间戳语义要特别小心

Spark 内部以 UTC 存储时间戳,而 Pandas 使用的是 datetime64[ns]。官方说明:

  • Spark → Pandas:会转换到 Spark session 时区,并显示为本地时间
  • Spark → PyArrow Table:保持 UTC 和微秒精度
  • Pandas / PyArrow → Spark:会转成 UTC 微秒,纳秒会被截断

所以,时间戳列在 Arrow 场景下虽然会自动转换,但跨系统和跨时区时一定要明确 session 时区配置。

8.4 toPandas()toArrow() 仍然是 Driver 收集操作

即使开启 Arrow,这两个 API 本质上仍是"全量收集到 Driver",不是分布式持久化接口,因此不能把它们当成大数据量导出方案。

8.5 self_destruct 可以省内存

从 Spark 3.2 开始,可以开启 spark.sql.execution.arrow.pyspark.selfDestruct.enabled 来减少 toPandas()toArrow() 转换时的内存占用,但这是实验特性,可能带来只读数组问题,甚至让部分 Pandas 操作报错。官方还提醒,这种模式通常会更慢,因为它是单线程的。

9. 什么时候优先考虑 Arrow

如果你的场景符合下面几类,Arrow 往往值得优先考虑:

  • Spark DataFrame 与 Pandas DataFrame 频繁互转
  • 需要使用 Pandas UDF 做向量化处理
  • 需要 applyInPandas()mapInPandas() 这类 Pandas Function APIs
  • Python UDF 性能瓶颈明显,且可以尝试 Arrow Python UDF

如果只是普通 Spark SQL 查询,或者完全不涉及 Pandas / NumPy / Python UDF,那么 Arrow 带来的收益通常没那么明显。这个结论与官方对 Arrow 适用场景的整体定位是一致的。

10. 总结

Arrow 在 PySpark 里的价值,核心不是"多了一个配置项",而是它打通了 Spark、Pandas、PyArrow 之间更高效的数据交换路径。你可以把它理解成 PySpark 与 Python 数据生态之间的高速通道:DataFrame 转 Pandas 更快,Pandas UDF 更高效,applyInPandas()mapInPandas() 这类 API 也能更自然地发挥作用。当然,它并不是万能开关,内存、类型支持、时间戳语义和 Driver 收集风险仍然要重点关注。只要把这些边界想清楚,Arrow 基本就是 PySpark Python 侧性能优化里绕不开的一环。

相关推荐
Hello.Reader3 小时前
Pandas API on Spark 配置选项系统、默认索引与性能调优
大数据·spark·pandas
言之。4 小时前
Apache ZooKeeper 核心技术全解(面试+实战版)
zookeeper·面试·apache
AI架构师之家1 天前
Apache Camel使用教程一
apache
Python大数据分析@1 天前
Pandas相比Excel的优势是哪些?
excel·pandas
yzx9910131 天前
实时数据处理实战:使用 Apache Flink 消费 Kafka 数据并进行窗口聚合
flink·kafka·apache
Shepherd06192 天前
【IT 实战】Apache 反向代理 UniFi Controller 的终极指北(解决白屏、502、400 错误)
运维·网络·apache·it·unifi
额1292 天前
CentOS 7 安装apache部署discuz导入数据库表
数据库·centos·apache
Hello.Reader2 天前
Pandas API on Spark 快速入门像写 Pandas 一样使用 Spark
大数据·spark·pandas
qzhqbb2 天前
Nginx/Apache 访问规则
运维·nginx·apache