目录
- 第一部分:基础理论
- 第一章:绪论与历史脉络
- 第二章:线性时不变系统与经典状态空间模型
- 第三章:离散化理论------从连续到离散的桥梁
- 第四章:序列建模视角------卷积与递推的对偶性
- 第二部分:长程依赖与 HiPPO 理论
- 第五章:长程依赖问题------为什么 RNN 会遗忘
- 第六章:HiPPO 框架------用多项式投影逼近历史
- 第七章:从 HiPPO 到 LSSL------线性状态空间层
- 第三部分:S4 结构化参数化
- 第八章:S4 的核心思想------结构化参数化与对角化
- 第九章:S4 的高效计算------Cauchy 核与 FFT
- 第十章:S4D 与 S5------对角化变体
- 第四部分:选择性机制与 Mamba
- 第十一章:选择性状态空间------从 LTI 到输入依赖
- 第十二章:Mamba 架构------选择性扫描与硬件感知设计
- 第十三章:Mamba-2 与结构化状态空间对偶性
- 第五部分:理论分析与实践
- 第十四章:SSM 与 Transformer 的理论对比
- 第十五章:完整可运行代码实现
- 第十六章:实验、应用与未来方向
- 附录
第一部分:基础理论
第一章:绪论与历史脉络
1.1 序列建模的核心问题
序列数据无处不在:自然语言是由词组成的序列,语音是随时间变化的声波序列,视频是帧的序列,甚至 DNA 也是碱基对的序列。序列建模的核心问题是:
如何高效地捕捉序列元素之间的依赖关系,同时保持对超长序列的可扩展性?
这个问题看似简单,实则蕴含着深刻的数学挑战。序列中的依赖关系可以是:
- 局部的:相邻词之间的语法关系
- 中程的:段落内的指代消解
- 长程的:跨越数千个 token 的主题一致性
不同架构对这些依赖关系的建模能力差异巨大。
1.2 三种范式的演进
1.2.1 循环神经网络(RNN)范式
RNN 的核心思想是维护一个隐状态 hth_tht,在每个时间步通过递推更新:
ht=f(ht−1,xt)h_t = f(h_{t-1}, x_t)ht=f(ht−1,xt)
这一范式的历史可以追溯到 Elman (1990) 和 Jordan (1986) 的开创性工作。其优势在于:
- 恒定的内存占用 O(d)O(d)O(d),不随序列长度增长
- 理论上可以处理任意长度的序列
但其致命缺陷在于梯度消失与爆炸问题 (Bengio et al., 1994)。在反向传播通过时间(BPTT)中,梯度需要经过 TTT 次矩阵乘法:
∂hT∂h0=∏t=1T∂ht∂ht−1\frac{\partial h_T}{\partial h_0} = \prod_{t=1}^{T} \frac{\partial h_t}{\partial h_{t-1}}∂h0∂hT=t=1∏T∂ht−1∂ht
当 TTT 很大时,这个乘积要么指数增长(梯度爆炸),要么指数衰减(梯度消失),使得模型无法学习长程依赖。
LSTM(Hochreiter & Schmidhuber, 1997)和 GRU(Cho et al., 2014)通过门控机制部分缓解了这一问题,但并未从根本上解决。它们的本质仍是非线性递推,理论上的有效记忆长度仍然有限。
1.2.2 Transformer 范式
Transformer(Vaswani et al., 2017)彻底抛弃了递推结构,改用自注意力机制:
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk QKT)V
其核心优势是任意两点之间的路径长度为 O(1)O(1)O(1)------任何位置都可以直接关注任何其他位置,不存在信息衰减。
但代价是 O(n2)O(n^2)O(n2) 的时间和空间复杂度,以及 O(n2)O(n^2)O(n2) 的 KV Cache 内存占用。当序列长度 nnn 从 2K 增长到 128K 乃至 1M 时,这个二次成本变得不可接受。
1.2.3 状态空间模型(SSM)范式
SSM 范式试图在两者之间找到一个最优平衡点:
- 像 RNN 一样具有恒定的内存和线性的复杂度
- 像 Transformer 一样能够捕捉长程依赖
- 同时支持递推 和卷积两种计算模式
这一范式的理论基础来自经典控制论中的状态空间表示,但经过了从 HiPPO 到 S4 再到 Mamba 的一系列深刻改造,才成为现代序列建模的有力竞争者。
1.3 历史脉络
| 年份 | 里程碑 | 贡献 |
|---|---|---|
| 1960s | Kalman 滤波器 | 状态空间模型用于信号处理和控制 |
| 1990 | Elman Network | 简单循环网络用于序列建模 |
| 1997 | LSTM | 门控机制缓解梯度消失 |
| 2017 | Transformer | 自注意力机制 |
| 2020 | HiPPO (Gu et al.) | 用正交多项式理论初始化状态矩阵 |
| 2021 | LSSL (Gu et al.) | 将 HiPPO 与线性 SSM 结合 |
| 2022 | S4 (Gu et al.) | 结构化参数化实现高效计算 |
| 2022 | S4D (Gu et al.) | 对角化简化 |
| 2023 | Mamba (Gu & Dao) | 选择性机制 + 硬件感知扫描 |
| 2023 | Mamba-2 (Dao & Gu) | 结构化状态空间对偶性(SSD) |
| 2024 | Jamba / Mamba-2-Hybrid | SSM + Attention 混合架构 |
1.4 本文的组织
本文将从经典控制论中的状态空间模型出发,逐步构建出现代 SSM 的完整理论体系。我们不仅关注"怎么做",更关注"为什么这样做"------每一个设计选择背后的数学动机。
第二章:线性时不变系统与经典状态空间模型
2.1 从微分方程到状态空间表示
2.1.1 动机:描述动态系统
考虑一个最简单的物理系统------弹簧-阻尼器-质量系统(spring-mass-damper):
mx¨(t)+cx˙(t)+kx(t)=F(t)m\ddot{x}(t) + c\dot{x}(t) + kx(t) = F(t)mx¨(t)+cx˙(t)+kx(t)=F(t)
这是一个二阶常微分方程(ODE)。为了用一阶 ODE 系统来描述它,我们引入状态变量:
s1(t)=x(t),s2(t)=x˙(t)s_1(t) = x(t), \quad s_2(t) = \dot{x}(t)s1(t)=x(t),s2(t)=x˙(t)
则:
s˙1(t)=s2(t)\dot{s}_1(t) = s_2(t)s˙1(t)=s2(t)
s˙2(t)=−kms1(t)−cms2(t)+1mF(t)\dot{s}_2(t) = -\frac{k}{m}s_1(t) - \frac{c}{m}s_2(t) + \frac{1}{m}F(t)s˙2(t)=−mks1(t)−mcs2(t)+m1F(t)
写成矩阵形式:
ddts1s2=01−k/m−c/ms1s2+01/mF(t)\frac{d}{dt}\begin{bmatrix} s_1 \\ s_2 \end{bmatrix} = \begin{bmatrix} 0 & 1 \\ -k/m & -c/m \end{bmatrix}\begin{bmatrix} s_1 \\ s_2 \end{bmatrix} + \begin{bmatrix} 0 \\ 1/m \end{bmatrix}F(t)dtds1s2=0−k/m1−c/ms1s2+01/mF(t)
这就是状态空间表示的原型。
2.1.2 一般形式
一个连续时间线性时不变(LTI)系统的一般形式为:
h˙(t)=Ah(t)+Bx(t)(状态方程)\dot{h}(t) = Ah(t) + Bx(t) \quad \text{(状态方程)}h˙(t)=Ah(t)+Bx(t)(状态方程)
y(t)=Ch(t)+Dx(t)(观测方程)y(t) = Ch(t) + Dx(t) \quad \text{(观测方程)}y(t)=Ch(t)+Dx(t)(观测方程)
其中:
- h(t)∈RNh(t) \in \mathbb{R}^Nh(t)∈RN:隐状态(hidden state),编码了系统的所有历史信息
- x(t)∈Rx(t) \in \mathbb{R}x(t)∈R:输入(input),作用于系统的外部信号
- y(t)∈Ry(t) \in \mathbb{R}y(t)∈R:输出(output),我们能观测到的量
- A∈RN×NA \in \mathbb{R}^{N \times N}A∈RN×N:状态矩阵(state matrix),描述状态自身的演化规律
- B∈RN×1B \in \mathbb{R}^{N \times 1}B∈RN×1:输入矩阵(input matrix),描述输入如何影响状态
- C∈R1×NC \in \mathbb{R}^{1 \times N}C∈R1×N:输出矩阵(output matrix),描述如何从状态中读取输出
- D∈RD \in \mathbb{R}D∈R:直通项(feedthrough),输入对输出的直接影响
在序列建模中,通常省略直通项(D=0D = 0D=0),并加上一个残差连接:
h˙(t)=Ah(t)+Bx(t)\dot{h}(t) = Ah(t) + Bx(t)h˙(t)=Ah(t)+Bx(t)
y(t)=Ch(t)+x(t)y(t) = Ch(t) + x(t)y(t)=Ch(t)+x(t)
这样 DDD 被固定为恒等映射,确保输入信号可以直接传递到输出,而隐状态负责捕捉更高阶的动态。
2.1.3 状态空间表示的物理含义
AAA 矩阵的本质是描述隐状态之间的耦合关系。它的特征值决定了系统的固有行为:
设 AAA 的特征值为 λ1,λ2,...,λN\lambda_1, \lambda_2, \dots, \lambda_Nλ1,λ2,...,λN(可能为复数),则系统的自由响应为:
h(t)=∑i=1Ncieλitvih(t) = \sum_{i=1}^{N} c_i e^{\lambda_i t} v_ih(t)=i=1∑Ncieλitvi
其中 viv_ivi 是对应的特征向量,cic_ici 由初始条件决定。
- 若 Re(λi)<0\text{Re}(\lambda_i) < 0Re(λi)<0:对应模态指数衰减(稳定)
- 若 Re(λi)=0\text{Re}(\lambda_i) = 0Re(λi)=0:对应模态等幅振荡(临界稳定)
- 若 Re(λi)>0\text{Re}(\lambda_i) > 0Re(λi)>0:对应模态指数增长(不稳定)
对于序列建模,我们希望系统稳定(所有模态衰减),以便远处的输入信号的影响随时间自然消退,同时又不要衰减得太快,以保留长程信息。
2.2 线性时不变系统的解析解
2.2.1 矩阵指数
LTI 系统的自由响应(x(t)=0x(t) = 0x(t)=0)可以通过矩阵指数精确求解:
h(t)=eAth(0)h(t) = e^{At} h(0)h(t)=eAth(0)
其中矩阵指数定义为:
eAt=∑k=0∞(At)kk!=I+At+(At)22!+(At)33!+⋯e^{At} = \sum_{k=0}^{\infty} \frac{(At)^k}{k!} = I + At + \frac{(At)^2}{2!} + \frac{(At)^3}{3!} + \cdotseAt=k=0∑∞k!(At)k=I+At+2!(At)2+3!(At)3+⋯
当输入不为零时,解为:
h(t)=eAth(0)+∫0teA(t−τ)Bx(τ) dτh(t) = e^{At} h(0) + \int_0^t e^{A(t-\tau)} B x(\tau) \, d\tauh(t)=eAth(0)+∫0teA(t−τ)Bx(τ)dτ
这个积分称为卷积积分------状态是初始条件的衰减与输入信号的卷积之和。
2.2.2 输出的卷积形式
输出为:
y(t)=Ch(t)=CeAth(0)+∫0tCeA(t−τ)Bx(τ) dτy(t) = C h(t) = C e^{At} h(0) + \int_0^t C e^{A(t-\tau)} B x(\tau) \, d\tauy(t)=Ch(t)=CeAth(0)+∫0tCeA(t−τ)Bx(τ)dτ
如果我们从零初始条件开始(h(0)=0h(0) = 0h(0)=0),则:
y(t)=∫0tCeA(t−τ)B⏟脉冲响应 K(t−τ)x(τ) dτ=(K∗x)(t)y(t) = \int_0^t \underbrace{C e^{A(t-\tau)} B}_{\text{脉冲响应 } K(t-\tau)} x(\tau) \, d\tau = (K * x)(t)y(t)=∫0t脉冲响应 K(t−τ) CeA(t−τ)Bx(τ)dτ=(K∗x)(t)
其中 K(t)=CeAtBK(t) = Ce^{At}BK(t)=CeAtB 是系统的脉冲响应 (impulse response),也称为卷积核。
这意味着 LTI 系统完全由其脉冲响应刻画。这是一个极其深刻的结论------我们不需要知道内部状态如何演化,只需要知道脉冲响应,就能计算任意输入的输出。
2.2.3 脉冲响应的性质
对于 NNN 维 LTI 系统,脉冲响应 K(t)=CeAtBK(t) = Ce^{At}BK(t)=CeAtB 是 NNN 个指数函数的线性组合:
K(t)=∑i=1NαieλitK(t) = \sum_{i=1}^{N} \alpha_i e^{\lambda_i t}K(t)=i=1∑Nαieλit
其中 λi\lambda_iλi 是 AAA 的特征值,αi\alpha_iαi 由 A,B,CA, B, CA,B,C 共同决定。
这给了我们直觉:
- 如果 AAA 有大量不同的特征值,脉冲响应可以是任意复杂的
- 特征值的实部决定了每个模态的衰减率
- 特征值的虚部决定了振荡频率
2.3 多变量推广
在实际的序列建模中,输入和输出通常是高维的。将单变量推广到多变量:
h˙(t)=Ah(t)+Bx(t)\dot{h}(t) = Ah(t) + Bx(t)h˙(t)=Ah(t)+Bx(t)
y(t)=Ch(t)y(t) = Ch(t)y(t)=Ch(t)
其中:
- h(t)∈RNh(t) \in \mathbb{R}^Nh(t)∈RN:隐状态
- x(t)∈RDx(t) \in \mathbb{R}^Dx(t)∈RD:DDD 维输入
- y(t)∈RDy(t) \in \mathbb{R}^Dy(t)∈RD:DDD 维输出
- A∈RN×NA \in \mathbb{R}^{N \times N}A∈RN×N
- B∈RN×DB \in \mathbb{R}^{N \times D}B∈RN×D
- C∈RD×NC \in \mathbb{R}^{D \times N}C∈RD×N
在实践中,DDD 个通道通常是独立 的------每个通道有自己的 Bi∈RNB_i \in \mathbb{R}^NBi∈RN 和 Ci∈RNC_i \in \mathbb{R}^NCi∈RN,但共享同一个 AAA。这样:
yi(t)=CiTeAth(0)+∫0tCiTeA(t−τ)Bixi(τ) dτy_i(t) = C_i^T e^{At} h(0) + \int_0^t C_i^T e^{A(t-\tau)} B_i x_i(\tau) \, d\tauyi(t)=CiTeAth(0)+∫0tCiTeA(t−τ)Bixi(τ)dτ
每个通道有自己的卷积核 Ki(t)=CiTeAtBiK_i(t) = C_i^T e^{At} B_iKi(t)=CiTeAtBi,但状态转移矩阵 AAA 是共享的。这种共享确保了所有通道"看到"相同的动态特性,但可以以不同的方式与输入和输出交互。
2.4 连续时间 LTI 系统的性质总结
定理 2.1(LTI 系统的基本性质)
连续时间 LTI 系统 h˙(t)=Ah(t)+Bx(t)\dot{h}(t) = Ah(t) + Bx(t)h˙(t)=Ah(t)+Bx(t), y(t)=Ch(t)y(t) = Ch(t)y(t)=Ch(t) 具有以下性质:
-
线性性 :若输入 x1(t)x_1(t)x1(t) 产生输出 y1(t)y_1(t)y1(t),输入 x2(t)x_2(t)x2(t) 产生输出 y2(t)y_2(t)y2(t),则 αx1(t)+βx2(t)\alpha x_1(t) + \beta x_2(t)αx1(t)+βx2(t) 产生 αy1(t)+βy2(t)\alpha y_1(t) + \beta y_2(t)αy1(t)+βy2(t)。
-
时不变性 :若输入 x(t)x(t)x(t) 产生输出 y(t)y(t)y(t),则延迟输入 x(t−τ)x(t - \tau)x(t−τ) 产生延迟输出 y(t−τ)y(t - \tau)y(t−τ)。
-
因果性 :输出 y(t)y(t)y(t) 只依赖于 ttt 时刻及之前的输入。
-
卷积性 :输出等于输入与脉冲响应的卷积:y=K∗xy = K * xy=K∗x。
-
稳定性 (BIBO):当且仅当 AAA 的所有特征值实部为负时,有界输入产生有界输出。
第三章:离散化理论------从连续到离散的桥梁
3.1 为什么需要离散化?
虽然连续时间 SSM 在数学上非常优雅,但计算机处理的是离散的时间序列。我们需要将连续系统转换为等价的离散系统:
ht=Aˉht−1+Bˉxth_t = \bar{A} h_{t-1} + \bar{B} x_tht=Aˉht−1+Bˉxt
yt=Cˉhty_t = \bar{C} h_tyt=Cˉht
其中 Aˉ,Bˉ\bar{A}, \bar{B}Aˉ,Bˉ 是离散化后的参数,t=0,1,2,...t = 0, 1, 2, \dotst=0,1,2,... 是离散时间步。
关键问题 :如何从连续参数 (A,B,C)(A, B, C)(A,B,C) 推导出离散参数 (Aˉ,Bˉ,Cˉ)(\bar{A}, \bar{B}, \bar{C})(Aˉ,Bˉ,Cˉ)?
3.2 零阶保持(Zero-Order Hold, ZOH)
3.2.1 定义与推导
零阶保持是最常用的离散化方法。其基本假设是:在每个采样间隔内,输入信号保持恒定。
即对于 t∈[kΔ,(k+1)Δ)t \in [k\Delta, (k+1)\Delta)t∈[kΔ,(k+1)Δ),x(t)=xkx(t) = x_kx(t)=xk(常数)。
在这一假设下,连续状态方程在 kΔ,(k+1)Δk\\Delta, (k+1)\\DeltakΔ,(k+1)Δ 上的解为:
h((k+1)Δ)=eAΔh(kΔ)+∫kΔ(k+1)ΔeA((k+1)Δ−τ)B⋅xk dτh((k+1)\Delta) = e^{A\Delta} h(k\Delta) + \int_{k\Delta}^{(k+1)\Delta} e^{A((k+1)\Delta - \tau)} B \cdot x_k \, d\tauh((k+1)Δ)=eAΔh(kΔ)+∫kΔ(k+1)ΔeA((k+1)Δ−τ)B⋅xkdτ
令 σ=(k+1)Δ−τ\sigma = (k+1)\Delta - \tauσ=(k+1)Δ−τ,则积分变为:
∫0ΔeAσB dσ⋅xk\int_0^{\Delta} e^{A\sigma} B \, d\sigma \cdot x_k∫0ΔeAσBdσ⋅xk
因此:
Aˉ=eAΔ\bar{A} = e^{A\Delta}Aˉ=eAΔ
Bˉ=(∫0ΔeAσ dσ)B=A−1(eAΔ−I)B=A−1(Aˉ−I)B\bar{B} = \left(\int_0^{\Delta} e^{A\sigma} \, d\sigma\right) B = A^{-1}(e^{A\Delta} - I) B = A^{-1}(\bar{A} - I) BBˉ=(∫0ΔeAσdσ)B=A−1(eAΔ−I)B=A−1(Aˉ−I)B
输出矩阵不变:Cˉ=C\bar{C} = CCˉ=C。
3.2.2 ZOH 的计算
对于小规模矩阵,可以直接用 scipy.linalg.expm 计算矩阵指数。对于大规模对角矩阵,可以逐元素计算。
引理 3.1 :当 AAA 可逆时,Bˉ=A−1(Aˉ−I)B\bar{B} = A^{-1}(\bar{A} - I)BBˉ=A−1(Aˉ−I)B。
当 AAA 不可逆时(存在零特征值),需要使用极限形式或避免求逆:
Bˉ=(∫0ΔeAσdσ)B\bar{B} = \left(\int_0^{\Delta} e^{A\sigma} d\sigma\right) BBˉ=(∫0ΔeAσdσ)B
这个积分可以通过泰勒展开或 Padé 逼近来计算。
3.2.3 ZOH 的性质
定理 3.1(ZOH 等价性):ZOH 离散化保持了以下性质:
- 稳定性 :若 AAA 的所有特征值实部为负,则 Aˉ=eAΔ\bar{A} = e^{A\Delta}Aˉ=eAΔ 的所有特征值的模小于 1。
- 特征值映射 :Aˉ\bar{A}Aˉ 的特征值为 λˉi=eλiΔ\bar{\lambda}_i = e^{\lambda_i \Delta}λˉi=eλiΔ。
- 在采样点处精确:对于阶跃输入,ZOH 离散化的解在采样点处与连续解完全一致。
证明:
-
若 Re(λi)<0\text{Re}(\lambda_i) < 0Re(λi)<0,则 ∣eλiΔ∣=eRe(λi)Δ<1|e^{\lambda_i \Delta}| = e^{\text{Re}(\lambda_i) \Delta} < 1∣eλiΔ∣=eRe(λi)Δ<1。□\square□
-
设 Avi=λiviA v_i = \lambda_i v_iAvi=λivi,则 eAΔvi=eλiΔvie^{A\Delta} v_i = e^{\lambda_i \Delta} v_ieAΔvi=eλiΔvi。□\square□
-
ZOH 的构造保证了在 t=kΔt = k\Deltat=kΔ 处,离散解等于连续解。□\square□
3.3 双线性变换(Bilinear Transform / Tustin Method)
3.3.1 定义
双线性变换(也称为梯形法或 Tustin 方法)基于以下近似:
s=2Δ⋅z−1z+1s = \frac{2}{\Delta} \cdot \frac{z - 1}{z + 1}s=Δ2⋅z+1z−1
其中 sss 是连续域的拉普拉斯变量,zzz 是离散域的 Z 变换变量。
对于状态矩阵:
Aˉ=(I−Δ2A)−1(I+Δ2A)\bar{A} = \left(I - \frac{\Delta}{2}A\right)^{-1}\left(I + \frac{\Delta}{2}A\right)Aˉ=(I−2ΔA)−1(I+2ΔA)
Bˉ=(I−Δ2A)−1ΔB\bar{B} = \left(I - \frac{\Delta}{2}A\right)^{-1} \Delta BBˉ=(I−2ΔA)−1ΔB
3.3.2 性质
定理 3.2(双线性变换的稳定性保持):双线性变换将连续域的左半平面映射到离散域的单位圆内部。
证明:设 λ\lambdaλ 是 AAA 的特征值,Re(λ)<0\text{Re}(\lambda) < 0Re(λ)<0。则 Aˉ\bar{A}Aˉ 对应的特征值为:
λˉ=1+Δ2λ1−Δ2λ\bar{\lambda} = \frac{1 + \frac{\Delta}{2}\lambda}{1 - \frac{\Delta}{2}\lambda}λˉ=1−2Δλ1+2Δλ
令 λ=a+bi\lambda = a + biλ=a+bi,其中 a<0a < 0a<0,则:
∣λˉ∣2=(1+Δ2a)2+(Δ2b)2(1−Δ2a)2+(Δ2b)2|\bar{\lambda}|^2 = \frac{(1 + \frac{\Delta}{2}a)^2 + (\frac{\Delta}{2}b)^2}{(1 - \frac{\Delta}{2}a)^2 + (\frac{\Delta}{2}b)^2}∣λˉ∣2=(1−2Δa)2+(2Δb)2(1+2Δa)2+(2Δb)2
由于 a<0a < 0a<0,分子中 (1+Δ2a)2<(1−Δ2a)2(1 + \frac{\Delta}{2}a)^2 < (1 - \frac{\Delta}{2}a)^2(1+2Δa)2<(1−2Δa)2,因此 ∣λˉ∣<1|\bar{\lambda}| < 1∣λˉ∣<1。□\square□
3.3.3 ZOH 与双线性变换的比较
| 性质 | ZOH | 双线性变换 |
|---|---|---|
| 采样点精度 | 在采样点处精确 | 有频率畸变(warping) |
| 频率响应 | 保持相位特性 | 保持增益特性 |
| 稳定性 | 保持 | 保持 |
| 计算 | 需要矩阵指数 | 只需要矩阵求逆 |
| S4 中的使用 | 是(主要方法) | 否 |
3.4 步长参数 Δ\DeltaΔ 的含义
离散化步长 Δ\DeltaΔ 是一个极其重要的超参数。它控制着连续信号被采样的粒度:
- Δ\DeltaΔ 较大:粗粒度采样,相邻离散步之间的跨度大,模型"看"到的是信号的低频成分
- Δ\DeltaΔ 较小:细粒度采样,相邻离散步之间的跨度小,模型能捕捉信号的高频变化
在 SSM 的上下文中,Δ\DeltaΔ 的选择直接影响离散化后的 Aˉ\bar{A}Aˉ 矩阵:
Aˉ=eAΔ\bar{A} = e^{A\Delta}Aˉ=eAΔ
- 当 Δ→0\Delta \to 0Δ→0 时,Aˉ→I\bar{A} \to IAˉ→I,系统几乎不衰减,等价于恒等映射
- 当 Δ→∞\Delta \to \inftyΔ→∞ 时,Aˉ→0\bar{A} \to 0Aˉ→0,系统完全衰减,历史信息完全丢失
合适的 Δ\DeltaΔ 应该使得系统在有意义的时间尺度上衰减------足够慢以保留长程信息,又足够快以忘去过时的信息。
在 Mamba 中 ,Δ\DeltaΔ 不再是固定超参数,而是由输入数据动态生成的------这就是选择性机制的核心之一,我们将在第十一章详细讨论。
3.5 离散化后的频域分析
3.5.1 Z 变换
对离散 LTI 系统进行 Z 变换:
H(z)=Cˉ(zI−Aˉ)−1Bˉ+DH(z) = \bar{C}(zI - \bar{A})^{-1}\bar{B} + DH(z)=Cˉ(zI−Aˉ)−1Bˉ+D
这是一个有理函数(rational function),其极点是 Aˉ\bar{A}Aˉ 的特征值,零点取决于 B,CB, CB,C。
3.5.2 频率响应
令 z=ejωz = e^{j\omega}z=ejω(单位圆上的点),得到频率响应:
H(ejω)=Cˉ(ejωI−Aˉ)−1BˉH(e^{j\omega}) = \bar{C}(e^{j\omega}I - \bar{A})^{-1}\bar{B}H(ejω)=Cˉ(ejωI−Aˉ)−1Bˉ
∣H(ejω)∣|H(e^{j\omega})|∣H(ejω)∣ 描述了系统对不同频率正弦输入的增益。在序列建模中,这对应于模型对不同频率模式的敏感度。
3.6 一维输入情况的显式公式
当输入维度为 1 时,B∈RNB \in \mathbb{R}^NB∈RN,C∈RNC \in \mathbb{R}^NC∈RN,我们可以写出更显式的公式。
设 AAA 的特征分解为 A=VΛV−1A = V \Lambda V^{-1}A=VΛV−1(假设可对角化),则:
Aˉ=VeΛΔV−1\bar{A} = V e^{\Lambda \Delta} V^{-1}Aˉ=VeΛΔV−1
Kn=CAˉnBˉ=CV(eΛΔ)nV−1Bˉ=∑i=1NC~i(eλiΔ)nB~iK_n = C \bar{A}^n \bar{B} = C V (e^{\Lambda \Delta})^n V^{-1} \bar{B} = \sum_{i=1}^{N} \tilde{C}_i (e^{\lambda_i \Delta})^n \tilde{B}_iKn=CAˉnBˉ=CV(eΛΔ)nV−1Bˉ=i=1∑NC~i(eλiΔ)nB~i
其中 C~=CV\tilde{C} = C VC~=CV, B~=V−1Bˉ\tilde{B} = V^{-1}\bar{B}B~=V−1Bˉ, C~i,B~i\tilde{C}_i, \tilde{B}_iC~i,B~i 分别是第 iii 个分量。
记 λˉi=eλiΔ\bar{\lambda}_i = e^{\lambda_i \Delta}λˉi=eλiΔ,则:
Kn=∑i=1NC~iB~iλˉinK_n = \sum_{i=1}^{N} \tilde{C}_i \tilde{B}_i \bar{\lambda}_i^nKn=i=1∑NC~iB~iλˉin
这是一个指数和核 (exponential sum kernel)------NNN 个指数衰减/振荡模态的叠加。这正是 S4 高效计算的基础。
第四章:序列建模视角------卷积与递推的对偶性
4.1 离散 SSM 的两种等价计算模式
离散 LTI-SSM:
ht=Aˉht−1+Bˉxth_t = \bar{A} h_{t-1} + \bar{B} x_tht=Aˉht−1+Bˉxt
yt=Cˉhty_t = \bar{C} h_tyt=Cˉht
有两种等价的方式来计算输出序列 {y0,y1,...,yL−1}\{y_0, y_1, \dots, y_{L-1}\}{y0,y1,...,yL−1}。
4.1.1 递推模式(Recurrence Mode)
直接按时间步递推:
ht=Aˉht−1+Bˉxt,yt=Cˉhth_t = \bar{A} h_{t-1} + \bar{B} x_t, \quad y_t = \bar{C} h_tht=Aˉht−1+Bˉxt,yt=Cˉht
复杂度分析:
- 时间:O(BLN2)O(BLN^2)O(BLN2),其中 BBB 是 batch size,LLL 是序列长度,NNN 是状态维度
- 空间:O(BN)O(BN)O(BN)(只需要存储当前隐状态)
- 特点:串行------每个时间步依赖前一个,无法并行
递推模式适合自回归推理(inference),即一个一个 token 生成的场景。
4.1.2 卷积模式(Convolution Mode)
展开递推关系:
y0=CˉBˉx0y_0 = \bar{C}\bar{B} x_0y0=CˉBˉx0
y1=CˉAˉBˉx0+CˉBˉx1y_1 = \bar{C}\bar{A}\bar{B} x_0 + \bar{C}\bar{B} x_1y1=CˉAˉBˉx0+CˉBˉx1
y2=CˉAˉ2Bˉx0+CˉAˉBˉx1+CˉBˉx2y_2 = \bar{C}\bar{A}^2\bar{B} x_0 + \bar{C}\bar{A}\bar{B} x_1 + \bar{C}\bar{B} x_2y2=CˉAˉ2Bˉx0+CˉAˉBˉx1+CˉBˉx2
⋮\vdots⋮
yn=∑k=0nCˉAˉn−kBˉxk=∑k=0nKn−kxky_n = \sum_{k=0}^{n} \bar{C}\bar{A}^{n-k}\bar{B} x_k = \sum_{k=0}^{n} K_{n-k} x_kyn=k=0∑nCˉAˉn−kBˉxk=k=0∑nKn−kxk
其中卷积核为:
Kn=CˉAˉnBˉ,n=0,1,2,...,L−1K_n = \bar{C}\bar{A}^n\bar{B}, \quad n = 0, 1, 2, \dots, L-1Kn=CˉAˉnBˉ,n=0,1,2,...,L−1
这正是我们之前从连续系统推导出的脉冲响应的离散化形式。
因此,输出是输入与卷积核 KKK 的一维卷积:
y=K∗xy = K * xy=K∗x
复杂度分析(直接卷积):
- 时间:O(BL2N)O(BL^2N)O(BL2N)(对每个输出位置,求 O(L)O(L)O(L) 项的和)
- 空间:O(BL+BL)O(BL + BL)O(BL+BL)(存储输入和卷积核)
复杂度分析(FFT 卷积):
- 时间:O(BLlogL)O(BL \log L)O(BLlogL)(利用 FFT 将卷积转化为逐元素乘法)
- 空间:O(BL)O(BL)O(BL)
- 特点:完全并行------所有输出位置可以同时计算
卷积模式适合训练(前向传播),因为训练时整个输入序列已知,可以一次性计算所有输出。
4.2 核的计算问题
卷积模式的关键瓶颈在于如何高效计算卷积核 K={K0,K1,...,KL−1}K = \{K_0, K_1, \dots, K_{L-1}\}K={K0,K1,...,KL−1}。
朴素方法:逐次计算矩阵幂
Kn=CAˉnBˉK_n = C\bar{A}^n\bar{B}Kn=CAˉnBˉ
这需要 LLL 次矩阵-向量乘法,每次 O(N2)O(N^2)O(N2),总复杂度 O(LN2)O(LN^2)O(LN2)。
问题 :当 NNN 很大时(如 N=256N = 256N=256 或更大),O(N2)O(N^2)O(N2) 的矩阵-向量乘法成为瓶颈。
这就是 S4 要解决的核心计算问题------我们将在第九章详细讨论。
4.3 卷积-递推对偶性的深层含义
定理 4.1(卷积-递推对偶性):对于任何 LTI 系统,以下两个计算过程产生完全相同的输出:
- 递推模式 :ht=Aˉht−1+Bˉxth_t = \bar{A}h_{t-1} + \bar{B}x_tht=Aˉht−1+Bˉxt, yt=Cˉhty_t = \bar{C}h_tyt=Cˉht
- 卷积模式 :y=K∗xy = K * xy=K∗x, 其中 Kn=CˉAˉnBˉK_n = \bar{C}\bar{A}^n\bar{B}Kn=CˉAˉnBˉ
证明:由递推展开,hn=∑k=0nAˉn−kBˉxkh_n = \sum_{k=0}^{n} \bar{A}^{n-k}\bar{B}x_khn=∑k=0nAˉn−kBˉxk(假设 h0=0h_0 = 0h0=0),因此 yn=Cˉhn=∑k=0nCˉAˉn−kBˉxk=∑k=0nKn−kxky_n = \bar{C}h_n = \sum_{k=0}^{n}\bar{C}\bar{A}^{n-k}\bar{B}x_k = \sum_{k=0}^{n}K_{n-k}x_kyn=Cˉhn=∑k=0nCˉAˉn−kBˉxk=∑k=0nKn−kxk。□\square□
这个对偶性意味着:
- 训练用卷积模式(并行,利用 GPU 的吞吐量)
- 推理用递推模式(流式,只需要常数内存)
这是 SSM 相比 Transformer 的一个结构性优势:Transformer 的训练和推理都依赖注意力机制,都需要 O(n2)O(n^2)O(n2) 的计算;而 SSM 在训练和推理中可以分别使用最高效的方式。
4.4 从卷积核到全局依赖
让我们仔细分析卷积核 Kn=CˉAˉnBˉK_n = \bar{C}\bar{A}^n\bar{B}Kn=CˉAˉnBˉ 的衰减行为。
如果 Aˉ\bar{A}Aˉ 的特征值都在单位圆内(∣λˉi∣<1|\bar{\lambda}_i| < 1∣λˉi∣<1),则:
Kn=∑i=1NαiλˉinK_n = \sum_{i=1}^{N} \alpha_i \bar{\lambda}_i^nKn=i=1∑Nαiλˉin
随着 n→∞n \to \inftyn→∞,Kn→0K_n \to 0Kn→0。但衰减的速度取决于最大的 ∣λˉi∣|\bar{\lambda}_i|∣λˉi∣:
- 如果 ∣λˉmax∣|\bar{\lambda}_{\max}|∣λˉmax∣ 接近 1:衰减很慢,核具有长程记忆
- 如果 ∣λˉmax∣|\bar{\lambda}_{\max}|∣λˉmax∣ 很小:衰减很快,核只能记住最近的信息
这就是 SSM 的"记忆容量"------它由 Aˉ\bar{A}Aˉ 的谱性质决定,而 Aˉ\bar{A}Aˉ 又由连续参数 AAA 和步长 Δ\DeltaΔ 共同决定。
4.5 与 Transformer 注意力核的对比
Transformer 的注意力权重可以看作一个动态卷积核:
yi=∑jαijxj,αij=exp(qiTkj/d)∑jexp(qiTkj/d)y_i = \sum_j \alpha_{ij} x_j, \quad \alpha_{ij} = \frac{\exp(q_i^T k_j / \sqrt{d})}{\sum_j \exp(q_i^T k_j / \sqrt{d})}yi=j∑αijxj,αij=∑jexp(qiTkj/d )exp(qiTkj/d )
与 SSM 的卷积核对比:
| 性质 | SSM 卷积核 | 注意力核 |
|---|---|---|
| 参数化 | Kn=CˉAˉnBˉK_n = \bar{C}\bar{A}^n\bar{B}Kn=CˉAˉnBˉ | αij=softmax(qiTkj)\alpha_{ij} = \text{softmax}(q_i^T k_j)αij=softmax(qiTkj) |
| 是否依赖输入 | 否(LTI 情况) | 是 |
| 核形状 | 只依赖相对位置 n−kn - kn−k | 依赖绝对位置 iii 和 jjj |
| 计算复杂度 | O(LlogL)O(L \log L)O(LlogL)(FFT) | O(L2)O(L^2)O(L2) |
| 全局依赖 | 通过长衰减核间接实现 | 直接实现 |
SSM 的核心劣势在于:在 LTI 设定下,卷积核是固定的------不管输入是什么样的,核都不变。这意味着模型无法根据输入内容动态调整"看哪里"------它只能学习一个固定的衰减模式。
这正是 Mamba 的选择性机制要解决的问题,我们将在第十一章详细讨论。
4.6 高维输入的情况
对于 DDD 维输入(xt∈RDx_t \in \mathbb{R}^Dxt∈RD),如果每个通道独立处理(B∈RN×DB \in \mathbb{R}^{N \times D}B∈RN×D, C∈RD×NC \in \mathbb{R}^{D \times N}C∈RD×N),则:
yt,d=∑k=0tKt−k,d⋅xk,dy_{t,d} = \sum_{k=0}^{t} K_{t-k,d} \cdot x_{k,d}yt,d=k=0∑tKt−k,d⋅xk,d
其中 Kn,d=CdAˉnBdK_{n,d} = C_d \bar{A}^n B_dKn,d=CdAˉnBd 是第 ddd 个通道的卷积核。
总共有 DDD 个独立的卷积核,每个长度为 LLL。卷积核矩阵 K∈RD×LK \in \mathbb{R}^{D \times L}K∈RD×L 的形状是 D×LD \times LD×L。
在 S4 的实现中,这可以表示为一个二维卷积操作,其中:
- 第一个维度是通道维度(DDD 个独立的卷积)
- 第二个维度是时间维度(长度为 LLL 的卷积)
第二部分:长程依赖与 HiPPO 理论
第五章:长程依赖问题------为什么 RNN 会遗忘
5.1 梯度消失的严格分析
5.1.1 BPTT 中的梯度传播
考虑一个简单的 RNN:
ht=σ(Whht−1+Wxxt)h_t = \sigma(W_h h_{t-1} + W_x x_t)ht=σ(Whht−1+Wxxt)
其中 σ\sigmaσ 是激活函数(如 tanh)。
在反向传播通过时间(BPTT)中,损失 L\mathcal{L}L 对 hsh_shs(s<ts < ts<t)的梯度为:
∂L∂hs=∂L∂ht∏k=s+1t∂hk∂hk−1\frac{\partial \mathcal{L}}{\partial h_s} = \frac{\partial \mathcal{L}}{\partial h_t} \prod_{k=s+1}^{t} \frac{\partial h_k}{\partial h_{k-1}}∂hs∂L=∂ht∂Lk=s+1∏t∂hk−1∂hk
其中每一步的雅可比矩阵为:
∂hk∂hk−1=diag(σ′(zk))Wh\frac{\partial h_k}{\partial h_{k-1}} = \text{diag}(\sigma'(z_k)) W_h∂hk−1∂hk=diag(σ′(zk))Wh
其中 zk=Whhk−1+Wxxkz_k = W_h h_{k-1} + W_x x_kzk=Whhk−1+Wxxk。
因此:
∂L∂hs=∂L∂ht∏k=s+1tdiag(σ′(zk))Wh\frac{\partial \mathcal{L}}{\partial h_s} = \frac{\partial \mathcal{L}}{\partial h_t} \prod_{k=s+1}^{t} \text{diag}(\sigma'(z_k)) W_h∂hs∂L=∂ht∂Lk=s+1∏tdiag(σ′(zk))Wh
5.1.2 指数衰减的证明
取范数:
∥∂L∂hs∥≤∥∂L∂ht∥∏k=s+1t∥diag(σ′(zk))∥⋅∥Wh∥\left\|\frac{\partial \mathcal{L}}{\partial h_s}\right\| \leq \left\|\frac{\partial \mathcal{L}}{\partial h_t}\right\| \prod_{k=s+1}^{t} \|\text{diag}(\sigma'(z_k))\| \cdot \|W_h\| ∂hs∂L ≤ ∂ht∂L k=s+1∏t∥diag(σ′(zk))∥⋅∥Wh∥
对于 tanh 激活函数,∣σ′(z)∣=∣1−tanh2(z)∣≤1|\sigma'(z)| = |1 - \tanh^2(z)| \leq 1∣σ′(z)∣=∣1−tanh2(z)∣≤1,且在饱和区(∣z∣|z|∣z∣ 较大时)趋近于 0。
设 γ=maxk∥diag(σ′(zk))∥⋅∥Wh∥\gamma = \max_k \|\text{diag}(\sigma'(z_k))\| \cdot \|W_h\|γ=maxk∥diag(σ′(zk))∥⋅∥Wh∥,则:
∥∂L∂hs∥≤γt−s∥∂L∂ht∥\left\|\frac{\partial \mathcal{L}}{\partial h_s}\right\| \leq \gamma^{t-s} \left\|\frac{\partial \mathcal{L}}{\partial h_t}\right\| ∂hs∂L ≤γt−s ∂ht∂L
- 当 γ<1\gamma < 1γ<1 时:梯度以 γt−s\gamma^{t-s}γt−s 指数衰减(梯度消失)
- 当 γ>1\gamma > 1γ>1 时:梯度以 γt−s\gamma^{t-s}γt−s 指数增长(梯度爆炸)
5.1.3 有效记忆长度
定义 :有效记忆长度 LeffL_{\text{eff}}Leff 是指梯度衰减到初始值的 1/e1/e1/e 所需的时间步数。
γLeff=1e ⟹ Leff=1ln(1/γ)=−1lnγ\gamma^{L_{\text{eff}}} = \frac{1}{e} \implies L_{\text{eff}} = \frac{1}{\ln(1/\gamma)} = \frac{-1}{\ln \gamma}γLeff=e1⟹Leff=ln(1/γ)1=lnγ−1
对于典型的 RNN 设置(∥Wh∥≈1\|W_h\| \approx 1∥Wh∥≈1, ∣σ′∣≈0.5|\sigma'| \approx 0.5∣σ′∣≈0.5),γ≈0.5\gamma \approx 0.5γ≈0.5,则:
Leff≈1ln2≈1.44L_{\text{eff}} \approx \frac{1}{\ln 2} \approx 1.44Leff≈ln21≈1.44
这意味着梯度在仅约 1-2 个时间步后就衰减到 1/e1/e1/e!实际中,RNN 的有效记忆长度通常只有 10-20 个时间步。
5.1.4 数值验证
python
import numpy as np
import matplotlib.pyplot as plt
def simulate_gradient_decay(seq_len: int = 200, hidden_dim: int = 64, num_trials: int = 100) -> np.ndarray:
"""模拟 RNN 中梯度的指数衰减行为。
通过随机矩阵乘法的累积效应,展示梯度范数随时间步的衰减。
"""
decay_curves = np.zeros((num_trials, seq_len))
for trial in range(num_trials):
# 随机初始化权重矩阵,谱范数约 0.9(保证稳定但有衰减)
W = np.random.randn(hidden_dim, hidden_dim) * (0.9 / np.sqrt(hidden_dim))
# tanh 激活函数导数的平均值约为 0.6
sigma_prime_mean = 0.6
# 初始梯度
grad = np.ones(hidden_dim)
grad_norm_init = np.linalg.norm(grad)
for t in range(seq_len):
# 模拟一步 BPTT:乘以 diag(sigma') @ W
sigma_prime = np.random.uniform(0.3, 0.9, hidden_dim) # tanh' 的随机模拟
grad = (sigma_prime * (W.T @ grad))
grad_norm = np.linalg.norm(grad)
decay_curves[trial, t] = grad_norm / grad_norm_init if grad_norm_init > 0 else 0
return decay_curves
def demonstrate_gradient_decay():
"""演示 RNN 梯度衰减现象。"""
np.random.seed(42)
decay_curves = simulate_gradient_decay(seq_len=100, hidden_dim=64, num_trials=50)
mean_curve = np.mean(decay_curves, axis=0)
std_curve = np.std(decay_curves, axis=0)
print("RNN 梯度衰减模拟结果:")
print(f" 序列长度: 100")
print(f" 隐状态维度: 64")
print(f" 模拟次数: 50")
print(f" t=0 时梯度相对范数: {mean_curve[0]:.6f}")
print(f" t=10 时梯度相对范数: {mean_curve[10]:.6f}")
print(f" t=20 时梯度相对范数: {mean_curve[20]:.6f}")
print(f" t=50 时梯度相对范数: {mean_curve[50]:.6f}")
print(f" t=99 时梯度相对范数: {mean_curve[99]:.6f}")
print()
print("结论:梯度在约 10-20 步内衰减到接近零,解释了 RNN 的长程依赖困难。")
return mean_curve, std_curve
if __name__ == "__main__":
demonstrate_gradient_decay()
5.2 LSTM 和 GRU 的门控机制:缓解而非解决
5.2.1 LSTM 的遗忘门
LSTM 引入了遗忘门 ftf_tft 来控制信息的流动:
ft=σ(Wfht−1,xt+bf)f_t = \sigma(W_f h_{t-1}, x_t + b_f)ft=σ(Wfht−1,xt+bf)
it=σ(Wiht−1,xt+bi)i_t = \sigma(W_i h_{t-1}, x_t + b_i)it=σ(Wiht−1,xt+bi)
c~t=tanh(Wcht−1,xt+bc)\tilde{c}t = \tanh(W_c h_{t-1}, x_t + b_c)c~t=tanh(Wcht−1,xt+bc)
ct=ft⊙ct−1+it⊙c~tc_t = f_t \odot c{t-1} + i_t \odot \tilde{c}_tct=ft⊙ct−1+it⊙c~t
ot=σ(Woht−1,xt+bo)o_t = \sigma(W_o h_{t-1}, x_t + b_o)ot=σ(Woht−1,xt+bo)
ht=ot⊙tanh(ct)h_t = o_t \odot \tanh(c_t)ht=ot⊙tanh(ct)
关键观察:细胞状态 ctc_tct 的递推是线性的(加法而非乘法):
ct=ft⊙ct−1+it⊙c~tc_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_tct=ft⊙ct−1+it⊙c~t
当 ft=1f_t = 1ft=1(遗忘门全开)时,ct=ct−1+newc_t = c_{t-1} + \text{new}ct=ct−1+new,信息可以无损传递 。这使得 LSTM 的梯度可以通过 ctc_tct 的线性路径流通,不会指数衰减。
5.2.2 LSTM 的局限
但 LSTM 仍然有根本性的局限:
- 门控是输入依赖的 :ftf_tft 依赖于 xtx_txt,这意味着门控策略是针对每个输入独立决定的,缺乏全局规划
- 隐状态维度有限 :ct∈Rdc_t \in \mathbb{R}^dct∈Rd,ddd 通常较小(256-1024),信息容量有限
- 串行计算 :ctc_tct 依赖 ct−1c_{t-1}ct−1,无法并行
- 理论有效记忆长度仍然有限:虽然比简单 RNN 好很多,但实验表明 LSTM 的有效记忆长度约为 100-200 个时间步
5.3 长程依赖的数学刻画
5.3.1 Long Range Arena (LRA) 基准
Long Range Arena(Tay et al., 2021)是一个专门评估长程依赖建模能力的基准测试集,包含:
| 任务 | 序列长度 | 依赖类型 | 难度 |
|---|---|---|---|
| ListOps | 2000 | 层次结构 | 高 |
| Text | 4096 | 语义 | 中 |
| Retrieval | 4096 | 匹配 | 中 |
| Image | 1024 | 空间 | 中 |
| Pathfinder | 1024 | 路径追踪 | 高 |
| Path-X | 16384 | 超长路径 | 极高 |
Path-X 任务(序列长度 16384)是真正的试金石------大多数 Transformer 变体在这个任务上都失败了,因为 O(n2)O(n^2)O(n2) 的注意力使得内存不足。
5.3.2 信息论视角
从信息论的角度,长程依赖可以被理解为互信息的衰减:
I(Xt;Xt+k)=H(Xt)−H(Xt∣Xt+k)I(X_t; X_{t+k}) = H(X_t) - H(X_t | X_{t+k})I(Xt;Xt+k)=H(Xt)−H(Xt∣Xt+k)
对于马尔可夫过程,互信息随着 kkk 指数衰减。但自然语言中的长程依赖意味着 I(Xt;Xt+k)I(X_t; X_{t+k})I(Xt;Xt+k) 可以保持显著水平即使 kkk 很大。
一个理想的序列模型应该能够选择性地保留 与未来预测相关的信息,而选择性地丢弃无关信息。
5.4 状态空间模型的长程依赖理论
5.4.1 线性 SSM 的记忆能力
对于离散 LTI-SSM ht=Aˉht−1+Bˉxth_t = \bar{A} h_{t-1} + \bar{B} x_tht=Aˉht−1+Bˉxt,其"记忆"由卷积核 Kn=CˉAˉnBˉK_n = \bar{C}\bar{A}^n\bar{B}Kn=CˉAˉnBˉ 决定。
定理 5.1(LTI-SSM 的记忆衰减) :若 Aˉ\bar{A}Aˉ 的谱半径 ρ(Aˉ)<1\rho(\bar{A}) < 1ρ(Aˉ)<1,则:
∣Kn∣≤Mρ(Aˉ)n|K_n| \leq M \rho(\bar{A})^n∣Kn∣≤Mρ(Aˉ)n
对某个常数 M>0M > 0M>0。即卷积核以谱半径的指数速度衰减。
证明:由矩阵范数的性质,∥Aˉn∥≤Mρ(Aˉ)n\|\bar{A}^n\| \leq M \rho(\bar{A})^n∥Aˉn∥≤Mρ(Aˉ)n(Gelfand 公式),因此 ∣Kn∣=∣CˉAˉnBˉ∣≤∥Cˉ∥⋅∥Aˉn∥⋅∥Bˉ∥≤M′ρ(Aˉ)n|K_n| = |\bar{C}\bar{A}^n\bar{B}| \leq \|\bar{C}\| \cdot \|\bar{A}^n\| \cdot \|\bar{B}\| \leq M' \rho(\bar{A})^n∣Kn∣=∣CˉAˉnBˉ∣≤∥Cˉ∥⋅∥Aˉn∥⋅∥Bˉ∥≤M′ρ(Aˉ)n。□\square□
5.4.2 "选择性"的重要性
这个定理揭示了 LTI-SSM 的根本局限:卷积核的衰减是固定的,不依赖于输入内容。
考虑两个场景:
- 场景 A:"The cat sat on the mat. It was very fluffy." ------ "It" 指代 "cat",依赖距离为 6
- 场景 B:"The cat, which was adopted from the shelter last year and has been living happily ever since, sat on the mat." ------ 依赖距离为 20+
一个固定的衰减核无法同时很好地处理这两种情况。理想情况下,模型应该:
- 在场景 A 中快速衰减,关注最近的 "cat"
- 在场景 B 中保持记忆,跨越中间的修饰语
这就是选择性(selectivity)的核心动机------让衰减率依赖于输入内容。
第六章:HiPPO 框架------用多项式投影逼近历史
6.1 动机:如何最优地压缩历史信息?
给定一个连续信号 f:0,t→Rf: 0, t \to \mathbb{R}f:0,t→R,我们想要用一个有限维的状态向量 c(t)∈RNc(t) \in \mathbb{R}^Nc(t)∈RN 来最优地表示信号的历史。
"最优"在这里的含义是:在某个函数空间中,c(t)c(t)c(t) 所对应的近似函数 ftf_tft 是对历史 f(τ)f(\tau)f(τ)(τ≤t\tau \leq tτ≤t)在某种意义下的最佳逼近。
HiPPO(Hi gh-order P olynomial P rojection Operators)框架(Gu et al., 2020)正是回答了这个问题。
6.2 问题形式化
6.2.1 在线函数逼近
考虑一个实时到达的信号 f(t)f(t)f(t)。在时刻 ttt,我们希望找到一个多项式 ft(τ)f_t(\tau)ft(τ)(定义在 τ∈0,t\tau \in 0, tτ∈0,t 上),使得:
ft=argming∈PN∥f∣0,t−g∥μtf_t = \arg\min_{g \in \mathcal{P}N} \|f|{0,t} - g\|_{\mu_t}ft=argg∈PNmin∥f∣0,t−g∥μt
其中:
- PN\mathcal{P}_NPN 是次数不超过 N−1N-1N−1 的多项式空间
- ∥⋅∥μt\|\cdot\|_{\mu_t}∥⋅∥μt 是关于某个测度 μt\mu_tμt 的 L2L^2L2 范数
- f∣0,tf|_{0,t}f∣0,t 是 fff 在 0,t0, t0,t 上的限制
关键问题:如何选择测度 μt\mu_tμt?
6.2.2 测度的选择
不同的测度对应不同的"遗忘策略":
| 测度 | 定义 | 含义 |
|---|---|---|
| 均匀测度 | dμt(τ)=1tdτd\mu_t(\tau) = \frac{1}{t}d\taudμt(τ)=t1dτ | 历史上的所有时刻同等重要 |
| 衰减测度 | dμt(τ)=e−λ(t−τ)dτd\mu_t(\tau) = e^{-\lambda(t-\tau)}d\taudμt(τ)=e−λ(t−τ)dτ | 越近的时刻越重要 |
| 滑动窗口 | dμt(τ)=1τ∈\[t−w,t]dτd\mu_t(\tau) = \mathbb{1}\\tau \\in \[t-w, t] d\taudμt(τ)=1τ∈\[t−w,t]dτ | 只关心最近 www 个时间单位 |
| 阶梯测度 | dμt(τ)=∑kwkδ(τ−tk)d\mu_t(\tau) = \sum_{k} w_k \delta(\tau - t_k)dμt(τ)=∑kwkδ(τ−tk) | 只关心特定时刻 |
HiPPO 的核心贡献是:对于一大类测度,这个在线逼近问题有闭式解,并且解可以表示为一个线性 ODE。
6.3 HiPPO-LegS(缩放勒让德测度)
6.3.1 测度定义
HiPPO-LegS 使用缩放勒让德测度(Scaled Legendre Measure):
μt(s)=1t10,t(τ)\mu_t^{(s)} = \frac{1}{t} \mathbb{1}_{0, t}(\tau)μt(s)=t110,t(τ)
即在 0,t0, t0,t 上的均匀测度。"S" 代表 "Scaled"(随 ttt 缩放)。
6.3.2 多项式基
在 0,t0, t0,t 上定义缩放勒让德多项式 。令 s=τ/t∈0,1s = \tau / t \in 0, 1s=τ/t∈0,1,则标准勒让德多项式为:
P0(s)=1P_0(s) = 1P0(s)=1
P1(s)=2s−1P_1(s) = 2s - 1P1(s)=2s−1
P2(s)=12(3s2−6s+1)P_2(s) = \frac{1}{2}(3s^2 - 6s + 1)P2(s)=21(3s2−6s+1)
Pn(s)=1n!dndsn(s2−1)nP_n(s) = \frac{1}{n!}\frac{d^n}{ds^n}(s^2 - 1)^nPn(s)=n!1dsndn(s2−1)n
这些多项式在 0,10, 10,1 上关于 Lebesgue 测度正交:
∫01Pm(s)Pn(s) ds=12n+1δmn\int_0^1 P_m(s) P_n(s) \, ds = \frac{1}{2n + 1} \delta_{mn}∫01Pm(s)Pn(s)ds=2n+11δmn
6.3.3 投影系数
信号 fff 在多项式基上的投影系数为:
cn(t)=(2n+1)∫0tf(τ)Pn(τt)dτtc_n(t) = (2n + 1) \int_0^t f(\tau) P_n\left(\frac{\tau}{t}\right) \frac{d\tau}{t}cn(t)=(2n+1)∫0tf(τ)Pn(tτ)tdτ
这些系数随 ttt 变化,我们想知道它们的演化规律。
6.3.4 核心定理
定理 6.1(HiPPO-LegS ODE) :对于缩放勒让德测度,投影系数 c(t)=c0(t),c1(t),...,cN−1(t)Tc(t) = c_0(t), c_1(t), \\dots, c_{N-1}(t)^Tc(t)=c0(t),c1(t),...,cN−1(t)T 满足以下线性 ODE:
ddtc(t)=−1tAc(t)+1tBf(t)\frac{d}{dt} c(t) = -\frac{1}{t} A c(t) + \frac{1}{t} B f(t)dtdc(t)=−t1Ac(t)+t1Bf(t)
其中:
Ank={(2n+1)1/2(2k+1)1/2if n>kn+1if n=k0if n<kA_{nk} = \begin{cases} (2n + 1)^{1/2}(2k + 1)^{1/2} & \text{if } n > k \\ n + 1 & \text{if } n = k \\ 0 & \text{if } n < k \end{cases}Ank=⎩ ⎨ ⎧(2n+1)1/2(2k+1)1/2n+10if n>kif n=kif n<k
Bn=(2n+1)1/2B_n = (2n + 1)^{1/2}Bn=(2n+1)1/2
即 AAA 是一个下三角矩阵加上对角项。
证明概要:
将 cn(t)=(2n+1)∫0tf(τ)Pn(τ/t)dτ/tc_n(t) = (2n+1) \int_0^t f(\tau) P_n(\tau/t) d\tau/tcn(t)=(2n+1)∫0tf(τ)Pn(τ/t)dτ/t 对 ttt 求导,利用勒让德多项式的递推关系:
(2n+1)sPn(s)=(n+1)Pn+1(s)+nPn−1(s)(2n + 1) s P_n(s) = (n + 1) P_{n+1}(s) + n P_{n-1}(s)(2n+1)sPn(s)=(n+1)Pn+1(s)+nPn−1(s)
以及正交性条件,可以推导出 ODE 的闭式表达。具体推导较为繁琐,涉及勒让德多项式的以下关键性质:
- 递推关系 :(2n+1)sPn(s)=(n+1)Pn+1(s)+nPn−1(s)(2n+1)s P_n(s) = (n+1)P_{n+1}(s) + n P_{n-1}(s)(2n+1)sPn(s)=(n+1)Pn+1(s)+nPn−1(s)
- 正交性 :∫01Pm(s)Pn(s)ds=δmn2n+1\int_0^1 P_m(s) P_n(s) ds = \frac{\delta_{mn}}{2n+1}∫01Pm(s)Pn(s)ds=2n+1δmn
- 导数公式 :Pn′(s)=∑k=0n−1(2k+1)Pk(s)P_n'(s) = \sum_{k=0}^{n-1} (2k+1) P_k(s)Pn′(s)=∑k=0n−1(2k+1)Pk(s)(对 n>0n > 0n>0)
- 边界值 :Pn(0)=(−1)nP_n(0) = (-1)^nPn(0)=(−1)n, Pn(1)=1P_n(1) = 1Pn(1)=1
利用这些性质,对 cn(t)c_n(t)cn(t) 关于 ttt 求导并整理,即可得到 HiPPO-LegS 的 ODE 形式。□\square□
6.3.5 简化形式
通过变量替换 c(t)→t1/2c(t)c(t) \to t^{1/2} c(t)c(t)→t1/2c(t),可以将含 1/t1/t1/t 的 ODE 化为标准的 LTI 形式:
ddtc(t)=Ac(t)+Bf(t)\frac{d}{dt} c(t) = A c(t) + B f(t)dtdc(t)=Ac(t)+Bf(t)
(注意这里 AAA 矩阵略有不同,但结构相同。)
这就是将在线函数逼近问题转化为线性状态空间模型的关键洞察。
6.4 HiPPO-LegT(平移勒让德测度)
6.4.1 测度定义
HiPPO-LegT 使用平移勒让德测度(Translated Legendre Measure):
μt(T)=1t−θ,t(τ)dτ\mu_t^{(T)} = \mathbb{1}_{t - \\theta, t}(\tau) d\tauμt(T)=1t−θ,t(τ)dτ
即一个固定宽度 θ\thetaθ 的滑动窗口。"T" 代表 "Translated"。
6.4.2 核心结果
定理 6.2(HiPPO-LegT ODE):对于平移勒让德测度,投影系数满足:
ddtc(t)=ATc(t)+BTf(t)\frac{d}{dt} c(t) = A_{\text{T}} c(t) + B_{\text{T}} f(t)dtdc(t)=ATc(t)+BTf(t)
其中 ATA_{\text{T}}AT 是一个常数矩阵 (不依赖 ttt),其元素为:
(AT)nk={2n+1θ(−1)n−k(2k+1)1/2(2n+1)1/2if n≥k−n+1θif n=k0if n<k(A_{\text{T}})_{nk} = \begin{cases} \frac{2n+1}{\theta}(-1)^{n-k}(2k+1)^{1/2}(2n+1)^{1/2} & \text{if } n \geq k \\ -\frac{n+1}{\theta} & \text{if } n = k \\ 0 & \text{if } n < k \end{cases}(AT)nk=⎩ ⎨ ⎧θ2n+1(−1)n−k(2k+1)1/2(2n+1)1/2−θn+10if n≥kif n=kif n<k
BTB_{\text{T}}BT 的元素为:
(BT)n=(2n+1)1/2θ(−1)n(B_{\text{T}})_n = \frac{(2n+1)^{1/2}}{\theta}(-1)^n(BT)n=θ(2n+1)1/2(−1)n
HiPPO-LegT 的优势在于 AAA 矩阵不依赖时间 ttt,因此可以直接用标准的 LTI-SSM 来表示。它有一个明确的"遗忘窗口" θ\thetaθ。
6.5 HiPPO-LagT(拉盖尔测度)
6.5.1 测度定义
HiPPO-LagT 使用指数衰减测度(由拉盖尔多项式正交化):
μt(L)=e−(τ−t)1τ≥t−θ(τ)dτ\mu_t^{(L)} = e^{-(\tau - t)} \mathbb{1}_{\\tau \\geq t - \\theta}(\tau) d\tauμt(L)=e−(τ−t)1τ≥t−θ(τ)dτ
6.5.2 核心结果
定理 6.3(HiPPO-LagT ODE):对于拉盖尔测度,投影系数满足:
ddtc(t)=−12c(t)+ALc(t)+BLf(t)\frac{d}{dt} c(t) = -\frac{1}{2} c(t) + A_{\text{L}} c(t) + B_{\text{L}} f(t)dtdc(t)=−21c(t)+ALc(t)+BLf(t)
其中 ALA_{\text{L}}AL 是一个下双对角矩阵:
(AL)nk={(n+1)1/2if k=n+10otherwise(A_{\text{L}})_{nk} = \begin{cases} (n+1)^{1/2} & \text{if } k = n+1 \\ 0 & \text{otherwise} \end{cases}(AL)nk={(n+1)1/20if k=n+1otherwise
6.6 HiPPO 矩阵的数值实现
python
import numpy as np
import scipy.linalg
def hippo_legs_matrix(N: int) -> tuple[np.ndarray, np.ndarray]:
"""构造 HiPPO-LegS(缩放勒让德)矩阵。
Args:
N: 状态维度(多项式阶数)
Returns:
A: (N, N) HiPPO 矩阵
B: (N, 1) 输入矩阵
"""
P = np.arange(1, N + 1).astype(np.float64) # [1, 2, ..., N]
Q = np.arange(1, N + 1).astype(np.float64)
# A_{nk} = (2n+1)^{1/2} (2k+1)^{1/2} if n > k, else n+1 if n==k, else 0
A = np.zeros((N, N))
for n in range(N):
for k in range(N):
if n > k:
A[n, k] = np.sqrt(2 * n + 1) * np.sqrt(2 * k + 1)
elif n == k:
A[n, k] = n + 1
B = np.sqrt(2 * np.arange(1, N + 1) - 1).reshape(-1, 1)
# 修正 B: (2n+1)^{1/2}, n 从 0 开始
B = np.sqrt(2 * np.arange(N) + 1).reshape(-1, 1).astype(np.float64)
return A, B
def hippo_legt_matrix(N: int, theta: float = 1.0) -> tuple[np.ndarray, np.ndarray]:
"""构造 HiPPO-LegT(平移勒让德)矩阵。
Args:
N: 状态维度
theta: 滑动窗口宽度
Returns:
A: (N, N) HiPPO-LegT 矩阵
B: (N, 1) 输入矩阵
"""
A = np.zeros((N, N))
for n in range(N):
for k in range(N):
if n > k:
A[n, k] = (2 * n + 1) * (-1) ** (n - k) * np.sqrt(2 * k + 1) * np.sqrt(2 * n + 1) / theta
elif n == k:
A[n, k] = -(n + 1) / theta
B = np.array([((-1) ** n) * np.sqrt(2 * n + 1) / theta for n in range(N)]).reshape(-1, 1)
return A, B
def hippo_lagt_matrix(N: int) -> tuple[np.ndarray, np.ndarray]:
"""构造 HiPPO-LagT(拉盖尔)矩阵。
Args:
N: 状态维度
Returns:
A: (N, N) HiPPO-LagT 矩阵
B: (N, 1) 输入矩阵
"""
A = np.zeros((N, N))
for n in range(N):
if n + 1 < N:
A[n, n + 1] = np.sqrt(n + 1)
A[n, n] = -0.5
B = np.array([np.sqrt(2 * n + 1) for n in range(N)]).reshape(-1, 1)
return A, B
def demonstrate_hippo_matrices():
"""展示不同 HiPPO 矩阵的结构。"""
N = 8
print("=" * 60)
print("HiPPO 矩阵结构展示 (N = 8)")
print("=" * 60)
A_legs, B_legs = hippo_legs_matrix(N)
print("\n--- HiPPO-LegS (缩放勒让德) ---")
print("A 矩阵:")
print(np.array2string(A_legs, precision=3, suppress_small=True))
print("B 向量:")
print(np.array2string(B_legs.flatten(), precision=3))
A_legt, B_legt = hippo_legt_matrix(N, theta=1.0)
print("\n--- HiPPO-LegT (平移勒让德, theta=1.0) ---")
print("A 矩阵:")
print(np.array2string(A_legt, precision=3, suppress_small=True))
print("B 向量:")
print(np.array2string(B_legt.flatten(), precision=3))
A_lagt, B_lagt = hippo_lagt_matrix(N)
print("\n--- HiPPO-LagT (拉盖尔) ---")
print("A 矩阵:")
print(np.array2string(A_lagt, precision=3, suppress_small=True))
print("B 向量:")
print(np.array2string(B_lagt.flatten(), precision=3))
# 检验 A 的特征值
print("\n--- A 矩阵的特征值 ---")
for name, A in [("LegS", A_legs), ("LegT", A_legt), ("LagT", A_lagt)]:
eigvals = np.linalg.eigvals(A)
print(f" {name}: max Re(lambda) = {np.max(np.real(eigvals)):.4f}, "
f"min Re(lambda) = {np.min(np.real(eigvals)):.4f}")
if __name__ == "__main__":
demonstrate_hippo_matrices()
6.7 HiPPO 的理论意义
6.7.1 从逼近理论到 SSM
HiPPO 的核心贡献是建立了在线函数逼近 和线性状态空间模型之间的桥梁:
在线函数逼近⏟信号处理视角⟷LTI-SSM⏟序列建模视角\underbrace{\text{在线函数逼近}}{\text{信号处理视角}} \quad \longleftrightarrow \quad \underbrace{\text{LTI-SSM}}{\text{序列建模视角}}信号处理视角 在线函数逼近⟷序列建模视角 LTI-SSM
- 逼近理论告诉我们应该学习什么(最优地压缩历史)
- SSM 告诉我们怎么计算(通过矩阵运算)
6.7.2 HiPPO 初始化 vs 随机初始化
定理 6.4(HiPPO 初始化的优势) :使用 HiPPO 矩阵初始化 AAA 的 SSM 在 Path-X 等超长程任务上显著优于随机初始化。
直觉上,HiPPO 矩阵具有以下良好性质:
- 特征值分布在左半平面:保证系统稳定
- 下三角结构:不同阶多项式之间有层次化的信息流动
- 特定的谱结构:使得脉冲响应具有多尺度的衰减特性
这些性质使得模型从一开始就具有"合理的"记忆行为,而不是依赖训练去发现这些结构。
6.7.3 高阶多项式逼近的误差
定理 6.5(多项式逼近误差) :设 fff 是 0,t0, t0,t 上的 Lipschitz 连续函数,NNN 阶 HiPPO-LegS 逼近的 L2L^2L2 误差为:
∥f−ft(N)∥L2≤CN\|f - f_t^{(N)}\|_{L^2} \leq \frac{C}{N}∥f−ft(N)∥L2≤NC
其中 CCC 是与 fff 的 Lipschitz 常数和 ttt 相关的常数。
这意味着增加状态维度 NNN 可以系统性地减小逼近误差 ,且收敛速率为 O(1/N)O(1/N)O(1/N)。
第七章:从 HiPPO 到 LSSL------线性状态空间层
7.1 LSSL 的基本框架
LSSL(L inear S tate S pace Layer, Gu et al., 2021)是将 HiPPO 理论与线性 SSM 结合的第一个成功尝试。
其基本框架是:
- 用 HiPPO 矩阵初始化 AAA:确保模型从合理的记忆行为开始
- 将 B,CB, CB,C 设为可学习参数:让模型适应具体的任务
- 用 ZOH 离散化:将连续 SSM 转化为离散序列模型
- 用卷积模式训练,递推模式推理:利用对偶性获得最佳效率
7.1.1 LSSL 的参数
- AAA:由 HiPPO-LegS 初始化,可学习(或固定)
- B∈RN×1B \in \mathbb{R}^{N \times 1}B∈RN×1:可学习
- C∈R1×NC \in \mathbb{R}^{1 \times N}C∈R1×N:可学习
- Δ\DeltaΔ:离散化步长,可学习标量
- 输入输出维度均为 1(多通道通过并行独立的 LSSL 实现)
7.1.2 LSSL 的计算流程
训练(卷积模式):
- 用 HiPPO 初始化 AAA
- 计算离散化参数 Aˉ=eAΔ\bar{A} = e^{A\Delta}Aˉ=eAΔ, Bˉ=(AΔ)−1(eAΔ−I)(BΔ)\bar{B} = (A\Delta)^{-1}(e^{A\Delta} - I)(B\Delta)Bˉ=(AΔ)−1(eAΔ−I)(BΔ)
- 计算卷积核 K=CBˉ,CAˉBˉ,CAˉ2Bˉ,...,CAˉL−1BˉK = C\\bar{B}, C\\bar{A}\\bar{B}, C\\bar{A}\^2\\bar{B}, \\dots, C\\bar{A}\^{L-1}\\bar{B}K=CBˉ,CAˉBˉ,CAˉ2Bˉ,...,CAˉL−1Bˉ
- 用 FFT 计算 y=K∗xy = K * xy=K∗x
推理(递推模式):
- 使用相同的 Aˉ,Bˉ,C\bar{A}, \bar{B}, CAˉ,Bˉ,C
- 递推计算 ht=Aˉht−1+Bˉxth_t = \bar{A}h_{t-1} + \bar{B}x_tht=Aˉht−1+Bˉxt, yt=Chty_t = Ch_tyt=Cht
7.2 LSSL 的实验结果
7.2.1 在长程基准上的表现
LSSL 在当时(2021年)取得了一系列突破性的结果:
| 任务 | 序列长度 | Transformer | LSTM | LSSL |
|---|---|---|---|---|
| ListOps | 2000 | 36.4 | 26.0 | 60.1 |
| Text | 4096 | 64.3 | 62.4 | 85.4 |
| Retrieval | 4096 | 57.5 | 57.1 | 89.2 |
| Image (32x32) | 1024 | 42.4 | 33.2 | 87.7 |
| Pathfinder | 1024 | 71.4 | 62.8 | 92.2 |
这些结果表明,HiPPO 初始化的 SSM 能够有效地捕捉长程依赖。
7.2.2 LSSL 的局限
但 LSSL 有几个关键局限:
- AAA 矩阵的计算瓶颈 :eAΔe^{A\Delta}eAΔ 需要 O(N3)O(N^3)O(N3) 的矩阵指数运算
- 卷积核的计算 :CAˉnBC\bar{A}^n BCAˉnB 需要 O(N2)O(N^2)O(N2) 的矩阵-向量乘法
- Aˉ\bar{A}Aˉ 是稠密矩阵 :即使 AAA 有特殊结构,eAΔe^{A\Delta}eAΔ 通常是稠密的
- 梯度计算困难 :eAΔe^{A\Delta}eAΔ 关于 AAA 的梯度涉及矩阵指数的微分
这些局限促使了 S4 的诞生------通过结构化参数化来解决计算瓶颈。
7.3 从 LSSL 到 S4 的演进
LSSL 的核心问题可以总结为:
HiPPO 给了我们好的 AAA 矩阵,但没有给我们高效计算 eAΔe^{A\Delta}eAΔ 的方法。
S4 的核心洞察是:
通过选择特殊的 AAA 矩阵结构(如正规加低秩,NPLR),可以将矩阵指数的计算从 O(N3)O(N^3)O(N3) 降到 O(N)O(N)O(N) 或 O(NlogN)O(N \log N)O(NlogN)。
这需要我们在 HiPPO 矩阵的数学结构上下功夫------下一章将详细展开。
7.4 LSSL 的完整实现
python
import numpy as np
import scipy.linalg
class SimpleLSSL:
"""简单的 LSSL(Linear State Space Layer)实现。
用于演示 HiPPO 初始化 + 离散化 + 卷积计算的基本流程。
注意:这是一个教学用的简化实现,不做反向传播。
"""
def __init__(self, N: int = 64, dt: float = 0.001):
"""
Args:
N: 状态维度
dt: 离散化步长 Delta
"""
self.N = N
self.dt = dt
# 用 HiPPO-LegS 初始化 A
self.A, self.B_cont = self._hippo_legs(N)
self.C = np.random.randn(1, N) * (1.0 / np.sqrt(N))
# ZOH 离散化
self._discretize()
@staticmethod
def _hippo_legs(N: int) -> tuple[np.ndarray, np.ndarray]:
"""构造 HiPPO-LegS 矩阵。"""
A = np.zeros((N, N))
for n in range(N):
for k in range(N):
if n > k:
A[n, k] = np.sqrt(2 * n + 1) * np.sqrt(2 * k + 1)
elif n == k:
A[n, k] = n + 1
B = np.sqrt(2 * np.arange(N) + 1).reshape(-1, 1).astype(np.float64)
return A, B
def _discretize(self):
"""使用 ZOH 进行离散化。"""
dtA = self.A * self.dt
# bar_A = expm(A * dt)
self.bar_A = scipy.linalg.expm(dtA)
# bar_B = A^{-1} (bar_A - I) B dt
# 当 A 可能奇异时,使用积分形式
self.bar_B = np.linalg.solve(
self.A,
(self.bar_A - np.eye(self.N)) @ self.B_cont
) * self.dt # 注意:这里实际上不需要再乘 dt,因为 B_cont 已经是 B
# 修正:bar_B = A^{-1} (exp(A*dt) - I) B
# 不需要额外乘 dt
self.bar_B = np.linalg.solve(
self.A,
(self.bar_A - np.eye(self.N)) @ self.B_cont
)
def compute_kernel(self, L: int) -> np.ndarray:
"""计算长度为 L 的卷积核。
Args:
L: 序列长度
Returns:
K: (L,) 卷积核
"""
K = np.zeros(L)
A_power_B = self.bar_B.copy() # bar_A^0 @ bar_B = bar_B
for n in range(L):
K[n] = (self.C @ A_power_B).item()
A_power_B = self.bar_A @ A_power_B
return K
def forward_conv(self, x: np.ndarray) -> np.ndarray:
"""卷积模式前向传播。
Args:
x: (L,) 输入序列
Returns:
y: (L,) 输出序列
"""
L = len(x)
K = self.compute_kernel(L)
# 使用 FFT 计算卷积
X = np.fft.rfft(x, n=2 * L)
K_f = np.fft.rfft(K, n=2 * L)
y = np.fft.irfft(X * K_f, n=2 * L)[:L]
return y
def forward_recur(self, x: np.ndarray) -> np.ndarray:
"""递推模式前向传播。
Args:
x: (L,) 输入序列
Returns:
y: (L,) 输出序列
"""
L = len(x)
h = np.zeros((self.N, 1))
y = np.zeros(L)
for t in range(L):
h = self.bar_A @ h + self.bar_B * x[t]
y[t] = (self.C @ h).item()
return y
def demonstrate_lssl():
"""演示 LSSL 的卷积模式和递推模式。"""
np.random.seed(42)
N = 32
L = 256
dt = 0.01
model = SimpleLSSL(N=N, dt=dt)
# 随机输入
x = np.random.randn(L)
# 两种模式
y_conv = model.forward_conv(x)
y_recur = model.forward_recur(x)
# 比较
diff = np.max(np.abs(y_conv - y_recur))
print(f"LSSL 双模式验证 (N={N}, L={L})")
print(f" 卷积模式 vs 递推模式最大差异: {diff:.2e}")
print(f" 输出均值: {np.mean(y_conv):.4f}")
print(f" 输出标准差: {np.std(y_conv):.4f}")
# 展示卷积核
K = model.compute_kernel(L)
print(f"\n卷积核统计:")
print(f" K[0] = {K[0]:.6f}")
print(f" K[L//2] = {K[L//2]:.6f}")
print(f" K[L-1] = {K[L-1]:.6f}")
print(f" ||K||_1 = {np.sum(np.abs(K)):.4f}")
# 展示卷积核的频谱
K_fft = np.abs(np.fft.rfft(K))
print(f"\n卷积核频谱:")
print(f" 低频能量 (前10%): {np.sum(K_fft[:len(K_fft)//10]**2) / np.sum(K_fft**2):.4f}")
print(f" 高频能量 (后10%): {np.sum(K_fft[-len(K_fft)//10:]**2) / np.sum(K_fft**2):.4f}")
if __name__ == "__main__":
demonstrate_lssl()