状态空间模型:从经典控制论到现代序列建模——S4、Mamba 及其理论体系的完整论述(一)


目录

  • 第一部分:基础理论
    • 第一章:绪论与历史脉络
    • 第二章:线性时不变系统与经典状态空间模型
    • 第三章:离散化理论------从连续到离散的桥梁
    • 第四章:序列建模视角------卷积与递推的对偶性
  • 第二部分:长程依赖与 HiPPO 理论
    • 第五章:长程依赖问题------为什么 RNN 会遗忘
    • 第六章:HiPPO 框架------用多项式投影逼近历史
    • 第七章:从 HiPPO 到 LSSL------线性状态空间层
  • 第三部分:S4 结构化参数化
    • 第八章:S4 的核心思想------结构化参数化与对角化
    • 第九章:S4 的高效计算------Cauchy 核与 FFT
    • 第十章:S4D 与 S5------对角化变体
  • 第四部分:选择性机制与 Mamba
    • 第十一章:选择性状态空间------从 LTI 到输入依赖
    • 第十二章:Mamba 架构------选择性扫描与硬件感知设计
    • 第十三章:Mamba-2 与结构化状态空间对偶性
  • 第五部分:理论分析与实践
    • 第十四章:SSM 与 Transformer 的理论对比
    • 第十五章:完整可运行代码实现
    • 第十六章:实验、应用与未来方向
  • 附录

第一部分:基础理论


第一章:绪论与历史脉络

1.1 序列建模的核心问题

序列数据无处不在:自然语言是由词组成的序列,语音是随时间变化的声波序列,视频是帧的序列,甚至 DNA 也是碱基对的序列。序列建模的核心问题是:

如何高效地捕捉序列元素之间的依赖关系,同时保持对超长序列的可扩展性?

这个问题看似简单,实则蕴含着深刻的数学挑战。序列中的依赖关系可以是:

  • 局部的:相邻词之间的语法关系
  • 中程的:段落内的指代消解
  • 长程的:跨越数千个 token 的主题一致性

不同架构对这些依赖关系的建模能力差异巨大。

1.2 三种范式的演进

1.2.1 循环神经网络(RNN)范式

RNN 的核心思想是维护一个隐状态 hth_tht,在每个时间步通过递推更新:

ht=f(ht−1,xt)h_t = f(h_{t-1}, x_t)ht=f(ht−1,xt)

这一范式的历史可以追溯到 Elman (1990) 和 Jordan (1986) 的开创性工作。其优势在于:

  • 恒定的内存占用 O(d)O(d)O(d),不随序列长度增长
  • 理论上可以处理任意长度的序列

但其致命缺陷在于梯度消失与爆炸问题 (Bengio et al., 1994)。在反向传播通过时间(BPTT)中,梯度需要经过 TTT 次矩阵乘法:

∂hT∂h0=∏t=1T∂ht∂ht−1\frac{\partial h_T}{\partial h_0} = \prod_{t=1}^{T} \frac{\partial h_t}{\partial h_{t-1}}∂h0∂hT=t=1∏T∂ht−1∂ht

当 TTT 很大时,这个乘积要么指数增长(梯度爆炸),要么指数衰减(梯度消失),使得模型无法学习长程依赖。

LSTM(Hochreiter & Schmidhuber, 1997)和 GRU(Cho et al., 2014)通过门控机制部分缓解了这一问题,但并未从根本上解决。它们的本质仍是非线性递推,理论上的有效记忆长度仍然有限。

1.2.2 Transformer 范式

Transformer(Vaswani et al., 2017)彻底抛弃了递推结构,改用自注意力机制

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk QKT)V

其核心优势是任意两点之间的路径长度为 O(1)O(1)O(1)------任何位置都可以直接关注任何其他位置,不存在信息衰减。

但代价是 O(n2)O(n^2)O(n2) 的时间和空间复杂度,以及 O(n2)O(n^2)O(n2) 的 KV Cache 内存占用。当序列长度 nnn 从 2K 增长到 128K 乃至 1M 时,这个二次成本变得不可接受。

1.2.3 状态空间模型(SSM)范式

SSM 范式试图在两者之间找到一个最优平衡点

  • 像 RNN 一样具有恒定的内存和线性的复杂度
  • 像 Transformer 一样能够捕捉长程依赖
  • 同时支持递推卷积两种计算模式

这一范式的理论基础来自经典控制论中的状态空间表示,但经过了从 HiPPO 到 S4 再到 Mamba 的一系列深刻改造,才成为现代序列建模的有力竞争者。

1.3 历史脉络

年份 里程碑 贡献
1960s Kalman 滤波器 状态空间模型用于信号处理和控制
1990 Elman Network 简单循环网络用于序列建模
1997 LSTM 门控机制缓解梯度消失
2017 Transformer 自注意力机制
2020 HiPPO (Gu et al.) 用正交多项式理论初始化状态矩阵
2021 LSSL (Gu et al.) 将 HiPPO 与线性 SSM 结合
2022 S4 (Gu et al.) 结构化参数化实现高效计算
2022 S4D (Gu et al.) 对角化简化
2023 Mamba (Gu & Dao) 选择性机制 + 硬件感知扫描
2023 Mamba-2 (Dao & Gu) 结构化状态空间对偶性(SSD)
2024 Jamba / Mamba-2-Hybrid SSM + Attention 混合架构

1.4 本文的组织

本文将从经典控制论中的状态空间模型出发,逐步构建出现代 SSM 的完整理论体系。我们不仅关注"怎么做",更关注"为什么这样做"------每一个设计选择背后的数学动机。


第二章:线性时不变系统与经典状态空间模型

2.1 从微分方程到状态空间表示

2.1.1 动机:描述动态系统

考虑一个最简单的物理系统------弹簧-阻尼器-质量系统(spring-mass-damper):

mx¨(t)+cx˙(t)+kx(t)=F(t)m\ddot{x}(t) + c\dot{x}(t) + kx(t) = F(t)mx¨(t)+cx˙(t)+kx(t)=F(t)

这是一个二阶常微分方程(ODE)。为了用一阶 ODE 系统来描述它,我们引入状态变量:

s1(t)=x(t),s2(t)=x˙(t)s_1(t) = x(t), \quad s_2(t) = \dot{x}(t)s1(t)=x(t),s2(t)=x˙(t)

则:

s˙1(t)=s2(t)\dot{s}_1(t) = s_2(t)s˙1(t)=s2(t)
s˙2(t)=−kms1(t)−cms2(t)+1mF(t)\dot{s}_2(t) = -\frac{k}{m}s_1(t) - \frac{c}{m}s_2(t) + \frac{1}{m}F(t)s˙2(t)=−mks1(t)−mcs2(t)+m1F(t)

写成矩阵形式:

ddts1s2=01−k/m−c/ms1s2+01/mF(t)\frac{d}{dt}\begin{bmatrix} s_1 \\ s_2 \end{bmatrix} = \begin{bmatrix} 0 & 1 \\ -k/m & -c/m \end{bmatrix}\begin{bmatrix} s_1 \\ s_2 \end{bmatrix} + \begin{bmatrix} 0 \\ 1/m \end{bmatrix}F(t)dtds1s2=0−k/m1−c/ms1s2+01/mF(t)

这就是状态空间表示的原型。

2.1.2 一般形式

一个连续时间线性时不变(LTI)系统的一般形式为:

h˙(t)=Ah(t)+Bx(t)(状态方程)\dot{h}(t) = Ah(t) + Bx(t) \quad \text{(状态方程)}h˙(t)=Ah(t)+Bx(t)(状态方程)
y(t)=Ch(t)+Dx(t)(观测方程)y(t) = Ch(t) + Dx(t) \quad \text{(观测方程)}y(t)=Ch(t)+Dx(t)(观测方程)

其中:

  • h(t)∈RNh(t) \in \mathbb{R}^Nh(t)∈RN:隐状态(hidden state),编码了系统的所有历史信息
  • x(t)∈Rx(t) \in \mathbb{R}x(t)∈R:输入(input),作用于系统的外部信号
  • y(t)∈Ry(t) \in \mathbb{R}y(t)∈R:输出(output),我们能观测到的量
  • A∈RN×NA \in \mathbb{R}^{N \times N}A∈RN×N:状态矩阵(state matrix),描述状态自身的演化规律
  • B∈RN×1B \in \mathbb{R}^{N \times 1}B∈RN×1:输入矩阵(input matrix),描述输入如何影响状态
  • C∈R1×NC \in \mathbb{R}^{1 \times N}C∈R1×N:输出矩阵(output matrix),描述如何从状态中读取输出
  • D∈RD \in \mathbb{R}D∈R:直通项(feedthrough),输入对输出的直接影响

在序列建模中,通常省略直通项(D=0D = 0D=0),并加上一个残差连接

h˙(t)=Ah(t)+Bx(t)\dot{h}(t) = Ah(t) + Bx(t)h˙(t)=Ah(t)+Bx(t)
y(t)=Ch(t)+x(t)y(t) = Ch(t) + x(t)y(t)=Ch(t)+x(t)

这样 DDD 被固定为恒等映射,确保输入信号可以直接传递到输出,而隐状态负责捕捉更高阶的动态。

2.1.3 状态空间表示的物理含义

AAA 矩阵的本质是描述隐状态之间的耦合关系。它的特征值决定了系统的固有行为:

设 AAA 的特征值为 λ1,λ2,...,λN\lambda_1, \lambda_2, \dots, \lambda_Nλ1,λ2,...,λN(可能为复数),则系统的自由响应为:

h(t)=∑i=1Ncieλitvih(t) = \sum_{i=1}^{N} c_i e^{\lambda_i t} v_ih(t)=i=1∑Ncieλitvi

其中 viv_ivi 是对应的特征向量,cic_ici 由初始条件决定。

  • 若 Re(λi)<0\text{Re}(\lambda_i) < 0Re(λi)<0:对应模态指数衰减(稳定)
  • 若 Re(λi)=0\text{Re}(\lambda_i) = 0Re(λi)=0:对应模态等幅振荡(临界稳定)
  • 若 Re(λi)>0\text{Re}(\lambda_i) > 0Re(λi)>0:对应模态指数增长(不稳定)

