背景
Apache Datafusion Comet 是苹果公司开源的加速Spark运行的向量化项目。
本项目采用了 Spark插件化 + Protobuf + Arrow + DataFusion 架构形式
其中
- Spark插件是 利用 SparkPlugin 插件,其中分为 DriverPlugin 和 ExecutorPlugin ,这两个插件在driver和 Executor启动的时候就会调用
- Protobuf 是用来序列化 spark对应的表达式以及计划,用来传递给 native 引擎去执行,利用了 体积小,速度快的特性
- Arrow 是用来 spark 和 native 引擎进行高效的数据交换(native执行的结果或者spark执行的数据结果),主要在JNI中利用Arrow IPC 列式存储以及零拷贝等特点进行进程间数据交换
- DataFusion 主要是利用Rust native以及Arrow内存格式实现的向量化执行引擎,Spark中主要offload对应的算子到该引擎中去执行
本文基于 datafusion comet 截止到2026年1月13号的main分支的最新代码(对应的commit为 eef5f28a0727d9aef043fa2b87d6747ff68b827a)
主要分析 Rust Native 数据读取以及读取时候的零拷贝
读数据 decodeShuffleBlock
读取的主要在于 CometBlockStoreShuffleReader 这块,具体的流程参考Spark Datafusion Comet 向量化Rule--CometExecRule Shuffle分析,Spark Shuffle数据的读取都在都是由这个类来处理, 具体和JNI(rust)进行交互为部分为:
Java 侧:
val batch = nativeUtil.getNextBatch(
fieldCount,
(arrayAddrs, schemaAddrs) => {
native.decodeShuffleBlock(
dataBuf,
bytesToRead.toInt,
arrayAddrs,
schemaAddrs,
tracingEnabled)
})
这里的dataBuf(DirectByteBuffer类型)数据是对应的BlockId的字节流数据:
@native def decodeShuffleBlock(
shuffleBlock: ByteBuffer,
length: Int,
arrayAddrs: Array[Long],
schemaAddrs: Array[Long],
tracingEnabled: Boolean): Long
参数说明:
- shuffleBlock
包含一个已经 编码并压缩好的 native shuffle block, - length
表示 shuffleBlock 中 有效数据的字节数 - arrayAddrs
存放输出 Arrow 列数据的 FFI 结构体地址 - schemaAddrs
存放输出列 schema(Arrow DataType)的 FFI 结构体地址 - tracingEnabled
用于控制是否对该 decode 过程进行 trace/log, 默认为False(通过spark.comet.tracing.enabled配置)
Rust侧:
pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock(
e: JNIEnv,
_class: JClass,
byte_buffer: JByteBuffer,
length: jint,
array_addrs: JLongArray,
schema_addrs: JLongArray,
tracing_enabled: jboolean,
) -> jlong {
try_unwrap_or_throw(&e, |mut env| {
with_trace("decodeShuffleBlock", tracing_enabled != JNI_FALSE, || {
let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?;
let length = length as usize;
let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) };
let batch = read_ipc_compressed(slice)?;
prepare_output(&mut env, array_addrs, schema_addrs, batch, false)
})
})
}
参数中需要注意的一点是 java侧的ByteBuffer 对应Rust侧JByteBuffer.
try_unwrap_or_throw 和 with_trace已经在 Spark Datafusion Comet 向量化Rust Native-- 数据写入 中已经说明了,不再赘述
-
获取ByteBuffer对应的bytes切片
let raw_pointer: *mut u8 = env.get_direct_buffer_address(&byte_buffer)?;//获取直接内存的起始位置 let length = length as usize; let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) };// 获取切片引用对于Rust JNI的 JNIEnv 类中的方法可以参考JNIEnv
-
解压并解析IPC数据流为
RecordBatch
主要是通过read_ipc_compressed方式实现:pub fn read_ipc_compressed(bytes: &[u8]) -> Result<RecordBatch> { match &bytes[0..4] { b"SNAP" => { let decoder = snap::read::FrameDecoder::new(&bytes[4..]); let mut reader = unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; reader.next().unwrap().map_err(|e| e.into()) } b"LZ4_" => { let decoder = lz4_flex::frame::FrameDecoder::new(&bytes[4..]); let mut reader: StreamReader<lz4_flex::frame::FrameDecoder<&[u8]>> = unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; reader.next().unwrap().map_err(|e| e.into()) } b"ZSTD" => { let decoder = zstd::Decoder::new(&bytes[4..])?; let mut reader = unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; reader.next().unwrap().map_err(|e| e.into()) } b"NONE" => { let mut reader = unsafe { StreamReader::try_new(&bytes[4..], None)?.with_skip_validation(true) }; reader.next().unwrap().map_err(|e| e.into()) } other => Err(DataFusionError::Execution(format!( "Failed to decode batch: invalid compression codec: {other:?}" ))), } }-
根据压缩格式构造不同的FrameDecoder进行解析
bytes[0..4]获取该codec header,之后构造 StreamReader ,且调用 reader.next().unwrap() 读取第一个 RecordBatch -
把 RecordBatch 输出到 JVM 的地址数组中
主要是调用prepare_output方法(部分解释):
通过如下方法let array_addrs = unsafe { env.get_array_elements(&array_addrs, ReleaseMode::NoCopyBack)? }; let array_addrs: &[i64] = &*array_addrs;通过AutoElements的
deref方法获取到对应的 切片引用&[i64]while i < results.len() { let array_ref = results.get(i).ok_or(CometError::IndexOutOfBounds(i))?; if array_ref.offset() != 0 { // https://github.com/apache/datafusion-comet/issues/2051 // Bug with non-zero offset FFI, so take to a new array which will have an offset of 0. // We expect this to be a cold code path, hence the check_bounds: true and assert_eq. let indices = UInt32Array::from((0..num_rows as u32).collect::<Vec<u32>>()); let new_array = take( array_ref, &indices, Some(TakeOptions { check_bounds: true }), )?; assert_eq!(new_array.offset(), 0); new_array .to_data() .move_to_spark(array_addrs[i], schema_addrs[i])?; } else { array_ref .to_data() .move_to_spark(array_addrs[i], schema_addrs[i])?; } i += 1; }对于每一列的值(是个数组)调用
SparkArrowConvert的move_to_spark方法,从而把ArrayData数据array_addrs赋值到array_addrs(Java中传过来的)中去:fn move_to_spark(&self, array: i64, schema: i64) -> Result<(), ExecutionError> { let array_ptr = array as *mut FFI_ArrowArray; let schema_ptr = schema as *mut FFI_ArrowSchema; let array_align = std::mem::align_of::<FFI_ArrowArray>(); let schema_align = std::mem::align_of::<FFI_ArrowSchema>(); // Check if the pointer alignment is correct. if array_ptr.align_offset(array_align) != 0 || schema_ptr.align_offset(schema_align) != 0 { unsafe { std::ptr::write_unaligned(array_ptr, FFI_ArrowArray::new(self)); std::ptr::write_unaligned(schema_ptr, FFI_ArrowSchema::try_from(self.data_type())?); } } else { // SAFETY: `array_ptr` and `schema_ptr` are aligned correctly. unsafe { std::ptr::write(array_ptr, FFI_ArrowArray::new(self)); std::ptr::write(schema_ptr, FFI_ArrowSchema::try_from(self.data_type())?); } } Ok(()) }其中有几个方法:
- std::mem::align_of 是 Rust 标准库中的一个函数,用于在编译时返回类型 T 的 ABI(应用程序二进制接口)所要求的最小内存对齐方式(以字节为单位)。该数值表示该类型值在内存中存放时,其地址必须是该对齐值的倍数。它常用于确定结构体字段的布局或进行底层内存操作
- align_offset 是指针(裸指针 *const T 或 *mut T)的一个方法,用于计算从当前指针位置开始,满足特定内存对齐要求(align)所需的最小偏移量(以元素数量为单位)。它常用于需要对齐内存进行高效 SIMD 操作或优化读取的底层开发场景
- std::ptr::write_unaligned 是一个 unsafe 函数,用于将一个值(src)写入到可能未对齐的内存地址(dst),而不会读取或丢弃旧值。它允许在不需要强制内存对齐的情况下安全地写入数据
- std::ptr::write(dst, src) 是一个 unsafe 函数,用于将值 src 覆盖写入到裸指针 dst 指向的内存地址。它直接按位复制数据,不会调用旧值的 Drop 析构函数,也不会复制 src 的所有权,避免了双重释放问题·
-
Arrow的零拷贝
-
Java端 传递过来的包含
Recordbatch字节流的ByteBuffer,在Rust端直接被引用了(from_raw_parts),而不是重新拷贝一份数据// jvm 侧 val batch = nativeUtil.getNextBatch( fieldCount, (arrayAddrs, schemaAddrs) => { native.decodeShuffleBlock( dataBuf, bytesToRead.toInt, arrayAddrs, schemaAddrs, tracingEnabled) }) // Rust侧 let raw_pointer: *mut u8 = env.get_direct_buffer_address(&byte_buffer)?; let length = length as usize; let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; let batch = read_ipc_compressed(slice)?; -
Java 通过 JNI 调用 Arrow C 数据接口(ArrowArray 和 ArrowSchema 结构体),将内存指针传递给Rust端,Rust端直接使用(FFI_ArrowArray 和 FFI_ArrowSchema)兼容 C 数据接口把数据复制给 Java 端,Java端无需序列化和反序列化操作就可以直接使用该数据
// jvm 侧 def allocateArrowStructs(numCols: Int): (Array[ArrowArray], Array[ArrowSchema]) = { val arrays = new Array[ArrowArray](numCols) val schemas = new Array[ArrowSchema](numCols) (0 until numCols).foreach { index => val arrowSchema = ArrowSchema.allocateNew(allocator) val arrowArray = ArrowArray.allocateNew(allocator) arrays(index) = arrowArray schemas(index) = arrowSchema } (arrays, schemas) } // Rust 侧 if array_ptr.align_offset(array_align) != 0 || schema_ptr.align_offset(schema_align) != 0 { unsafe { std::ptr::write_unaligned(array_ptr, FFI_ArrowArray::new(self)); std::ptr::write_unaligned(schema_ptr, FFI_ArrowSchema::try_from(self.data_type())?); } } else { // SAFETY: `array_ptr` and `schema_ptr` are aligned correctly. unsafe { std::ptr::write(array_ptr, FFI_ArrowArray::new(self)); std::ptr::write(schema_ptr, FFI_ArrowSchema::try_from(self.data_type())?); } }