状态空间模型:从经典控制论到现代序列建模——S4、Mamba 及其理论体系的完整论述(三)

第十二章:Mamba 架构------选择性扫描与硬件感知设计

12.1 Mamba 的整体架构

12.1.1 设计哲学

Mamba 的设计哲学可以总结为:

用选择性机制换取表达能力,用硬件感知算法弥补计算效率的损失。

12.1.2 Mamba Block

一个 Mamba Block 的计算流程:

复制代码
输入 x ∈ R^{B,L,D}
    │
    ├──→ Linear (+) ──→ SiLU ──→ Conv1D ──→ SiLU ──→ SSM ──→ (+) ──→ 输出
    │                                                    │
    └──→ Linear (+) ──→ SiLU ──→ 乘法门控 ─────────────┘

更具体地:

  1. 输入投影 :x→(z,x′)x \to (z, x')x→(z,x′),通过两个线性层,分别得到门控和主通路
  2. 主通路 :x′→Conv1D(x′)→SiLU→SSM(⋅)x' \to \text{Conv1D}(x') \to \text{SiLU} \to \text{SSM}(\cdot)x′→Conv1D(x′)→SiLU→SSM(⋅)
  3. 门控 :z→SiLU(z)z \to \text{SiLU}(z)z→SiLU(z)
  4. 输出 :SSM(⋅)⊙SiLU(z)\text{SSM}(\cdot) \odot \text{SiLU}(z)SSM(⋅)⊙SiLU(z)
  5. 输出投影 :通过线性层映射回 DDD 维

12.1.3 与 Transformer Block 的对比

组件 Transformer Block Mamba Block
核心操作 Self-Attention Selective SSM
前馈网络 FFN (两层 MLP) 门控 + 输出投影
归一化 LayerNorm RMSNorm (或无)
残差连接

12.2 选择性 SSM 的具体实现

12.2.1 参数生成

给定输入 xt∈RDx_t \in \mathbb{R}^Dxt∈RD:

Bt=LinearB(xt)∈RNB_t = \text{Linear}_B(x_t) \in \mathbb{R}^NBt=LinearB(xt)∈RN

Ct=LinearC(xt)∈RNC_t = \text{Linear}_C(x_t) \in \mathbb{R}^NCt=LinearC(xt)∈RN

Δt=softplus(LinearΔ(xt))∈R>0\Delta_t = \text{softplus}(\text{Linear}\Delta(x_t)) \in \mathbb{R}{>0}Δt=softplus(LinearΔ(xt))∈R>0

其中 softplus 保证 Δ>0\Delta > 0Δ>0:

softplus(x)=log⁡(1+ex)\text{softplus}(x) = \log(1 + e^x)softplus(x)=log(1+ex)

12.2.2 离散化与递推

python 复制代码
def selective_ssm_step(
    Lambda: complex,      # A 的对角元素 (标量,单通道)
    B_t: float,           # 输入依赖的 B
    C_t: float,           # 输入依赖的 C
    delta_t: float,       # 输入依赖的步长
    h_prev: complex,      # 前一步的隐状态
    x_t: float,           # 当前输入
) -> tuple[complex, float]:
    """选择性 SSM 的单步递推(单通道单模态)。"""
    # 离散化
    A_bar = np.exp(Lambda * delta_t)
    B_bar = (A_bar - 1) / Lambda * B_t if abs(Lambda) > 1e-7 else delta_t * B_t

    # 递推
    h_t = A_bar * h_prev + B_bar * x_t

    # 输出
    y_t = (C_t * h_t).real

    return h_t, y_t

12.3 硬件感知的选择性扫描算法

12.3.1 GPU 内存层次

现代 GPU(如 A100)的内存层次:

层级 大小 带宽 延迟
寄存器 ~256 KB/SM ~19 TB/s ~1 cycle
共享内存 (SRAM) ~192 KB/SM ~19 TB/s ~5 cycles
L2 缓存 40 MB ~5 TB/s ~30 cycles
HBM(主存) 80 GB ~2 TB/s ~200 cycles

关键观察:SRAM 的带宽是 HBM 的约 10 倍

12.3.2 朴素实现的 IO 瓶颈

朴素的递推实现需要在每个时间步读写隐状态 ht∈RB×Nh_t \in \mathbb{R}^{B \times N}ht∈RB×N:

  • 读 ht−1h_{t-1}ht−1:O(BN)O(BN)O(BN) 次 HBM 访问
  • 写 hth_tht:O(BN)O(BN)O(BN) 次 HBM 访问
  • 总共 LLL 步:O(BLN)O(BLN)O(BLN) 次 HBM 访问

当 B=8B = 8B=8, L=2048L = 2048L=2048, N=16N = 16N=16 时,这是 2.6×1052.6 \times 10^52.6×105 次 HBM 访问,即使每次只需 1 个 cycle,也意味着大量时间花在数据搬运上。

12.3.3 Mamba 的选择性扫描算法

Mamba 的核心优化是:将整个序列加载到 SRAM 中,在 SRAM 内完成递推,只将最终结果写回 HBM。

算法步骤:

  1. 加载 :将 x,B,C,Δx, B, C, \Deltax,B,C,Δ 的一个块加载到 SRAM
  2. 初始化 :在 SRAM 中初始化 h=0h = 0h=0
  3. 递推 :在 SRAM 中完成所有 LLL 步递推
  4. 写回 :将输出 yyy 写回 HBM

这将 HBM 访问从 O(BLN)O(BLN)O(BLN) 降到 O(BL(D+N))O(BL(D + N))O(BL(D+N))(只加载和写回一次)。

12.3.4 反向传播:重计算策略

问题 :训练时反向传播需要中间状态 h1,h2,...,hLh_1, h_2, \dots, h_Lh1,h2,...,hL,但为了节省内存,SRAM 中只保留了最终状态 hLh_LhL。

解决方案:重计算(recomputation)

前向传播时只保存:

  • 最终状态 hLh_LhL
  • 输入 x,B,C,Δx, B, C, \Deltax,B,C,Δ

反向传播时,从 hLh_LhL 开始反向递推

ht−1=Aˉt−1(ht−Bˉtxt)h_{t-1} = \bar{A}_t^{-1}(h_t - \bar{B}_t x_t)ht−1=Aˉt−1(ht−Bˉtxt)

这需要 O(LBN)O(LBN)O(LBN) 次计算,但不需要额外的内存来存储中间状态。

12.4 Mamba 的完整实现

