目录
[1. 数据传输效率低下](#1. 数据传输效率低下)
[2. Python处理开销大](#2. Python处理开销大)
[3. 模型加载重复](#3. 模型加载重复)
一、问题背景
在最近项目中,需要将深度学习算法部署到spark集群中运行,部署过程中遇到一个坑,在此记录分享一下。
我是用pandas_udf的方式进行ONNX推理时报错,报错原因:
-
ONNX 模型推理时报错,输入数据形状不匹配
-
批量预测时数据维度混乱
二、遇到的问题
在分布式机器学习推理场景中,PySpark需要将预处理后的特征数据传输到Python UDF中进行模型推理。原始特征数据通常是多维时间序列(如形状为(batch_size, sequence_length, feature_size)),在Spark和Python之间传输时遇到性能瓶颈。
1. 数据传输效率低下
-
嵌套数组序列化慢 :复杂嵌套结构(如array<array<double>>)在JVM和Python间传输时需要递归序列化
-
内存占用高:字符串序列化方案产生额外编码开销
-
网络带宽浪费:非紧凑的数据格式增加传输量
2. Python处理开销大
-
逐条类型检查:传统UDF对每行数据单独处理,Python调用开销显著
-
循环效率低:无法利用NumPy向量化操作
-
内存碎片化:非连续内存布局降低CPU缓存命中率
3. 模型加载重复
- 每个任务或批次可能重复加载ONNX模型,造成资源浪费
三、问题根因
ONNX 模型期望输入格式:*(batch_size, seq_length=6, input_size=4)*的 float32 数组
但问题出在 PySpark Arrow 传递复杂嵌套类型时的数据格式:
-
PySpark 中 feature_seq 列类型是 array<array<double>>
-
通过 pandas_udf 传递时,数据可能被序列化为字符串形式或包含 Row 对象
-
直接用 np.array() 转换时失败,因为元素不是纯数值
性能瓶颈点
数据传输瓶颈 → 序列化/反序列化 → Python处理瓶颈 → 模型推理瓶颈
四、解决方案
| 方案1: UDF中处理复杂类型 | 方案2: 字符串序列化 | 方案3: 展平为一维数组 | |
|---|---|---|---|
| Arrow传输效率 | ❌ 低 (嵌套数组序列化慢) | ✅ 高 (字符串原生支持) | ✅ 最高 (一维数组最优) |
| 内存占用 | 中等 | 较高 (字符串开销) | 最低 |
| Python处理开销 | 高 (逐条类型检查) | 中 (字符串解析) | 低 (直接reshape) |
| 网络传输 | 中等 | 较高 | 最低 |
评估下来,方案3:展平为一维数组传输是最合适的。
下面具体分析下方案三:
技术实现
# 关键步骤1:在Spark端展平
df.withColumn(
"feature_flat",
expr("""
flatten(
transform(
sequence(0, {seq_length-1}),
i -> array(Tdb_list[i], Te_list[i], C_list[i], F_list[i])
)
)
""")
)
# 关键步骤2:在UDF中还原
batch_array = np.array(valid_data, dtype=np.float32).reshape(-1, SEQUENCE_LENGTH, INPUT_SIZE)
举例
┌─────────────────────────────────────────────────────────────────────────────┐
│ 完整数据流转过程 │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ 原始格式 (在PySpark中): │
│ feature_seq = [[32.0, 25.0, 68.0, 600.0], ← 时间步 t-5 │
│ [31.5, 17.0, 75.0, 600.0], ← 时间步 t-4 │
│ [30.5, 16.0, 75.0, 600.0], ← 时间步 t-3 │
│ [29.5, 16.0, 75.0, 600.0], ← 时间步 t-2 │
│ [29.0, 15.0, 75.0, 600.0], ← 时间步 t-1 │
│ [28.5, 15.0, 75.0, 600.0]] ← 时间步 t (当前) │
│ │
│ ↓ 展平 (flatten) 只为传输 │
│ │
│ 展平格式 (Arrow传输): │
│ feature_flat = [32.0, 25.0, 68.0, 600.0, 31.5, 17.0, 75.0, 600.0, │
│ 30.5, 16.0, 75.0, 600.0, 29.5, 16.0, 75.0, 600.0, │
│ 29.0, 15.0, 75.0, 600.0, 28.5, 15.0, 75.0, 600.0] │
│ │
│ ↓ reshape 还原形状 │
│ │
│ 模型输入格式 (ONNX推理时): │
│ batch_array.shape = (batch_size, 6, 4) │
│ ┌─────────────────────────────────┐ │
│ │ [[32.0, 25.0, 68.0, 600.0], │ ← 时间步 t-5 │
│ │ [31.5, 17.0, 75.0, 600.0], │ ← 时间步 t-4 │
│ │ [30.5, 16.0, 75.0, 600.0], │ ← 时间步 t-3 │
│ │ [29.5, 16.0, 75.0, 600.0], │ ← 时间步 t-2 │
│ │ [29.0, 15.0, 75.0, 600.0], │ ← 时间步 t-1 │
│ │ [28.5, 15.0, 75.0, 600.0]] │ ← 时间步 t (当前) │
│ └─────────────────────────────────┘ │
│ │
│ ✅ 与原始格式完全一致,模型结果不受任何影响 │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ Arrow 数据传输性能 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ array<array<double>> ────► 慢 (需要递归序列化) │
│ │
│ string ────► 中等 (字符串编码开销) │
│ │
│ array<double> ────► 快 (Arrow原生支持,零拷贝) │
│ │
└─────────────────────────────────────────────────────────────────┘
原理
在了解了技术实现之后,来看下技术原理,即"为什么"这么做有效
执行位置:计算向数据移动
-
Driver端: 仅负责任务解析与调度,不执行UDF代码,避免了单点瓶颈。
-
Executor端: 真正的算力所在。每个Executor节点持有独立的数据分区,predict_udf直接在节点内存中执行,遵循"数据本地性"原则,避免了数据在网络中来回传输。
┌─────────────────────────────────────────────────────────────────────────────────────┐
│ Spark 分布式执行架构 │
├─────────────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────────────────┐ │
│ │ Driver (主节点) │ │
│ │ │ │
│ │ • 解析代码,生成执行计划 │ │
│ │ • 调度任务到 Executor │ │
│ │ • 收集最终结果 │ │
│ │ • ❌ 不执行 UDF 代码 │ │
│ └─────────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌──────────────────┼──────────────────┐ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌─────────────────────┐ ┌─────────────────────┐ ┌─────────────────────┐ │
│ │ Executor 1 │ │ Executor 2 │ │ Executor N │ │
│ │ (工作节点) │ │ (工作节点) │ │ (工作节点) │ │
│ │ │ │ │ │ │ │
│ │ ┌───────────────┐ │ │ ┌───────────────┐ │ │ ┌───────────────┐ │ │
│ │ │ Partition 1 │ │ │ │ Partition 2 │ │ │ │ Partition N │ │ │
│ │ │ (数据分区) │ │ │ │ (数据分区) │ │ │ │ (数据分区) │ │ │
│ │ │ 10000 rows │ │ │ │ 10000 rows │ │ │ │ 10000 rows │ │ │
│ │ └───────┬───────┘ │ │ └───────┬───────┘ │ │ └───────┬───────┘ │ │
│ │ │ │ │ │ │ │ │ │ │
│ │ ▼ │ │ ▼ │ │ ▼ │ │
│ │ ┌───────────────┐ │ │ ┌───────────────┐ │ │ ┌───────────────┐ │ │
│ │ │ predict_udf │ │ │ │ predict_udf │ │ │ │ predict_udf │ │ │
│ │ │ ✅ 在此执行 │ │ │ │ ✅ 在此执行 │ │ │ │ ✅ 在此执行 │ │ │
│ │ │ │ │ │ │ │ │ │ │ │ │ │
│ │ │ ONNX Session │ │ │ │ ONNX Session │ │ │ │ ONNX Session │ │ │
│ │ │ (单例模式) │ │ │ │ (单例模式) │ │ │ │ (单例模式) │ │ │
│ │ └───────────────┘ │ │ └───────────────┘ │ │ └───────────────┘ │ │
│ │ │ │ │ │ │ │
│ └─────────────────────┘ └─────────────────────┘ └─────────────────────┘ │
│ │
│ 并行执行:所有 Executor 同时处理各自的分区,互不干扰 │
│ │
└─────────────────────────────────────────────────────────────────────────────────────┘
四大高效引擎
**Arrow零拷贝:**JVM数据直接映射为Python内存对象,省去序列化步骤。
┌─────────────────────────────────────────────────────────────────┐
│ 传统序列化 vs Arrow │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 传统方式 (pickle): │
│ ┌─────────┐ 序列化 ┌─────────┐ 反序列化 ┌───────┐ │
│ │ JVM数据 │ ──────────► │ 网络传输 │ ──────────► │Python │ │
│ │ (Row) │ 慢! │ 字节流 │ 慢! │ 对象 │ │
│ └─────────┘ └─────────┘ └───────┘ │
│ │
│ Arrow 方式: │
│ ┌─────────┐ ┌─────────┐ ┌───────┐ │
│ │ JVM数据 │ 直接共享 │ 内存 │ 零拷贝 │ Pandas│ │
│ │ (列式) │ ──────────► │ 映射 │ ──────────► │Series │ │
│ └─────────┘ 快! └─────────┘ 快! └───────┘ │
│ │
│ 性能提升: 10-50倍 │
│ │
└─────────────────────────────────────────────────────────────────┘
向量化批量处理:pandas_udf一次处理成千上万行数据
# ❌ 传统 UDF - 逐行处理 (慢)
@udf(returnType=DoubleType())
def slow_udf(feature):
# 每行都要:Python调用开销 + 模型推理
return model.predict(feature) # 10000次调用 = 10000次推理
# ✅ pandas_udf - 批量处理 (快)
@pandas_udf(returnType=DoubleType())
def fast_udf(features: pd.Series):
# 一次处理10000行
batch = np.array(features.tolist()).reshape(-1, 6, 4)
return model.predict(batch) # 1次调用 = 1次批量推理
**单例模式:**ONNX模型在每个Executor进程启动时仅加载一次,后续推理复用该实例,消除IO开销。
**分布式并行:**所有Executor同时处理各自分区,线性扩展计算能力。
关键优势:
-
Arrow 零拷贝传输:一维 array<double> 是 Arrow 原生类型,可直接映射到 pandas/numpy
-
向量化处理:避免了 Python 循环,充分利用 NumPy 向量化操作
-
内存连续:连续内存布局,CPU 缓存命中率高
-
网络带宽低:数据紧凑,传输量最小
性能提升预估: 相比方案1可提升**30-50%**的整体吞吐量。
五、适用场景
-
✅ 分布式机器学习推理
-
✅ 大规模特征数据传输
-
✅ 时间序列或图像等多维数据
-
✅ Spark + Python混合架构
六、扩展优化
# 进一步优化:使用PySpark的向量化UDF
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import ArrayType, DoubleType
# 使用更高效的Arrow格式
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "10000")