Rust Zero Overhead Abstraction
相比C/C++,Rust在保留内存可见性和底层操作能力的基础上,提供了更友好的语法和类型系统支持。为保障内存安全,Rust也给开发者做了重重限制。在此我罗列了一些细节,阐明如何在所有权与生命周期的限制下,充分利用Rust的零成本抽象能力,写出最易读且高性能的代码。本文按以下条目组织:
性能评估方式
以下所有的性能评估均基于对汇编代码的分析完成。我曾试图通过benchmark来获取对比,结果发现不同写法的差别可能仅是个位数条指令,benchmark自身的波动就足以掩盖这种差别了。因此最终选择了直接分析对比汇编代码的方式。 通过如下操作安装和调用 cargo-show-asm
库就可以比较简单完成对汇编代码的分析。
arduino
cargo install cargo-show-asm
cargo asm --rust --bin ntt ntt::main
一个简单的输出样例如下。 --rust
参数会将源码混合在汇编代码中输出。注意Rust在编译过程中做了大量的inline操作,因此你会看见大量标准库或第三方库中的代码。同时由于Rust在编译时会将编译期能确定的求值直接预计算好,嵌入到编译结果中,因此建议所有的输入都通过随机值产生,避免一些比较简单的计算逻辑被直接省略。
同时对于关注的函数,可以添加 #[inline(never)]
的标记,避免其被内联,方便分析。
asm
// /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ptr/mod.rs : 1178
crate::intrinsics::read_via_copy(src)
mov esi, dword ptr [rsp + 4*rax + 152]
// /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/index_range.rs : 61
self.start = unsafe { unchecked_add(value, 1) };
mov qword ptr [rsp + 968], rdx
// /home/winkar/rntt/src/main.rs : 337
!(fu >= lim || fu <= -lim || gu >= lim || gu <= -lim)
add esi, -50
xor ecx, ecx
cmp esi, -99
整型运算
Rust中不支持的整型操作
相比C/C++,Rust的类型系统更为严格。
-
无隐式类型转换 :
1u32 + 2u64
这种操作在C/C++中会触发隐式的向上提升类型转换,但在Rust中需要显式对前者通过as u64
做类型转换。这个问题很容易解决,在所有需要的地方加上
as [target_type]
就可以。但在转写C代码时要注意一个细节,有时会出现如下所示的代码:rustlet x: u64 = ...; ... let y = x as u32 as u64;
这跟直接的
y = x as u64
是不同的,会将x在u32的数据范围内截断。因此不能简化成一次as。 -
禁止无符号整型取负 :
let x = 3i32; let y = -x;
这样的操作在C/C++中是合法的,但在Rust中是非法的。无符号整型取负的实质,是求它的补码。因此我编写了如下的代码,能实现任意无符号整型取负的操作:
rust#[inline] pub fn neg<T>(x: T) -> T where T: Not<Output = T> + From<u8> + Add<T, Output = T>, { (!x) + T::from(1u8) }
也可以通过另一个方法:
0.wrapping_sub(x)
。它跟上面的代码是等价的,但使用的0需要根据x的类型做变换,比如x是u32时,就需要写成0u32.wrapping_sub(x)
。wrapping_sub的本质是模当前数据类型上限(对u32来说就是2^32)的减法运算。虽然说是取模,但实际上在汇编层面就是一条普通的忽略溢出的指令。上述两种方法开优化编译成汇编后都是一样的,在u32类型上,最终编译结果如下:
asmshr rax, 32 neg eax
-
默认进行溢出检查:Rust在dev模式下,会对所有的整型计算进行溢出检查,并在溢出时panic。
在Falcon的计算代码里,很多时候是将整型计算的溢出当一个feature去用的。绕过这种检查有三个方法:
-
使用wrapping计算方法
上面的wrapping_sub就是其中一个例子,它不会造成溢出。类似的还有wrapping_add, wrapping_sub
-
使用std::num::Wrapping类型
将所有用到i32的地方封装成Wrapping可以达到相同的效果。这个操作是零运行时成本的,但在源代码里会非常丑陋。
-
关闭dev模式下的溢出检查
实际上我现在就是这么做的。
ini[profile.dev] overflow-checks = false
-
slice、数组与Vec的选择
Rust中的slice &[T]
是以胖指针的形式实现的。其内部结构类似于下面的C结构体:
rust
struct slice {
void* ptr;
size_t len;
}
如果将其作为参数传递,它实际上会复制这样一个结构体。以x86_64为例,默认调用约定使用寄存器传递前两个参数,slice的长度和起始指针会被分别存放到两个寄存器中,作为两个参数传递。因此很容易发现,slice类型的大小是指针的两倍。当然,这个4~8个字节的区别对于非热点函数调用来说影响很小,几乎可以忽略。
但slice在使用过程中有一个问题需要注意:slice有强制的运行期范围检查(release下依然存在),在超出slice范围时会触发panic。因此在使用时,如果实际范围大小是确定的,几乎总是建议转换成确定大小的数组使用。这样编译器可以在编译期完成检查。
rust
fn testSlice(x: &[i32]) {
let mut s = 0;
for i in 0..200 {
s += x[i];
}
println!("{}", s);
}
rust
fn testSlice(x: &[i32]) {
let mut s = 0;
let x:[i32;100] = x.try_into().unwrap();
for i in 0..50 {
s += x[i];
}
println!("{}", s);
}
像上面这两段代码,Rust会为左边每次对slice的下标访问x[i]插入一个范围判断,对右边数组则不会。所有的判断在unwrap时已经完成了。
如果不确定大小,无法在源代码中做一个确定的转换,也可以通过assert
的方法来触发编译器的优化,达到同样的效果。几乎总是建议通过assert或者try_into来帮助编译器优化后续代码。
rust
fn testSlice(x: &[i32]) {
let mut s = 0;
assert!(x.len() > 200);
// let x = x[..200]也可以起到相同的效果
for i in 0..200 {
s += x[i];
}
println!("{}", s);
}
此外,还有个有意思的case:如果做了转换,unwrap正常,但访问的范围过大,会怎么样呢?
rust
fn testSlice(x: &[i32]) {
let mut s = 0;
let x:[i32;100] = x.try_into().unwrap();
for i in 0..200 {
s += x[i];
}
println!("{}", s);
}
答案是:编译器生成的汇编中会完成前一百次加和,然后在下标大于100时直接panic。
Vec的访问形式与slice类似,但其内存存储在堆上,且具有动态扩容的能力。其结构类似下面的格式
rust
struct Vec {
void* ptr;
size_t len;
size_t capacity;
}
比slice多了一个字段。因此建议如果没有动态扩容的需求,尽量传递slice而非vec。而如果有在堆上分配内存的需求,可以考虑Box<[u32]> ,它只占一个指针大小。但注意因为Rust不提供safe的内存分配API,所以当其长度不确定时,你实际上需要通过vec.into_boxed_slice
API来获取这个Box对象。当然,还是跟前面说的一样,这点传参的代价差别很小。
内存复用
在Falcon的实现中,为了避免分配内存,分配了很长的一段内存作为内存池,所有的中间结果都在这一段内存上存储。
在Rust中,这个写法就与所有权和类型规则产生了冲突:
-
同时持有内存池中不同偏移的引用(比如两个内存池中存储的向量相加)
这个问题非常简单,使用
split
和split_mut
API即可。其底层通过unsafe的API实现,但我们无需关心其中细节------它能将一个slice拆成两个引用返回给我们。所消耗的代价仅仅是一次额外的长度检查而已。而对于slice内部不同引用的操作(比如交换),Rust也提供了一系列封装好的方法,比如
slice::swap
slice::copy_within
等等。它们可以帮助我们实现一些在C/C++中很简单但在Rust中很困难的操作。 -
同一段内存以不同的类型引用(先以[u32]的形式访问,再以[u64]的形式访问)
对此,我给出了一些尝试性的解决方案。不过总的来说,这个问题在safe Rust中暂时可能是无解的:
-
仅访问不修改,[u64]→[u8]或[u8]→[u64]的情形
不局限于u64,其它整数类型亦可,但转换的另一端一定得是u8(或有符号的i8)。
将slice用struct做封装,[u8;8]→[u64]可以通过
from_ne_bytes
实现安全的转换,反之则可以通过to_ne_bytes
实现。将所有加减乘除操作封装后,就可以实现像普通数字一样的简单运算访问。而且由于Rust的优化,这一层封装几乎是无代价的,仅在数据初始化时会多一次长度检查。缺陷:首先是此封装类型具有一定的传染性,需要在上下文做显式的转换。其次由于Rust不提供赋值的重载,所以这个写法在需要修改对应slice时会非常复杂。
-
unsafe transmute
类似于C++中的reinterpret_cast,这应该是最简单也性能代价最低的解决方案。但它使用了unsafe API,让我们的代码不纯洁了。
-
复制一份,计算完成后再复制回去
额外增加两轮复制和一些内存分配(由于流程中很多时候需要的内存大小不固定,vec处理更简单)的操作,性能消耗很高,但写起来简单,而且它是safe的。我最终选择的就是这个方案。
-
分别为不同类型开设不同的内存池
我没实现过这个方案,它表面上看很美,但在整个过程中记录不同类型内存池的偏移并在层层调用中传递这个信息,可能会引发一些未预期的问题------而且这看起来很麻烦。
-
类型封装与操作符重载
Rust提供了struct Fpr(u64)
这种形式的匿名成员定义。它为我们提供了一个非常好用的零成本抽象手段。
在Falcon的实现中,它大量使用了Fpr(以整型存储的浮点值)类型,其加减乘除均和普通整数不同。用上面的类型封装后,再去重载所有相关的操作符,实现成员函数,就能在调用特定运算操作的同时,代码像简单的整数运算一样整洁。
截取一个Falcon中的片段说明这个问题:
rust
pub fn fpc_div(a_re: Fpr, a_im: Fpr, b_re: Fpr, b_im: Fpr) -> (Fpr, Fpr) {
let m = b_re.sqr() + b_im.sqr();
let m_inv: Fpr = m.inv();
let b_re_scaled = b_re * m_inv;
let b_im_scaled = b_im.neg() * m_inv;
let d_re = a_re * b_re_scaled - a_im * b_im_scaled;
let d_im = a_re * b_im_scaled + a_im * b_re_scaled;
(d_re, d_im)
}
对应的C代码如下所示。可以发现用操作符重载和成员函数重写后,代码的复杂程度大大降低。当然,实际上我还可以把fpc这个类型也做相同的处理。如此抽象迭代之后,可以大大提高可读性。
c
#define FPC_DIV(d_re, d_im, a_re, a_im, b_re, b_im) do { \
fpr fpct_a_re, fpct_a_im; \
fpr fpct_b_re, fpct_b_im; \
fpr fpct_d_re, fpct_d_im; \
fpr fpct_m; \
fpct_a_re = (a_re); \
fpct_a_im = (a_im); \
fpct_b_re = (b_re); \
fpct_b_im = (b_im); \
fpct_m = fpr_add(fpr_sqr(fpct_b_re), fpr_sqr(fpct_b_im)); \
fpct_m = fpr_inv(fpct_m); \
fpct_b_re = fpr_mul(fpct_b_re, fpct_m); \
fpct_b_im = fpr_mul(fpr_neg(fpct_b_im), fpct_m); \
fpct_d_re = fpr_sub( \
fpr_mul(fpct_a_re, fpct_b_re), \
fpr_mul(fpct_a_im, fpct_b_im)); \
fpct_d_im = fpr_add( \
fpr_mul(fpct_a_re, fpct_b_im), \
fpr_mul(fpct_a_im, fpct_b_re)); \
(d_re) = fpct_d_re; \
(d_im) = fpct_d_im; \
} while (0)
这些抽象并不会引入额外的代价。该调函数调函数,该内联就内联,与原本代码的性能完全一致。
除了上述例子之外,多项式等类型也可以做相同的处理。不过也有类型无法做这样的处理:我一直想对环上的modq运算做类似的类型构建,但额外存储一个q作为成员代价似乎过高,不存储q的话又不太匹配通用的加减乘除trait,代码无法简化。
迭代器
Rust中提供了丰富的迭代器语义,所有slice都可以转换成迭代器Iter,在其基础上进行takewhile, map, fold等操作。由于Rust中默认不提供C sytle for语句,很多时候可以通过迭代器写出语义更清晰的循环。但在使用时要注意,迭代器作为一种抽象,会引入一些额外的代价。
例如 x.iter.step_by(p).take(q)
这样的语句,它表面上看可能跟 for (int i=0, a=x; i<q; i++, a+=p)
这样的循环是等价的,但实际上step_by和take都会插入额外的判断语句。其中 step_by
会检查p是否为0(这只是单次检查,代价较小),而 take
会在每次迭代后检查是否超出长度限制(代价相对更大)。
避免这种检查的方式与上文相同:将x转换为固定长度的数组类型,或在前面插入对x.len()的assert判断,帮助编译器优化迭代内部的长度校验。当我们给编译器提供了充足信息时,它才能将迭代器优化为抽象代价最低的代码------Rust在默认的release级别上,会将简单的循环运算(迭代器也同样如此)做循环展开,并做SSE SIMD优化(下面有一个例子)。
迭代器API中的长度检查(step_by,take等)可以通过传入常量(比如 take(2)
)或者在循环之前对对应边界做 assertion
来规避。
在迭代两个slice时,有个常用的迭代器函数是 zip
。在数组上使用这个函数时,有个容易误用的点:
javascript
let x = [0u32;56];
let y = x;
for (&i, &j) in x.iter().zip(y) {
//...
}
这个循环的写法是合法的,能通过编译,也能输出符合预期的结果。但它有一个小问题:会在迭代的时候对y进行一次不必要的memcpy。
原因很简单:我们在zip(y)时直接传值,而数组作为一个实现了Copy trait的类型,不会被move,只会被值传递,因此这里就插入了一次额外的memcpy。
要避免这次复制也很简单:将迭代的代码改为 x.iter().zip(y.iter())
即可。
另外注意一点:Rust对于迭代器的优化好于直接循环。
scss
pub fn poly_add(a: &mut [Fpr], b: &[Fpr], logn: u32) {
let n = 1usize << logn;
assert!(a.len() >= n && b.len() >= n);
// Compare to implementation: below,
// iterator based code will not generate bound-check-and-panic code.
// for u in 0..n {
// a[u] = a[u] + b[u];
// }
for (ax, bx) in a.iter_mut().zip(b.iter()).take(n) {
*ax = *ax + *bx;
}
}
可以参考上面这组对比的代码,在语义上,毫无疑问我们已经通过assertion约束了循环中绝不会出现越界的情况,但实际上编译器仍然会为循环生成边界检查的代码。要避免这种情况,就可以使用下面的迭代器写法,虽然代码更难读了一些,但如此生成的代码性能更佳。
另外无论是哪一种写法,在启用了AVX(具体说明见SIMD节)之后,得到的核心汇编都如下:
asm
// /home/winkar/pqc-rust/src/falcon/fft.rs : 308
*ax = *ax + *bx;
vmovupd ymm0, ymmword ptr [rdi + 8*rsi]
vmovupd ymm1, ymmword ptr [rdi + 8*rsi + 32]
vmovupd ymm2, ymmword ptr [rdi + 8*rsi + 64]
vmovupd ymm3, ymmword ptr [rdi + 8*rsi + 96]
// /home/winkar/pqc-rust/src/falcon/fpr.rs : 443
Fpr((x + y).to_bits())
vaddpd ymm0, ymm0, ymmword ptr [rdx + 8*rsi]
vaddpd ymm1, ymm1, ymmword ptr [rdx + 8*rsi + 32]
vaddpd ymm2, ymm2, ymmword ptr [rdx + 8*rsi + 64]
vaddpd ymm3, ymm3, ymmword ptr [rdx + 8*rsi + 96]
// /home/winkar/pqc-rust/src/falcon/fft.rs : 308
*ax = *ax + *bx;
vmovupd ymmword ptr [rdi + 8*rsi], ymm0
vmovupd ymmword ptr [rdi + 8*rsi + 32], ymm1
vmovupd ymmword ptr [rdi + 8*rsi + 64], ymm2
vmovupd ymmword ptr [rdi + 8*rsi + 96], ymm3
可以看出是进行了循环展开又用SIMD做了向量化的高度优化的代码------而我们只需要写最原始的逻辑,循环展开和向量化的优化都由编译器自动完成。
基于宏和模板的代码复用
Rust中提供了泛型和宏的能力。虽然Rust的泛型自带了C++20才加入的concept支持,但很可惜stable Rust到现在还不支持const generic expr,导致我们无法在密码算法中简单地通过泛型参数来实现不同安全等级的版本。
上述需求通过泛型虽然无法实现,但通过宏可以非常简单地做到。一个样例类似于此:
rust
#[macro_export]
macro_rules! define_falcon_keypair {
($logn: expr, $sk_bytes: expr, $pk_bytes: expr) => {
// Generate keypair, return (sk, pk)
// ### Example
// ```
// # fn main() {
// let (sk, pk) = keypair();
// # }
// ```
pub fn keypair() -> (Vec<u8>, Vec<u8>) {
let mut pk = [0u8; $pk_bytes];
let mut sk = [0u8; $sk_bytes];
let mut seed = [0u8; SEED_BYTES];
// ...
}
}
}
但这个做法的缺点也是显然的:目前Rust-Analyzer对宏的解析支持不好,宏中无法进行代码跳转。因此仅建议在代码实现调试完成后再做宏封装。不然会非常影响开发效率。
另外当其中涉及到一些表达式较为复杂的常量参数,不适合直接作为单个参数传递时,也可以使用const function进行编译期求值:
rust
pub const fn mkn(logn: u32) -> u32 {
1 << logn
}
// 可以直接将mkn的eval结果作为常量使用,且mkn也仍然可以正确作用于变量
let mut logn = 10;
const logn_2 = 9;
let buffer = [0u32; mkn(logn_2) as usize];
let n = mkn(logn);
// 上述语句都是合法的。
SIMD
SIMD在计算密集型任务中非常常用。由于我自身目前的应用场景,在此暂时不讨论异构设备(比如GPU)上的SIMD,只讨论针对CPU的优化。使用合适的CPU指令,可以在一个指令周期中完成多次加法、乘法计算,对计算密集型程序有较大的增益。
目前主流架构较新型号的CPU均提供了SIMD指令集,在x86/x86_64上有SSE*(Streaming SIMD Extension)和AVX*(Advanced Vector Extension)系列指令,在ARM上有NEON,SVE*(Scalable Vector Extension)指令集,另外在MIPS、RISC-V上也均有对应的SIMD指令集。这些指令集提供的原语和接口各有不同,需要在开发时做针对性的优化。
在正式讨论SIMD代码的编写之前,我们需要先知道如何在Rust中启用SIMD。仍然以x86_64为例,可以通过 rustc --print=cfg | grep "target_feature"
查看默认启用的优化,在我的CPU上,可以看到默认只启用了SSE的优化,没有启用支持更大向量的AVX指令集。
ini
target_feature="fxsr"
target_feature="sse"
target_feature="sse2"
要启用AVX指令集,可以在编译之前设置RUSTFLAGS
环境变量,export RUSTFLAGS="-C target-feature=+avx2,+fma"
,如此便会在编译产物中自动引入AVX指令(fma指令是浮点乘加指令,是一个补充,可以不开启)。
感兴趣的读者可以使用rustc --print=target-features
查看编译器和当前CPU支持的所有拓展。
在启用上述编译选项后,无论是否调用了专用的SIMD函数,编译器都会为合适的语句生成AVX指令版的优化代码。
为便于编译器生成更合适的SIMD代码,可以参考这篇文章,揣摩如何写出编译器友好的代码------最简单的思路是:对我们的循环计算做展开,展开后的代码往往能更好地被向量化。
当然,相当多的时候,编译器并无法很好地自动对我们的代码进行向量化,这时候就需要我们手动调用对应的指令集进行运算。
这并不需要我们去在Rust当中嵌入汇编指令------虽然这确实可以做到------只需要调用std::arch::x86_64
(或者其它目标架构)中封装好的函数,比如_mm256_set1_pd
,_mm256_mul_pd
等等。即可用纯Rust代码实现向量化的计算。可惜的是,上述代码都是unsafe的。Rust unstable中提供了std::simd
库,它目前以portable_simd
的名字在github上发布。用它可以实现safe的向量化计算。但这个库当前还无法使用。如果确实有需要,也可以使用 wide
替代。不过无论是 portable_simd
还是wide
,它们提供的都是通用的向量化API,可能无法实现一些只有特定架构上的SIMD指令才能实现的操作。