Spark Datafusion Comet 向量化Rust Native--读数据

背景

Apache Datafusion Comet 是苹果公司开源的加速Spark运行的向量化项目。

本项目采用了 Spark插件化 + Protobuf + Arrow + DataFusion 架构形式

其中

  • Spark插件是 利用 SparkPlugin 插件,其中分为 DriverPlugin 和 ExecutorPlugin ,这两个插件在driver和 Executor启动的时候就会调用
  • Protobuf 是用来序列化 spark对应的表达式以及计划,用来传递给 native 引擎去执行,利用了 体积小,速度快的特性
  • Arrow 是用来 spark 和 native 引擎进行高效的数据交换(native执行的结果或者spark执行的数据结果),主要在JNI中利用Arrow IPC 列式存储以及零拷贝等特点进行进程间数据交换
  • DataFusion 主要是利用Rust native以及Arrow内存格式实现的向量化执行引擎,Spark中主要offload对应的算子到该引擎中去执行

本文基于 datafusion comet 截止到2026年1月13号的main分支的最新代码(对应的commit为 eef5f28a0727d9aef043fa2b87d6747ff68b827a)
主要分析 Rust Native 数据读取以及读取时候的零拷贝

读数据 decodeShuffleBlock

读取的主要在于 CometBlockStoreShuffleReader 这块,具体的流程参考Spark Datafusion Comet 向量化Rule--CometExecRule Shuffle分析,Spark Shuffle数据的读取都在都是由这个类来处理, 具体和JNI(rust)进行交互为部分为:

Java 侧:

复制代码
   val batch = nativeUtil.getNextBatch(
      fieldCount,
      (arrayAddrs, schemaAddrs) => {
        native.decodeShuffleBlock(
          dataBuf,
          bytesToRead.toInt,
          arrayAddrs,
          schemaAddrs,
          tracingEnabled)
      })

这里的dataBuf(DirectByteBuffer类型)数据是对应的BlockId的字节流数据:

复制代码
  @native def decodeShuffleBlock(
      shuffleBlock: ByteBuffer,
      length: Int,
      arrayAddrs: Array[Long],
      schemaAddrs: Array[Long],
      tracingEnabled: Boolean): Long

参数说明:

  • shuffleBlock
    包含一个已经 编码并压缩好的 native shuffle block,
  • length
    表示 shuffleBlock 中 有效数据的字节数
  • arrayAddrs
    存放输出 Arrow 列数据的 FFI 结构体地址
  • schemaAddrs
    存放输出列 schema(Arrow DataType)的 FFI 结构体地址
  • tracingEnabled
    用于控制是否对该 decode 过程进行 trace/log, 默认为False(通过spark.comet.tracing.enabled配置)

Rust侧:

复制代码
pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock(
    e: JNIEnv,
    _class: JClass,
    byte_buffer: JByteBuffer,
    length: jint,
    array_addrs: JLongArray,
    schema_addrs: JLongArray,
    tracing_enabled: jboolean,
) -> jlong {
    try_unwrap_or_throw(&e, |mut env| {
        with_trace("decodeShuffleBlock", tracing_enabled != JNI_FALSE, || {
            let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?;
            let length = length as usize;
            let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) };
            let batch = read_ipc_compressed(slice)?;
            prepare_output(&mut env, array_addrs, schema_addrs, batch, false)
        })
    })
}

