高性能现代CPP--表达式模板(expression templates)

高性能CPP:expression templates

表达式模板是一种C++模板元编程技术,用于在编译时构建表示计算的结构,其中表达式仅根据需要进行计算 ,以便为整个计算生成高效的代码。

假设我们平常写出这样的代码:

C++ 复制代码
Vector a, b, c, d;
d = a + b + c;

如果 operator+ 是普通函数,那么编译器的执行方式大概是:先计算a+b,生成一个临时变量temp1,再计算temp1+c,生成另外一个临时变量temp2,最后把temp2赋值给d。这会导致生成多个中间对象(temp1, temp2`)。每个中间对象都要分配内存加上一次循环遍历,所以如果向量很大,这会非常低效。

表达式模板通过模板+运算符重载,把 a + b + c编译期重写为一个表达式树,而不是马上计算。也就是说,编译器看到的不是「先算 a+b」,而是保留一个抽象的组合对象,表示「d 等于 a+b+c」。等到真正需要赋值给 d 的时候,表达式模板会在一次循环里把所有操作融合起来,例如:

C++ 复制代码
for (int i = 0; i < N; i++) {
    d[i] = a[i] + b[i] + c[i]; // 一次遍历完成
}

这绕开了C++语言的常规规则,不需要按顺序计算每个+号,表达式模板可以延迟求值,直到最后赋值时才对整个表达式求值。

表达式模板常见于向量库中,一种常见的数学运算是按元素添加两个向量 uv,以生成一个新向量。这个操作我们平常会使用重载+号运算符实现,它返回一个新的向量对象:

C++ 复制代码
using std::array;

/// @brief class representing a mathematical 3D vector
class Vec3 : public array<double, 3> {
public:
    Vec3():
        array<double, 3>() {}
};


/// @brief sum 'u' and 'v' into a new instance of Vec3 
Vec3 operator+(Vec3 const& u, Vec3 const& v) {
    Vec3 sum;
    for (size_t i = 0; i < u.size(); i++) {
        sum[i] = u[i] + v[i];
    }
    return sum;
}

此时用户可以写出Vec3 x = a + b;这样的代码,其中a和b都是Vec3的实例,但是上面说过,这种方法会重复生成临时对象,拷贝开销大。所以我们使用表达式模板的延迟求值,让 operator+ 返回一个辅助类型的对象(例如 Vec3Sum)来在 C++ 中实现,该对象表示两个 Vec3 的未求值总和,或具有 Vec3Sum 的向量等。然后得到的表达式可以构建表达式树,这些表达式树仅在分配给实际的Vec3变量时才会被计算。

当你写 a + b + c 时,C++ 模板机制会把它变成一个表达式树类型(比如 VecSum<VecSum<Vec3,Vec3>,Vec3>)。这个表达式树是一个纯类型结构,只存在于编译期,不会在运行时创建多余对象。当你执行 Vec3 x = a + b + c,编译器会自动生成(实例化)一个 Vec3 的构造函数或赋值运算符,用来接受这种表达式对象。这个函数会在一次遍历循环中展开计算。所以只需要一次内存分配+一次循环就可以。

以下是表达式模板的实现示例:

C++ 复制代码
template <typename E>
class VecExpression {
public:
    static constexpr bool IS_LEAF = false;

    [[nodiscard]]
    double operator[](size_t i) const {
        // Delegation to the actual expression type. This avoids dynamic polymorphism (a.k.a. virtual functions in C++)
        return static_cast<E const&>(*this)[i];
    }

    [[nodiscard]]
    size_t size() const { 
        return static_cast<E const&>(*this).size(); 
    }
};

VecExpression<E>可以表示任何向量值表达式,他要根据实际的表达式类型E进行实例化。写一个像 VecExpression 这样的基类并不是实现表达式模板的必需条件。它的主要作用是:当我们在写函数(比如 operator+、构造函数)时,可以用 VecExpression 作为参数类型,以便在类型系统里"统一识别"这些表达式,而不需要为每个具体的表达式类型都写一份函数。

比如:

C++ 复制代码
template <typename L, typename R>
struct AddExpr {
    const L& lhs;
    const R& rhs;
    double operator[](size_t i) const { return lhs[i] + rhs[i]; }
};

这里根本没有用基类 VecExpression,照样可以构建表达式树:a+bAddExpr<Vec3, Vec3>,(a+b)+cAddExpr< AddExpr<Vec3,Vec3>, Vec3 >。编译器能直接通过模板推导出具体类型,所以从功能上确实不需要基类。

那这里为什么还要用呢,因为我们要简化函数签名,统一表达式类型。

没有基类的写法如下:

C++ 复制代码
template <typename L, typename R>
AddExpr<L,R> operator+(const L& lhs, const R& rhs) {
    return {lhs, rhs};
}

这个没问题,但要求L 和 R 必须是能支持 operator[] 的类型。否则传个别的类型会编译错误,但错误信息可能很长。这里就需要SFIANE模板推导或者C++20的concept约束。但是有了基类之后,只有派生自 VecExpression 的类型(向量、表达式节点)才能传进来。如果你传了个 int,编译器直接报"没有匹配的函数",而不是报一大堆模板展开错误。

顺便统一了接口,以后所有表达式节点、叶子节点都继承自 VecExpression,所以函数参数就可以写得很简洁(VecExpression<E>),不用担心 L/R 是 Vec3AddExpr、还是 MulExpr

C++ 复制代码
template <typename L, typename R>
AddExpr<L,R> operator+(VecExpression<L> const& lhs, VecExpression<R> const& rhs) {
    return {static_cast<L const&>(lhs), static_cast<R const&>(rhs)};
}

接下来我们回到VecExpression基类中,观察其成员:

C++ 复制代码
static constexpr bool IS_LEAF = false;

当你写一个这样的表达式:

C++ 复制代码
Vec3 a, b, c, d;
Vec3 x = a + b + c + d;

编译器会通过表达式模板把它解析成一棵"表达式树":

markdown 复制代码
          +
        /   \
       +     d
     /   \
    +     c
  /   \
 a     b

中间节点(+)是临时的表达式对象,比如 AddExpr<L, R>。叶子节点(a, b, c, d)是真正存放数据的向量对象 (Vec)。在表达式模板的框架中,我们需要区分真正持有数据的 Vec,即叶子节点。和只是表达式组合(如 AddExpr),本身不持有数据的内部节点。所以 is_leaf 就是一个编译期布尔标志,默认为非叶子节点。

而在 Vec 里面,它继承自 VecExpression<Vec>,并覆盖这个标志:

C++ 复制代码
class Vec : public VecExpression<Vec> {
public:
    static constexpr bool IS_LEAF = true;  // Vec 是叶子
    ...
private:
    std::array<double, 3> data; // 真正存放向量数据
};

这样,编译器在处理表达式树时就能在编译期区分,如果是leaf的话就可以直接去所以data[i],如果是internal的话就递归索引。实现基本上是重写operator[] 去取值。如果节点是表达式(比如 AddExpr),它会把请求递归下去: lhs[i] + rhs[i]。如果节点是 叶子 (比如 Vec),它直接访问存储的 data[i]

这样保证了所有节点都能用同一个接口访问 (operator[])。编译器在展开表达式时能区分是否需要递归。不需要运行时检查,全是编译期优化。

C++ 复制代码
[[nodiscard]]
double operator[](size_t i) const {
    // Delegation to the actual expression type. This avoids dynamic polymorphism (a.k.a. virtual functions in C++)
    return static_cast<E const&>(*this)[i];
}

这是一个下标运算符的重载,const表示它不能修改对象。这里的关键是CRTP,我之前的博客里面也分析过CRTP,感兴趣的可以去看看。*this 本来的类型是 VecExpression<E>,但实际上它是某个具体派生类(比如 VecVecSum 等)的实例。我们用 static_cast<E const&>(*this) 把当前对象强制转换为"派生类的常量引用"。这样就能调用 派生类自己实现的 operator[]CRTP是编译期静态解析调用的,比虚函数实现的多态性能更高,因为不需要虚表,没有运行时查表的开销。

C++ 复制代码
size_t size() const { 
    return static_cast<E const&>(*this).size(); 
}

这个同理,在此不过多赘述。下面我们来看表达式模板中的leaf节点:

C++ 复制代码
class Vec3 : public VecExpression<Vec3> {
    private    
        array<double, 3> elems;
    public:
    static constexpr bool IS_LEAF = true;

    [[nodiscard]]
    double operator[](size_t i) const noexcept { 
        return elems[i]; 
    }

    double& operator[](size_t i) noexcept { 
        return elems[i]; 
    }

    [[nodiscard]]
    size_t size() const noexcept { 
        return elems.size(); 
    }

    // construct Vec using initializer list 
    Vec3(std::initializer_list<double> init) {
        std::ranges::copy(init, elems.begin());
    }

    // A Vec can be constructed from any VecExpression, forcing its evaluation.
    template <typename E>
    Vec3(VecExpression<E> const& expr) {
        for (size_t i = 0; i != expr.size(); ++i) {
            elems[i] = expr[i];
        }
    }
};

class Vec3 : public VecExpression<Vec>通过继承基类模板实现CRTP,允许基类里的 operator[] 自动委托到派生类(前面讲过的 static_cast<E const&>(*this))。这里值得说的就是Vec3的两个构造函数,第一个就是initializer_list的构造函数,这允许使用花括号构造Vec3 v{1.0, 2.0, 3.0};std::ranges::copy 把初始化列表复制到 elems 里。

C++ 复制代码
// A Vec can be constructed from any VecExpression, forcing its evaluation.
template <typename E>
Vec3(VecExpression<E> const& expr) {
    for (size_t i = 0; i != expr.size(); ++i) {
        elems[i] = expr[i];
    }
}

这个构造函数是叶子节点的核心,可以用任意表达式(比如 a+b,它的类型是 VecSum<Vec3,Vec3>)来构造一个 Vec3。循环里会逐元素访问表达式,并把结果写入 elems。这样,表达式模板的延迟计算就被强制求值(evaluation),结果存进了一个真正的向量。

接下来是VecSum类,两个 Vec 的和由一个新的类型 VecSum 表示。VecSum 是一个类模板,它的模板参数是加法左右两边的类型,因此它可以应用于任意一对 Vec 表达式。重载的 operator+ 运算符只是 VecSum 构造函数的一种语法糖。

C++ 复制代码
template <typename E1, typename E2>
class Vec3Sum : public VecExpression<Vec3Sum<E1, E2>> {
private:
    // cref if leaf, copy otherwise
    typename conditional<E1::is_leaf, const E1&, const E1>::type u;
    typename conditional<E2::is_leaf, const E2&, const E2>::type v;
public:
    static constexpr bool IS_LEAF = false;

    Vec3Sum(E1 const& u, E2 const& v): 
        u{u}, v{v} {
        assert(u.size() == v.size());
    }

    [[nodiscard]]
    double operator[](size_t i) const noexcept { 
        return u[i] + v[i]; 
    }

    [[nodiscard]]
    size_t size() const noexcept { 
        return v.size(); 
    }
};
  
template <typename E1, typename E2>
Vec3Sum<E1, E2> operator+(VecExpression<E1> const& u, VecExpression<E2> const& v) {
   return Vec3Sum<E1, E2>(*static_cast<const E1*>(&u), *static_cast<const E2*>(&v));
}

Vec3Sum这个类表示两个三维向量表达式的加法结果,写 u + v 的时候,不会立刻计算,而是生成一个 Vec3Sum<E1, E2> 对象,用来延迟计算。关于: public VecExpression<Vec3Sum<E1, E2>>这句。这一句用到了CRTP,VecExpression<T> 是一个基类模板,里面提供了统一的接口(比如 operator[]size() 等)。Vec3Sum<E1, E2> 继承自 VecExpression<Vec3Sum<E1, E2>>,表示Vec3Sum 是一种 VecExpression 表达式,具体的类型就是Vec3Sum

C++ 复制代码
typename conditional<E1::is_leaf, const E1&, const E1>::type u;
typename conditional<E2::is_leaf, const E2&, const E2>::type v;

std::conditional<condition, T, F> 是标准库里的一个 条件选择模板 :如果 condition == true,那么它的 ::type 就是 T,如果 condition == false,那么它的 ::type 就是 F。这里表明了如果是Vec类型的话,就让type为引用,如果是VecSum类型的话就让type为拷贝。

为什么叶子节点就存引用,内部节点就存拷贝呢。因为叶子节点的生命周期是由用户控制的,比如用户声明一个Vec对象:Vec a(1,2,3);,如果后续使用这个对象a进行计算的话,计算完之后a对象还是存在,所以可以安全引用。但是内部节点就不一样了:

C++ 复制代码
Vec a(1,2,3), b(4,5,6), c(7,8,9);

// 构建一个表达式
auto expr = a + b + c;

对于上述表达式,a + b + c的构造过程是这样的,首先a + b生成一个临时的Vec3Sum<Vec, Vec>临时对象,然后这个临时对象再和 c 做加法 → 生成一个新的Vec3Sum<Vec3Sum<Vec, Vec>, Vec> 对象。第一步里的 (a+b) 是个临时变量,用完就销毁了。如果我们在第二步里只存 (a+b) 的引用,等表达式求值时,这个临时对象早就没了 → 悬空引用,未定义行为。所以内部节点只可以拷贝,否则会挂掉。

C++ 复制代码
static constexpr bool IS_LEAF = false;

这表示Vec3Sum是内部节点,与上面的Vec3相对。

C++ 复制代码
double operator[](size_t i) const noexcept { 
    return u[i] + v[i]; 
}

对于上述运算符重载,因为Vec3Sum是内部节点,所以理应将左右子节点对应下标处元素相加。

接下来就是最重要的+运算符重载函数:

C++ 复制代码
template <typename E1, typename E2>
Vec3Sum<E1, E2> operator+(VecExpression<E1> const& u, VecExpression<E2> const& v) {
    return Vec3Sum<E1, E2>(*static_cast<const E1*>(&u), *static_cast<const E2*>(&v));
}

当编译器看到 a + b 时,若 a 的类型是 Vec,它能把 Vec 匹配到参数类型 VecExpression<E1> const&,从而推导出 E1 = Vec(同理 b 推导 E2)。所以 E1E2 实际上就是传入对象的"真实派生类型"。return Vec3Sum<E1, E2>(...):返回一个延迟求值(lazy)的表达式对象,这个 Vec3Sum 是一个轻量的"表达式节点",内部保存 uv

解释以下为什么要static_cast<const E1*>(&u)u 的静态类型是 const VecExpression<E1>&(基类引用)。但 Vec3Sum 的构造需要 E1(或 const E1&)类型的参数 ------ 所以必须把基类引用转换回派生类类型。所以static_cast<const E1*>(&u) 的作用:是把 &u(指向基类的指针)转为指向派生类型 E1 的指针(const E1*)。再用 * 解引用得到 const E1&(传给 Vec3Sum 构造函数)。

这里不需要使用dynamic_cast,因为CRTP的保证,想了解的可以看我之前的博客。

这个+号运算符主要就是构造延迟求值的表达式对象。使用Vec3 x = a + b + c时,会先构造出Vec对象,然后进入到上面的重载中,最终得到Vec3Sum<Vec3Sum<Vec3, Vec3>, Vec3>对象。

下面给出示例程序,参考维基百科关于表达式模板的代码:

C++ 复制代码
int main() {
    Vec3 v0 = {23.4,  12.5,  144.56};
    Vec3 v1 = {67.12, 34.8,  90.34};
    Vec3 v2 = {34.90, 111.9, 45.12};
    
    // Following assignment will call the ctor of Vec3 which accept type of 
    // `VecExpression<E> const&`. Then expand the loop body to 
    // a.elems[i] + b.elems[i] + c.elems[i]
    Vec3 sumOfVecType = v0 + v1 + v2; 

    for (size_t i = 0; i < sumOfVecType.size(); ++i) {
        std::println("{}", sumOfVecType[i]);
    }

    // To avoid creating any extra storage, other than v0, v1, v2
    // one can do the following (Tested with C++11 on GCC 5.3.0)
    auto sum = v0 + v1 + v2;
    for (size_t i = 0; i < sum.size(); ++i) {
        std::println("{}", sum[i]);
    }
    // Observe that in this case typeid(sum) will be Vec3Sum<Vec3Sum<Vec3, Vec3>, Vec3>
    // and this chaining of operations can go on.
}
相关推荐
Java水解3 分钟前
Java 中间件:Dubbo 服务降级(Mock 机制)
java·后端
千寻girling6 分钟前
一份不可多得的 《 Python 》语言教程
人工智能·后端·python
南风9996 分钟前
Claude code安装使用保姆级教程
后端
爱泡脚的鸡腿7 分钟前
Node.js 拓展
前端·后端
蚂蚁背大象1 小时前
Rust 所有权系统是为了解决什么问题
后端·rust
子玖3 小时前
go实现通过ip解析城市
后端·go
Java不加班3 小时前
Java 后端定时任务实现方案与工程化指南
后端
心在飞扬3 小时前
RAG 进阶检索学习笔记
后端
Moment3 小时前
想要长期陪伴你的助理?先从部署一个 OpenClaw 开始 😍😍😍
前端·后端·github