python 复制代码
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class SelectiveSSM(nn.Module):
    """选择性状态空间模型的核心组件。

    参数:
    - Lambda: A 的对角元素(可学习,初始化为 HiPPO-LegS)
    - B_proj, C_proj, dt_proj: 输入依赖的 B, C, Delta 投影层
    """

    def __init__(self, d_model: int, d_state: int = 16, dt_rank: str = "auto"):
        """
        Args:
            d_model: 模型维度 D
            d_state: 状态维度 N
            dt_rank: Delta 投影的秩
        """
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state

        # A 的对角元素:用 HiPPO-LegS 的特征值初始化
        # Lambda_k = -(k + 0.5), k = 0, ..., N-1
        Lambda = -(torch.arange(d_state, dtype=torch.float32) + 0.5)
        self.Lambda = nn.Parameter(Lambda)  # (N,)

        # B 和 C 的输入依赖投影
        self.B_proj = nn.Linear(d_model, d_state, bias=False)
        self.C_proj = nn.Linear(d_model, d_state, bias=False)

        # Delta 的投影
        if dt_rank == "auto":
            dt_rank = max(16, d_model // 16)
        self.dt_proj = nn.Linear(dt_rank, d_model, bias=True)

        # Delta 投影的输入投影(先降维再升维)
        self.dt_input_proj = nn.Linear(d_model, dt_rank, bias=False)

        # D 参数(skip connection)
        self.D = nn.Parameter(torch.ones(d_model))

        # 确保 Delta > 0
        # 使用 inv_softplus(0.1) 作为 bias 初始化,使得初始 Delta ≈ 0.1
        dt_init_std = 0.5
        nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
        with torch.no_grad():
            self.dt_proj.bias.copy_(torch.log(torch.expm1(torch.tensor(0.1))))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """选择性 SSM 前向传播。

        Args:
            x: (batch, seq_len, d_model)

        Returns:
            y: (batch, seq_len, d_model)
        """
        batch, L, D = x.shape
        N = self.d_state

        # 生成输入依赖的参数
        B = self.B_proj(x)  # (batch, L, N)
        C = self.C_proj(x)  # (batch, L, N)

        # Delta: softplus 保证正值
        dt_input = self.dt_input_proj(x)  # (batch, L, dt_rank)
        delta = F.softplus(self.dt_proj(dt_input))  # (batch, L, D)

        # A_bar 和 B_bar(逐元素计算,因为 A 是对角的)
        # Lambda: (N,), delta: (batch, L, D)
        # 需要将 Lambda 广播到 (batch, L, D, N)
        Lambda = self.Lambda.unsqueeze(0).unsqueeze(0).unsqueeze(0)  # (1, 1, 1, N)
        delta_expanded = delta.unsqueeze(-1)  # (batch, L, D, 1)
        B_expanded = B.unsqueeze(2)  # (batch, L, 1, N)
        C_expanded = C.unsqueeze(2)  # (batch, L, 1, N)

        # 离散化
        A_bar = torch.exp(Lambda * delta_expanded)  # (batch, L, D, N)

        # B_bar = (exp(Lambda*delta) - 1) / Lambda * B
        # 数值稳定版本
        B_bar = torch.where(
            torch.abs(Lambda) > 1e-7,
            (A_bar - 1.0) / Lambda * B_expanded,
            delta_expanded * B_expanded,
        )  # (batch, L, D, N)

        # 选择性扫描(递推)
        h = torch.zeros(batch, D, N, dtype=x.dtype, device=x.device)
        outputs = []

        for t in range(L):
            # h_t = A_bar_t * h_{t-1} + B_bar_t * x_t
            h = A_bar[:, t] * h + B_bar[:, t] * x[:, t, :, None]  # (batch, D, N)

            # y_t = C_t @ h_t
            y_t = torch.sum(C_expanded[:, t] * h, dim=-1)  # (batch, D)

            # 加上 skip connection D * x
            y_t = y_t + self.D * x[:, t]

            outputs.append(y_t)

        return torch.stack(outputs, dim=1)  # (batch, L, D)


class MambaBlock(nn.Module):
    """Mamba 架构的一个完整 block。"""

    def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2):
        """
        Args:
            d_model: 模型维度
            d_state: SSM 状态维度
            d_conv: 局部卷积的核大小
            expand: 内部扩展因子
        """
        super().__init__()
        d_inner = d_model * expand

        # 输入投影:x -> (x_ssm, z)
        self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)

        # 局部卷积
        self.conv1d = nn.Conv1d(
            d_inner, d_inner, kernel_size=d_conv,
            padding=d_conv - 1, groups=d_inner, bias=True,
        )

        # 选择性 SSM
        self.ssm = SelectiveSSM(d_model=d_inner, d_state=d_state)

        # 输出投影
        self.out_proj = nn.Linear(d_inner, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, seq_len, d_model)

        Returns:
            y: (batch, seq_len, d_model)
        """
        residual = x

        # 输入投影
        xz = self.in_proj(x)  # (batch, L, 2 * d_inner)
        x_ssm, z = xz.chunk(2, dim=-1)  # 各 (batch, L, d_inner)

        # 主通路:Conv1D + SSM
        # Conv1D 需要 (batch, channels, length) 的输入
        x_conv = x_ssm.transpose(1, 2)  # (batch, d_inner, L)
        x_conv = self.conv1d(x_conv)[:, :, :x.size(1)]  # 因果卷积,截断到 L
        x_conv = x_conv.transpose(1, 2)  # (batch, L, d_inner)

        x_conv = F.silu(x_conv)
        y = self.ssm(x_conv)  # (batch, L, d_inner)

        # 门控
        z = F.silu(z)
        y = y * z

        # 输出投影
        y = self.out_proj(y)

        return y


class Mamba(nn.Module):
    """完整的 Mamba 模型(简化版)。"""

    def __init__(
        self,
        vocab_size: int = 256,
        d_model: int = 128,
        n_layers: int = 4,
        d_state: int = 16,
        d_conv: int = 4,
        expand: int = 2,
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)

        self.layers = nn.ModuleList([
            MambaBlock(d_model, d_state, d_conv, expand)
            for _ in range(n_layers)
        ])

        self.norm = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input_ids: (batch, seq_len) 整数序列

        Returns:
            logits: (batch, seq_len, vocab_size)
        """
        x = self.embedding(input_ids)  # (batch, L, d_model)

        for layer in self.layers:
            x = x + layer(x)  # 残差连接
            x = self.norm(x)

        logits = self.lm_head(x)
        return logits


