1. 什么是 Apache Arrow
Apache Arrow 是一种内存列式数据格式。在 PySpark 里,它的核心作用是提升 JVM 与 Python 之间的数据传输效率,因此对经常使用 Pandas、NumPy 的 Python 用户尤其有价值。不过 Arrow 并不会自动在所有场景下生效,通常需要额外的配置或特定 API 才能启用。
2. 使用 Arrow 的前提
要在 PySpark 中使用 Arrow,首先要确保安装了推荐版本的 PyArrow。官方说明中提到,如果你通过 pip 安装 PySpark,可以使用下面的方式安装 SQL 相关依赖:
bash
pip install pyspark[sql]
如果不是这种安装方式,就需要手动保证 PyArrow 在所有集群节点上都可用。官方还说明,在 pyspark.sql 场景下,最低支持版本是 Pandas 2.2.0 和 PyArrow 11.0.0。
3. Spark 与 PyArrow Table 互转
从 Spark 4.0 开始,Spark DataFrame 与 PyArrow Table 可以直接互转:
SparkSession.createDataFrame(pyarrow_table)DataFrame.toArrow()
python
import pyarrow as pa
import numpy as np
table = pa.table(
[pa.array(np.random.rand(100)) for _ in range(3)],
names=["a", "b", "c"]
)
df = spark.createDataFrame(table)
result_table = df.select("*").toArrow()
print(result_table.schema)
需要注意的是,DataFrame.toArrow() 会把 DataFrame 全部收集到 Driver 端,因此只适合小数据集。另外,并不是所有 Spark 和 Arrow 类型都完全支持,遇到不支持的类型会直接报错。
4. 启用 Arrow 优化 Pandas 转换
Arrow 最常见的使用场景,是优化 Spark DataFrame 与 Pandas DataFrame 的互转:
SparkSession.createDataFrame(pandas_df)DataFrame.toPandas()
要启用这类优化,需要先打开配置:
python
import numpy as np
import pandas as pd
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
pdf = pd.DataFrame(np.random.rand(100, 3))
df = spark.createDataFrame(pdf)
result_pdf = df.select("*").toPandas()
print(result_pdf.describe())
默认情况下,这个配置是关闭的。官方还说明,如果在真正执行前发生错误,Spark 可以自动回退到非 Arrow 实现,这由 spark.sql.execution.arrow.pyspark.fallback.enabled 控制。
同样要记住:即使开启 Arrow,toPandas() 仍然会把所有数据收集到 Driver,因此它依然只适合小数据集。
5. Pandas UDF:Arrow 最常见的高性能入口
Pandas UDF 本质上就是借助 Arrow 传输数据,再用 Pandas 执行向量化计算,因此它通常比普通 Python UDF 更适合数值计算和批处理逻辑。官方文档说明,Pandas UDF 可以通过 pandas_udf() 装饰器定义,不需要额外配置。
5.1 Series to Series
最常见的一种形式是输入 pandas.Series,输出 pandas.Series:
python
import pandas as pd
from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import LongType
def multiply_func(a: pd.Series, b: pd.Series) -> pd.Series:
return a * b
multiply = pandas_udf(multiply_func, returnType=LongType())
df = spark.createDataFrame(pd.DataFrame([1, 2, 3], columns=["x"]))
df.select(multiply(col("x"), col("x"))).show()
这种形式要求输出长度和输入长度一致。
5.2 Iterator of Series to Iterator of Series
如果你的函数需要一次初始化状态,然后对多个批次复用,可以使用迭代器形式:
python
from typing import Iterator
import pandas as pd
from pyspark.sql.functions import pandas_udf
@pandas_udf("long")
def plus_one(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
for x in iterator:
yield x + 1
这种方式适合"初始化一次,多批次重复使用"的场景。
5.3 多列输入
如果需要多个输入列,可以写成"多 Series 迭代器"形式:
python
from typing import Iterator, Tuple
import pandas as pd
from pyspark.sql.functions import pandas_udf
@pandas_udf("long")
def multiply_two_cols(
iterator: Iterator[Tuple[pd.Series, pd.Series]]
) -> Iterator[pd.Series]:
for a, b in iterator:
yield a * b
这种方式适合多列联动计算。
5.4 Series to Scalar
Pandas UDF 还可以做聚合,输入 Series,输出一个标量:
python
import pandas as pd
from pyspark.sql.functions import pandas_udf
from pyspark.sql import Window
df = spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
("id", "v")
)
@pandas_udf("double")
def mean_udf(v: pd.Series) -> float:
return v.mean()
df.groupby("id").agg(mean_udf(df["v"])).show()
官方特别提醒,这类 UDF 不支持部分聚合,分组或窗口中的所有数据都会被加载到内存里,因此大组数据要特别小心内存压力。
6. Pandas Function APIs
除了 Pandas UDF,PySpark 还提供了 Pandas Function APIs。它们内部同样依赖 Arrow 进行数据传输,但对外表现为 DataFrame 级别 API,而不是列级别 API。官方重点提到三类:applyInPandas()、mapInPandas() 和 cogroup().applyInPandas()。
6.1 applyInPandas():按组处理
groupBy().applyInPandas() 适合对每个分组执行 pandas.DataFrame 级别的逻辑:
python
import pandas as pd
df = spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
("id", "v")
)
def subtract_mean(pdf: pd.DataFrame) -> pd.DataFrame:
return pdf.assign(v=pdf.v - pdf.v.mean())
df.groupby("id").applyInPandas(
subtract_mean,
schema="id long, v double"
).show()
官方明确说明:每个组的所有数据都会先加载到内存,再交给函数处理,因此如果分组倾斜严重,很容易出现 OOM。
6.2 mapInPandas():批次映射
mapInPandas() 可以把一个 pandas.DataFrame 迭代器映射成另一个迭代器,输出长度可以任意变化:
python
from typing import Iterable
import pandas as pd
df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
def filter_func(iterator: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]:
for pdf in iterator:
yield pdf[pdf.id == 1]
df.mapInPandas(filter_func, schema=df.schema).show()
它很适合需要自定义过滤、拆分、重组批次数据的场景。
6.3 cogroup().applyInPandas():双表按组处理
如果需要两个 DataFrame 按同一键分组后再联合处理,可以使用 cogroup 版本:
python
import pandas as pd
df1 = spark.createDataFrame(
[(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
("time", "id", "v1")
)
df2 = spark.createDataFrame(
[(20000101, 1, "x"), (20000101, 2, "y")],
("time", "id", "v2")
)
def merge_ordered(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
return pd.merge_ordered(left, right)
df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(
merge_ordered,
schema="time int, id int, v1 double, v2 string"
).show()
同样,官方提醒 cogroup 的所有数据也会先加载到内存中,因此要小心大组问题。
7. Arrow Python UDF
除了 Pandas UDF,官方还引入了 Arrow Python UDF。它依然是逐行执行,但使用 Arrow 进行高效批量传输和序列化。定义时需要在 udf() 里设置 useArrow=True,或者全局打开 spark.sql.execution.pythonUDF.arrow.enabled。
python
from pyspark.sql.functions import udf
@udf(returnType="int")
def slen(s):
return len(s)
@udf(returnType="int", useArrow=True)
def arrow_slen(s):
return len(s)
df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
df.select(slen("name"), arrow_slen("name")).show()
官方指出,相比默认的 pickled Python UDF,Arrow Python UDF 在类型转换机制上更一致,也更能减少类型不匹配带来的歧义和数据损失。
8. 使用 Arrow 时要注意的几个点
8.1 并非所有类型都支持
官方说明,目前 Arrow 转换支持大多数 Spark SQL 类型,但 ArrayType(TimestampType) 仍然不支持;而 MapType 和嵌套 StructType 的 ArrayType 需要 PyArrow 2.0.0 及以上版本。
8.2 批大小会影响内存
Spark 会把数据分区转换为 Arrow record batch。默认情况下,每个 batch 最大是 10000 行,由 spark.sql.execution.arrow.maxRecordsPerBatch 控制。如果列很多,应该适当降低这个值,以避免 JVM 内存压力过大。
8.3 时间戳语义要特别小心
Spark 内部以 UTC 存储时间戳,而 Pandas 使用的是 datetime64[ns]。官方说明:
- Spark → Pandas:会转换到 Spark session 时区,并显示为本地时间
- Spark → PyArrow Table:保持 UTC 和微秒精度
- Pandas / PyArrow → Spark:会转成 UTC 微秒,纳秒会被截断
所以,时间戳列在 Arrow 场景下虽然会自动转换,但跨系统和跨时区时一定要明确 session 时区配置。
8.4 toPandas() 和 toArrow() 仍然是 Driver 收集操作
即使开启 Arrow,这两个 API 本质上仍是"全量收集到 Driver",不是分布式持久化接口,因此不能把它们当成大数据量导出方案。
8.5 self_destruct 可以省内存
从 Spark 3.2 开始,可以开启 spark.sql.execution.arrow.pyspark.selfDestruct.enabled 来减少 toPandas() 或 toArrow() 转换时的内存占用,但这是实验特性,可能带来只读数组问题,甚至让部分 Pandas 操作报错。官方还提醒,这种模式通常会更慢,因为它是单线程的。
9. 什么时候优先考虑 Arrow
如果你的场景符合下面几类,Arrow 往往值得优先考虑:
- Spark DataFrame 与 Pandas DataFrame 频繁互转
- 需要使用 Pandas UDF 做向量化处理
- 需要
applyInPandas()、mapInPandas()这类 Pandas Function APIs - Python UDF 性能瓶颈明显,且可以尝试 Arrow Python UDF
如果只是普通 Spark SQL 查询,或者完全不涉及 Pandas / NumPy / Python UDF,那么 Arrow 带来的收益通常没那么明显。这个结论与官方对 Arrow 适用场景的整体定位是一致的。
10. 总结
Arrow 在 PySpark 里的价值,核心不是"多了一个配置项",而是它打通了 Spark、Pandas、PyArrow 之间更高效的数据交换路径。你可以把它理解成 PySpark 与 Python 数据生态之间的高速通道:DataFrame 转 Pandas 更快,Pandas UDF 更高效,applyInPandas() 和 mapInPandas() 这类 API 也能更自然地发挥作用。当然,它并不是万能开关,内存、类型支持、时间戳语义和 Driver 收集风险仍然要重点关注。只要把这些边界想清楚,Arrow 基本就是 PySpark Python 侧性能优化里绕不开的一环。