背景
Apache Datafusion Comet 是苹果公司开源的加速Spark运行的向量化项目。
本项目采用了 Spark插件化 + Protobuf + Arrow + DataFusion 架构形式
其中
- Spark插件是 利用 SparkPlugin 插件,其中分为 DriverPlugin 和 ExecutorPlugin ,这两个插件在driver和 Executor启动的时候就会调用
- Protobuf 是用来序列化 spark对应的表达式以及计划,用来传递给 native 引擎去执行,利用了 体积小,速度快的特性
- Arrow 是用来 spark 和 native 引擎进行高效的数据交换(native执行的结果或者spark执行的数据结果),主要在JNI中利用Arrow IPC 列式存储以及零拷贝等特点进行进程间数据交换
- DataFusion 主要是利用Rust native以及Arrow内存格式实现的向量化执行引擎,Spark中主要offload对应的算子到该引擎中去执行
本文基于 datafusion comet 截止到2026年1月13号的main分支的最新代码(对应的commit为 eef5f28a0727d9aef043fa2b87d6747ff68b827a)
主要分析 Rust Native 数据写入这块,在之前的Spark Datafusion Comet 向量化Rule--CometExecRule Shuffle分析已经介绍过了数据写入的大概流程,这里只是来分析Rust这部分的数据写入细节
注意:
shell
对于文件的写入和读取都offload到了Rust的native这一块,为什么会让Rust进行读写呢?
是因为Rust编译为机器码,无需GC,更加接近于操作系统,读写文件性能极高,而Java依赖GC,在高并发写操作时可能触发垃圾回收导致卡顿
写数据 writeSortedFileNative
这里涉及到的写数据, 主要有两个地方会涉及到,一个是CometBypassMergeSortShuffleWriter,另一个是CometUnsafeShuffleWriter,都会调用SpillWriter.doSpilling,从而调用Native.writeSortedFileNative:
@native def writeSortedFileNative(
addresses: Array[Long],
rowSizes: Array[Int],
datatypes: Array[Array[Byte]],
file: String,
preferDictionaryRatio: Double,
batchSize: Int,
checksumEnabled: Boolean,
checksumAlgo: Int,
currentChecksum: Long,
compressionCodec: String,
compressionLevel: Int,
tracingEnabled: Boolean): Array[Long]
其中 参数详细说明:
-
addresses
Spark UnsafeRow 的内存地址数组,每个元素是一个 UnsafeRow 在内存中的起始地址(指针)
-
rowSizes
每个 Spark UnsafeRow 的大小(字节数)数组,与
addresses配合使用,用于确定每个行的数据范围 -
datatypes
序列化的数据类型数组,每个元素是一个字段的数据类型序列化结果,描述每列的数据类型,用于在转换过程中正确解析和构建Spark数据类型
-
file
指定排序后的数据写入的目标文件路径
-
preferDictionaryRatio
字典编码偏好比例阈值(默认是10,通过
spark.comet.shuffle.preferDictionary.ratio配置),决定字符串列和二进制列是否使用字典编码,如果唯一值数量 * preferDictionaryRatio < 总行数,则使用字典编码并且只有
spark.comet.exec.shuffle.mode 为JVM时才有效 -
batchSize
批处理大小,用于控制行到列转换时的缓冲区大小,默认值为
8192(通过配置spark.comet.columnar.shuffle.batch.size设置), 方法内部会循环处理数据,每次处理batchSize行,每个批次会转换为一个 Arrow RecordBatch 并写入文件 -
checksumEnabled
启用校验和计算
-
checksumAlgo
校验和算法类型,默认是Adler32 校验和算法
-
currentChecksum
当前校验和值,用于增量校验和计算
-
compressionCodec
压缩编解码器名称,默认是LZ4,通过
spar.comet.exec.shuffle.compression.codec配置 -
compressionLevel
压缩级别,仅对zstd有效
-
tracingEnabled
是否启用跟踪/日志记录,默认是false,通过
spark.comet.tracing.enabled配置 -
返回值是Array,包含[写的字节大小,校验码] 两项
通过在jvm层级将这写参数传递给Rust对应的方法
pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative.pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative(
e: JNIEnv,
_class: JClass,
row_addresses: JLongArray,
row_sizes: JIntArray,
serialized_datatypes: JObjectArray,
file_path: JString,
prefer_dictionary_ratio: jdouble,
batch_size: jlong,
checksum_enabled: jboolean,
checksum_algo: jint,
current_checksum: jlong,
compression_codec: JString,
compression_level: jint,
tracing_enabled: jboolean,
) -> jlongArray {
try_unwrap_or_throw(&e, |mut env| unsafe {
with_trace(
"writeSortedFileNative",
tracing_enabled != JNI_FALSE,
|| {
let data_types = convert_datatype_arrays(&mut env, serialized_datatypes)?;let row_num = env.get_array_length(&row_addresses)? as usize; let row_addresses = env.get_array_elements(&row_addresses, ReleaseMode::NoCopyBack)?; let row_sizes = env.get_array_elements(&row_sizes, ReleaseMode::NoCopyBack)?; let row_addresses_ptr = row_addresses.as_ptr(); let row_sizes_ptr = row_sizes.as_ptr(); let output_path: String = env.get_string(&file_path).unwrap().into(); let checksum_enabled = checksum_enabled == 1; let current_checksum = if current_checksum == i64::MIN { // Initial checksum is not available. None } else { Some(current_checksum as u32) }; let compression_codec: String = env.get_string(&compression_codec).unwrap().into(); let compression_codec = match compression_codec.as_str() { "zstd" => CompressionCodec::Zstd(compression_level), "lz4" => CompressionCodec::Lz4Frame, "snappy" => CompressionCodec::Snappy, _ => CompressionCodec::Lz4Frame, }; let (written_bytes, checksum) = process_sorted_row_partition( row_num, batch_size as usize, row_addresses_ptr, row_sizes_ptr, &data_types, output_path, prefer_dictionary_ratio, checksum_enabled, checksum_algo, current_checksum, &compression_codec, )?; let checksum = if let Some(checksum) = checksum { checksum as i64 } else { // Spark checksums (CRC32 or Adler32) are both u32, so we use i64::MIN to indicate // checksum is not available. i64::MIN }; let long_array = env.new_long_array(2)?; env.set_long_array_region(&long_array, 0, &[written_bytes, checksum])?; Ok(long_array.into_raw()) }, ) })}
首先是数据类型,其中
JLongArray ↔ long[]
JIntArray ↔ int[]
JObjectArray ↔ Object[] / byte[][]
JString ↔ String
jdouble ↔ double
jlong ↔ long
jboolean ↔ boolean
返回值 jlongArray ↔ long[]
入参前两个参数是标准 JNI 签名:
e: JNIEnv:JNI 环境句柄,用于访问 Java 对象、数组、字符串等。
_class: JClass:Java 侧 Native 类对象(静态 native 方法的第二个参数),这里不需要所以命名为 _class。
-
首先是最外层的
try_unwrap_or_throw是将运行中的Rust方法抛出的异常统一转换为 JAVA 异常,并抛出来,且返回一个默认值
注意:异常处理机制:与 Java 不同,JNI 抛出异常后,C++ 代码会继续执行。因此,建议在 Throw 后紧跟 return 语句,确保将异常交给 Java 环境处理 -
闭包函数中的
with_trace默认是不开启的,如果开启的话,会记录对应操作的开始和结束的时间,并写到对应的文件中(comet-event-trace.json)
-
数据类型转换
convert_datatype_arrays因为java那端传递过来的是byte[][],最外层的byte数组记录一个序列化的field,所以会调用:
env.get_object_array_element inner_array.into()并调用
into方法转换为JByteArray类型,并最终反序列化为 protobuf中定义的DataType类型,最终会转换为 Arrow对应的类型,这里面的数据类型,决定了如何解释每个UnsafeRow里的字段
-
row_addresses 地址的转换获取
调用如下方法:
let row_addresses = env.get_array_elements(&row_addresses, ReleaseMode::NoCopyBack)?; let row_addresses_ptr = row_addresses.as_ptr();获取UnsafeRow数据的原生指针,row_sizes也是一样。
-
对于Jstring的获取
直接调用env.get_string()...unwrap().into()方法
-
process_sorted_row_partition写入对应的数据到指定的文件路径下首先这里的数据要么是同一个分区的unsafeRow数据,或者说是已经按照Partition排序好的数据
输入: row_num 行,已经按 partition id 排好序(上游 sortRowPartitionsNative 已经完成)。 每行由 (地址指针, 行大小) 描述。 schema(data_types)描述每列类型。 output_path 指定输出文件。 prefer_dictionary_ratio 控制字典编码偏好。 checksum_*、codec 控制校验和和压缩。-
首先是 初始化 ArrayBuilder 列构建器
let mut data_builders: Vec<Box<dyn ArrayBuilder>> = vec![]; schema.iter().try_for_each(|dt| { make_builders(dt, batch_size, prefer_dictionary_ratio) .map(|builder| data_builders.push(builder))?; Ok::<(), CometError>(()) })?;
这里会根据每个类型构造不同的
Arrow ArrayBuilder:比如说对于
DataType::Utf8,DataType::Binary会选择对应的字典构建器,比如说StringDictionaryBuilder,BinaryDictionaryBuilder-
新建输出文件流
let mut output_data = OpenOptions::new() .create(true) .append(true) .open(&output_path)?; -
按 batch_size(默认是8196) 批量读取 UnsafeRow,逐行解析字段,append 到各列 builder
macro_rules! append_column_to_builder { ($builder_type:ty, $accessor:expr) => {{ let element_builder = builder .as_any_mut() .downcast_mut::<$builder_type>() .expect(stringify!($builder_type)); let mut row = SparkUnsafeRow::new(schema); for i in row_start..row_end { let row_addr = unsafe { *row_addresses_ptr.add(i) }; let row_size = unsafe { *row_sizes_ptr.add(i) }; row.point_to(row_addr, row_size); let is_null = row.is_null_at(column_idx); if is_null { // The element value is null. // Append a null value to the element builder. element_builder.append_null(); } else { $accessor(element_builder, &row, column_idx); } } }}; }
以上是定义了一个宏,通过构造
SparkUnsafeRow还原 Spark对应的UnsafeRow,后续根据不同的类型调用不同的方法获取对应的值,再append到对应的Arrow ArrayBuilder中-
builder_to_array 对 字典编码builder进行处理
对于DataType::Utf8,DataType::Binary类型如果不满足字典编码的条件(唯一值数量 * preferDictionaryRatio < 总行数),则转换为string array类型 -
构造Arrow RecordBatch
let fields = arrays .iter() .enumerate() .map(|(i, array)| Field::new(format!("c{i}"), array.data_type().clone(), true)) .collect::<Vec<_>>(); let schema = Arc::new(Schema::new(fields)); let options = RecordBatchOptions::new().with_row_count(Option::from(row_count)); RecordBatch::try_new_with_options(schema, arrays, &options)主要就是利用Arrow自带的api构造
RecordBatch:RecordBatch::try_new_with_options(schema, arrays, &options) -
写入文件
ShuffleBlockWriter 按 Arrow IPC 格式 + 压缩写入到 output_pathlet block_writer = ShuffleBlockWriter::try_new(batch.schema().as_ref(), codec.clone())?; written += block_writer.write_batch(&batch, &mut cursor, &ipc_time)?; if let Some(checksum) = &mut current_checksum { checksum.update(&mut cursor)?; } output_data.write_all(&frozen)?;这里的
write_batch方法会把对应的RecordBatch 以Arrow IPC bytes的形式写到文件中,主要是调用Arrow StreamWriter对应的方法,如果有指定压缩格式的话,则会对该批次的数据进行压缩。
返回写入的字节数和校验和 -
返回对应的值给JVM端
let long_array: jni::objects::JPrimitiveArray<'_, i64> = env.new_long_array(2)?; env.set_long_array_region(&long_array, 0, &[written_bytes, checksum])?; Ok(long_array.into_raw())通过构造一个
JPrimitiveArray将对应的值返回给Java调用端。
-
其他
JNI 里其实有两套"映射"要分清楚:
Java 语言层类型 → JNI 类型(C/Rust 侧 typedef)
Java 语言层类型 → 签名字符串(method descriptor)
Java 类型 ↔ JNI 签名字符串
JNI 在 method/field ID 查找时要用"签名字符串",规则是:
- 基本类型签名
Java 类型 签名
boolean Z
byte B
char C
short S
int I
long J
float F
double D
void V - 引用类型签名
任意类 com.example.Foo:
签名:Lcom/example/Foo;
接口、java.lang.String 等同样规则:
java.lang.String → Ljava/lang/String;
java.lang.Object → Ljava/lang/Object; - 数组类型签名
一维数组:在元素类型前加一个 [:
int[] → [I
String[] → [Ljava/lang/String;
多维数组:加多个 [:
int[][] → [[I
Foo[][] → [[Lcom/example/Foo; - 方法签名
格式:(参数类型...)返回类型
例如:
int foo(long x, String s)→ 签名:(JLjava/lang/String;)Ivoid bar(int[] a, Foo[][] f)(Foo 在 com.example)→ ([I[[Lcom/example/Foo;)VNative.writeSortedFileNative(...) : long[](简化):
Java 侧:long[] writeSortedFileNative(long[] addresses, int[] sizes, byte[][] datatypes, String file, ...)
签名开头部分:([J[ I[[BLjava/lang/String;...) [J → 全串类似([J[ I[[BLjava/lang/String;DIZI...) [J(这里只是说明结构)