def count_parameters(model: nn.Module) -> int:
    """统计模型参数量。"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def demonstrate_mamba():
    """演示 Mamba 模型的基本功能。"""
    torch.manual_seed(42)

    # 超参数
    vocab_size = 256
    d_model = 64
    n_layers = 2
    d_state = 16
    batch_size = 2
    seq_len = 128

    # 创建模型
    model = Mamba(
        vocab_size=vocab_size,
        d_model=d_model,
        n_layers=n_layers,
        d_state=d_state,
    )

    print("Mamba 模型信息:")
    print(f"  词汇表大小: {vocab_size}")
    print(f"  模型维度: {d_model}")
    print(f"  层数: {n_layers}")
    print(f"  状态维度: {d_state}")
    print(f"  总参数量: {count_parameters(model):,}")

    # 前向传播
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
    logits = model(input_ids)

    print(f"\n前向传播:")
    print(f"  输入形状: {input_ids.shape}")
    print(f"  输出形状: {logits.shape}")
    print(f"  输出统计: mean={logits.mean():.4f}, std={logits.std():.4f}")

    # 损失计算
    targets = torch.randint(0, vocab_size, (batch_size, seq_len))
    loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
    print(f"  交叉熵损失: {loss.item():.4f}")

    # 反向传播
    loss.backward()
    grad_norms = {name: p.grad.norm().item() for name, p in model.named_parameters() if p.grad is not None}
    print(f"\n梯度统计:")
    print(f"  有梯度的参数数: {len(grad_norms)}")
    print(f"  最大梯度范数: {max(grad_norms.values()):.4f}")
    print(f"  最小梯度范数: {min(grad_norms.values()):.6f}")


if __name__ == "__main__":
    demonstrate_mamba()

12.5 Mamba 的理论分析

12.5.1 计算复杂度

操作 Transformer Mamba
前向传播 O(L2D+LD2)O(L^2 D + L D^2)O(L2D+LD2) O(LDN)O(L D N)O(LDN)
内存 O(L2+LD)O(L^2 + L D)O(L2+LD) O(LD)O(L D)O(LD)
推理(每步) O(LD+D2)O(L D + D^2)O(LD+D2) O(DN)O(D N)O(DN)
推理(KV Cache) O(LD)O(L D)O(LD) O(N)O(N)O(N)(固定)

Mamba 的推理成本与序列长度无关------这是一个关键优势。

12.5.2 表达能力

定理 12.1:选择性 SSM 可以模拟任意有限状态自动机(Finite State Automaton, FSA)。

证明:对于 QQQ 个状态的 FSA,设置 N=QN = QN=Q,令 AAA 为转移矩阵的对角化形式,BBB 为输入字母表的编码,CCC 为状态读出函数,Δ\DeltaΔ 的选择性机制用于控制何时转移。□\square□

这意味着 Mamba 在理论上比 LTI-SSM 更强大------它可以精确地执行需要"条件逻辑"的任务。


第十三章:Mamba-2 与结构化状态空间对偶性

13.1 Mamba 的计算瓶颈

虽然 Mamba 的选择性扫描算法已经通过硬件感知设计大幅优化,但其核心仍是串行递推

ht=Aˉt⊙ht−1+Bˉt⊙xth_t = \bar{A}t \odot h{t-1} + \bar{B}_t \odot x_tht=Aˉt⊙ht−1+Bˉt⊙xt

这限制了 GPU 的并行度。Mamba-2(Dao & Gu, 2024)的核心贡献是发现了结构化状态空间与注意力之间的对偶性,从而设计出更高效的算法。

13.2 结构化状态空间对偶性(SSD)

13.2.1 核心发现

定理 13.1(SSD 对偶性) :在特定条件下,选择性 SSM 的计算可以被等价地表示为一种半可分(semiseparable)矩阵乘法,而这与线性注意力(linear attention)具有相同的形式。

具体来说,选择性 SSM 的输入-输出关系可以写成:

Y=MXY = M XY=MX

其中 MMM 是一个下三角半可分矩阵

Mij={CiT(∏k=j+1iAˉk)Bˉjif i≥j0if i<jM_{ij} = \begin{cases} C_i^T \left(\prod_{k=j+1}^{i} \bar{A}_k\right) \bar{B}_j & \text{if } i \geq j \\ 0 & \text{if } i < j \end{cases}Mij={CiT(∏k=j+1iAˉk)Bˉj0if i≥jif i<j

这个矩阵具有特殊的结构,使得 Y=MXY = MXY=MX 可以用分块算法高效计算。

13.2.2 与线性注意力的关系

线性注意力的计算为:

Yi=∑j≤iϕ(qi)Tϕ(kj)vj∑j≤iϕ(qi)Tϕ(kj)Y_i = \frac{\sum_{j \leq i} \phi(q_i)^T \phi(k_j) v_j}{\sum_{j \leq i} \phi(q_i)^T \phi(k_j)}Yi=∑j≤iϕ(qi)Tϕ(kj)∑j≤iϕ(qi)Tϕ(kj)vj

忽略归一化分母,这等价于:

Yi=∑j≤iϕ(qi)Tϕ(kj)vjY_i = \sum_{j \leq i} \phi(q_i)^T \phi(k_j) v_jYi=j≤i∑ϕ(qi)Tϕ(kj)vj

当 ϕ(qi)Tϕ(kj)=CiT(∏k=j+1iAˉk)Bˉj\phi(q_i)^T \phi(k_j) = C_i^T \left(\prod_{k=j+1}^{i} \bar{A}_k\right) \bar{B}_jϕ(qi)Tϕ(kj)=CiT(∏k=j+1iAˉk)Bˉj 时,两者完全等价。

13.2.3 对偶性的数学基础

半可分矩阵 (semiseparable matrix)是一类具有低秩子块的矩阵。对于 n×nn \times nn×n 的下三角矩阵 MMM,若其每个 k×kk \times kk×k 的前主子式(leading principal minor)的秩为 O(r)O(r)O(r)(rrr 为分隔秩),则 MMM 是秩-rrr 半可分的。

选择性 SSM 产生的矩阵 MMM 正是秩-NNN 半可分的(NNN 为状态维度)。

13.3 Mamba-2 的块算法

13.3.1 分块策略

将长度为 LLL 的序列分为大小为 TTT 的块:

X=X1,X2,...,XL/TX = X_1, X_2, \\dots, X_{L/T}X=X1,X2,...,XL/T

每个块 Xi∈RT×DX_i \in \mathbb{R}^{T \times D}Xi∈RT×D。

13.3.2 块内计算

在每个块内,用并行扫描计算局部状态。这一步完全并行,可以在 GPU 的线程块内完成。

13.3.3 块间传递

块之间的状态传递通过一个传递矩阵 完成。由于半可分结构,这个传递矩阵只需要 O(N)O(N)O(N) 的空间。

13.3.4 总复杂度

时间:O(LT⋅(T2D+TDN))=O(LTD+LDN)\text{时间:} O\left(\frac{L}{T} \cdot (T^2 D + T D N)\right) = O(LTD + LDN)时间:O(TL⋅(T2D+TDN))=O(LTD+LDN)

当 T≈NT \approx NT≈N 时,这是 O(LND+LDN)=O(LDN)O(LND + LDN) = O(LDN)O(LND+LDN)=O(LDN)------与朴素递推相同,但常数因子更小,且块内可以并行。

13.4 Mamba-2 vs Mamba-1

特性 Mamba-1 Mamba-2
核心算法 选择性扫描(串行) SSD 块算法(分块并行)
并行度 低(完全串行) 中(块内并行)
GPU 利用率 受限于序列长度 更高
与注意力的关系 明确的对偶性
混合架构 非自然 自然(可与注意力层交替)

13.5 Mamba-2 的实现要点

Mamba-2 的实现基于以下关键组件:

  1. 因果线性注意力:利用半可分矩阵结构
  2. 分块并行扫描:在块内使用并行前缀积
  3. 状态传递:通过压缩的状态表示在块间传递信息
python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class SSDBlock(nn.Module):
    """Mamba-2 的核心:结构化状态空间对偶(SSD)层。

    实现基于半可分矩阵的分块计算。
    """

    def __init__(self, d_model: int, d_state: int = 64, chunk_size: int = 64):
        """
        Args:
            d_model: 模型维度 D
            d_state: 状态维度 N
            chunk_size: 块大小 T
        """
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.chunk_size = chunk_size

        # A 的对角元素(HiPPO-LegS 初始化)
        Lambda = -(torch.arange(d_state, dtype=torch.float32) + 0.5)
        self.Lambda = nn.Parameter(Lambda)

        # B, C, Delta 的投影(选择性)
        self.B_proj = nn.Linear(d_model, d_state, bias=False)
        self.C_proj = nn.Linear(d_model, d_state, bias=False)
        self.dt_proj = nn.Linear(d_model, d_model, bias=True)
        self.D = nn.Parameter(torch.ones(d_model))

    def _compute_discretized_params(self, x):
        """计算输入依赖的离散化参数。"""
        B = self.B_proj(x)  # (batch, L, N)
        C = self.C_proj(x)  # (batch, L, N)
        delta = F.softplus(self.dt_proj(x))  # (batch, L, D)

        return B, C, delta

    def _chunkwise_scan(self, x, A_bar, B_bar, C):
        """分块并行扫描。

        在每个块内用递推计算,块间通过状态传递。
        """
        batch, L, D = x.shape
        N = self.d_state
        T = self.chunk_size

        # 填充到 T 的整数倍
        n_chunks = math.ceil(L / T)
        pad_len = n_chunks * T - L
        x_pad = F.pad(x, (0, 0, 0, pad_len))
        A_bar_pad = F.pad(A_bar, (0, 0, 0, 0, 0, pad_len))
        B_bar_pad = F.pad(B_bar, (0, 0, 0, 0, 0, pad_len))
        C_pad = F.pad(C, (0, 0, 0, 0, 0, pad_len))

        # 重塑为 (batch, n_chunks, T, ...)
        x_chunks = x_pad.reshape(batch, n_chunks, T, D)
        A_chunks = A_bar_pad.reshape(batch, n_chunks, T, D, N)
        B_chunks = B_bar_pad.reshape(batch, n_chunks, T, D, N)
        C_chunks = C_pad.reshape(batch, n_chunks, T, D, N)

        # 块内递推
        outputs = []
        h = torch.zeros(batch, D, N, dtype=x.dtype, device=x.device)

        for c in range(n_chunks):
            # 当前块的参数
            x_c = x_chunks[:, c]  # (batch, T, D)
            A_c = A_chunks[:, c]  # (batch, T, D, N)
            B_c = B_chunks[:, c]  # (batch, T, D, N)
            C_c = C_chunks[:, c]  # (batch, T, D, N)

            # 块内递推
            y_list = []
            for t in range(T):
                h = A_c[:, t] * h + B_c[:, t] * x_c[:, t, :, None]
                y_t = torch.sum(C_c[:, t] * h, dim=-1) + self.D * x_c[:, t]
                y_list.append(y_t)

            y_c = torch.stack(y_list, dim=1)  # (batch, T, D)
            outputs.append(y_c)

        # 拼接并截断
        y = torch.cat(outputs, dim=1)[:, :L]
        return y

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, seq_len, d_model)
        Returns:
            y: (batch, seq_len, d_model)
        """
        B, C, delta = self._compute_discretized_params(x)

        # 离散化
        Lambda = self.Lambda.unsqueeze(0).unsqueeze(0).unsqueeze(0)
        delta_exp = delta.unsqueeze(-1)
        A_bar = torch.exp(Lambda * delta_exp)
        B_bar = torch.where(
            torch.abs(Lambda) > 1e-7,
            (A_bar - 1.0) / Lambda * B.unsqueeze(2),
            delta_exp * B.unsqueeze(2),
        )

        # 分块扫描
        y = self._chunkwise_scan(x, A_bar, B_bar, C)
        return y


def demonstrate_ssd():
    """演示 Mamba-2 的 SSD 层。"""
    torch.manual_seed(42)

    batch, seq_len, d_model = 2, 256, 64
    d_state = 32
    chunk_size = 64

    layer = SSDBlock(d_model, d_state, chunk_size)

    x = torch.randn(batch, seq_len, d_model)
    y = layer(x)

    print("Mamba-2 SSD 层演示")
    print(f"  输入形状: {x.shape}")
    print(f"  状态维度: {d_state}")
    print(f"  块大小: {chunk_size}")
    print(f"  输出形状: {y.shape}")
    print(f"  输出统计: mean={y.mean():.4f}, std={y.std():.4f}")

    # 参数量
    n_params = sum(p.numel() for p in layer.parameters())
    print(f"  参数量: {n_params:,}")


if __name__ == "__main__":
    demonstrate_ssd()

第五部分:理论分析与实践


第十四章:SSM 与 Transformer 的理论对比

14.1 表达能力

14.1.1 形式语言理论视角

从计算理论的角度,一个序列模型可以被看作一种序列到序列的映射 f:Σ∗→Σ∗f: \Sigma^* \to \Sigma^*f:Σ∗→Σ∗。我们可以根据模型能够识别的形式语言类来比较其表达能力。

定理 14.1(LTI-SSM 的表达能力) :LTI-SSM 可以精确识别所有正则语言(regular languages) ,但不能精确识别上下文无关语言(context-free languages)

证明:

  • 正则语言等价于有限状态自动机(FSA)。LTI-SSM 的有限维隐状态 ht∈RNh_t \in \mathbb{R}^Nht∈RN 可以编码 FSA 的 QQQ 个状态(当 N≥QN \geq QN≥Q 时)。
  • 上下文无关语言(如回文语言 {wwR}\{ww^R\}{wwR})需要栈(stack)来识别。LTI-SSM 的隐状态只能压缩固定维度的信息,无法模拟无限深的栈。□\square□

定理 14.2(选择性 SSM 的表达能力) :选择性 SSM 可以精确识别一大类非正则语言,包括某些上下文无关语言。

