高性能CPP:expression templates
表达式模板是一种C++模板元编程技术,用于在编译时构建表示计算的结构,其中表达式仅根据需要进行计算 ,以便为整个计算生成高效的代码。
假设我们平常写出这样的代码:
C++
Vector a, b, c, d;
d = a + b + c;
如果 operator+
是普通函数,那么编译器的执行方式大概是:先计算a+b
,生成一个临时变量temp1
,再计算temp1+c
,生成另外一个临时变量temp2
,最后把temp2
赋值给d
。这会导致生成多个中间对象(temp1,
temp2`)。每个中间对象都要分配内存加上一次循环遍历,所以如果向量很大,这会非常低效。
表达式模板通过模板+运算符重载,把 a + b + c
编译期重写为一个表达式树,而不是马上计算。也就是说,编译器看到的不是「先算 a+b
」,而是保留一个抽象的组合对象,表示「d 等于 a+b+c」。等到真正需要赋值给 d
的时候,表达式模板会在一次循环里把所有操作融合起来,例如:
C++
for (int i = 0; i < N; i++) {
d[i] = a[i] + b[i] + c[i]; // 一次遍历完成
}
这绕开了C++语言的常规规则,不需要按顺序计算每个+号,表达式模板可以延迟求值,直到最后赋值时才对整个表达式求值。
表达式模板常见于向量库中,一种常见的数学运算是按元素添加两个向量 u 和 v,以生成一个新向量。这个操作我们平常会使用重载+号运算符实现,它返回一个新的向量对象:
C++
using std::array;
/// @brief class representing a mathematical 3D vector
class Vec3 : public array<double, 3> {
public:
Vec3():
array<double, 3>() {}
};
/// @brief sum 'u' and 'v' into a new instance of Vec3
Vec3 operator+(Vec3 const& u, Vec3 const& v) {
Vec3 sum;
for (size_t i = 0; i < u.size(); i++) {
sum[i] = u[i] + v[i];
}
return sum;
}
此时用户可以写出Vec3 x = a + b;
这样的代码,其中a和b都是Vec3
的实例,但是上面说过,这种方法会重复生成临时对象,拷贝开销大。所以我们使用表达式模板的延迟求值,让 operator+
返回一个辅助类型的对象(例如 Vec3Sum
)来在 C++ 中实现,该对象表示两个 Vec3
的未求值总和,或具有 Vec3Sum
的向量等。然后得到的表达式可以构建表达式树,这些表达式树仅在分配给实际的Vec3
变量时才会被计算。
当你写 a + b + c
时,C++ 模板机制会把它变成一个表达式树类型(比如 VecSum<VecSum<Vec3,Vec3>,Vec3>
)。这个表达式树是一个纯类型结构,只存在于编译期,不会在运行时创建多余对象。当你执行 Vec3 x = a + b + c
,编译器会自动生成(实例化)一个 Vec3
的构造函数或赋值运算符,用来接受这种表达式对象。这个函数会在一次遍历循环中展开计算。所以只需要一次内存分配+一次循环就可以。
以下是表达式模板的实现示例:
C++
template <typename E>
class VecExpression {
public:
static constexpr bool IS_LEAF = false;
[[nodiscard]]
double operator[](size_t i) const {
// Delegation to the actual expression type. This avoids dynamic polymorphism (a.k.a. virtual functions in C++)
return static_cast<E const&>(*this)[i];
}
[[nodiscard]]
size_t size() const {
return static_cast<E const&>(*this).size();
}
};
VecExpression<E>
可以表示任何向量值表达式,他要根据实际的表达式类型E
进行实例化。写一个像 VecExpression
这样的基类并不是实现表达式模板的必需条件。它的主要作用是:当我们在写函数(比如 operator+
、构造函数)时,可以用 VecExpression
作为参数类型,以便在类型系统里"统一识别"这些表达式,而不需要为每个具体的表达式类型都写一份函数。
比如:
C++
template <typename L, typename R>
struct AddExpr {
const L& lhs;
const R& rhs;
double operator[](size_t i) const { return lhs[i] + rhs[i]; }
};
这里根本没有用基类 VecExpression
,照样可以构建表达式树:a+b→
AddExpr<Vec3, Vec3>,(a+b)+c→
AddExpr< AddExpr<Vec3,Vec3>, Vec3 >。编译器能直接通过模板推导出具体类型,所以从功能上确实不需要基类。
那这里为什么还要用呢,因为我们要简化函数签名,统一表达式类型。
没有基类的写法如下:
C++
template <typename L, typename R>
AddExpr<L,R> operator+(const L& lhs, const R& rhs) {
return {lhs, rhs};
}
这个没问题,但要求L 和 R 必须是能支持 operator[]
的类型。否则传个别的类型会编译错误,但错误信息可能很长。这里就需要SFIANE模板推导或者C++20的concept约束。但是有了基类之后,只有派生自 VecExpression
的类型(向量、表达式节点)才能传进来。如果你传了个 int
,编译器直接报"没有匹配的函数",而不是报一大堆模板展开错误。
顺便统一了接口,以后所有表达式节点、叶子节点都继承自 VecExpression
,所以函数参数就可以写得很简洁(VecExpression<E>
),不用担心 L/R 是 Vec3
、AddExpr
、还是 MulExpr
。
C++
template <typename L, typename R>
AddExpr<L,R> operator+(VecExpression<L> const& lhs, VecExpression<R> const& rhs) {
return {static_cast<L const&>(lhs), static_cast<R const&>(rhs)};
}
接下来我们回到VecExpression
基类中,观察其成员:
C++
static constexpr bool IS_LEAF = false;
当你写一个这样的表达式:
C++
Vec3 a, b, c, d;
Vec3 x = a + b + c + d;
编译器会通过表达式模板把它解析成一棵"表达式树":
markdown
+
/ \
+ d
/ \
+ c
/ \
a b
中间节点(+
)是临时的表达式对象,比如 AddExpr<L, R>
。叶子节点(a
, b
, c
, d
)是真正存放数据的向量对象 (Vec
)。在表达式模板的框架中,我们需要区分真正持有数据的 Vec
,即叶子节点。和只是表达式组合(如 AddExpr
),本身不持有数据的内部节点。所以 is_leaf
就是一个编译期布尔标志,默认为非叶子节点。
而在 Vec
里面,它继承自 VecExpression<Vec>
,并覆盖这个标志:
C++
class Vec : public VecExpression<Vec> {
public:
static constexpr bool IS_LEAF = true; // Vec 是叶子
...
private:
std::array<double, 3> data; // 真正存放向量数据
};
这样,编译器在处理表达式树时就能在编译期区分,如果是leaf的话就可以直接去所以data[i]
,如果是internal的话就递归索引。实现基本上是重写operator[]
去取值。如果节点是表达式(比如 AddExpr
),它会把请求递归下去: lhs[i] + rhs[i]
。如果节点是 叶子 (比如 Vec
),它直接访问存储的 data[i]
。
这样保证了所有节点都能用同一个接口访问 (operator[]
)。编译器在展开表达式时能区分是否需要递归。不需要运行时检查,全是编译期优化。
C++
[[nodiscard]]
double operator[](size_t i) const {
// Delegation to the actual expression type. This avoids dynamic polymorphism (a.k.a. virtual functions in C++)
return static_cast<E const&>(*this)[i];
}
这是一个下标运算符的重载,const
表示它不能修改对象。这里的关键是CRTP
,我之前的博客里面也分析过CRTP
,感兴趣的可以去看看。*this
本来的类型是 VecExpression<E>
,但实际上它是某个具体派生类(比如 Vec
、VecSum
等)的实例。我们用 static_cast<E const&>(*this)
把当前对象强制转换为"派生类的常量引用"。这样就能调用 派生类自己实现的 operator[]
。CRTP
是编译期静态解析调用的,比虚函数实现的多态性能更高,因为不需要虚表,没有运行时查表的开销。
C++
size_t size() const {
return static_cast<E const&>(*this).size();
}
这个同理,在此不过多赘述。下面我们来看表达式模板中的leaf节点:
C++
class Vec3 : public VecExpression<Vec3> {
private
array<double, 3> elems;
public:
static constexpr bool IS_LEAF = true;
[[nodiscard]]
double operator[](size_t i) const noexcept {
return elems[i];
}
double& operator[](size_t i) noexcept {
return elems[i];
}
[[nodiscard]]
size_t size() const noexcept {
return elems.size();
}
// construct Vec using initializer list
Vec3(std::initializer_list<double> init) {
std::ranges::copy(init, elems.begin());
}
// A Vec can be constructed from any VecExpression, forcing its evaluation.
template <typename E>
Vec3(VecExpression<E> const& expr) {
for (size_t i = 0; i != expr.size(); ++i) {
elems[i] = expr[i];
}
}
};
class Vec3 : public VecExpression<Vec>
通过继承基类模板实现CRTP,允许基类里的 operator[]
自动委托到派生类(前面讲过的 static_cast<E const&>(*this)
)。这里值得说的就是Vec3
的两个构造函数,第一个就是initializer_list
的构造函数,这允许使用花括号构造Vec3 v{1.0, 2.0, 3.0};
,std::ranges::copy
把初始化列表复制到 elems
里。
C++
// A Vec can be constructed from any VecExpression, forcing its evaluation.
template <typename E>
Vec3(VecExpression<E> const& expr) {
for (size_t i = 0; i != expr.size(); ++i) {
elems[i] = expr[i];
}
}
这个构造函数是叶子节点的核心,可以用任意表达式(比如 a+b
,它的类型是 VecSum<Vec3,Vec3>
)来构造一个 Vec3
。循环里会逐元素访问表达式,并把结果写入 elems
。这样,表达式模板的延迟计算就被强制求值(evaluation),结果存进了一个真正的向量。
接下来是VecSum
类,两个 Vec
的和由一个新的类型 VecSum
表示。VecSum
是一个类模板,它的模板参数是加法左右两边的类型,因此它可以应用于任意一对 Vec
表达式。重载的 operator+
运算符只是 VecSum
构造函数的一种语法糖。
C++
template <typename E1, typename E2>
class Vec3Sum : public VecExpression<Vec3Sum<E1, E2>> {
private:
// cref if leaf, copy otherwise
typename conditional<E1::is_leaf, const E1&, const E1>::type u;
typename conditional<E2::is_leaf, const E2&, const E2>::type v;
public:
static constexpr bool IS_LEAF = false;
Vec3Sum(E1 const& u, E2 const& v):
u{u}, v{v} {
assert(u.size() == v.size());
}
[[nodiscard]]
double operator[](size_t i) const noexcept {
return u[i] + v[i];
}
[[nodiscard]]
size_t size() const noexcept {
return v.size();
}
};
template <typename E1, typename E2>
Vec3Sum<E1, E2> operator+(VecExpression<E1> const& u, VecExpression<E2> const& v) {
return Vec3Sum<E1, E2>(*static_cast<const E1*>(&u), *static_cast<const E2*>(&v));
}
Vec3Sum
这个类表示两个三维向量表达式的加法结果,写 u + v
的时候,不会立刻计算,而是生成一个 Vec3Sum<E1, E2>
对象,用来延迟计算。关于: public VecExpression<Vec3Sum<E1, E2>>
这句。这一句用到了CRTP,VecExpression<T>
是一个基类模板,里面提供了统一的接口(比如 operator[]
、size()
等)。Vec3Sum<E1, E2>
继承自 VecExpression<Vec3Sum<E1, E2>>
,表示Vec3Sum
是一种 VecExpression
表达式,具体的类型就是Vec3Sum
。
C++
typename conditional<E1::is_leaf, const E1&, const E1>::type u;
typename conditional<E2::is_leaf, const E2&, const E2>::type v;
std::conditional<condition, T, F>
是标准库里的一个 条件选择模板 :如果 condition == true
,那么它的 ::type
就是 T
,如果 condition == false
,那么它的 ::type
就是 F
。这里表明了如果是Vec类型的话,就让type为引用,如果是VecSum类型的话就让type为拷贝。
为什么叶子节点就存引用,内部节点就存拷贝呢。因为叶子节点的生命周期是由用户控制的,比如用户声明一个Vec对象:Vec a(1,2,3);
,如果后续使用这个对象a进行计算的话,计算完之后a对象还是存在,所以可以安全引用。但是内部节点就不一样了:
C++
Vec a(1,2,3), b(4,5,6), c(7,8,9);
// 构建一个表达式
auto expr = a + b + c;
对于上述表达式,a + b + c
的构造过程是这样的,首先a + b
生成一个临时的Vec3Sum<Vec, Vec>
临时对象,然后这个临时对象再和 c
做加法 → 生成一个新的Vec3Sum<Vec3Sum<Vec, Vec>, Vec>
对象。第一步里的 (a+b)
是个临时变量,用完就销毁了。如果我们在第二步里只存 (a+b)
的引用,等表达式求值时,这个临时对象早就没了 → 悬空引用,未定义行为。所以内部节点只可以拷贝,否则会挂掉。
C++
static constexpr bool IS_LEAF = false;
这表示Vec3Sum
是内部节点,与上面的Vec3
相对。
C++
double operator[](size_t i) const noexcept {
return u[i] + v[i];
}
对于上述运算符重载,因为Vec3Sum
是内部节点,所以理应将左右子节点对应下标处元素相加。
接下来就是最重要的+运算符重载函数:
C++
template <typename E1, typename E2>
Vec3Sum<E1, E2> operator+(VecExpression<E1> const& u, VecExpression<E2> const& v) {
return Vec3Sum<E1, E2>(*static_cast<const E1*>(&u), *static_cast<const E2*>(&v));
}
当编译器看到 a + b
时,若 a
的类型是 Vec
,它能把 Vec
匹配到参数类型 VecExpression<E1> const&
,从而推导出 E1 = Vec
(同理 b
推导 E2
)。所以 E1
、E2
实际上就是传入对象的"真实派生类型"。return Vec3Sum<E1, E2>(...)
:返回一个延迟求值(lazy)的表达式对象,这个 Vec3Sum
是一个轻量的"表达式节点",内部保存 u
和 v
。
解释以下为什么要static_cast<const E1*>(&u)
,u
的静态类型是 const VecExpression<E1>&
(基类引用)。但 Vec3Sum
的构造需要 E1
(或 const E1&
)类型的参数 ------ 所以必须把基类引用转换回派生类类型。所以static_cast<const E1*>(&u)
的作用:是把 &u
(指向基类的指针)转为指向派生类型 E1
的指针(const E1*
)。再用 *
解引用得到 const E1&
(传给 Vec3Sum
构造函数)。
这里不需要使用dynamic_cast
,因为CRTP的保证,想了解的可以看我之前的博客。
这个+号运算符主要就是构造延迟求值的表达式对象。使用Vec3 x = a + b + c
时,会先构造出Vec对象,然后进入到上面的重载中,最终得到Vec3Sum<Vec3Sum<Vec3, Vec3>, Vec3>
对象。
下面给出示例程序,参考维基百科关于表达式模板的代码:
C++
int main() {
Vec3 v0 = {23.4, 12.5, 144.56};
Vec3 v1 = {67.12, 34.8, 90.34};
Vec3 v2 = {34.90, 111.9, 45.12};
// Following assignment will call the ctor of Vec3 which accept type of
// `VecExpression<E> const&`. Then expand the loop body to
// a.elems[i] + b.elems[i] + c.elems[i]
Vec3 sumOfVecType = v0 + v1 + v2;
for (size_t i = 0; i < sumOfVecType.size(); ++i) {
std::println("{}", sumOfVecType[i]);
}
// To avoid creating any extra storage, other than v0, v1, v2
// one can do the following (Tested with C++11 on GCC 5.3.0)
auto sum = v0 + v1 + v2;
for (size_t i = 0; i < sum.size(); ++i) {
std::println("{}", sum[i]);
}
// Observe that in this case typeid(sum) will be Vec3Sum<Vec3Sum<Vec3, Vec3>, Vec3>
// and this chaining of operations can go on.
}