对于序列建模,我们希望系统稳定(所有模态衰减),以便远处的输入信号的影响随时间自然消退,同时又不要衰减得太快,以保留长程信息。

2.2 线性时不变系统的解析解

2.2.1 矩阵指数

LTI 系统的自由响应(x(t)=0x(t) = 0x(t)=0)可以通过矩阵指数精确求解:

h(t)=eAth(0)h(t) = e^{At} h(0)h(t)=eAth(0)

其中矩阵指数定义为:

eAt=∑k=0∞(At)kk!=I+At+(At)22!+(At)33!+⋯e^{At} = \sum_{k=0}^{\infty} \frac{(At)^k}{k!} = I + At + \frac{(At)^2}{2!} + \frac{(At)^3}{3!} + \cdotseAt=k=0∑∞k!(At)k=I+At+2!(At)2+3!(At)3+⋯

当输入不为零时,解为:

h(t)=eAth(0)+∫0teA(t−τ)Bx(τ) dτh(t) = e^{At} h(0) + \int_0^t e^{A(t-\tau)} B x(\tau) \, d\tauh(t)=eAth(0)+∫0teA(t−τ)Bx(τ)dτ

这个积分称为卷积积分------状态是初始条件的衰减与输入信号的卷积之和。

2.2.2 输出的卷积形式

输出为:

y(t)=Ch(t)=CeAth(0)+∫0tCeA(t−τ)Bx(τ) dτy(t) = C h(t) = C e^{At} h(0) + \int_0^t C e^{A(t-\tau)} B x(\tau) \, d\tauy(t)=Ch(t)=CeAth(0)+∫0tCeA(t−τ)Bx(τ)dτ

如果我们从零初始条件开始(h(0)=0h(0) = 0h(0)=0),则:

y(t)=∫0tCeA(t−τ)B⏟脉冲响应 K(t−τ)x(τ) dτ=(K∗x)(t)y(t) = \int_0^t \underbrace{C e^{A(t-\tau)} B}_{\text{脉冲响应 } K(t-\tau)} x(\tau) \, d\tau = (K * x)(t)y(t)=∫0t脉冲响应 K(t−τ) CeA(t−τ)Bx(τ)dτ=(K∗x)(t)

其中 K(t)=CeAtBK(t) = Ce^{At}BK(t)=CeAtB 是系统的脉冲响应 (impulse response),也称为卷积核

这意味着 LTI 系统完全由其脉冲响应刻画。这是一个极其深刻的结论------我们不需要知道内部状态如何演化,只需要知道脉冲响应,就能计算任意输入的输出。

2.2.3 脉冲响应的性质

对于 NNN 维 LTI 系统,脉冲响应 K(t)=CeAtBK(t) = Ce^{At}BK(t)=CeAtB 是 NNN 个指数函数的线性组合:

K(t)=∑i=1NαieλitK(t) = \sum_{i=1}^{N} \alpha_i e^{\lambda_i t}K(t)=i=1∑Nαieλit

其中 λi\lambda_iλi 是 AAA 的特征值,αi\alpha_iαi 由 A,B,CA, B, CA,B,C 共同决定。

这给了我们直觉:

  • 如果 AAA 有大量不同的特征值,脉冲响应可以是任意复杂的
  • 特征值的实部决定了每个模态的衰减率
  • 特征值的虚部决定了振荡频率

2.3 多变量推广

在实际的序列建模中,输入和输出通常是高维的。将单变量推广到多变量:

h˙(t)=Ah(t)+Bx(t)\dot{h}(t) = Ah(t) + Bx(t)h˙(t)=Ah(t)+Bx(t)
y(t)=Ch(t)y(t) = Ch(t)y(t)=Ch(t)

其中:

  • h(t)∈RNh(t) \in \mathbb{R}^Nh(t)∈RN:隐状态
  • x(t)∈RDx(t) \in \mathbb{R}^Dx(t)∈RD:DDD 维输入
  • y(t)∈RDy(t) \in \mathbb{R}^Dy(t)∈RD:DDD 维输出
  • A∈RN×NA \in \mathbb{R}^{N \times N}A∈RN×N
  • B∈RN×DB \in \mathbb{R}^{N \times D}B∈RN×D
  • C∈RD×NC \in \mathbb{R}^{D \times N}C∈RD×N

在实践中,DDD 个通道通常是独立 的------每个通道有自己的 Bi∈RNB_i \in \mathbb{R}^NBi∈RN 和 Ci∈RNC_i \in \mathbb{R}^NCi∈RN,但共享同一个 AAA。这样:

yi(t)=CiTeAth(0)+∫0tCiTeA(t−τ)Bixi(τ) dτy_i(t) = C_i^T e^{At} h(0) + \int_0^t C_i^T e^{A(t-\tau)} B_i x_i(\tau) \, d\tauyi(t)=CiTeAth(0)+∫0tCiTeA(t−τ)Bixi(τ)dτ

每个通道有自己的卷积核 Ki(t)=CiTeAtBiK_i(t) = C_i^T e^{At} B_iKi(t)=CiTeAtBi,但状态转移矩阵 AAA 是共享的。这种共享确保了所有通道"看到"相同的动态特性,但可以以不同的方式与输入和输出交互。

2.4 连续时间 LTI 系统的性质总结

定理 2.1(LTI 系统的基本性质)

连续时间 LTI 系统 h˙(t)=Ah(t)+Bx(t)\dot{h}(t) = Ah(t) + Bx(t)h˙(t)=Ah(t)+Bx(t), y(t)=Ch(t)y(t) = Ch(t)y(t)=Ch(t) 具有以下性质:

  1. 线性性 :若输入 x1(t)x_1(t)x1(t) 产生输出 y1(t)y_1(t)y1(t),输入 x2(t)x_2(t)x2(t) 产生输出 y2(t)y_2(t)y2(t),则 αx1(t)+βx2(t)\alpha x_1(t) + \beta x_2(t)αx1(t)+βx2(t) 产生 αy1(t)+βy2(t)\alpha y_1(t) + \beta y_2(t)αy1(t)+βy2(t)。

  2. 时不变性 :若输入 x(t)x(t)x(t) 产生输出 y(t)y(t)y(t),则延迟输入 x(t−τ)x(t - \tau)x(t−τ) 产生延迟输出 y(t−τ)y(t - \tau)y(t−τ)。

  3. 因果性 :输出 y(t)y(t)y(t) 只依赖于 ttt 时刻及之前的输入。

  4. 卷积性 :输出等于输入与脉冲响应的卷积:y=K∗xy = K * xy=K∗x。

  5. 稳定性 (BIBO):当且仅当 AAA 的所有特征值实部为负时,有界输入产生有界输出。


第三章:离散化理论------从连续到离散的桥梁

3.1 为什么需要离散化?

虽然连续时间 SSM 在数学上非常优雅,但计算机处理的是离散的时间序列。我们需要将连续系统转换为等价的离散系统:

ht=Aˉht−1+Bˉxth_t = \bar{A} h_{t-1} + \bar{B} x_tht=Aˉht−1+Bˉxt
yt=Cˉhty_t = \bar{C} h_tyt=Cˉht

其中 Aˉ,Bˉ\bar{A}, \bar{B}Aˉ,Bˉ 是离散化后的参数,t=0,1,2,...t = 0, 1, 2, \dotst=0,1,2,... 是离散时间步。

关键问题 :如何从连续参数 (A,B,C)(A, B, C)(A,B,C) 推导出离散参数 (Aˉ,Bˉ,Cˉ)(\bar{A}, \bar{B}, \bar{C})(Aˉ,Bˉ,Cˉ)?

3.2 零阶保持(Zero-Order Hold, ZOH)

3.2.1 定义与推导

零阶保持是最常用的离散化方法。其基本假设是:在每个采样间隔内,输入信号保持恒定

即对于 t∈[kΔ,(k+1)Δ)t \in [k\Delta, (k+1)\Delta)t∈[kΔ,(k+1)Δ),x(t)=xkx(t) = x_kx(t)=xk(常数)。

在这一假设下,连续状态方程在 kΔ,(k+1)Δk\\Delta, (k+1)\\DeltakΔ,(k+1)Δ 上的解为:

h((k+1)Δ)=eAΔh(kΔ)+∫kΔ(k+1)ΔeA((k+1)Δ−τ)B⋅xk dτh((k+1)\Delta) = e^{A\Delta} h(k\Delta) + \int_{k\Delta}^{(k+1)\Delta} e^{A((k+1)\Delta - \tau)} B \cdot x_k \, d\tauh((k+1)Δ)=eAΔh(kΔ)+∫kΔ(k+1)ΔeA((k+1)Δ−τ)B⋅xkdτ

令 σ=(k+1)Δ−τ\sigma = (k+1)\Delta - \tauσ=(k+1)Δ−τ,则积分变为:

∫0ΔeAσB dσ⋅xk\int_0^{\Delta} e^{A\sigma} B \, d\sigma \cdot x_k∫0ΔeAσBdσ⋅xk

因此:

Aˉ=eAΔ\bar{A} = e^{A\Delta}Aˉ=eAΔ
Bˉ=(∫0ΔeAσ dσ)B=A−1(eAΔ−I)B=A−1(Aˉ−I)B\bar{B} = \left(\int_0^{\Delta} e^{A\sigma} \, d\sigma\right) B = A^{-1}(e^{A\Delta} - I) B = A^{-1}(\bar{A} - I) BBˉ=(∫0ΔeAσdσ)B=A−1(eAΔ−I)B=A−1(Aˉ−I)B

输出矩阵不变:Cˉ=C\bar{C} = CCˉ=C。

3.2.2 ZOH 的计算

对于小规模矩阵,可以直接用 scipy.linalg.expm 计算矩阵指数。对于大规模对角矩阵,可以逐元素计算。

引理 3.1 :当 AAA 可逆时,Bˉ=A−1(Aˉ−I)B\bar{B} = A^{-1}(\bar{A} - I)BBˉ=A−1(Aˉ−I)B。

当 AAA 不可逆时(存在零特征值),需要使用极限形式或避免求逆:

Bˉ=(∫0ΔeAσdσ)B\bar{B} = \left(\int_0^{\Delta} e^{A\sigma} d\sigma\right) BBˉ=(∫0ΔeAσdσ)B