证明思路:选择性机制允许模型根据输入动态调整记忆策略。例如,对于计数语言 {anbn}\{a^n b^n\}{anbn},选择性 SSM 可以在 aaa 阶段用大的 Δ\DeltaΔ 写入计数,在 bbb 阶段用小的 Δ\DeltaΔ 读出计数。□\square□

14.1.2 逼近理论视角

定理 14.3(SSM 的通用逼近) :对于任意 ϵ>0\epsilon > 0ϵ>0 和任意因果序列映射 fff,存在一个足够大的 SSM,使得其输出与 fff 的输出在任意有限长度上的差异小于 ϵ\epsilonϵ。

这与 Transformer 的通用逼近定理(Yun et al., 2020)相对应:

定理 14.4(Transformer 的通用逼近):具有足够多头和层数的 Transformer 可以逼近任意序列到序列的连续映射。

两个定理都表明,给定足够的参数量 ,SSM 和 Transformer 在理论上具有相同的表达能力。差别在于实际中的效率------某些函数用 SSM 表达更高效,某些用 Transformer 更高效。

14.1.3 效率的差距

定义(表达效率) :用 ppp 个参数逼近目标函数 fff 所需的最小误差 ϵ(p)\epsilon(p)ϵ(p)。

对于不同的任务,SSM 和 Transformer 的表达效率差异巨大:

任务类型 SSM 效率 Transformer 效率 原因
长程记忆 SSM 用 O(N)O(N)O(N) 状态存储长程信息
动态路由 中(需选择性) 注意力天然支持动态路由
局部模式 SSM 的卷积模式天然擅长
全局聚合 注意力的 O(1)O(1)O(1) 路径长度

14.2 复杂度分析

14.2.1 训练复杂度

模型 时间复杂度 空间复杂度 并行度
Transformer O(L2D+LD2)O(L^2 D + L D^2)O(L2D+LD2) O(L2+LD)O(L^2 + L D)O(L2+LD) 完全并行
S4/S4D O(LDlog⁡L+LDN)O(L D \log L + L D N)O(LDlogL+LDN) O(LD)O(L D)O(LD) 完全并行
Mamba O(LDN)O(L D N)O(LDN) O(LD)O(L D)O(LD) 串行
Mamba-2 O(LDN)O(L D N)O(LDN) O(LD)O(L D)O(LD) 分块并行

14.2.2 推理复杂度

推理的单步成本(生成一个新 token):

模型 计算 内存(KV Cache / State)
Transformer O(LD+D2)O(L D + D^2)O(LD+D2) O(LD)O(L D)O(LD),随 LLL 线性增长
SSM O(DN)O(D N)O(DN) O(N)O(N)O(N),固定

SSM 的推理优势在于:恒定的内存占用和恒定的计算成本 。当上下文长度 LLL 很大时,这个优势变得极为显著。

14.2.3 推理吞吐量对比

考虑一个自回归生成场景,生成 TTT 个 token:

模型 第 1 步 第 2 步 ... 第 T 步 总时间
Transformer O(LD2)O(L D^2)O(LD2) O((L+1)D2)O((L+1)D^2)O((L+1)D2) ... O((L+T)D2)O((L+T)D^2)O((L+T)D2) O(TLD2+T2D2)O(TLD^2 + T^2D^2)O(TLD2+T2D2)
SSM O(DN)O(DN)O(DN) O(DN)O(DN)O(DN) ... O(DN)O(DN)O(DN) O(TDN)O(TDN)O(TDN)

当 TTT 很大时(如 T=4096T = 4096T=4096),SSM 的优势是压倒性的。

14.3 泛化能力

14.3.1 序列长度泛化

问题 :在长度为 LtrainL_{\text{train}}Ltrain 的序列上训练的模型,能否在长度为 Ltest>LtrainL_{\text{test}} > L_{\text{train}}Ltest>Ltrain 的序列上正确工作?

  • Transformer:由于绝对位置编码的限制,通常无法泛化到训练中未见过的长度。相对位置编码(如 RoPE)可以部分缓解,但仍有外推困难。
  • SSM:由于其递推结构和连续时间基础,天然具有长度泛化能力------递推不依赖于绝对位置。

14.3.2 任务泛化

在 LRA 基准上的对比(2023年数据):

模型 ListOps Text Retrieval Image Pathfinder 平均
Transformer 36.4 64.3 57.5 42.4 71.4 54.4
S4 58.4 86.1 90.6 88.0 91.2 82.9
Mamba 56.1 82.2 90.2 85.7 94.2 81.7
Mamba-2 59.1 85.3 91.0 87.3 94.5 83.4

SSM 在长程依赖任务上全面领先,但在需要精确动态路由的任务(如 ListOps)上优势较小。

14.4 信息论分析

14.4.1 信息瓶颈

对于序列 x1,x2,...,xLx_1, x_2, \dots, x_Lx1,x2,...,xL,模型在位置 ttt 的隐状态 hth_tht 必须压缩所有历史信息:

I(ht;x1,...,xt)≤dim⁡(ht)⋅log⁡(精度)I(h_t; x_1, \dots, x_t) \leq \dim(h_t) \cdot \log(\text{精度})I(ht;x1,...,xt)≤dim(ht)⋅log(精度)

  • SSM :ht∈RNh_t \in \mathbb{R}^Nht∈RN,信息瓶颈为 O(N)O(N)O(N)
  • Transformer :KV Cache 存储所有历史的 K,VK, VK,V,信息瓶颈为 O(LD)O(LD)O(LD)

SSM 的信息瓶颈更紧------它必须用固定维度的状态编码任意长度的历史。这既是优势(高效),也是劣势(可能丢失信息)。

14.4.2 选择性压缩

选择性 SSM 可以根据输入内容自适应地分配状态空间:

  • 重要信息:用大 Δ\DeltaΔ 写入,保持低衰减
  • 冗余信息:用小 Δ\DeltaΔ 忽略,或用高衰减快速遗忘

这在信息论上近似于**率失真理论(rate-distortion theory)**中的最优压缩策略------在给定的比特预算下最小化失真。

14.5 混合架构的理论基础

14.5.1 为什么需要混合?

SSM 和 Transformer 各有优势:

  • SSM 擅长:长程记忆、线性复杂度、恒定推理成本
  • Transformer 擅长:精确的动态路由、全局信息聚合、成熟的训练技术

混合架构试图兼取两者之长。

14.5.2 理论论证

定理 14.5(混合架构的优势) :对于一类同时需要长程记忆和精确动态路由的任务,混合架构(SSM + Attention 交替层)可以用 O(p)O(\sqrt{p})O(p ) 的参数达到纯 SSM 或纯 Transformer 需要 O(p)O(p)O(p) 参数才能达到的精度。

直觉上,SSM 层负责"存储"长程信息,Attention 层负责"检索"相关信息------类似于计算机中内存(SSM)和缓存(Attention)的分工。


第十五章:完整可运行代码实现

本章提供一个自包含的、可直接运行的代码实现,涵盖从经典的 HiPPO-SSM 到 Mamba 的完整演化路径。所有代码使用 NumPy 实现核心算法,使用 PyTorch 实现可训练的模型。

15.1 经典 LTI-SSM(NumPy 实现)

python 复制代码
"""
经典线性时不变状态空间模型 (LTI-SSM) 的纯 NumPy 实现。
包含: 矩阵指数、ZOH 离散化、卷积模式和递推模式。
"""

import numpy as np
from scipy.linalg import expm


class LTISSM:
    """线性时不变状态空间模型。

    连续形式:
        dh/dt = A h(t) + B x(t)
        y(t)  = C h(t)

    离散形式 (ZOH):
        h_t = bar_A h_{t-1} + bar_B x_t
        y_t = C h_t
    """

    def __init__(self, A: np.ndarray, B: np.ndarray, C: np.ndarray, dt: float):
        """
        Args:
            A: (N, N) 状态矩阵
            B: (N,) 或 (N, D) 输入矩阵
            C: (N,) 或 (D, N) 输出矩阵
            dt: 离散化步长
        """
        self.A = A
        self.B = B.reshape(-1, 1) if B.ndim == 1 else B
        self.C = C.reshape(1, -1) if C.ndim == 1 else C
        self.dt = dt
        self.N = A.shape[0]

        # ZOH 离散化
        self.bar_A = expm(A * dt)
        # bar_B = A^{-1} (bar_A - I) B
        try:
            self.bar_B = np.linalg.solve(A, (self.bar_A - np.eye(self.N)) @ self.B)
        except np.linalg.LinAlgError:
            # A 奇异时使用泰勒展开
            self.bar_B = dt * self.B
            for k in range(1, 10):
                term = (dt ** (k + 1)) / np.math.factorial(k + 1) * np.linalg.matrix_power(A, k) @ self.B
                self.bar_B += term

    def kernel(self, L: int) -> np.ndarray:
        """计算长度为 L 的卷积核。

        Returns:
            K: (L,) 卷积核
        """
        K = np.zeros(L)
        v = self.bar_B.copy()  # (N, 1)
        for n in range(L):
            K[n] = (self.C @ v).item()
            v = self.bar_A @ v
        return K

    def forward_conv(self, x: np.ndarray) -> np.ndarray:
        """卷积模式前向传播。

        Args:
            x: (L,) 输入序列

        Returns:
            y: (L,) 输出序列
        """
        L = len(x)
        K = self.kernel(L)
        X = np.fft.rfft(x, n=2 * L)
        K_f = np.fft.rfft(K, n=2 * L)
        y = np.fft.irfft(X * K_f, n=2 * L)[:L]
        return y

    def forward_recur(self, x: np.ndarray) -> np.ndarray:
        """递推模式前向传播。

        Args:
            x: (L,) 输入序列

        Returns:
            y: (L,) 输出序列
        """
        L = len(x)
        h = np.zeros((self.N, 1))
        y = np.zeros(L)
        for t in range(L):
            h = self.bar_A @ h + self.bar_B * x[t]
            y[t] = (self.C @ h).item()
        return y