参数中需要注意的一点是 java侧的ByteBuffer 对应Rust侧JByteBuffer.
try_unwrap_or_throw with_trace已经在 Spark Datafusion Comet 向量化Rust Native-- 数据写入 中已经说明了,不再赘述

  • 获取ByteBuffer对应的bytes切片

    复制代码
    let raw_pointer: *mut u8 = env.get_direct_buffer_address(&byte_buffer)?;//获取直接内存的起始位置
    let length = length as usize;
    let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) };// 获取切片引用

    对于Rust JNI的 JNIEnv 类中的方法可以参考JNIEnv

  • 解压并解析IPC数据流为RecordBatch
    主要是通过read_ipc_compressed方式实现:

    复制代码
     pub fn read_ipc_compressed(bytes: &[u8]) -> Result<RecordBatch> {
      match &bytes[0..4] {
          b"SNAP" => {
              let decoder = snap::read::FrameDecoder::new(&bytes[4..]);
              let mut reader =
                  unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) };
              reader.next().unwrap().map_err(|e| e.into())
          }
          b"LZ4_" => {
              let decoder = lz4_flex::frame::FrameDecoder::new(&bytes[4..]);
              let mut reader: StreamReader<lz4_flex::frame::FrameDecoder<&[u8]>> =
                  unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) };
              reader.next().unwrap().map_err(|e| e.into())
          }
          b"ZSTD" => {
              let decoder = zstd::Decoder::new(&bytes[4..])?;
              let mut reader =
                  unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) };
              reader.next().unwrap().map_err(|e| e.into())
          }
          b"NONE" => {
              let mut reader =
                  unsafe { StreamReader::try_new(&bytes[4..], None)?.with_skip_validation(true) };
              reader.next().unwrap().map_err(|e| e.into())
          }
          other => Err(DataFusionError::Execution(format!(
              "Failed to decode batch: invalid compression codec: {other:?}"
          ))),
      }
     }
    • 根据压缩格式构造不同的FrameDecoder进行解析
      bytes[0..4]获取该codec header,之后构造 StreamReader ,且调用 reader.next().unwrap() 读取第一个 RecordBatch

    • 把 RecordBatch 输出到 JVM 的地址数组中
      主要是调用prepare_output方法(部分解释):
      通过如下方法

      复制代码
       let array_addrs = unsafe { env.get_array_elements(&array_addrs, ReleaseMode::NoCopyBack)? };
       let array_addrs: &[i64] = &*array_addrs;

      通过AutoElements的 deref 方法获取到对应的 切片引用 &[i64]

      复制代码
       while i < results.len() {
              let array_ref = results.get(i).ok_or(CometError::IndexOutOfBounds(i))?;
      
              if array_ref.offset() != 0 {
                  // https://github.com/apache/datafusion-comet/issues/2051
                  // Bug with non-zero offset FFI, so take to a new array which will have an offset of 0.
                  // We expect this to be a cold code path, hence the check_bounds: true and assert_eq.
                  let indices = UInt32Array::from((0..num_rows as u32).collect::<Vec<u32>>());
                  let new_array = take(
                      array_ref,
                      &indices,
                      Some(TakeOptions { check_bounds: true }),
                  )?;
      
                  assert_eq!(new_array.offset(), 0);
      
                  new_array
                      .to_data()
                      .move_to_spark(array_addrs[i], schema_addrs[i])?;
              } else {
                  array_ref
                      .to_data()
                      .move_to_spark(array_addrs[i], schema_addrs[i])?;
              }
              i += 1;
          }

      对于每一列的值(是个数组)调用SparkArrowConvertmove_to_spark方法,从而把ArrayData数据array_addrs赋值到array_addrs(Java中传过来的)中去:

      复制代码
       fn move_to_spark(&self, array: i64, schema: i64) -> Result<(), ExecutionError> {
          let array_ptr = array as *mut FFI_ArrowArray;
          let schema_ptr = schema as *mut FFI_ArrowSchema;
      
          let array_align = std::mem::align_of::<FFI_ArrowArray>();
          let schema_align = std::mem::align_of::<FFI_ArrowSchema>();
      
          // Check if the pointer alignment is correct.
          if array_ptr.align_offset(array_align) != 0 || schema_ptr.align_offset(schema_align) != 0 {
              unsafe {
                  std::ptr::write_unaligned(array_ptr, FFI_ArrowArray::new(self));
                  std::ptr::write_unaligned(schema_ptr, FFI_ArrowSchema::try_from(self.data_type())?);
              }
          } else {
              // SAFETY: `array_ptr` and `schema_ptr` are aligned correctly.
              unsafe {
                  std::ptr::write(array_ptr, FFI_ArrowArray::new(self));
                  std::ptr::write(schema_ptr, FFI_ArrowSchema::try_from(self.data_type())?);
              }
          }
      
          Ok(())
       }

      其中有几个方法:

      • std::mem::align_of 是 Rust 标准库中的一个函数,用于在编译时返回类型 T 的 ABI(应用程序二进制接口)所要求的最小内存对齐方式(以字节为单位)。该数值表示该类型值在内存中存放时,其地址必须是该对齐值的倍数。它常用于确定结构体字段的布局或进行底层内存操作
      • align_offset 是指针(裸指针 *const T 或 *mut T)的一个方法,用于计算从当前指针位置开始,满足特定内存对齐要求(align)所需的最小偏移量(以元素数量为单位)。它常用于需要对齐内存进行高效 SIMD 操作或优化读取的底层开发场景
      • std::ptr::write_unaligned 是一个 unsafe 函数,用于将一个值(src)写入到可能未对齐的内存地址(dst),而不会读取或丢弃旧值。它允许在不需要强制内存对齐的情况下安全地写入数据
      • std::ptr::write(dst, src) 是一个 unsafe 函数,用于将值 src 覆盖写入到裸指针 dst 指向的内存地址。它直接按位复制数据,不会调用旧值的 Drop 析构函数,也不会复制 src 的所有权,避免了双重释放问题·