这个积分可以通过泰勒展开或 Padé 逼近来计算。

3.2.3 ZOH 的性质

定理 3.1(ZOH 等价性):ZOH 离散化保持了以下性质:

  1. 稳定性 :若 AAA 的所有特征值实部为负,则 Aˉ=eAΔ\bar{A} = e^{A\Delta}Aˉ=eAΔ 的所有特征值的模小于 1。
  2. 特征值映射 :Aˉ\bar{A}Aˉ 的特征值为 λˉi=eλiΔ\bar{\lambda}_i = e^{\lambda_i \Delta}λˉi=eλiΔ。
  3. 在采样点处精确:对于阶跃输入,ZOH 离散化的解在采样点处与连续解完全一致。

证明:

  1. 若 Re(λi)<0\text{Re}(\lambda_i) < 0Re(λi)<0,则 ∣eλiΔ∣=eRe(λi)Δ<1|e^{\lambda_i \Delta}| = e^{\text{Re}(\lambda_i) \Delta} < 1∣eλiΔ∣=eRe(λi)Δ<1。□\square□

  2. 设 Avi=λiviA v_i = \lambda_i v_iAvi=λivi,则 eAΔvi=eλiΔvie^{A\Delta} v_i = e^{\lambda_i \Delta} v_ieAΔvi=eλiΔvi。□\square□

  3. ZOH 的构造保证了在 t=kΔt = k\Deltat=kΔ 处,离散解等于连续解。□\square□

3.3 双线性变换(Bilinear Transform / Tustin Method)

3.3.1 定义

双线性变换(也称为梯形法或 Tustin 方法)基于以下近似:

s=2Δ⋅z−1z+1s = \frac{2}{\Delta} \cdot \frac{z - 1}{z + 1}s=Δ2⋅z+1z−1

其中 sss 是连续域的拉普拉斯变量,zzz 是离散域的 Z 变换变量。

对于状态矩阵:

Aˉ=(I−Δ2A)−1(I+Δ2A)\bar{A} = \left(I - \frac{\Delta}{2}A\right)^{-1}\left(I + \frac{\Delta}{2}A\right)Aˉ=(I−2ΔA)−1(I+2ΔA)

Bˉ=(I−Δ2A)−1ΔB\bar{B} = \left(I - \frac{\Delta}{2}A\right)^{-1} \Delta BBˉ=(I−2ΔA)−1ΔB

3.3.2 性质

定理 3.2(双线性变换的稳定性保持):双线性变换将连续域的左半平面映射到离散域的单位圆内部。

证明:设 λ\lambdaλ 是 AAA 的特征值,Re(λ)<0\text{Re}(\lambda) < 0Re(λ)<0。则 Aˉ\bar{A}Aˉ 对应的特征值为:

λˉ=1+Δ2λ1−Δ2λ\bar{\lambda} = \frac{1 + \frac{\Delta}{2}\lambda}{1 - \frac{\Delta}{2}\lambda}λˉ=1−2Δλ1+2Δλ

令 λ=a+bi\lambda = a + biλ=a+bi,其中 a<0a < 0a<0,则:

∣λˉ∣2=(1+Δ2a)2+(Δ2b)2(1−Δ2a)2+(Δ2b)2|\bar{\lambda}|^2 = \frac{(1 + \frac{\Delta}{2}a)^2 + (\frac{\Delta}{2}b)^2}{(1 - \frac{\Delta}{2}a)^2 + (\frac{\Delta}{2}b)^2}∣λˉ∣2=(1−2Δa)2+(2Δb)2(1+2Δa)2+(2Δb)2

由于 a<0a < 0a<0,分子中 (1+Δ2a)2<(1−Δ2a)2(1 + \frac{\Delta}{2}a)^2 < (1 - \frac{\Delta}{2}a)^2(1+2Δa)2<(1−2Δa)2,因此 ∣λˉ∣<1|\bar{\lambda}| < 1∣λˉ∣<1。□\square□

3.3.3 ZOH 与双线性变换的比较

性质 ZOH 双线性变换
采样点精度 在采样点处精确 有频率畸变(warping)
频率响应 保持相位特性 保持增益特性
稳定性 保持 保持
计算 需要矩阵指数 只需要矩阵求逆
S4 中的使用 是(主要方法)

3.4 步长参数 Δ\DeltaΔ 的含义

离散化步长 Δ\DeltaΔ 是一个极其重要的超参数。它控制着连续信号被采样的粒度

  • Δ\DeltaΔ 较大:粗粒度采样,相邻离散步之间的跨度大,模型"看"到的是信号的低频成分
  • Δ\DeltaΔ 较小:细粒度采样,相邻离散步之间的跨度小,模型能捕捉信号的高频变化

在 SSM 的上下文中,Δ\DeltaΔ 的选择直接影响离散化后的 Aˉ\bar{A}Aˉ 矩阵:

Aˉ=eAΔ\bar{A} = e^{A\Delta}Aˉ=eAΔ

  • 当 Δ→0\Delta \to 0Δ→0 时,Aˉ→I\bar{A} \to IAˉ→I,系统几乎不衰减,等价于恒等映射
  • 当 Δ→∞\Delta \to \inftyΔ→∞ 时,Aˉ→0\bar{A} \to 0Aˉ→0,系统完全衰减,历史信息完全丢失

合适的 Δ\DeltaΔ 应该使得系统在有意义的时间尺度上衰减------足够慢以保留长程信息,又足够快以忘去过时的信息。

在 Mamba 中 ,Δ\DeltaΔ 不再是固定超参数,而是由输入数据动态生成的------这就是选择性机制的核心之一,我们将在第十一章详细讨论。

3.5 离散化后的频域分析

3.5.1 Z 变换

对离散 LTI 系统进行 Z 变换:

H(z)=Cˉ(zI−Aˉ)−1Bˉ+DH(z) = \bar{C}(zI - \bar{A})^{-1}\bar{B} + DH(z)=Cˉ(zI−Aˉ)−1Bˉ+D

这是一个有理函数(rational function),其极点是 Aˉ\bar{A}Aˉ 的特征值,零点取决于 B,CB, CB,C。

3.5.2 频率响应

令 z=ejωz = e^{j\omega}z=ejω(单位圆上的点),得到频率响应:

H(ejω)=Cˉ(ejωI−Aˉ)−1BˉH(e^{j\omega}) = \bar{C}(e^{j\omega}I - \bar{A})^{-1}\bar{B}H(ejω)=Cˉ(ejωI−Aˉ)−1Bˉ

∣H(ejω)∣|H(e^{j\omega})|∣H(ejω)∣ 描述了系统对不同频率正弦输入的增益。在序列建模中,这对应于模型对不同频率模式的敏感度。

3.6 一维输入情况的显式公式

当输入维度为 1 时,B∈RNB \in \mathbb{R}^NB∈RN,C∈RNC \in \mathbb{R}^NC∈RN,我们可以写出更显式的公式。

设 AAA 的特征分解为 A=VΛV−1A = V \Lambda V^{-1}A=VΛV−1(假设可对角化),则:

Aˉ=VeΛΔV−1\bar{A} = V e^{\Lambda \Delta} V^{-1}Aˉ=VeΛΔV−1

Kn=CAˉnBˉ=CV(eΛΔ)nV−1Bˉ=∑i=1NC~i(eλiΔ)nB~iK_n = C \bar{A}^n \bar{B} = C V (e^{\Lambda \Delta})^n V^{-1} \bar{B} = \sum_{i=1}^{N} \tilde{C}_i (e^{\lambda_i \Delta})^n \tilde{B}_iKn=CAˉnBˉ=CV(eΛΔ)nV−1Bˉ=i=1∑NC~i(eλiΔ)nB~i

其中 C~=CV\tilde{C} = C VC~=CV, B~=V−1Bˉ\tilde{B} = V^{-1}\bar{B}B~=V−1Bˉ, C~i,B~i\tilde{C}_i, \tilde{B}_iC~i,B~i 分别是第 iii 个分量。

记 λˉi=eλiΔ\bar{\lambda}_i = e^{\lambda_i \Delta}λˉi=eλiΔ,则:

Kn=∑i=1NC~iB~iλˉinK_n = \sum_{i=1}^{N} \tilde{C}_i \tilde{B}_i \bar{\lambda}_i^nKn=i=1∑NC~iB~iλˉin

这是一个指数和核 (exponential sum kernel)------NNN 个指数衰减/振荡模态的叠加。这正是 S4 高效计算的基础。


第四章:序列建模视角------卷积与递推的对偶性

4.1 离散 SSM 的两种等价计算模式

离散 LTI-SSM:

ht=Aˉht−1+Bˉxth_t = \bar{A} h_{t-1} + \bar{B} x_tht=Aˉht−1+Bˉxt
yt=Cˉhty_t = \bar{C} h_tyt=Cˉht

有两种等价的方式来计算输出序列 {y0,y1,...,yL−1}\{y_0, y_1, \dots, y_{L-1}\}{y0,y1,...,yL−1}。

4.1.1 递推模式(Recurrence Mode)

直接按时间步递推:

ht=Aˉht−1+Bˉxt,yt=Cˉhth_t = \bar{A} h_{t-1} + \bar{B} x_t, \quad y_t = \bar{C} h_tht=Aˉht−1+Bˉxt,yt=Cˉht

复杂度分析

  • 时间:O(BLN2)O(BLN^2)O(BLN2),其中 BBB 是 batch size,LLL 是序列长度,NNN 是状态维度
  • 空间:O(BN)O(BN)O(BN)(只需要存储当前隐状态)
  • 特点:串行------每个时间步依赖前一个,无法并行

递推模式适合自回归推理(inference),即一个一个 token 生成的场景。

4.1.2 卷积模式(Convolution Mode)

展开递推关系:

y0=CˉBˉx0y_0 = \bar{C}\bar{B} x_0y0=CˉBˉx0
y1=CˉAˉBˉx0+CˉBˉx1y_1 = \bar{C}\bar{A}\bar{B} x_0 + \bar{C}\bar{B} x_1y1=CˉAˉBˉx0+CˉBˉx1
y2=CˉAˉ2Bˉx0+CˉAˉBˉx1+CˉBˉx2y_2 = \bar{C}\bar{A}^2\bar{B} x_0 + \bar{C}\bar{A}\bar{B} x_1 + \bar{C}\bar{B} x_2y2=CˉAˉ2Bˉx0+CˉAˉBˉx1+CˉBˉx2
⋮\vdots⋮
yn=∑k=0nCˉAˉn−kBˉxk=∑k=0nKn−kxky_n = \sum_{k=0}^{n} \bar{C}\bar{A}^{n-k}\bar{B} x_k = \sum_{k=0}^{n} K_{n-k} x_kyn=k=0∑nCˉAˉn−kBˉxk=k=0∑nKn−kxk

其中卷积核为:

Kn=CˉAˉnBˉ,n=0,1,2,...,L−1K_n = \bar{C}\bar{A}^n\bar{B}, \quad n = 0, 1, 2, \dots, L-1Kn=CˉAˉnBˉ,n=0,1,2,...,L−1

这正是我们之前从连续系统推导出的脉冲响应的离散化形式。

因此,输出是输入与卷积核 KKK 的一维卷积

y=K∗xy = K * xy=K∗x

复杂度分析(直接卷积):

  • 时间:O(BL2N)O(BL^2N)O(BL2N)(对每个输出位置,求 O(L)O(L)O(L) 项的和)
  • 空间:O(BL+BL)O(BL + BL)O(BL+BL)(存储输入和卷积核)

复杂度分析(FFT 卷积):

  • 时间:O(BLlog⁡L)O(BL \log L)O(BLlogL)(利用 FFT 将卷积转化为逐元素乘法)
  • 空间:O(BL)O(BL)O(BL)
  • 特点:完全并行------所有输出位置可以同时计算

卷积模式适合训练(前向传播),因为训练时整个输入序列已知,可以一次性计算所有输出。

4.2 核的计算问题

卷积模式的关键瓶颈在于如何高效计算卷积核 K={K0,K1,...,KL−1}K = \{K_0, K_1, \dots, K_{L-1}\}K={K0,K1,...,KL−1}

朴素方法:逐次计算矩阵幂

Kn=CAˉnBˉK_n = C\bar{A}^n\bar{B}Kn=CAˉnBˉ

这需要 LLL 次矩阵-向量乘法,每次 O(N2)O(N^2)O(N2),总复杂度 O(LN2)O(LN^2)O(LN2)。

问题 :当 NNN 很大时(如 N=256N = 256N=256 或更大),O(N2)O(N^2)O(N2) 的矩阵-向量乘法成为瓶颈。

这就是 S4 要解决的核心计算问题------我们将在第九章详细讨论。

4.3 卷积-递推对偶性的深层含义

定理 4.1(卷积-递推对偶性):对于任何 LTI 系统,以下两个计算过程产生完全相同的输出:

  1. 递推模式 :ht=Aˉht−1+Bˉxth_t = \bar{A}h_{t-1} + \bar{B}x_tht=Aˉht−1+Bˉxt, yt=Cˉhty_t = \bar{C}h_tyt=Cˉht
  2. 卷积模式 :y=K∗xy = K * xy=K∗x, 其中 Kn=CˉAˉnBˉK_n = \bar{C}\bar{A}^n\bar{B}Kn=CˉAˉnBˉ

证明:由递推展开,hn=∑k=0nAˉn−kBˉxkh_n = \sum_{k=0}^{n} \bar{A}^{n-k}\bar{B}x_khn=∑k=0nAˉn−kBˉxk(假设 h0=0h_0 = 0h0=0),因此 yn=Cˉhn=∑k=0nCˉAˉn−kBˉxk=∑k=0nKn−kxky_n = \bar{C}h_n = \sum_{k=0}^{n}\bar{C}\bar{A}^{n-k}\bar{B}x_k = \sum_{k=0}^{n}K_{n-k}x_kyn=Cˉhn=∑k=0nCˉAˉn−kBˉxk=∑k=0nKn−kxk。□\square□

这个对偶性意味着:

  • 训练用卷积模式(并行,利用 GPU 的吞吐量)
  • 推理用递推模式(流式,只需要常数内存)

这是 SSM 相比 Transformer 的一个结构性优势:Transformer 的训练和推理都依赖注意力机制,都需要 O(n2)O(n^2)O(n2) 的计算;而 SSM 在训练和推理中可以分别使用最高效的方式。

4.4 从卷积核到全局依赖

让我们仔细分析卷积核 Kn=CˉAˉnBˉK_n = \bar{C}\bar{A}^n\bar{B}Kn=CˉAˉnBˉ 的衰减行为。

如果 Aˉ\bar{A}Aˉ 的特征值都在单位圆内(∣λˉi∣<1|\bar{\lambda}_i| < 1∣λˉi∣<1),则:

Kn=∑i=1NαiλˉinK_n = \sum_{i=1}^{N} \alpha_i \bar{\lambda}_i^nKn=i=1∑Nαiλˉin

随着 n→∞n \to \inftyn→∞,Kn→0K_n \to 0Kn→0。但衰减的速度取决于最大的 ∣λˉi∣|\bar{\lambda}_i|∣λˉi∣:

  • 如果 ∣λˉmax⁡∣|\bar{\lambda}_{\max}|∣λˉmax∣ 接近 1:衰减很慢,核具有长程记忆
  • 如果 ∣λˉmax⁡∣|\bar{\lambda}_{\max}|∣λˉmax∣ 很小:衰减很快,核只能记住最近的信息

这就是 SSM 的"记忆容量"------它由 Aˉ\bar{A}Aˉ 的谱性质决定,而 Aˉ\bar{A}Aˉ 又由连续参数 AAA 和步长 Δ\DeltaΔ 共同决定。

4.5 与 Transformer 注意力核的对比

Transformer 的注意力权重可以看作一个动态卷积核

yi=∑jαijxj,αij=exp⁡(qiTkj/d)∑jexp⁡(qiTkj/d)y_i = \sum_j \alpha_{ij} x_j, \quad \alpha_{ij} = \frac{\exp(q_i^T k_j / \sqrt{d})}{\sum_j \exp(q_i^T k_j / \sqrt{d})}yi=j∑αijxj,αij=∑jexp(qiTkj/d )exp(qiTkj/d )

与 SSM 的卷积核对比:

性质 SSM 卷积核 注意力核
参数化 Kn=CˉAˉnBˉK_n = \bar{C}\bar{A}^n\bar{B}Kn=CˉAˉnBˉ αij=softmax(qiTkj)\alpha_{ij} = \text{softmax}(q_i^T k_j)αij=softmax(qiTkj)
是否依赖输入 (LTI 情况)
核形状 只依赖相对位置 n−kn - kn−k 依赖绝对位置 iii 和 jjj
计算复杂度 O(Llog⁡L)O(L \log L)O(LlogL)(FFT) O(L2)O(L^2)O(L2)
全局依赖 通过长衰减核间接实现 直接实现

SSM 的核心劣势在于:在 LTI 设定下,卷积核是固定的------不管输入是什么样的,核都不变。这意味着模型无法根据输入内容动态调整"看哪里"------它只能学习一个固定的衰减模式。

这正是 Mamba 的选择性机制要解决的问题,我们将在第十一章详细讨论。

4.6 高维输入的情况

对于 DDD 维输入(xt∈RDx_t \in \mathbb{R}^Dxt∈RD),如果每个通道独立处理(B∈RN×DB \in \mathbb{R}^{N \times D}B∈RN×D, C∈RD×NC \in \mathbb{R}^{D \times N}C∈RD×N),则:

yt,d=∑k=0tKt−k,d⋅xk,dy_{t,d} = \sum_{k=0}^{t} K_{t-k,d} \cdot x_{k,d}yt,d=k=0∑tKt−k,d⋅xk,d

其中 Kn,d=CdAˉnBdK_{n,d} = C_d \bar{A}^n B_dKn,d=CdAˉnBd 是第 ddd 个通道的卷积核。

总共有 DDD 个独立的卷积核,每个长度为 LLL。卷积核矩阵 K∈RD×LK \in \mathbb{R}^{D \times L}K∈RD×L 的形状是 D×LD \times LD×L。

在 S4 的实现中,这可以表示为一个二维卷积操作,其中:

  • 第一个维度是通道维度(DDD 个独立的卷积)
  • 第二个维度是时间维度(长度为 LLL 的卷积)

第二部分:长程依赖与 HiPPO 理论


第五章:长程依赖问题------为什么 RNN 会遗忘

5.1 梯度消失的严格分析

5.1.1 BPTT 中的梯度传播

考虑一个简单的 RNN:

ht=σ(Whht−1+Wxxt)h_t = \sigma(W_h h_{t-1} + W_x x_t)ht=σ(Whht−1+Wxxt)

其中 σ\sigmaσ 是激活函数(如 tanh)。

在反向传播通过时间(BPTT)中,损失 L\mathcal{L}L 对 hsh_shs(s<ts < ts<t)的梯度为:

∂L∂hs=∂L∂ht∏k=s+1t∂hk∂hk−1\frac{\partial \mathcal{L}}{\partial h_s} = \frac{\partial \mathcal{L}}{\partial h_t} \prod_{k=s+1}^{t} \frac{\partial h_k}{\partial h_{k-1}}∂hs∂L=∂ht∂Lk=s+1∏t∂hk−1∂hk

其中每一步的雅可比矩阵为:

∂hk∂hk−1=diag(σ′(zk))Wh\frac{\partial h_k}{\partial h_{k-1}} = \text{diag}(\sigma'(z_k)) W_h∂hk−1∂hk=diag(σ′(zk))Wh

其中 zk=Whhk−1+Wxxkz_k = W_h h_{k-1} + W_x x_kzk=Whhk−1+Wxxk。

因此:

∂L∂hs=∂L∂ht∏k=s+1tdiag(σ′(zk))Wh\frac{\partial \mathcal{L}}{\partial h_s} = \frac{\partial \mathcal{L}}{\partial h_t} \prod_{k=s+1}^{t} \text{diag}(\sigma'(z_k)) W_h∂hs∂L=∂ht∂Lk=s+1∏tdiag(σ′(zk))Wh

5.1.2 指数衰减的证明

取范数:

∥∂L∂hs∥≤∥∂L∂ht∥∏k=s+1t∥diag(σ′(zk))∥⋅∥Wh∥\left\|\frac{\partial \mathcal{L}}{\partial h_s}\right\| \leq \left\|\frac{\partial \mathcal{L}}{\partial h_t}\right\| \prod_{k=s+1}^{t} \|\text{diag}(\sigma'(z_k))\| \cdot \|W_h\| ∂hs∂L ≤ ∂ht∂L k=s+1∏t∥diag(σ′(zk))∥⋅∥Wh∥

对于 tanh 激活函数,∣σ′(z)∣=∣1−tanh⁡2(z)∣≤1|\sigma'(z)| = |1 - \tanh^2(z)| \leq 1∣σ′(z)∣=∣1−tanh2(z)∣≤1,且在饱和区(∣z∣|z|∣z∣ 较大时)趋近于 0。

设 γ=max⁡k∥diag(σ′(zk))∥⋅∥Wh∥\gamma = \max_k \|\text{diag}(\sigma'(z_k))\| \cdot \|W_h\|γ=maxk∥diag(σ′(zk))∥⋅∥Wh∥,则:

∥∂L∂hs∥≤γt−s∥∂L∂ht∥\left\|\frac{\partial \mathcal{L}}{\partial h_s}\right\| \leq \gamma^{t-s} \left\|\frac{\partial \mathcal{L}}{\partial h_t}\right\| ∂hs∂L ≤γt−s ∂ht∂L

  • 当 γ<1\gamma < 1γ<1 时:梯度以 γt−s\gamma^{t-s}γt−s 指数衰减(梯度消失)
  • 当 γ>1\gamma > 1γ>1 时:梯度以 γt−s\gamma^{t-s}γt−s 指数增长(梯度爆炸)

5.1.3 有效记忆长度

定义 :有效记忆长度 LeffL_{\text{eff}}Leff 是指梯度衰减到初始值的 1/e1/e1/e 所需的时间步数。

γLeff=1e  ⟹  Leff=1ln⁡(1/γ)=−1ln⁡γ\gamma^{L_{\text{eff}}} = \frac{1}{e} \implies L_{\text{eff}} = \frac{1}{\ln(1/\gamma)} = \frac{-1}{\ln \gamma}γLeff=e1⟹Leff=ln(1/γ)1=lnγ−1

对于典型的 RNN 设置(∥Wh∥≈1\|W_h\| \approx 1∥Wh∥≈1, ∣σ′∣≈0.5|\sigma'| \approx 0.5∣σ′∣≈0.5),γ≈0.5\gamma \approx 0.5γ≈0.5,则:

Leff≈1ln⁡2≈1.44L_{\text{eff}} \approx \frac{1}{\ln 2} \approx 1.44Leff≈ln21≈1.44

这意味着梯度在仅约 1-2 个时间步后就衰减到 1/e1/e1/e!实际中,RNN 的有效记忆长度通常只有 10-20 个时间步。

5.1.4 数值验证

python 复制代码
import numpy as np
import matplotlib.pyplot as plt


def simulate_gradient_decay(seq_len: int = 200, hidden_dim: int = 64, num_trials: int = 100) -> np.ndarray:
    """模拟 RNN 中梯度的指数衰减行为。

    通过随机矩阵乘法的累积效应,展示梯度范数随时间步的衰减。
    """
    decay_curves = np.zeros((num_trials, seq_len))

    for trial in range(num_trials):
        # 随机初始化权重矩阵,谱范数约 0.9(保证稳定但有衰减)
        W = np.random.randn(hidden_dim, hidden_dim) * (0.9 / np.sqrt(hidden_dim))
        # tanh 激活函数导数的平均值约为 0.6
        sigma_prime_mean = 0.6

        # 初始梯度
        grad = np.ones(hidden_dim)
        grad_norm_init = np.linalg.norm(grad)

        for t in range(seq_len):
            # 模拟一步 BPTT:乘以 diag(sigma') @ W
            sigma_prime = np.random.uniform(0.3, 0.9, hidden_dim)  # tanh' 的随机模拟
            grad = (sigma_prime * (W.T @ grad))
            grad_norm = np.linalg.norm(grad)
            decay_curves[trial, t] = grad_norm / grad_norm_init if grad_norm_init > 0 else 0

    return decay_curves


def demonstrate_gradient_decay():
    """演示 RNN 梯度衰减现象。"""
    np.random.seed(42)
    decay_curves = simulate_gradient_decay(seq_len=100, hidden_dim=64, num_trials=50)
    mean_curve = np.mean(decay_curves, axis=0)
    std_curve = np.std(decay_curves, axis=0)

    print("RNN 梯度衰减模拟结果:")
    print(f"  序列长度: 100")
    print(f"  隐状态维度: 64")
    print(f"  模拟次数: 50")
    print(f"  t=0  时梯度相对范数: {mean_curve[0]:.6f}")
    print(f"  t=10 时梯度相对范数: {mean_curve[10]:.6f}")
    print(f"  t=20 时梯度相对范数: {mean_curve[20]:.6f}")
    print(f"  t=50 时梯度相对范数: {mean_curve[50]:.6f}")
    print(f"  t=99 时梯度相对范数: {mean_curve[99]:.6f}")
    print()
    print("结论:梯度在约 10-20 步内衰减到接近零,解释了 RNN 的长程依赖困难。")

    return mean_curve, std_curve


if __name__ == "__main__":
    demonstrate_gradient_decay()

5.2 LSTM 和 GRU 的门控机制:缓解而非解决

5.2.1 LSTM 的遗忘门

LSTM 引入了遗忘门 ftf_tft 来控制信息的流动:

ft=σ(Wfht−1,xt+bf)f_t = \sigma(W_f h_{t-1}, x_t + b_f)ft=σ(Wfht−1,xt+bf)
it=σ(Wiht−1,xt+bi)i_t = \sigma(W_i h_{t-1}, x_t + b_i)it=σ(Wiht−1,xt+bi)
c~t=tanh⁡(Wcht−1,xt+bc)\tilde{c}t = \tanh(W_c h_{t-1}, x_t + b_c)c~t=tanh(Wcht−1,xt+bc)
ct=ft⊙ct−1+it⊙c~tc_t = f_t \odot c
{t-1} + i_t \odot \tilde{c}_tct=ft⊙ct−1+it⊙c~t
ot=σ(Woht−1,xt+bo)o_t = \sigma(W_o h_{t-1}, x_t + b_o)ot=σ(Woht−1,xt+bo)
ht=ot⊙tanh⁡(ct)h_t = o_t \odot \tanh(c_t)ht=ot⊙tanh(ct)

关键观察:细胞状态 ctc_tct 的递推是线性的(加法而非乘法):

ct=ft⊙ct−1+it⊙c~tc_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_tct=ft⊙ct−1+it⊙c~t

当 ft=1f_t = 1ft=1(遗忘门全开)时,ct=ct−1+newc_t = c_{t-1} + \text{new}ct=ct−1+new,信息可以无损传递 。这使得 LSTM 的梯度可以通过 ctc_tct 的线性路径流通,不会指数衰减。

5.2.2 LSTM 的局限

但 LSTM 仍然有根本性的局限:

  1. 门控是输入依赖的 :ftf_tft 依赖于 xtx_txt,这意味着门控策略是针对每个输入独立决定的,缺乏全局规划
  2. 隐状态维度有限 :ct∈Rdc_t \in \mathbb{R}^dct∈Rd,ddd 通常较小(256-1024),信息容量有限
  3. 串行计算 :ctc_tct 依赖 ct−1c_{t-1}ct−1,无法并行
  4. 理论有效记忆长度仍然有限:虽然比简单 RNN 好很多,但实验表明 LSTM 的有效记忆长度约为 100-200 个时间步

5.3 长程依赖的数学刻画

5.3.1 Long Range Arena (LRA) 基准

Long Range Arena(Tay et al., 2021)是一个专门评估长程依赖建模能力的基准测试集,包含:

任务 序列长度 依赖类型 难度
ListOps 2000 层次结构
Text 4096 语义
Retrieval 4096 匹配
Image 1024 空间
Pathfinder 1024 路径追踪
Path-X 16384 超长路径 极高

Path-X 任务(序列长度 16384)是真正的试金石------大多数 Transformer 变体在这个任务上都失败了,因为 O(n2)O(n^2)O(n2) 的注意力使得内存不足。

5.3.2 信息论视角

从信息论的角度,长程依赖可以被理解为互信息的衰减

I(Xt;Xt+k)=H(Xt)−H(Xt∣Xt+k)I(X_t; X_{t+k}) = H(X_t) - H(X_t | X_{t+k})I(Xt;Xt+k)=H(Xt)−H(Xt∣Xt+k)

对于马尔可夫过程,互信息随着 kkk 指数衰减。但自然语言中的长程依赖意味着 I(Xt;Xt+k)I(X_t; X_{t+k})I(Xt;Xt+k) 可以保持显著水平即使 kkk 很大。

一个理想的序列模型应该能够选择性地保留 与未来预测相关的信息,而选择性地丢弃无关信息。

5.4 状态空间模型的长程依赖理论

5.4.1 线性 SSM 的记忆能力

对于离散 LTI-SSM ht=Aˉht−1+Bˉxth_t = \bar{A} h_{t-1} + \bar{B} x_tht=Aˉht−1+Bˉxt,其"记忆"由卷积核 Kn=CˉAˉnBˉK_n = \bar{C}\bar{A}^n\bar{B}Kn=CˉAˉnBˉ 决定。

定理 5.1(LTI-SSM 的记忆衰减) :若 Aˉ\bar{A}Aˉ 的谱半径 ρ(Aˉ)<1\rho(\bar{A}) < 1ρ(Aˉ)<1,则:

∣Kn∣≤Mρ(Aˉ)n|K_n| \leq M \rho(\bar{A})^n∣Kn∣≤Mρ(Aˉ)n

对某个常数 M>0M > 0M>0。即卷积核以谱半径的指数速度衰减。

证明:由矩阵范数的性质,∥Aˉn∥≤Mρ(Aˉ)n\|\bar{A}^n\| \leq M \rho(\bar{A})^n∥Aˉn∥≤Mρ(Aˉ)n(Gelfand 公式),因此 ∣Kn∣=∣CˉAˉnBˉ∣≤∥Cˉ∥⋅∥Aˉn∥⋅∥Bˉ∥≤M′ρ(Aˉ)n|K_n| = |\bar{C}\bar{A}^n\bar{B}| \leq \|\bar{C}\| \cdot \|\bar{A}^n\| \cdot \|\bar{B}\| \leq M' \rho(\bar{A})^n∣Kn∣=∣CˉAˉnBˉ∣≤∥Cˉ∥⋅∥Aˉn∥⋅∥Bˉ∥≤M′ρ(Aˉ)n。□\square□

5.4.2 "选择性"的重要性

这个定理揭示了 LTI-SSM 的根本局限:卷积核的衰减是固定的,不依赖于输入内容

考虑两个场景:

  • 场景 A:"The cat sat on the mat. It was very fluffy." ------ "It" 指代 "cat",依赖距离为 6
  • 场景 B:"The cat, which was adopted from the shelter last year and has been living happily ever since, sat on the mat." ------ 依赖距离为 20+

一个固定的衰减核无法同时很好地处理这两种情况。理想情况下,模型应该:

  • 在场景 A 中快速衰减,关注最近的 "cat"
  • 在场景 B 中保持记忆,跨越中间的修饰语

这就是选择性(selectivity)的核心动机------让衰减率依赖于输入内容。


第六章:HiPPO 框架------用多项式投影逼近历史

6.1 动机:如何最优地压缩历史信息?

给定一个连续信号 f:0,t→Rf: 0, t \to \mathbb{R}f:0,t→R,我们想要用一个有限维的状态向量 c(t)∈RNc(t) \in \mathbb{R}^Nc(t)∈RN 来最优地表示信号的历史。

"最优"在这里的含义是:在某个函数空间中,c(t)c(t)c(t) 所对应的近似函数 ftf_tft 是对历史 f(τ)f(\tau)f(τ)(τ≤t\tau \leq tτ≤t)在某种意义下的最佳逼近。

HiPPO(Hi gh-order P olynomial P rojection Operators)框架(Gu et al., 2020)正是回答了这个问题。

6.2 问题形式化

6.2.1 在线函数逼近

考虑一个实时到达的信号 f(t)f(t)f(t)。在时刻 ttt,我们希望找到一个多项式 ft(τ)f_t(\tau)ft(τ)(定义在 τ∈0,t\tau \in 0, tτ∈0,t 上),使得:

ft=arg⁡min⁡g∈PN∥f∣0,t−g∥μtf_t = \arg\min_{g \in \mathcal{P}N} \|f|{0,t} - g\|_{\mu_t}ft=argg∈PNmin∥f∣0,t−g∥μt

其中:

  • PN\mathcal{P}_NPN 是次数不超过 N−1N-1N−1 的多项式空间
  • ∥⋅∥μt\|\cdot\|_{\mu_t}∥⋅∥μt 是关于某个测度 μt\mu_tμt 的 L2L^2L2 范数
  • f∣0,tf|_{0,t}f∣0,t 是 fff 在 0,t0, t0,t 上的限制

关键问题:如何选择测度 μt\mu_tμt?

6.2.2 测度的选择

不同的测度对应不同的"遗忘策略":

测度 定义 含义
均匀测度 dμt(τ)=1tdτd\mu_t(\tau) = \frac{1}{t}d\taudμt(τ)=t1dτ 历史上的所有时刻同等重要
衰减测度 dμt(τ)=e−λ(t−τ)dτd\mu_t(\tau) = e^{-\lambda(t-\tau)}d\taudμt(τ)=e−λ(t−τ)dτ 越近的时刻越重要
滑动窗口 dμt(τ)=1τ∈\[t−w,t]dτd\mu_t(\tau) = \mathbb{1}\\tau \\in \[t-w, t] d\taudμt(τ)=1τ∈\[t−w,t]dτ 只关心最近 www 个时间单位
阶梯测度 dμt(τ)=∑kwkδ(τ−tk)d\mu_t(\tau) = \sum_{k} w_k \delta(\tau - t_k)dμt(τ)=∑kwkδ(τ−tk) 只关心特定时刻

HiPPO 的核心贡献是:对于一大类测度,这个在线逼近问题有闭式解,并且解可以表示为一个线性 ODE

6.3 HiPPO-LegS(缩放勒让德测度)

6.3.1 测度定义

HiPPO-LegS 使用缩放勒让德测度(Scaled Legendre Measure):

μt(s)=1t10,t(τ)\mu_t^{(s)} = \frac{1}{t} \mathbb{1}_{0, t}(\tau)μt(s)=t110,t(τ)

即在 0,t0, t0,t 上的均匀测度。"S" 代表 "Scaled"(随 ttt 缩放)。

6.3.2 多项式基

0,t0, t0,t 上定义缩放勒让德多项式 。令 s=τ/t∈0,1s = \tau / t \in 0, 1s=τ/t∈0,1,则标准勒让德多项式为:

P0(s)=1P_0(s) = 1P0(s)=1
P1(s)=2s−1P_1(s) = 2s - 1P1(s)=2s−1
P2(s)=12(3s2−6s+1)P_2(s) = \frac{1}{2}(3s^2 - 6s + 1)P2(s)=21(3s2−6s+1)
Pn(s)=1n!dndsn(s2−1)nP_n(s) = \frac{1}{n!}\frac{d^n}{ds^n}(s^2 - 1)^nPn(s)=n!1dsndn(s2−1)n

这些多项式在 0,10, 10,1 上关于 Lebesgue 测度正交:

∫01Pm(s)Pn(s) ds=12n+1δmn\int_0^1 P_m(s) P_n(s) \, ds = \frac{1}{2n + 1} \delta_{mn}∫01Pm(s)Pn(s)ds=2n+11δmn

6.3.3 投影系数

信号 fff 在多项式基上的投影系数为:

cn(t)=(2n+1)∫0tf(τ)Pn(τt)dτtc_n(t) = (2n + 1) \int_0^t f(\tau) P_n\left(\frac{\tau}{t}\right) \frac{d\tau}{t}cn(t)=(2n+1)∫0tf(τ)Pn(tτ)tdτ

这些系数随 ttt 变化,我们想知道它们的演化规律。

6.3.4 核心定理

定理 6.1(HiPPO-LegS ODE) :对于缩放勒让德测度,投影系数 c(t)=c0(t),c1(t),...,cN−1(t)Tc(t) = c_0(t), c_1(t), \\dots, c_{N-1}(t)^Tc(t)=c0(t),c1(t),...,cN−1(t)T 满足以下线性 ODE:

ddtc(t)=−1tAc(t)+1tBf(t)\frac{d}{dt} c(t) = -\frac{1}{t} A c(t) + \frac{1}{t} B f(t)dtdc(t)=−t1Ac(t)+t1Bf(t)

其中:

Ank={(2n+1)1/2(2k+1)1/2if n>kn+1if n=k0if n<kA_{nk} = \begin{cases} (2n + 1)^{1/2}(2k + 1)^{1/2} & \text{if } n > k \\ n + 1 & \text{if } n = k \\ 0 & \text{if } n < k \end{cases}Ank=⎩ ⎨ ⎧(2n+1)1/2(2k+1)1/2n+10if n>kif n=kif n<k

Bn=(2n+1)1/2B_n = (2n + 1)^{1/2}Bn=(2n+1)1/2

即 AAA 是一个下三角矩阵加上对角项。

证明概要

将 cn(t)=(2n+1)∫0tf(τ)Pn(τ/t)dτ/tc_n(t) = (2n+1) \int_0^t f(\tau) P_n(\tau/t) d\tau/tcn(t)=(2n+1)∫0tf(τ)Pn(τ/t)dτ/t 对 ttt 求导,利用勒让德多项式的递推关系:

(2n+1)sPn(s)=(n+1)Pn+1(s)+nPn−1(s)(2n + 1) s P_n(s) = (n + 1) P_{n+1}(s) + n P_{n-1}(s)(2n+1)sPn(s)=(n+1)Pn+1(s)+nPn−1(s)

以及正交性条件,可以推导出 ODE 的闭式表达。具体推导较为繁琐,涉及勒让德多项式的以下关键性质:

  1. 递推关系 :(2n+1)sPn(s)=(n+1)Pn+1(s)+nPn−1(s)(2n+1)s P_n(s) = (n+1)P_{n+1}(s) + n P_{n-1}(s)(2n+1)sPn(s)=(n+1)Pn+1(s)+nPn−1(s)
  2. 正交性 :∫01Pm(s)Pn(s)ds=δmn2n+1\int_0^1 P_m(s) P_n(s) ds = \frac{\delta_{mn}}{2n+1}∫01Pm(s)Pn(s)ds=2n+1δmn
  3. 导数公式 :Pn′(s)=∑k=0n−1(2k+1)Pk(s)P_n'(s) = \sum_{k=0}^{n-1} (2k+1) P_k(s)Pn′(s)=∑k=0n−1(2k+1)Pk(s)(对 n>0n > 0n>0)
  4. 边界值 :Pn(0)=(−1)nP_n(0) = (-1)^nPn(0)=(−1)n, Pn(1)=1P_n(1) = 1Pn(1)=1

利用这些性质,对 cn(t)c_n(t)cn(t) 关于 ttt 求导并整理,即可得到 HiPPO-LegS 的 ODE 形式。□\square□

6.3.5 简化形式

通过变量替换 c(t)→t1/2c(t)c(t) \to t^{1/2} c(t)c(t)→t1/2c(t),可以将含 1/t1/t1/t 的 ODE 化为标准的 LTI 形式:

ddtc(t)=Ac(t)+Bf(t)\frac{d}{dt} c(t) = A c(t) + B f(t)dtdc(t)=Ac(t)+Bf(t)

(注意这里 AAA 矩阵略有不同,但结构相同。)

这就是将在线函数逼近问题转化为线性状态空间模型的关键洞察。

6.4 HiPPO-LegT(平移勒让德测度)

6.4.1 测度定义

HiPPO-LegT 使用平移勒让德测度(Translated Legendre Measure):

μt(T)=1t−θ,t(τ)dτ\mu_t^{(T)} = \mathbb{1}_{t - \\theta, t}(\tau) d\tauμt(T)=1t−θ,t(τ)dτ

即一个固定宽度 θ\thetaθ 的滑动窗口。"T" 代表 "Translated"。

6.4.2 核心结果

定理 6.2(HiPPO-LegT ODE):对于平移勒让德测度,投影系数满足:

ddtc(t)=ATc(t)+BTf(t)\frac{d}{dt} c(t) = A_{\text{T}} c(t) + B_{\text{T}} f(t)dtdc(t)=ATc(t)+BTf(t)

其中 ATA_{\text{T}}AT 是一个常数矩阵 (不依赖 ttt),其元素为:

(AT)nk={2n+1θ(−1)n−k(2k+1)1/2(2n+1)1/2if n≥k−n+1θif n=k0if n<k(A_{\text{T}})_{nk} = \begin{cases} \frac{2n+1}{\theta}(-1)^{n-k}(2k+1)^{1/2}(2n+1)^{1/2} & \text{if } n \geq k \\ -\frac{n+1}{\theta} & \text{if } n = k \\ 0 & \text{if } n < k \end{cases}(AT)nk=⎩ ⎨ ⎧θ2n+1(−1)n−k(2k+1)1/2(2n+1)1/2−θn+10if n≥kif n=kif n<k

BTB_{\text{T}}BT 的元素为:

(BT)n=(2n+1)1/2θ(−1)n(B_{\text{T}})_n = \frac{(2n+1)^{1/2}}{\theta}(-1)^n(BT)n=θ(2n+1)1/2(−1)n

HiPPO-LegT 的优势在于 AAA 矩阵不依赖时间 ttt,因此可以直接用标准的 LTI-SSM 来表示。它有一个明确的"遗忘窗口" θ\thetaθ。

6.5 HiPPO-LagT(拉盖尔测度)

6.5.1 测度定义

HiPPO-LagT 使用指数衰减测度(由拉盖尔多项式正交化):

μt(L)=e−(τ−t)1τ≥t−θ(τ)dτ\mu_t^{(L)} = e^{-(\tau - t)} \mathbb{1}_{\\tau \\geq t - \\theta}(\tau) d\tauμt(L)=e−(τ−t)1τ≥t−θ(τ)dτ

6.5.2 核心结果

定理 6.3(HiPPO-LagT ODE):对于拉盖尔测度,投影系数满足:

ddtc(t)=−12c(t)+ALc(t)+BLf(t)\frac{d}{dt} c(t) = -\frac{1}{2} c(t) + A_{\text{L}} c(t) + B_{\text{L}} f(t)dtdc(t)=−21c(t)+ALc(t)+BLf(t)

其中 ALA_{\text{L}}AL 是一个下双对角矩阵:

(AL)nk={(n+1)1/2if k=n+10otherwise(A_{\text{L}})_{nk} = \begin{cases} (n+1)^{1/2} & \text{if } k = n+1 \\ 0 & \text{otherwise} \end{cases}(AL)nk={(n+1)1/20if k=n+1otherwise

6.6 HiPPO 矩阵的数值实现

python 复制代码
import numpy as np
import scipy.linalg


def hippo_legs_matrix(N: int) -> tuple[np.ndarray, np.ndarray]:
    """构造 HiPPO-LegS(缩放勒让德)矩阵。

    Args:
        N: 状态维度(多项式阶数)

    Returns:
        A: (N, N) HiPPO 矩阵
        B: (N, 1) 输入矩阵
    """
    P = np.arange(1, N + 1).astype(np.float64)  # [1, 2, ..., N]
    Q = np.arange(1, N + 1).astype(np.float64)

    # A_{nk} = (2n+1)^{1/2} (2k+1)^{1/2} if n > k, else n+1 if n==k, else 0
    A = np.zeros((N, N))
    for n in range(N):
        for k in range(N):
            if n > k:
                A[n, k] = np.sqrt(2 * n + 1) * np.sqrt(2 * k + 1)
            elif n == k:
                A[n, k] = n + 1

    B = np.sqrt(2 * np.arange(1, N + 1) - 1).reshape(-1, 1)
    # 修正 B: (2n+1)^{1/2}, n 从 0 开始
    B = np.sqrt(2 * np.arange(N) + 1).reshape(-1, 1).astype(np.float64)

    return A, B


def hippo_legt_matrix(N: int, theta: float = 1.0) -> tuple[np.ndarray, np.ndarray]:
    """构造 HiPPO-LegT(平移勒让德)矩阵。

    Args:
        N: 状态维度
        theta: 滑动窗口宽度

    Returns:
        A: (N, N) HiPPO-LegT 矩阵
        B: (N, 1) 输入矩阵
    """
    A = np.zeros((N, N))
    for n in range(N):
        for k in range(N):
            if n > k:
                A[n, k] = (2 * n + 1) * (-1) ** (n - k) * np.sqrt(2 * k + 1) * np.sqrt(2 * n + 1) / theta
            elif n == k:
                A[n, k] = -(n + 1) / theta

    B = np.array([((-1) ** n) * np.sqrt(2 * n + 1) / theta for n in range(N)]).reshape(-1, 1)

    return A, B


def hippo_lagt_matrix(N: int) -> tuple[np.ndarray, np.ndarray]:
    """构造 HiPPO-LagT(拉盖尔)矩阵。

    Args:
        N: 状态维度

    Returns:
        A: (N, N) HiPPO-LagT 矩阵
        B: (N, 1) 输入矩阵
    """
    A = np.zeros((N, N))
    for n in range(N):
        if n + 1 < N:
            A[n, n + 1] = np.sqrt(n + 1)
        A[n, n] = -0.5

    B = np.array([np.sqrt(2 * n + 1) for n in range(N)]).reshape(-1, 1)

    return A, B


def demonstrate_hippo_matrices():
    """展示不同 HiPPO 矩阵的结构。"""
    N = 8

    print("=" * 60)
    print("HiPPO 矩阵结构展示 (N = 8)")
    print("=" * 60)

    A_legs, B_legs = hippo_legs_matrix(N)
    print("\n--- HiPPO-LegS (缩放勒让德) ---")
    print("A 矩阵:")
    print(np.array2string(A_legs, precision=3, suppress_small=True))
    print("B 向量:")
    print(np.array2string(B_legs.flatten(), precision=3))

    A_legt, B_legt = hippo_legt_matrix(N, theta=1.0)
    print("\n--- HiPPO-LegT (平移勒让德, theta=1.0) ---")
    print("A 矩阵:")
    print(np.array2string(A_legt, precision=3, suppress_small=True))
    print("B 向量:")
    print(np.array2string(B_legt.flatten(), precision=3))

    A_lagt, B_lagt = hippo_lagt_matrix(N)
    print("\n--- HiPPO-LagT (拉盖尔) ---")
    print("A 矩阵:")
    print(np.array2string(A_lagt, precision=3, suppress_small=True))
    print("B 向量:")
    print(np.array2string(B_lagt.flatten(), precision=3))

    # 检验 A 的特征值
    print("\n--- A 矩阵的特征值 ---")
    for name, A in [("LegS", A_legs), ("LegT", A_legt), ("LagT", A_lagt)]:
        eigvals = np.linalg.eigvals(A)
        print(f"  {name}: max Re(lambda) = {np.max(np.real(eigvals)):.4f}, "
              f"min Re(lambda) = {np.min(np.real(eigvals)):.4f}")


if __name__ == "__main__":
    demonstrate_hippo_matrices()

6.7 HiPPO 的理论意义

6.7.1 从逼近理论到 SSM

HiPPO 的核心贡献是建立了在线函数逼近线性状态空间模型之间的桥梁:

在线函数逼近⏟信号处理视角⟷LTI-SSM⏟序列建模视角\underbrace{\text{在线函数逼近}}{\text{信号处理视角}} \quad \longleftrightarrow \quad \underbrace{\text{LTI-SSM}}{\text{序列建模视角}}信号处理视角 在线函数逼近⟷序列建模视角 LTI-SSM

  • 逼近理论告诉我们应该学习什么(最优地压缩历史)
  • SSM 告诉我们怎么计算(通过矩阵运算)

6.7.2 HiPPO 初始化 vs 随机初始化

定理 6.4(HiPPO 初始化的优势) :使用 HiPPO 矩阵初始化 AAA 的 SSM 在 Path-X 等超长程任务上显著优于随机初始化。

直觉上,HiPPO 矩阵具有以下良好性质:

  1. 特征值分布在左半平面:保证系统稳定
  2. 下三角结构:不同阶多项式之间有层次化的信息流动
  3. 特定的谱结构:使得脉冲响应具有多尺度的衰减特性

这些性质使得模型从一开始就具有"合理的"记忆行为,而不是依赖训练去发现这些结构。

6.7.3 高阶多项式逼近的误差

定理 6.5(多项式逼近误差) :设 fff 是 0,t0, t0,t 上的 Lipschitz 连续函数,NNN 阶 HiPPO-LegS 逼近的 L2L^2L2 误差为:

∥f−ft(N)∥L2≤CN\|f - f_t^{(N)}\|_{L^2} \leq \frac{C}{N}∥f−ft(N)∥L2≤NC

其中 CCC 是与 fff 的 Lipschitz 常数和 ttt 相关的常数。

这意味着增加状态维度 NNN 可以系统性地减小逼近误差 ,且收敛速率为 O(1/N)O(1/N)O(1/N)。


第七章:从 HiPPO 到 LSSL------线性状态空间层

7.1 LSSL 的基本框架

LSSL(L inear S tate S pace Layer, Gu et al., 2021)是将 HiPPO 理论与线性 SSM 结合的第一个成功尝试。

其基本框架是:

  1. 用 HiPPO 矩阵初始化 AAA:确保模型从合理的记忆行为开始
  2. 将 B,CB, CB,C 设为可学习参数:让模型适应具体的任务
  3. 用 ZOH 离散化:将连续 SSM 转化为离散序列模型
  4. 用卷积模式训练,递推模式推理:利用对偶性获得最佳效率

7.1.1 LSSL 的参数

  • AAA:由 HiPPO-LegS 初始化,可学习(或固定)
  • B∈RN×1B \in \mathbb{R}^{N \times 1}B∈RN×1:可学习
  • C∈R1×NC \in \mathbb{R}^{1 \times N}C∈R1×N:可学习
  • Δ\DeltaΔ:离散化步长,可学习标量
  • 输入输出维度均为 1(多通道通过并行独立的 LSSL 实现)

7.1.2 LSSL 的计算流程

训练(卷积模式)

  1. 用 HiPPO 初始化 AAA
  2. 计算离散化参数 Aˉ=eAΔ\bar{A} = e^{A\Delta}Aˉ=eAΔ, Bˉ=(AΔ)−1(eAΔ−I)(BΔ)\bar{B} = (A\Delta)^{-1}(e^{A\Delta} - I)(B\Delta)Bˉ=(AΔ)−1(eAΔ−I)(BΔ)
  3. 计算卷积核 K=CBˉ,CAˉBˉ,CAˉ2Bˉ,...,CAˉL−1BˉK = C\\bar{B}, C\\bar{A}\\bar{B}, C\\bar{A}\^2\\bar{B}, \\dots, C\\bar{A}\^{L-1}\\bar{B}K=CBˉ,CAˉBˉ,CAˉ2Bˉ,...,CAˉL−1Bˉ
  4. 用 FFT 计算 y=K∗xy = K * xy=K∗x

推理(递推模式)

  1. 使用相同的 Aˉ,Bˉ,C\bar{A}, \bar{B}, CAˉ,Bˉ,C
  2. 递推计算 ht=Aˉht−1+Bˉxth_t = \bar{A}h_{t-1} + \bar{B}x_tht=Aˉht−1+Bˉxt, yt=Chty_t = Ch_tyt=Cht

7.2 LSSL 的实验结果

7.2.1 在长程基准上的表现

LSSL 在当时(2021年)取得了一系列突破性的结果:

任务 序列长度 Transformer LSTM LSSL
ListOps 2000 36.4 26.0 60.1
Text 4096 64.3 62.4 85.4
Retrieval 4096 57.5 57.1 89.2
Image (32x32) 1024 42.4 33.2 87.7
Pathfinder 1024 71.4 62.8 92.2

这些结果表明,HiPPO 初始化的 SSM 能够有效地捕捉长程依赖。

7.2.2 LSSL 的局限

但 LSSL 有几个关键局限:

  1. AAA 矩阵的计算瓶颈 :eAΔe^{A\Delta}eAΔ 需要 O(N3)O(N^3)O(N3) 的矩阵指数运算
  2. 卷积核的计算 :CAˉnBC\bar{A}^n BCAˉnB 需要 O(N2)O(N^2)O(N2) 的矩阵-向量乘法
  3. Aˉ\bar{A}Aˉ 是稠密矩阵 :即使 AAA 有特殊结构,eAΔe^{A\Delta}eAΔ 通常是稠密的
  4. 梯度计算困难 :eAΔe^{A\Delta}eAΔ 关于 AAA 的梯度涉及矩阵指数的微分

这些局限促使了 S4 的诞生------通过结构化参数化来解决计算瓶颈。

7.3 从 LSSL 到 S4 的演进

LSSL 的核心问题可以总结为:

HiPPO 给了我们好的 AAA 矩阵,但没有给我们高效计算 eAΔe^{A\Delta}eAΔ 的方法。

S4 的核心洞察是:

通过选择特殊的 AAA 矩阵结构(如正规加低秩,NPLR),可以将矩阵指数的计算从 O(N3)O(N^3)O(N3) 降到 O(N)O(N)O(N) 或 O(Nlog⁡N)O(N \log N)O(NlogN)。

这需要我们在 HiPPO 矩阵的数学结构上下功夫------下一章将详细展开。

7.4 LSSL 的完整实现

python 复制代码
import numpy as np
import scipy.linalg


class SimpleLSSL:
    """简单的 LSSL(Linear State Space Layer)实现。

    用于演示 HiPPO 初始化 + 离散化 + 卷积计算的基本流程。
    注意:这是一个教学用的简化实现,不做反向传播。
    """

    def __init__(self, N: int = 64, dt: float = 0.001):
        """
        Args:
            N: 状态维度
            dt: 离散化步长 Delta
        """
        self.N = N
        self.dt = dt

        # 用 HiPPO-LegS 初始化 A
        self.A, self.B_cont = self._hippo_legs(N)
        self.C = np.random.randn(1, N) * (1.0 / np.sqrt(N))

        # ZOH 离散化
        self._discretize()

    @staticmethod
    def _hippo_legs(N: int) -> tuple[np.ndarray, np.ndarray]:
        """构造 HiPPO-LegS 矩阵。"""
        A = np.zeros((N, N))
        for n in range(N):
            for k in range(N):
                if n > k:
                    A[n, k] = np.sqrt(2 * n + 1) * np.sqrt(2 * k + 1)
                elif n == k:
                    A[n, k] = n + 1
        B = np.sqrt(2 * np.arange(N) + 1).reshape(-1, 1).astype(np.float64)
        return A, B

    def _discretize(self):
        """使用 ZOH 进行离散化。"""
        dtA = self.A * self.dt
        # bar_A = expm(A * dt)
        self.bar_A = scipy.linalg.expm(dtA)
        # bar_B = A^{-1} (bar_A - I) B dt
        # 当 A 可能奇异时,使用积分形式
        self.bar_B = np.linalg.solve(
            self.A,
            (self.bar_A - np.eye(self.N)) @ self.B_cont
        ) * self.dt  # 注意:这里实际上不需要再乘 dt,因为 B_cont 已经是 B
        # 修正:bar_B = A^{-1} (exp(A*dt) - I) B
        # 不需要额外乘 dt
        self.bar_B = np.linalg.solve(
            self.A,
            (self.bar_A - np.eye(self.N)) @ self.B_cont
        )

    def compute_kernel(self, L: int) -> np.ndarray:
        """计算长度为 L 的卷积核。

        Args:
            L: 序列长度

        Returns:
            K: (L,) 卷积核
        """
        K = np.zeros(L)
        A_power_B = self.bar_B.copy()  # bar_A^0 @ bar_B = bar_B
        for n in range(L):
            K[n] = (self.C @ A_power_B).item()
            A_power_B = self.bar_A @ A_power_B
        return K

    def forward_conv(self, x: np.ndarray) -> np.ndarray:
        """卷积模式前向传播。

        Args:
            x: (L,) 输入序列

        Returns:
            y: (L,) 输出序列
        """
        L = len(x)
        K = self.compute_kernel(L)
        # 使用 FFT 计算卷积
        X = np.fft.rfft(x, n=2 * L)
        K_f = np.fft.rfft(K, n=2 * L)
        y = np.fft.irfft(X * K_f, n=2 * L)[:L]
        return y

    def forward_recur(self, x: np.ndarray) -> np.ndarray:
        """递推模式前向传播。

        Args:
            x: (L,) 输入序列

        Returns:
            y: (L,) 输出序列
        """
        L = len(x)
        h = np.zeros((self.N, 1))
        y = np.zeros(L)
        for t in range(L):
            h = self.bar_A @ h + self.bar_B * x[t]
            y[t] = (self.C @ h).item()
        return y


def demonstrate_lssl():
    """演示 LSSL 的卷积模式和递推模式。"""
    np.random.seed(42)

    N = 32
    L = 256
    dt = 0.01

    model = SimpleLSSL(N=N, dt=dt)

    # 随机输入
    x = np.random.randn(L)

    # 两种模式
    y_conv = model.forward_conv(x)
    y_recur = model.forward_recur(x)

    # 比较
    diff = np.max(np.abs(y_conv - y_recur))
    print(f"LSSL 双模式验证 (N={N}, L={L})")
    print(f"  卷积模式 vs 递推模式最大差异: {diff:.2e}")
    print(f"  输出均值: {np.mean(y_conv):.4f}")
    print(f"  输出标准差: {np.std(y_conv):.4f}")

    # 展示卷积核
    K = model.compute_kernel(L)
    print(f"\n卷积核统计:")
    print(f"  K[0] = {K[0]:.6f}")
    print(f"  K[L//2] = {K[L//2]:.6f}")
    print(f"  K[L-1] = {K[L-1]:.6f}")
    print(f"  ||K||_1 = {np.sum(np.abs(K)):.4f}")

    # 展示卷积核的频谱
    K_fft = np.abs(np.fft.rfft(K))
    print(f"\n卷积核频谱:")
    print(f"  低频能量 (前10%): {np.sum(K_fft[:len(K_fft)//10]**2) / np.sum(K_fft**2):.4f}")
    print(f"  高频能量 (后10%): {np.sum(K_fft[-len(K_fft)//10:]**2) / np.sum(K_fft**2):.4f}")


if __name__ == "__main__":
    demonstrate_lssl()

相关推荐
甩手网软件5 分钟前
Shopee2026新规:费率重构与履约收紧下,卖家如何破局?
大数据·人工智能
数据库小学妹6 分钟前
AI时代数据库怎么选?多模融合、数据统一存储与选型实战指南
数据库·人工智能·经验分享·ai
lizhihai_9914 分钟前
股市学习心得-AI 产业链核心标的梳理清单
大数据·服务器·人工智能·科技·学习
暮雪倾风17 分钟前
【AI】国内使用Claude Code,配置Claude Code,使用DeepSeek为例
人工智能
FrameNotWork25 分钟前
HarmonyOS6.1 AI 模型管理架构设计与最佳实践
人工智能·harmonyos
没事别瞎琢磨28 分钟前
十、统一 Runner 入口——能力检测与模式回退
人工智能·node.js
装不满的克莱因瓶30 分钟前
了解 LangChain 中的 LLM 与 ChatModel 的差异
人工智能·python·ai·langchain·llm·agent·chatmodel
dingzd9534 分钟前
跨境社媒运营越到后面 越比拼账号的表达稳定性
大数据·人工智能·矩阵·内容营销
云烟成雨TD35 分钟前
Spring AI 1.x 系列【54】Retry 机制分析
java·人工智能·spring
没事别瞎琢磨37 分钟前
八、环境隔离——构建安全的子进程环境
人工智能·node.js