def run_lti_ssm_demo():
    """运行 LTI-SSM 演示。"""
    np.random.seed(42)
    N = 32
    L = 512
    dt = 0.01

    # 构造稳定的 A 矩阵(特征值实部为负)
    A_raw = np.random.randn(N, N) * 0.5
    A = A_raw - 2.0 * np.eye(N)  # 确保稳定
    B = np.random.randn(N)
    C = np.random.randn(N)

    model = LTISSM(A, B, C, dt)

    # 验证两种模式的一致性
    x = np.random.randn(L)
    y_conv = model.forward_conv(x)
    y_recur = model.forward_recur(x)
    diff = np.max(np.abs(y_conv - y_recur))

    print("=" * 50)
    print("LTI-SSM 演示")
    print("=" * 50)
    print(f"  状态维度: {N}")
    print(f"  序列长度: {L}")
    print(f"  离散化步长: {dt}")
    print(f"  卷积 vs 递推最大差异: {diff:.2e}")

    K = model.kernel(L)
    print(f"\n  卷积核 K[0] = {K[0]:.6f}")
    print(f"  卷积核 K[L-1] = {K[-1]:.6f}")
    print(f"  ||K||_1 = {np.sum(np.abs(K)):.4f}")

    return model


if __name__ == "__main__":
    run_lti_ssm_demo()

15.2 HiPPO 矩阵构造

python 复制代码
"""
HiPPO (High-order Polynomial Projection Operators) 矩阵的完整实现。
包含: LegS, LegT, LagT 三种变体,以及数值验证。
"""

import numpy as np
from scipy.linalg import expm


def construct_hippo_legs(N: int) -> tuple[np.ndarray, np.ndarray]:
    """构造 HiPPO-LegS 矩阵。

    缩放勒让德测度: mu_t = (1/t) * 1_{[0,t]}

    A_{nk} = sqrt(2n+1) * sqrt(2k+1)  if n > k
             n + 1                       if n == k
             0                           if n < k

    B_n = sqrt(2n+1)

    Args:
        N: 状态维度

    Returns:
        A: (N, N)
        B: (N, 1)
    """
    A = np.zeros((N, N))
    for n in range(N):
        for k in range(N):
            if n > k:
                A[n, k] = np.sqrt(2 * n + 1) * np.sqrt(2 * k + 1)
            elif n == k:
                A[n, k] = n + 1

    B = np.sqrt(2 * np.arange(N) + 1).reshape(-1, 1).astype(np.float64)
    return A, B


def construct_hippo_legt(N: int, theta: float = 1.0) -> tuple[np.ndarray, np.ndarray]:
    """构造 HiPPO-LegT 矩阵。

    平移勒让德测度: mu_t = 1_{[t-theta, t]}

    Args:
        N: 状态维度
        theta: 滑动窗口宽度

    Returns:
        A: (N, N)
        B: (N, 1)
    """
    A = np.zeros((N, N))
    for n in range(N):
        for k in range(n + 1):
            if n > k:
                A[n, k] = (2 * n + 1) * (-1) ** (n - k) * np.sqrt(2 * k + 1) * np.sqrt(2 * n + 1) / theta
            elif n == k:
                A[n, k] = -(n + 1) / theta

    B = np.array([((-1) ** n) * np.sqrt(2 * n + 1) / theta for n in range(N)]).reshape(-1, 1)
    return A, B


def construct_hippo_lagt(N: int) -> tuple[np.ndarray, np.ndarray]:
    """构造 HiPPO-LagT 矩阵。

    拉盖尔测度: mu_t = e^{-(tau-t)} 1_{tau >= t}

    Args:
        N: 状态维度

    Returns:
        A: (N, N)
        B: (N, 1)
    """
    A = np.zeros((N, N))
    for n in range(N):
        A[n, n] = -0.5
        if n + 1 < N:
            A[n, n + 1] = np.sqrt(n + 1)

    B = np.array([np.sqrt(2 * n + 1) for n in range(N)]).reshape(-1, 1)
    return A, B


def verify_hippo_approximation(N: int = 16, L: int = 1000):
    """验证 HiPPO 矩阵的函数逼近能力。

    给定一个测试信号,验证 SSM 的隐状态确实编码了历史信息的最优逼近。
    """
    dt = 0.01
    t = np.arange(L) * dt

    # 测试信号:正弦波 + 噪声
    signal = np.sin(2 * np.pi * 0.5 * t) + 0.3 * np.sin(2 * np.pi * 2.0 * t)
    signal += np.random.randn(L) * 0.1

    # 用 HiPPO-LegS 构建 SSM
    A, B = construct_hippo_legs(N)
    C = np.eye(N)  # 读出所有状态分量

    # ZOH 离散化
    bar_A = expm(A * dt)
    bar_B = np.linalg.solve(A, (bar_A - np.eye(N)) @ B)

    # 递推
    h = np.zeros(N)
    reconstructions = []
    for t_idx in range(L):
        h = bar_A @ h + bar_B.flatten() * signal[t_idx]
        # 重建:用勒让德多项式的系数重建信号
        # 简化版:用状态的加权和作为重建
        recon = np.sum(h * np.sqrt(2 * np.arange(N) + 1))
        reconstructions.append(recon)

    reconstructions = np.array(reconstructions)

    # 计算误差
    mse = np.mean((signal - reconstructions) ** 2)
    correlation = np.corrcoef(signal, reconstructions)[0, 1]

    print(f"\nHiPPO-LegS 逼近验证 (N={N}, L={L})")
    print(f"  测试信号: 正弦波叠加 + 噪声")
    print(f"  MSE: {mse:.6f}")
    print(f"  相关系数: {correlation:.6f}")

    return signal, reconstructions


def run_hippo_demo():
    """运行 HiPPO 完整演示。"""
    print("=" * 60)
    print("HiPPO 矩阵演示")
    print("=" * 60)

    N = 8
    for name, func in [("LegS", construct_hippo_legs),
                        ("LegT", lambda N: construct_hippo_legt(N, 1.0)),
                        ("LagT", construct_hippo_lagt)]:
        A, B = func(N)
        eigvals = np.linalg.eigvals(A)
        print(f"\n--- HiPPO-{name} (N={N}) ---")
        print(f"  A 的特征值实部范围: [{np.min(np.real(eigvals)):.2f}, {np.max(np.real(eigvals)):.2f}]")
        print(f"  A 的特征值虚部范围: [{np.min(np.imag(eigvals)):.2f}, {np.max(np.imag(eigvals)):.2f}]")
        print(f"  A 的条件数: {np.linalg.cond(A):.2f}")

    # 逼近验证
    verify_hippo_approximation()


if __name__ == "__main__":
    run_hippo_demo()

15.3 S4D 完整实现(PyTorch)

python 复制代码
"""
S4D (Diagonal Structured State Space) 的完整 PyTorch 实现。
包含: S4D-Lin 初始化、训练模式(卷积)、推理模式(递推)。
"""

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class S4DKernel(nn.Module):
    """S4D 的卷积核计算。"""

    def __init__(self, N: int = 64, dt_min: float = 0.001, dt_max: float = 0.1, lr: float = 0.001):
        super().__init__()
        self.N = N

        # S4D-Lin 初始化
        Lambda_real = -(torch.arange(N, dtype=torch.float32) + 0.5)
        Lambda_imag = torch.zeros(N)
        self.Lambda_real = nn.Parameter(Lambda_real)
        self.Lambda_imag = nn.Parameter(Lambda_imag)

        self.B = nn.Parameter(torch.randn(N) / math.sqrt(N))
        self.C = nn.Parameter(torch.randn(N) / math.sqrt(N))

        log_dt = torch.rand(1) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)
        self.log_dt = nn.Parameter(log_dt)

    def forward(self, L: int) -> torch.Tensor:
        dt = torch.exp(self.log_dt)
        Lambda = torch.complex(self.Lambda_real, self.Lambda_imag)
        Lambda_bar = torch.exp(Lambda * dt)

        B_bar = torch.where(
            torch.abs(Lambda) > 1e-7,
            (Lambda_bar - 1.0) / Lambda * self.B,
            dt * self.B,
        )

        powers = torch.arange(L, dtype=torch.float32)
        Lambda_powers = Lambda_bar.unsqueeze(1) ** powers.unsqueeze(0)  # (N, L)
        CB = self.C * B_bar  # (N,)
        K = torch.sum(CB.unsqueeze(1) * Lambda_powers, dim=0)  # (L,)
        return K.real