Arrow的零拷贝

  1. Java端 传递过来的包含 Recordbatch字节流的ByteBuffer,在Rust端直接被引用了(from_raw_parts),而不是重新拷贝一份数据

    复制代码
      // jvm 侧
      val batch = nativeUtil.getNextBatch(
       fieldCount,
       (arrayAddrs, schemaAddrs) => {
         native.decodeShuffleBlock(
           dataBuf,
           bytesToRead.toInt,
           arrayAddrs,
           schemaAddrs,
           tracingEnabled)
       })
      // Rust侧
      let raw_pointer: *mut u8 = env.get_direct_buffer_address(&byte_buffer)?;
      let length = length as usize;
      let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) };
      let batch = read_ipc_compressed(slice)?;
  2. Java 通过 JNI 调用 Arrow C 数据接口(ArrowArray 和 ArrowSchema 结构体),将内存指针传递给Rust端,Rust端直接使用(FFI_ArrowArray 和 FFI_ArrowSchema)兼容 C 数据接口把数据复制给 Java 端,Java端无需序列化和反序列化操作就可以直接使用该数据

    复制代码
     // jvm 侧
      def allocateArrowStructs(numCols: Int): (Array[ArrowArray], Array[ArrowSchema]) = {
      val arrays = new Array[ArrowArray](numCols)
      val schemas = new Array[ArrowSchema](numCols)
    
      (0 until numCols).foreach { index =>
        val arrowSchema = ArrowSchema.allocateNew(allocator)
        val arrowArray = ArrowArray.allocateNew(allocator)
        arrays(index) = arrowArray
        schemas(index) = arrowSchema
      }
    
      (arrays, schemas)
    }
    // Rust 侧
    if array_ptr.align_offset(array_align) != 0 || schema_ptr.align_offset(schema_align) != 0 {
      unsafe {
          std::ptr::write_unaligned(array_ptr, FFI_ArrowArray::new(self));
          std::ptr::write_unaligned(schema_ptr, FFI_ArrowSchema::try_from(self.data_type())?);
      }
     } else {
      // SAFETY: `array_ptr` and `schema_ptr` are aligned correctly.
      unsafe {
          std::ptr::write(array_ptr, FFI_ArrowArray::new(self));
          std::ptr::write(schema_ptr, FFI_ArrowSchema::try_from(self.data_type())?);
      }
    }
复制代码
相关推荐
gis分享者7 小时前
学习threejs,打造原生3D高斯溅落实时渲染器
spark·threejs·ply·高斯·splat·溅落·实时渲染器
看起来不那么蠢的昵称8 小时前
Apache Spark 开发与调优实战手册 (Java / Spark 2.x)
java·spark
硬汉嵌入式8 小时前
基于Rust构建的单片机Ariel RTOS,支持Cortex-M、RISC-V 和 Xtensa
单片机·rust·risc-v
看起来不那么蠢的昵称8 小时前
高性能 Spark UDF 开发手册
java·大数据·spark
亚林瓜子1 天前
AWS Glue任务中使用一个dynamic frame数据过滤另外一个dynamic frame数据
java·python·sql·spark·aws·df·py
低调滴开发1 天前
Tauri开发桌面端服务,配置指定防火墙端口
rust·tauri·桌面端·windows防火墙规则
咚为1 天前
Rust Cell使用与原理
开发语言·网络·rust
鹿衔`1 天前
Apache Spark 任务资源配置与优先级指南
python·spark
咸甜适中2 天前
rust的docx-rs库,自定义docx模版批量生成docx文档(逐行注释)
开发语言·rust·docx·docx-rs