高性能现代CPP--表达式模板(expression templates)

高性能CPP:expression templates

表达式模板是一种C++模板元编程技术,用于在编译时构建表示计算的结构,其中表达式仅根据需要进行计算 ,以便为整个计算生成高效的代码。

假设我们平常写出这样的代码:

C++ 复制代码
Vector a, b, c, d;
d = a + b + c;

如果 operator+ 是普通函数,那么编译器的执行方式大概是:先计算a+b,生成一个临时变量temp1,再计算temp1+c,生成另外一个临时变量temp2,最后把temp2赋值给d。这会导致生成多个中间对象(temp1, temp2`)。每个中间对象都要分配内存加上一次循环遍历,所以如果向量很大,这会非常低效。

表达式模板通过模板+运算符重载,把 a + b + c编译期重写为一个表达式树,而不是马上计算。也就是说,编译器看到的不是「先算 a+b」,而是保留一个抽象的组合对象,表示「d 等于 a+b+c」。等到真正需要赋值给 d 的时候,表达式模板会在一次循环里把所有操作融合起来,例如:

C++ 复制代码
for (int i = 0; i < N; i++) {
    d[i] = a[i] + b[i] + c[i]; // 一次遍历完成
}

这绕开了C++语言的常规规则,不需要按顺序计算每个+号,表达式模板可以延迟求值,直到最后赋值时才对整个表达式求值。

表达式模板常见于向量库中,一种常见的数学运算是按元素添加两个向量 uv,以生成一个新向量。这个操作我们平常会使用重载+号运算符实现,它返回一个新的向量对象:

C++ 复制代码
using std::array;

/// @brief class representing a mathematical 3D vector
class Vec3 : public array<double, 3> {
public:
    Vec3():
        array<double, 3>() {}
};


/// @brief sum 'u' and 'v' into a new instance of Vec3 
Vec3 operator+(Vec3 const& u, Vec3 const& v) {
    Vec3 sum;
    for (size_t i = 0; i < u.size(); i++) {
        sum[i] = u[i] + v[i];
    }
    return sum;
}

此时用户可以写出Vec3 x = a + b;这样的代码,其中a和b都是Vec3的实例,但是上面说过,这种方法会重复生成临时对象,拷贝开销大。所以我们使用表达式模板的延迟求值,让 operator+ 返回一个辅助类型的对象(例如 Vec3Sum)来在 C++ 中实现,该对象表示两个 Vec3 的未求值总和,或具有 Vec3Sum 的向量等。然后得到的表达式可以构建表达式树,这些表达式树仅在分配给实际的Vec3变量时才会被计算。

当你写 a + b + c 时,C++ 模板机制会把它变成一个表达式树类型(比如 VecSum<VecSum<Vec3,Vec3>,Vec3>)。这个表达式树是一个纯类型结构,只存在于编译期,不会在运行时创建多余对象。当你执行 Vec3 x = a + b + c,编译器会自动生成(实例化)一个 Vec3 的构造函数或赋值运算符,用来接受这种表达式对象。这个函数会在一次遍历循环中展开计算。所以只需要一次内存分配+一次循环就可以。

以下是表达式模板的实现示例:

C++ 复制代码
template <typename E>
class VecExpression {
public:
    static constexpr bool IS_LEAF = false;

    [[nodiscard]]
    double operator[](size_t i) const {
        // Delegation to the actual expression type. This avoids dynamic polymorphism (a.k.a. virtual functions in C++)
        return static_cast<E const&>(*this)[i];
    }

    [[nodiscard]]
    size_t size() const { 
        return static_cast<E const&>(*this).size(); 
    }
};

VecExpression<E>可以表示任何向量值表达式,他要根据实际的表达式类型E进行实例化。写一个像 VecExpression 这样的基类并不是实现表达式模板的必需条件。它的主要作用是:当我们在写函数(比如 operator+、构造函数)时,可以用 VecExpression 作为参数类型,以便在类型系统里"统一识别"这些表达式,而不需要为每个具体的表达式类型都写一份函数。

比如:

C++ 复制代码
template <typename L, typename R>
struct AddExpr {
    const L& lhs;
    const R& rhs;
    double operator[](size_t i) const { return lhs[i] + rhs[i]; }
};

这里根本没有用基类 VecExpression,照样可以构建表达式树:a+bAddExpr<Vec3, Vec3>,(a+b)+cAddExpr< AddExpr<Vec3,Vec3>, Vec3 >。编译器能直接通过模板推导出具体类型,所以从功能上确实不需要基类。

那这里为什么还要用呢,因为我们要简化函数签名,统一表达式类型。

没有基类的写法如下:

C++ 复制代码
template <typename L, typename R>
AddExpr<L,R> operator+(const L& lhs, const R& rhs) {
    return {lhs, rhs};
}

这个没问题,但要求L 和 R 必须是能支持 operator[] 的类型。否则传个别的类型会编译错误,但错误信息可能很长。这里就需要SFIANE模板推导或者C++20的concept约束。但是有了基类之后,只有派生自 VecExpression 的类型(向量、表达式节点)才能传进来。如果你传了个 int,编译器直接报"没有匹配的函数",而不是报一大堆模板展开错误。

顺便统一了接口,以后所有表达式节点、叶子节点都继承自 VecExpression,所以函数参数就可以写得很简洁(VecExpression<E>),不用担心 L/R 是 Vec3AddExpr、还是 MulExpr

C++ 复制代码
template <typename L, typename R>
AddExpr<L,R> operator+(VecExpression<L> const& lhs, VecExpression<R> const& rhs) {
    return {static_cast<L const&>(lhs), static_cast<R const&>(rhs)};
}

接下来我们回到VecExpression基类中,观察其成员:

C++ 复制代码
static constexpr bool IS_LEAF = false;

当你写一个这样的表达式:

C++ 复制代码
Vec3 a, b, c, d;
Vec3 x = a + b + c + d;

编译器会通过表达式模板把它解析成一棵"表达式树":

markdown 复制代码
          +
        /   \
       +     d
     /   \
    +     c
  /   \
 a     b

中间节点(+)是临时的表达式对象,比如 AddExpr<L, R>。叶子节点(a, b, c, d)是真正存放数据的向量对象 (Vec)。在表达式模板的框架中,我们需要区分真正持有数据的 Vec,即叶子节点。和只是表达式组合(如 AddExpr),本身不持有数据的内部节点。所以 is_leaf 就是一个编译期布尔标志,默认为非叶子节点。

而在 Vec 里面,它继承自 VecExpression<Vec>,并覆盖这个标志:

C++ 复制代码
class Vec : public VecExpression<Vec> {
public:
    static constexpr bool IS_LEAF = true;  // Vec 是叶子
    ...
private:
    std::array<double, 3> data; // 真正存放向量数据
};

这样,编译器在处理表达式树时就能在编译期区分,如果是leaf的话就可以直接去所以data[i],如果是internal的话就递归索引。实现基本上是重写operator[] 去取值。如果节点是表达式(比如 AddExpr),它会把请求递归下去: lhs[i] + rhs[i]。如果节点是 叶子 (比如 Vec),它直接访问存储的 data[i]

这样保证了所有节点都能用同一个接口访问 (operator[])。编译器在展开表达式时能区分是否需要递归。不需要运行时检查,全是编译期优化。

C++ 复制代码
[[nodiscard]]
double operator[](size_t i) const {
    // Delegation to the actual expression type. This avoids dynamic polymorphism (a.k.a. virtual functions in C++)
    return static_cast<E const&>(*this)[i];
}

这是一个下标运算符的重载,const表示它不能修改对象。这里的关键是CRTP,我之前的博客里面也分析过CRTP,感兴趣的可以去看看。*this 本来的类型是 VecExpression<E>,但实际上它是某个具体派生类(比如 VecVecSum 等)的实例。我们用 static_cast<E const&>(*this) 把当前对象强制转换为"派生类的常量引用"。这样就能调用 派生类自己实现的 operator[]CRTP是编译期静态解析调用的,比虚函数实现的多态性能更高,因为不需要虚表,没有运行时查表的开销。

C++ 复制代码
size_t size() const { 
    return static_cast<E const&>(*this).size(); 
}

这个同理,在此不过多赘述。下面我们来看表达式模板中的leaf节点:

C++ 复制代码
class Vec3 : public VecExpression<Vec3> {
    private    
        array<double, 3> elems;
    public:
    static constexpr bool IS_LEAF = true;

    [[nodiscard]]
    double operator[](size_t i) const noexcept { 
        return elems[i]; 
    }

    double& operator[](size_t i) noexcept { 
        return elems[i]; 
    }

    [[nodiscard]]
    size_t size() const noexcept { 
        return elems.size(); 
    }

    // construct Vec using initializer list 
    Vec3(std::initializer_list<double> init) {
        std::ranges::copy(init, elems.begin());
    }

    // A Vec can be constructed from any VecExpression, forcing its evaluation.
    template <typename E>
    Vec3(VecExpression<E> const& expr) {
        for (size_t i = 0; i != expr.size(); ++i) {
            elems[i] = expr[i];
        }
    }
};

class Vec3 : public VecExpression<Vec>通过继承基类模板实现CRTP,允许基类里的 operator[] 自动委托到派生类(前面讲过的 static_cast<E const&>(*this))。这里值得说的就是Vec3的两个构造函数,第一个就是initializer_list的构造函数,这允许使用花括号构造Vec3 v{1.0, 2.0, 3.0};std::ranges::copy 把初始化列表复制到 elems 里。

C++ 复制代码
// A Vec can be constructed from any VecExpression, forcing its evaluation.
template <typename E>
Vec3(VecExpression<E> const& expr) {
    for (size_t i = 0; i != expr.size(); ++i) {
        elems[i] = expr[i];
    }
}

这个构造函数是叶子节点的核心,可以用任意表达式(比如 a+b,它的类型是 VecSum<Vec3,Vec3>)来构造一个 Vec3。循环里会逐元素访问表达式,并把结果写入 elems。这样,表达式模板的延迟计算就被强制求值(evaluation),结果存进了一个真正的向量。

接下来是VecSum类,两个 Vec 的和由一个新的类型 VecSum 表示。VecSum 是一个类模板,它的模板参数是加法左右两边的类型,因此它可以应用于任意一对 Vec 表达式。重载的 operator+ 运算符只是 VecSum 构造函数的一种语法糖。

C++ 复制代码
template <typename E1, typename E2>
class Vec3Sum : public VecExpression<Vec3Sum<E1, E2>> {
private:
    // cref if leaf, copy otherwise
    typename conditional<E1::is_leaf, const E1&, const E1>::type u;
    typename conditional<E2::is_leaf, const E2&, const E2>::type v;
public:
    static constexpr bool IS_LEAF = false;

    Vec3Sum(E1 const& u, E2 const& v): 
        u{u}, v{v} {
        assert(u.size() == v.size());
    }

    [[nodiscard]]
    double operator[](size_t i) const noexcept { 
        return u[i] + v[i]; 
    }

    [[nodiscard]]
    size_t size() const noexcept { 
        return v.size(); 
    }
};
  
template <typename E1, typename E2>
Vec3Sum<E1, E2> operator+(VecExpression<E1> const& u, VecExpression<E2> const& v) {
   return Vec3Sum<E1, E2>(*static_cast<const E1*>(&u), *static_cast<const E2*>(&v));
}

Vec3Sum这个类表示两个三维向量表达式的加法结果,写 u + v 的时候,不会立刻计算,而是生成一个 Vec3Sum<E1, E2> 对象,用来延迟计算。关于: public VecExpression<Vec3Sum<E1, E2>>这句。这一句用到了CRTP,VecExpression<T> 是一个基类模板,里面提供了统一的接口(比如 operator[]size() 等)。Vec3Sum<E1, E2> 继承自 VecExpression<Vec3Sum<E1, E2>>,表示Vec3Sum 是一种 VecExpression 表达式,具体的类型就是Vec3Sum

C++ 复制代码
typename conditional<E1::is_leaf, const E1&, const E1>::type u;
typename conditional<E2::is_leaf, const E2&, const E2>::type v;

std::conditional<condition, T, F> 是标准库里的一个 条件选择模板 :如果 condition == true,那么它的 ::type 就是 T,如果 condition == false,那么它的 ::type 就是 F。这里表明了如果是Vec类型的话,就让type为引用,如果是VecSum类型的话就让type为拷贝。

为什么叶子节点就存引用,内部节点就存拷贝呢。因为叶子节点的生命周期是由用户控制的,比如用户声明一个Vec对象:Vec a(1,2,3);,如果后续使用这个对象a进行计算的话,计算完之后a对象还是存在,所以可以安全引用。但是内部节点就不一样了:

C++ 复制代码
Vec a(1,2,3), b(4,5,6), c(7,8,9);

// 构建一个表达式
auto expr = a + b + c;

对于上述表达式,a + b + c的构造过程是这样的,首先a + b生成一个临时的Vec3Sum<Vec, Vec>临时对象,然后这个临时对象再和 c 做加法 → 生成一个新的Vec3Sum<Vec3Sum<Vec, Vec>, Vec> 对象。第一步里的 (a+b) 是个临时变量,用完就销毁了。如果我们在第二步里只存 (a+b) 的引用,等表达式求值时,这个临时对象早就没了 → 悬空引用,未定义行为。所以内部节点只可以拷贝,否则会挂掉。

C++ 复制代码
static constexpr bool IS_LEAF = false;

这表示Vec3Sum是内部节点,与上面的Vec3相对。

C++ 复制代码
double operator[](size_t i) const noexcept { 
    return u[i] + v[i]; 
}

对于上述运算符重载,因为Vec3Sum是内部节点,所以理应将左右子节点对应下标处元素相加。

接下来就是最重要的+运算符重载函数:

C++ 复制代码
template <typename E1, typename E2>
Vec3Sum<E1, E2> operator+(VecExpression<E1> const& u, VecExpression<E2> const& v) {
    return Vec3Sum<E1, E2>(*static_cast<const E1*>(&u), *static_cast<const E2*>(&v));
}

当编译器看到 a + b 时,若 a 的类型是 Vec,它能把 Vec 匹配到参数类型 VecExpression<E1> const&,从而推导出 E1 = Vec(同理 b 推导 E2)。所以 E1E2 实际上就是传入对象的"真实派生类型"。return Vec3Sum<E1, E2>(...):返回一个延迟求值(lazy)的表达式对象,这个 Vec3Sum 是一个轻量的"表达式节点",内部保存 uv

解释以下为什么要static_cast<const E1*>(&u)u 的静态类型是 const VecExpression<E1>&(基类引用)。但 Vec3Sum 的构造需要 E1(或 const E1&)类型的参数 ------ 所以必须把基类引用转换回派生类类型。所以static_cast<const E1*>(&u) 的作用:是把 &u(指向基类的指针)转为指向派生类型 E1 的指针(const E1*)。再用 * 解引用得到 const E1&(传给 Vec3Sum 构造函数)。

这里不需要使用dynamic_cast,因为CRTP的保证,想了解的可以看我之前的博客。

这个+号运算符主要就是构造延迟求值的表达式对象。使用Vec3 x = a + b + c时,会先构造出Vec对象,然后进入到上面的重载中,最终得到Vec3Sum<Vec3Sum<Vec3, Vec3>, Vec3>对象。

下面给出示例程序,参考维基百科关于表达式模板的代码:

C++ 复制代码
int main() {
    Vec3 v0 = {23.4,  12.5,  144.56};
    Vec3 v1 = {67.12, 34.8,  90.34};
    Vec3 v2 = {34.90, 111.9, 45.12};
    
    // Following assignment will call the ctor of Vec3 which accept type of 
    // `VecExpression<E> const&`. Then expand the loop body to 
    // a.elems[i] + b.elems[i] + c.elems[i]
    Vec3 sumOfVecType = v0 + v1 + v2; 

    for (size_t i = 0; i < sumOfVecType.size(); ++i) {
        std::println("{}", sumOfVecType[i]);
    }

    // To avoid creating any extra storage, other than v0, v1, v2
    // one can do the following (Tested with C++11 on GCC 5.3.0)
    auto sum = v0 + v1 + v2;
    for (size_t i = 0; i < sum.size(); ++i) {
        std::println("{}", sum[i]);
    }
    // Observe that in this case typeid(sum) will be Vec3Sum<Vec3Sum<Vec3, Vec3>, Vec3>
    // and this chaining of operations can go on.
}
相关推荐
莹Innsane3 小时前
使用 VictoriaLogs 存储和查询服务器日志
后端
karry_k3 小时前
BlockingQueue与SynchronousQueue
后端
前端伪大叔3 小时前
第15篇:Freqtrade策略不跑、跑错、跑飞?那可能是这几个参数没配好
前端·javascript·后端
Postkarte不想说话3 小时前
使用MSF生成反弹shell
后端
golang学习记3 小时前
Go 项目目录结构最佳实践:少即是多,实用至上
后端
合作小小程序员小小店3 小时前
web开发,在线%校园,论坛,社交管理%系统,基于html,css,python,django,mysql
数据库·后端·mysql·django·web app
用户4099322502123 小时前
PostgreSQL里的PL/pgSQL到底是啥?能让SQL从“说目标”变“讲步骤”?
后端·ai编程·trae
红烧code4 小时前
【Rust GUI开发入门】编写一个本地音乐播放器(9. 制作设置面板)
开发语言·后端·rust
你三大爷4 小时前
Safepoint的秘密探寻
java·后端