class S4DLayer(nn.Module):
    """S4D 层,支持训练(卷积)和推理(递推)。"""

    def __init__(self, d_model: int, N: int = 64):
        super().__init__()
        self.d_model = d_model
        self.N = N
        self.kernels = nn.ModuleList([S4DKernel(N=N) for _ in range(d_model)])
        self.D = nn.Parameter(torch.ones(d_model))

    def forward_train(self, x: torch.Tensor) -> torch.Tensor:
        batch, L, d = x.shape
        K = torch.stack([k(L) for k in self.kernels], dim=0)  # (d, L)
        K_f = torch.fft.rfft(K, n=2 * L)
        X_f = torch.fft.rfft(x.permute(0, 2, 1), n=2 * L)  # (batch, d, L+1)
        Y_f = X_f * K_f.unsqueeze(0)
        y = torch.fft.irfft(Y_f, n=2 * L)[:, :, :L].permute(0, 2, 1)
        return y + self.D * x

    def forward_recur(self, x: torch.Tensor) -> torch.Tensor:
        batch, L, d = x.shape
        device = x.device

        bar_As, bar_Bs, Cs = [], [], []
        for i in range(d):
            k = self.kernels[i]
            dt = torch.exp(k.log_dt)
            Lambda = torch.complex(k.Lambda_real, k.Lambda_imag)
            Lambda_bar = torch.exp(Lambda * dt)
            B_bar = torch.where(
                torch.abs(Lambda) > 1e-7,
                (Lambda_bar - 1.0) / Lambda * k.B,
                dt * k.B,
            )
            bar_As.append(Lambda_bar)
            bar_Bs.append(B_bar)
            Cs.append(k.C)

        h = [torch.zeros(batch, self.N, dtype=torch.complex64, device=device) for _ in range(d)]
        outputs = []
        for t in range(L):
            y_t = []
            for i in range(d):
                h[i] = bar_As[i] * h[i] + bar_Bs[i] * x[:, t, i].unsqueeze(-1)
                y_i = torch.sum(Cs[i] * h[i], dim=-1).real + self.D[i] * x[:, t, i]
                y_t.append(y_i)
            outputs.append(torch.stack(y_t, dim=-1))
        return torch.stack(outputs, dim=1)


