线性 RNN 能实现并行计算的核心,是线性递推关系满足结合律,从而可以通过 * 并行扫描(Parallel Scan)* 的分治算法,将原本 O (T) 的串行状态迭代,压缩为 O (log T) 的并行计算深度,充分利用 GPU 的并行算力加速训练。以下是完整的原理拆解:
一、先明确:线性 RNN 的数学定义
传统非线性 RNN 的状态更新包含激活函数,无法并行:
而线性 RNN (也叫线性状态空间模型)的状态转移是纯线性仿射变换,无非线性激活作用在状态传递路径上,标准形式为: 其中:
是时刻 t 的状态转移矩阵(可以是标量、对角矩阵,也可以是随输入动态变化的参数,比如 Mamba 的选择性门控)
是时刻 t 的输入项,由当前输入\(x_t\)线性投影得到
- 输出通常为
正是这个纯线性的递推结构,为并行化提供了数学基础。
二、核心前提:线性递推满足结合律
并行计算的本质是 "分而治之"------ 把长序列拆成多个子段,子段内部先并行计算,最后再合并结果。要让分治成立,运算必须满足结合律(就像加法、乘法可以并行求和 / 求积一样)。
1. 把单步操作 "打包"
我们把每个时间步的线性变换,封装成一个二元操作元 ,它的作用是: 输入前一时刻状态
,输出当前时刻状态
。
2. 定义复合算子
两个连续的操作元可以合并成一个等效的操作元。假设先执行第i步、再执行第j步,我们定义复合算子:
这个式子的物理意义很直观:
- 先做i步变换:
- 再做j步变换:
- 代入合并后:
,正好对应合并后的操作元。
3. 关键性质:算子满足结合律
对任意三个连续操作元,都有:
这意味着:无论我们怎么分组、按什么顺序合并相邻操作,最终的整体效果完全一致。结合律是所有分治并行算法的基石,也是线性 RNN 能并行计算的根本原因。
对比:传统非线性 RNN 因为有
、
等激活函数,无法写成线性仿射形式,也就定义不出满足结合律的复合算子,只能逐时刻串行计算。
三、并行扫描的具体执行过程
基于满足结合律的算子,我们可以用分治策略 计算全序列所有时刻的状态,算法分为「向上规约」和「向下传播」两个阶段,总计算深度为(T为序列长度)。
我们以长度为 8 的序列为例:
阶段 1:向上规约(Reduce)------ 自底向上合并子段
从最细的单步操作开始,逐层两两合并,得到各个区间的等效操作元:
- 第 1 层(步长 1):并行合并相邻单步,得到 4 个双步操作元:
- 第 2 层(步长 2):并行合并相邻双步,得到 2 个四步操作元:
- 第 3 层(步长 4):合并两个四步,得到整个序列的总操作元:
这个阶段只需要步,每一层的所有合并操作都可以完全并行执行。
阶段 2:向下传播(Downsweep)------ 计算每个位置的前缀结果
向上规约只得到了各个子段的内部等效操作,但我们还需要得到每个时刻 t 对应的前缀累积操作 (也就是从第 1 步到第 t 步的总操作元),才能代入初始状态算出所有
。
向下传播的逻辑是:
- 从左到右,把左侧区间的累积结果,作为 "初始偏移" 传递给右侧子段
- 右侧子段的每个位置,用左侧总操作元复合自身的前缀操作,就能得到全局的前缀结果
最终,每个时刻 t 都会得到对应的前缀操作元,代入初始状态
即可一步算出:
四、关键补充说明
-
计算量与加速效果 并行扫描的总计算量仍然是 O (T),和串行计算完全相同,但它把串行的 T 步迭代,转换成了 O (log T) 步的并行运算。在 GPU 的大规模并行算力下,墙钟时间会大幅缩短,序列越长加速效果越明显。
-
数值稳定性 线性递推理论上可以直接展开成前缀和形式(指数加权和),但长序列下会出现数值溢出(衰减系数的倒数爆炸)。而并行扫描的分治合并方式,数值稳定性远好于直接展开,是工业界实际采用的落地方案。
-
时变与时不变的兼容
- 如果
不随时间变化(比如 S4 这类时不变状态空间模型),还可以额外用 FFT 卷积加速(O (T log T));
- 如果
随输入动态变化(比如 Mamba 的选择性扫描、RWKV 的时间混合),FFT 方法不再适用,但并行扫描仍然可以高效工作,这也是它通用性更强的原因。
- 如果
-
训练与推理的差异
- 训练阶段:输入是完整序列,使用并行扫描充分利用 GPU 并行,大幅缩短训练时间;
- 推理阶段 :仍然按普通 RNN 的方式串行执行,只需要保存上一时刻的状态
,显存占用为 O (1),生成速度恒定,这是线性 RNN 对比 Transformer 的核心优势之一。
五、典型应用
这套线性 RNN + 并行扫描的架构,已经成为长序列建模的主流方向之一,代表性模型包括:
- Mamba(选择性线性 RNN,带输入依赖的动态 A/B 参数)
- RWKV(时间混合模块为对角线性 RNN)
- xLSTM(mLSTM 的矩阵记忆线性递推)
- LRU、S5、GateLoop 等线性状态空间模型