class SimpleMambaBlock(nn.Module):
    """简化的 Mamba Block(教学用)。"""

    def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2):
        super().__init__()
        d_inner = d_model * expand
        self.d_inner = d_inner

        self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
        self.conv1d = nn.Conv1d(d_inner, d_inner, d_conv, padding=d_conv - 1, groups=d_inner)
        self.ssm = S4DLayer(d_inner, d_state)
        self.out_proj = nn.Linear(d_inner, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        xz = self.in_proj(x)
        x_ssm, z = xz.chunk(2, dim=-1)

        x_conv = self.conv1d(x_ssm.transpose(1, 2))[:, :, :x.size(1)].transpose(1, 2)
        x_conv = F.silu(x_conv)
        y = self.ssm.forward_train(x_conv)
        y = y * F.silu(z)
        return self.out_proj(y)


def run_s4d_demo():
    """运行 S4D 完整演示。"""
    torch.manual_seed(42)

    batch, seq_len, d_model = 2, 256, 32
    N = 16

    layer = S4DLayer(d_model, N)
    x = torch.randn(batch, seq_len, d_model)

    y_train = layer.forward_train(x)
    y_recur = layer.forward_recur(x)
    diff = torch.max(torch.abs(y_train - y_recur)).item()

    print("=" * 60)
    print("S4D 层演示")
    print("=" * 60)
    print(f"  输入: {x.shape}")
    print(f"  状态维度: {N}")
    print(f"  卷积模式输出: mean={y_train.mean():.4f}, std={y_train.std():.4f}")
    print(f"  递推模式输出: mean={y_recur.mean():.4f}, std={y_recur.std():.4f}")
    print(f"  最大差异: {diff:.2e}")

    # Mamba Block 演示
    block = SimpleMambaBlock(d_model, d_state=N)
    y_block = block(x)
    print(f"\n  Mamba Block 输出: {y_block.shape}")
    print(f"  参数量: {sum(p.numel() for p in block.parameters()):,}")


if __name__ == "__main__":
    run_s4d_demo()

15.4 选择性 SSM 与 Mamba 完整实现

python 复制代码
"""
Mamba 的完整实现,包含选择性 SSM、Mamba Block 和完整的语言模型。
所有代码可直接运行。
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class SelectiveSSM(nn.Module):
    """选择性状态空间模型。"""

    def __init__(self, d_model: int, d_state: int = 16, dt_rank: int = None):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        dt_rank = dt_rank or max(16, d_model // 16)

        self.Lambda = nn.Parameter(-(torch.arange(d_state, dtype=torch.float32) + 0.5))
        self.B_proj = nn.Linear(d_model, d_state, bias=False)
        self.C_proj = nn.Linear(d_model, d_state, bias=False)
        self.dt_proj = nn.Linear(dt_rank, d_model, bias=True)
        self.dt_input_proj = nn.Linear(d_model, dt_rank, bias=False)
        self.D = nn.Parameter(torch.ones(d_model))

        nn.init.uniform_(self.dt_proj.weight, -0.5, 0.5)
        with torch.no_grad():
            self.dt_proj.bias.copy_(torch.log(torch.expm1(torch.tensor(0.1))))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch, L, D = x.shape
        N = self.d_state

        B = self.B_proj(x)
        C = self.C_proj(x)
        delta = F.softplus(self.dt_proj(self.dt_input_proj(x)))

        Lambda = self.Lambda.view(1, 1, 1, N)
        delta_e = delta.unsqueeze(-1)
        A_bar = torch.exp(Lambda * delta_e)
        B_bar = torch.where(torch.abs(Lambda) > 1e-7, (A_bar - 1) / Lambda * B.unsqueeze(2), delta_e * B.unsqueeze(2))

        h = torch.zeros(batch, D, N, device=x.device, dtype=x.dtype)
        outputs = []
        for t in range(L):
            h = A_bar[:, t] * h + B_bar[:, t] * x[:, t, :, None]
            y = torch.sum(C.unsqueeze(2)[:, t] * h, dim=-1) + self.D * x[:, t]
            outputs.append(y)
        return torch.stack(outputs, dim=1)


class MambaBlock(nn.Module):
    """Mamba Block。"""

    def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2):
        super().__init__()
        d_inner = d_model * expand
        self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
        self.conv1d = nn.Conv1d(d_inner, d_inner, d_conv, padding=d_conv - 1, groups=d_inner)
        self.ssm = SelectiveSSM(d_inner, d_state)
        self.out_proj = nn.Linear(d_inner, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        xz = self.in_proj(x)
        x_ssm, z = xz.chunk(2, dim=-1)
        x_conv = F.silu(self.conv1d(x_ssm.transpose(1, 2))[:, :, :x.size(1)].transpose(1, 2))
        y = self.ssm(x_conv) * F.silu(z)
        return self.out_proj(y)


class MambaLM(nn.Module):
    """Mamba 语言模型。"""

    def __init__(self, vocab_size: int, d_model: int = 128, n_layers: int = 4, d_state: int = 16):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([MambaBlock(d_model, d_state) for _ in range(n_layers)])
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, ids: torch.Tensor) -> torch.Tensor:
        x = self.embed(ids)
        for layer in self.layers:
            x = x + layer(x)
            x = self.norm(x)
        return self.head(x)


def train_mamba_lm():
    """训练 Mamba 语言模型(字符级)。"""
    torch.manual_seed(42)

    # 准备数据:简单的重复模式
    text = "hello world " * 100
    chars = sorted(set(text))
    vocab_size = len(chars)
    stoi = {c: i for i, c in enumerate(chars)}
    data = torch.tensor([stoi[c] for c in text], dtype=torch.long)

    # 模型
    model = MambaLM(vocab_size=vocab_size, d_model=32, n_layers=2, d_state=8)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    print("=" * 60)
    print("Mamba 语言模型训练")
    print("=" * 60)
    print(f"  词汇表大小: {vocab_size}")
    print(f"  数据长度: {len(data)}")
    print(f"  参数量: {sum(p.numel() for p in model.parameters()):,}")

    seq_len = 32
    batch_size = 8

    for step in range(50):
        # 随机采样 batch
        idx = torch.randint(0, len(data) - seq_len - 1, (batch_size,))
        x = torch.stack([data[i:i + seq_len] for i in idx])
        y = torch.stack([data[i + 1:i + seq_len + 1] for i in idx])

        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 10 == 0:
            print(f"  Step {step:3d}: loss = {loss.item():.4f}")

    # 生成
    print("\n生成示例:")
    model.eval()
    with torch.no_grad():
        prompt = "hello"
        ids = torch.tensor([[stoi[c] for c in prompt]], dtype=torch.long)
        for _ in range(30):
            logits = model(ids)
            next_id = logits[:, -1].argmax(dim=-1, keepdim=True)
            ids = torch.cat([ids, next_id], dim=1)
        generated = "".join(chars[i] for i in ids[0].tolist())
        print(f"  '{generated}'")


if __name__ == "__main__":
    train_mamba_lm()

15.5 并行扫描算法

python 复制代码
"""
并行扫描 (Parallel Scan) 的实现。
用于高效并行计算 SSM 的递推。
"""

import numpy as np


def sequential_scan(A: np.ndarray, B: np.ndarray, x: np.ndarray) -> np.ndarray:
    """朴素的串行扫描。

    h_t = A_t * h_{t-1} + B_t * x_t

    Args:
        A: (L, N) 每步的 A
        B: (L, N) 每步的 B
        x: (L,) 输入

    Returns:
        h: (L, N) 隐状态序列
    """
    L, N = A.shape
    h = np.zeros((L, N))
    for t in range(L):
        if t == 0:
            h[t] = B[t] * x[t]
        else:
            h[t] = A[t] * h[t - 1] + B[t] * x[t]
    return h


def parallel_scan(A: np.ndarray, B: np.ndarray, x: np.ndarray) -> np.ndarray:
    """并行扫描算法。

    使用结合律将串行递推转化为可并行计算的前缀积。

    Args:
        A: (L, N)
        B: (L, N)
        x: (L,)

    Returns:
        h: (L, N)
    """
    L, N = A.shape

    # 将递推转化为半环上的前缀积
    # 每个元素表示为 (a, b): h_t = a * h_{t-1} + b
    # 其中 b = B_t * x_t
    a = A.copy()
    b = B * x[:, None]

    # 并行前缀积(Blelloch 算法)
    # 上扫 (up-sweep)
    stride = 1
    while stride < L:
        for i in range(stride - 1, L - 1, 2 * stride):
            # 合并: (a2, b2) o (a1, b1) = (a2*a1, a2*b1 + b2)
            left = i - stride + 1
            right = i + 1
            a_new = a[right] * a[left]
            b_new = a[right] * b[left] + b[right]
            a[right] = a_new
            b[right] = b_new
        stride *= 2

    # 下扫 (down-sweep)
    stride //= 2
    while stride >= 1:
        for i in range(stride - 1, L - 1, 2 * stride):
            left = i - stride + 1
            right = i + 1
            # 传播
            a_new = a[left] * a[right]  # 不对,这里是修改 left
            b_new = a[left] * b[right] + b[left]
            # 修正:下扫时需要交换
            temp_a = a[right]
            temp_b = b[right]
            a[right] = a[left] * temp_a
            b[right] = a[left] * temp_b + b[left]
        stride //= 2

    return b  # b[t] 就是 h_t


def verify_parallel_scan():
    """验证并行扫描的正确性。"""
    np.random.seed(42)
    L = 16
    N = 4

    A = np.random.randn(L, N) * 0.5
    B = np.random.randn(L, N)
    x = np.random.randn(L)

    h_seq = sequential_scan(A, B, x)
    h_par = parallel_scan(A, B, x)

    diff = np.max(np.abs(h_seq - h_par))

    print("=" * 50)
    print("并行扫描验证")
    print("=" * 50)
    print(f"  序列长度: {L}")
    print(f"  状态维度: {N}")
    print(f"  串行 vs 并行最大差异: {diff:.2e}")
    print(f"  通过: {'是' if diff < 1e-10 else '否'}")

    # 展示结果
    print(f"\n  h[0] (串行): {h_seq[0]}")
    print(f"  h[0] (并行): {h_par[0]}")
    print(f"  h[L-1] (串行): {h_seq[-1]}")
    print(f"  h[L-1] (并行): {h_par[-1]}")


if __name__ == "__main__":
    verify_parallel_scan()

15.6 综合对比实验

python 复制代码
"""
综合对比实验: SSM vs Transformer 在不同任务上的表现。
"""

import numpy as np
import time


def generate_copy_task(seq_len: int, vocab_size: int = 10) -> tuple[np.ndarray, np.ndarray]:
    """生成"复制"任务数据。

    输入: [x1, x2, ..., xn, <sep>, 0, 0, ..., 0]
    输出: [0, 0, ..., 0, <sep>, x1, x2, ..., xn]
    """
    n = seq_len // 2
    data = np.random.randint(1, vocab_size, size=n)
    sep = np.array([vocab_size])  # 分隔符

    x = np.concatenate([data, sep, np.zeros(seq_len - n - 1, dtype=int)])
    y = np.concatenate([np.zeros(n, dtype=int), sep, data, np.zeros(seq_len - 2 * n - 1, dtype=int)])
    return x, y


def generate_adding_task(seq_len: int) -> tuple[np.ndarray, np.ndarray]:
    """生成"加法"任务。

    输入: (x, mask), 其中 x 是随机数,mask 标记两个位置
    输出: 标记的两个位置的 x 值之和
    """
    x = np.random.randn(seq_len)
    mask = np.zeros(seq_len)
    positions = np.random.choice(seq_len, 2, replace=False)
    mask[positions] = 1.0
    target = x[positions].sum()
    return np.stack([x, mask], axis=-1), target


def ssm_forward(A: np.ndarray, B: np.ndarray, C: np.ndarray, x: np.ndarray) -> np.ndarray:
    """SSM 递推前向。"""
    L = len(x)
    N = A.shape[0]
    h = np.zeros(N)
    y = np.zeros(L)
    for t in range(L):
        h = A * h + B * x[t]
        y[t] = C @ h
    return y


def run_comparison():
    """运行综合对比。"""
    np.random.seed(42)

    print("=" * 70)
    print("SSM 综合性能对比")
    print("=" * 70)

    # 不同序列长度下的计算时间
    N = 32
    lengths = [64, 128, 256, 512, 1024]

    print("\n1. 计算时间对比 (SSM 递推)")
    print(f"  {'Length':>8} {'Time (ms)':>12} {'Per-step (us)':>15}")
    print(f"  {'-'*8} {'-'*12} {'-'*15}")

    for L in lengths:
        A = np.random.randn(N) * 0.8 - 1.0
        B = np.random.randn(N)
        C = np.random.randn(N)
        x = np.random.randn(L)

        start = time.perf_counter()
        for _ in range(100):
            y = ssm_forward(A, B, C, x)
        elapsed = (time.perf_counter() - start) / 100

        print(f"  {L:>8} {elapsed*1000:>12.3f} {elapsed/L*1e6:>15.3f}")

    # 长程依赖测试
    print("\n2. 长程依赖保持能力")
    N_list = [8, 16, 32, 64, 128]
    L = 1000

    print(f"  {'N':>6} {'Kernel Energy (tail)':>25} {'Effective Memory':>20}")
    print(f"  {'-'*6} {'-'*25} {'-'*20}")

    for N in N_list:
        A = -(np.arange(N, dtype=float) + 0.5)  # HiPPO-LegS 特征值
        B = np.random.randn(N) / np.sqrt(N)
        C = np.random.randn(N) / np.sqrt(N)

        K = np.zeros(L)
        v = B.copy()
        for n in range(L):
            K[n] = C @ v
            v = np.exp(A * 0.01) * v

        tail_energy = np.sum(K[L // 2:] ** 2) / np.sum(K ** 2)
        # 有效记忆长度: 核衰减到 1/e 的位置
        threshold = np.abs(K[0]) / np.e
        eff_mem = np.argmax(np.abs(K) < threshold) if np.any(np.abs(K) < threshold) else L

        print(f"  {N:>6} {tail_energy:>25.6f} {eff_mem:>20}")

    print("\n3. 结论")
    print("  - SSM 的计算成本随序列长度线性增长")
    print("  - 更大的状态维度 N 提供更强的长程记忆能力")
    print("  - HiPPO 初始化确保了合理的初始记忆行为")


if __name__ == "__main__":
    run_comparison()

第十六章:实验、应用与未来方向

16.1 语言建模

16.1.1 Scaling Law

Mamba 在语言建模上的 scaling law 表现:

模型参数 Transformer PPL Mamba PPL 差异
125M 16.2 15.8 Mamba 更优
350M 12.5 12.1 Mamba 更优
1.4B 9.8 9.5 Mamba 更优
2.8B 8.7 8.6 相当

在相同参数量下,Mamba 在困惑度(perplexity)上与 Transformer 持平或略优,但在推理速度上有 2-5 倍的优势。

16.1.2 长上下文优势

在需要长上下文的任务中(如文档问答、长篇摘要),Mamba 的优势更加明显:

  • 128K 上下文:Mamba 的推理内存是 Transformer 的约 1/100
  • 1M 上下文:Transformer 几乎不可能运行,Mamba 仍然可行

16.2 计算机视觉

16.2.1 Vision Mamba (Vim)

将 Mamba 应用于视觉的方案:

  1. 图像展平 :将 H×WH \times WH×W 的图像展平为长度为 HWHWHW 的序列
  2. 双向扫描:正向和反向各扫描一次,拼接结果
  3. 位置编码:添加 2D 位置信息

16.2.2 视觉任务性能

模型 ImageNet Top-1 吞吐量 (img/s)
ViT-B 81.8 950
DeiT-B 83.4 830
Vim-S 83.1 1050
Mamba-Vision-T 83.5 1100

Mamba 在保持竞争力的同时,吞吐量提高了约 15-30%。

16.3 音频与语音

SSM 天然适合音频处理:

  • 采样率适配 :Δ\DeltaΔ 可以根据音频的采样率自动调整
  • 长序列:16kHz 采样率下,1 分钟的音频有 960,000 个样本点,SSM 的线性复杂度是关键优势
  • 多尺度:不同频率的声音可以用不同衰减率的模态来捕捉

16.4 时间序列预测

SSM 在时间序列预测上的优势:

  1. 连续时间基础:自然处理不规则采样的时间序列
  2. 长程依赖:捕捉跨越数月甚至数年的周期性模式
  3. 高效推理:实时预测时只需要常数内存

16.5 混合架构的实际应用

16.5.1 Jamba (AI21, 2024)

Jamba 是第一个大规模的 SSM-Attention 混合架构:

  • 交替使用 Mamba 层和 Attention 层(比例 7:1)
  • 使用 MoE(Mixture of Experts)进一步扩展参数量
  • 在 256K 上下文长度下训练

16.5.2 Zamba (Zyphra, 2024)

Zamba 采用了更激进的混合策略:

  • 共享的 Mamba 骨干
  • 稀疏的 Attention 层插入
  • 实现了极高的参数效率

16.6 未来方向

16.6.1 理论方向

  1. 选择性 SSM 的泛化理论:选择性机制如何影响泛化能力?是否存在"最优"的选择策略?

  2. SSM 与 Transformer 的统一框架:是否存在一个更一般的框架,将两者作为特例?

  3. 信息论极限:SSM 的压缩效率是否已达到信息论的下界?

16.6.2 工程方向

  1. 硬件原生支持:为 SSM 设计专用的硬件指令(如 TPU/XLA 的 scan 原语)

  2. 量化与蒸馏:如何将大型 SSM 压缩到边缘设备?

  3. 分布式训练:超长序列的分布式训练策略

16.6.3 应用方向

  1. 多模态 SSM:统一处理文本、图像、音频的 SSM 架构

  2. 强化学习中的 SSM:作为世界模型的核心组件

  3. 科学计算中的 SSM:求解偏微分方程、气候模拟等


附录

A. 数学符号表

符号 含义 维度
hth_tht 隐状态 RN\mathbb{R}^NRN
xtx_txt 输入 RD\mathbb{R}^DRD
yty_tyt 输出 RD\mathbb{R}^DRD
AAA 连续状态矩阵 RN×N\mathbb{R}^{N \times N}RN×N
BBB 输入矩阵 RN×D\mathbb{R}^{N \times D}RN×D
CCC 输出矩阵 RD×N\mathbb{R}^{D \times N}RD×N
Aˉ\bar{A}Aˉ 离散状态矩阵 RN×N\mathbb{R}^{N \times N}RN×N
Bˉ\bar{B}Bˉ 离散输入矩阵 RN×D\mathbb{R}^{N \times D}RN×D
Δ\DeltaΔ 离散化步长 R>0\mathbb{R}_{>0}R>0
KnK_nKn 卷积核第 nnn 个元素 RD\mathbb{R}^DRD
LLL 序列长度 N\mathbb{N}N
NNN 状态维度 N\mathbb{N}N
DDD 输入/输出维度 N\mathbb{N}N

B. 关键公式速查

连续 SSM

h˙(t)=Ah(t)+Bx(t),y(t)=Ch(t)\dot{h}(t) = Ah(t) + Bx(t), \quad y(t) = Ch(t)h˙(t)=Ah(t)+Bx(t),y(t)=Ch(t)

ZOH 离散化

Aˉ=eAΔ,Bˉ=A−1(Aˉ−I)B\bar{A} = e^{A\Delta}, \quad \bar{B} = A^{-1}(\bar{A} - I)BAˉ=eAΔ,Bˉ=A−1(Aˉ−I)B

离散递推

ht=Aˉht−1+Bˉxt,yt=Chth_t = \bar{A}h_{t-1} + \bar{B}x_t, \quad y_t = Ch_tht=Aˉht−1+Bˉxt,yt=Cht

卷积核

Kn=CAˉnBˉK_n = C\bar{A}^n\bar{B}Kn=CAˉnBˉ

生成函数

K^(z)=C(I−zAˉ)−1Bˉ\hat{K}(z) = C(I - z\bar{A})^{-1}\bar{B}K^(z)=C(I−zAˉ)−1Bˉ

选择性离散化

Aˉt=eAΔt,Bˉt=eAΔt−IABt\bar{A}_t = e^{A\Delta_t}, \quad \bar{B}_t = \frac{e^{A\Delta_t} - I}{A} B_tAˉt=eAΔt,Bˉt=AeAΔt−IBt

HiPPO-LegS 矩阵

Ank={(2n+1)(2k+1)n>kn+1n=k0n<kA_{nk} = \begin{cases} \sqrt{(2n+1)(2k+1)} & n > k \\ n+1 & n = k \\ 0 & n < k \end{cases}Ank=⎩ ⎨ ⎧(2n+1)(2k+1) n+10n>kn=kn<k

C. 参考文献

  1. Gu, A., Dao, T., Ermon, S., Rudra, A., & Ré, C. (2020). HiPPO: Recurrent Memory with Optimal Polynomial Projections. NeurIPS 2020.

  2. Gu, A., Goel, K., & Ré, C. (2022). Efficiently Modeling Long Sequences with Structured State Spaces. ICLR 2022. (S4)

  3. Gu, A., Gupta, A., Goel, K., & Ré, C. (2022). On the Parameterization and Initialization of Diagonal State Space Models. NeurIPS 2022. (S4D)

  4. Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv:2312.00752.

  5. Dao, T., & Gu, A. (2024). Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality. ICML 2024. (Mamba-2)

  6. Vaswani, A., et al. (2017). Attention Is All You Need. NeurIPS 2017.

  7. Hochreiter, S., & Schmidhuber, J. (1997). Long Short-Term Memory. Neural Computation.

  8. Smith, J. T. H., Warrington, A., & Linderman, S. W. (2023). Simplified State Space Layers for Sequence Modeling. ICLR 2023. (S5)

  9. Tay, Y., et al. (2021). Long Range Arena: A Benchmark for Efficient Transformers. ICLR 2021.

  10. Poli, M., Massaroli, S., Nguyen, E., Fu, D. Y., Dao, T., Baccus, S., Bengio, Y., Ermon, S., & Ré, C. (2023). Hyena Hierarchy: Towards Larger Convolutional Language Models. ICML 2023.


涵盖了状态空间模型从经典控制论到现代深度学习的完整理论体系。代码均使用 NumPy 或 PyTorch 实现,可直接运行。

相关推荐
财经资讯数据_灵砚智能11 小时前
基于全球经济类多源新闻的NLP情感分析与数据可视化(日间)2026年5月28日
大数据·人工智能·python·信息可视化·自然语言处理·ai编程·灵砚智能
weixin_4684668511 小时前
基于OpenCV的工业相机标定技术实战
图像处理·人工智能·opencv·计算机视觉·相机标定·机器视觉·工业相机
徐安安ye11 小时前
FlashAttention输出全是NaN?数值问题排查指南
人工智能·深度学习·机器学习
架构源启11 小时前
Spring AI 进阶篇(12)-边缘计算与离线部署:模型量化、本地推理与隐私保护实战
人工智能·spring·边缘计算
Ricky055311 小时前
YOLO-FCE:一种基于特征与聚类增强的物种分类目标检测模型(澳大利亚2026年研究)
图像处理·人工智能·yolo·目标检测·分类
学习中.........11 小时前
大语言模型的推理机制与工程应用
人工智能·语言模型·自然语言处理
一切皆是因缘际会11 小时前
从模型竞赛到全域智能的时代跃迁
人工智能·深度学习·ai·分布式系统
2601_9578885611 小时前
流量终局与信源争夺:GEO(生成式引擎优化)时代的爬虫分析与数据管道构建
人工智能·爬虫
名不经传的养虾人11 小时前
从0到1:企业级AI项目迭代日记 Vol.35|追问比演示重要——技术团队问出的五个工程缺口
人工智能·算法·机器学习·ai编程·ai工作流·企业ai