Mamba 状态空间模型:Transformer 的高效替代架构

Mamba 状态空间模型:Transformer 的高效替代架构

1. 引言

Transformer 的自注意力机制计算复杂度为 O(n²),在处理长序列时成为瓶颈。Mamba(2023)基于选择性状态空间模型(Selective SSM),实现了 O(n) 的线性复杂度,同时在语言建模任务上匹配甚至超越 Transformer。

核心创新:

  • 选择性机制:让模型根据输入动态调整状态转移参数
  • 硬件感知算法:利用 GPU 的 SRAM 和并行扫描实现高效推理
  • 线性复杂度:O(n) 时间和 O(1) 推理内存

2. 状态空间模型基础

2.1 SSM 数学形式

复制代码
连续时间 SSM:
  h'(t) = A·h(t) + B·x(t)      # 状态方程
  y(t)  = C·h(t)                # 输出方程

离散化后:
  h_k = Ā·h_{k-1} + B̄·x_k     # 递推
  y_k = C·h_k                   # 输出

其中:
  Ā = exp(Δ·A)                  # 状态转移矩阵
  B̄ = (Δ·A)^{-1}(Ā - I)·Δ·B   # 输入矩阵

2.2 SSM 的两种计算模式

python 复制代码
# 递推模式(Recurrent)- 推理时使用,O(1) 内存
def ssm_recurrent(x, A_bar, B_bar, C):
    """逐 token 递推"""
    h = torch.zeros(batch, d_state)  # 初始状态
    outputs = []
    for x_t in x:
        h = A_bar * h + B_bar * x_t  # 状态更新
        y_t = C @ h                   # 输出
        outputs.append(y_t)
    return torch.stack(outputs)

# 卷积模式(Convolutional)- 训练时使用,可并行
def ssm_convolutional(x, A_bar, B_bar, C):
    """全局卷积计算"""
    kernel = []
    for i in range(seq_len):
        kernel.append(C @ matrix_power(A_bar, i) @ B_bar)
    return conv1d(x, kernel)  # 一次卷积完成

3. Mamba 的选择性机制

3.1 核心思想

传统 SSM 的参数 A、B、Δ 是固定的,无法根据输入内容调整。Mamba 让这些参数依赖于输入

python 复制代码
class SelectiveSSM(nn.Module):
    """Mamba 的选择性 SSM 层"""

    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_inner = d_model * expand

        # 输入投影
        self.in_proj = nn.Linear(d_model, self.d_inner * 2)

        # 1D 卷积(局部特征提取)
        self.conv1d = nn.Conv1d(
            self.d_inner, self.d_inner, d_conv,
            padding=d_conv-1, groups=self.d_inner
        )

        # SSM 参数投影(依赖输入!)
        self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1)  # B, C, Δ
        self.dt_proj = nn.Linear(d_state, self.d_inner)

        # A 参数(对数空间初始化)
        A = torch.arange(1, d_state + 1).float().unsqueeze(0).expand(self.d_inner, -1)
        self.A_log = nn.Parameter(torch.log(A))

        # D 参数(跳跃连接)
        self.D = nn.Parameter(torch.ones(self.d_inner))

        # 输出投影
        self.out_proj = nn.Linear(self.d_inner, d_model)

    def forward(self, x):
        """x: (batch, seq_len, d_model)"""
        B, L, _ = x.shape

        # 双分支:x 分支和 z 分支(门控)
        xz = self.in_proj(x)  # (B, L, 2*d_inner)
        x_branch, z = xz.chunk(2, dim=-1)

        # 1D 卷积
        x_branch = x_branch.transpose(1, 2)  # (B, d_inner, L)
        x_branch = self.conv1d(x_branch)[:, :, :L]
        x_branch = x_branch.transpose(1, 2)  # (B, L, d_inner)
        x_branch = F.silu(x_branch)

        # 选择性参数生成
        x_proj = self.x_proj(x_branch)  # (B, L, 2*d_state + 1)
        B_param, C_param, delta = x_proj.split(
            [self.d_state, self.d_state, 1], dim=-1
        )

        # Δ 通过 softplus 确保正值
        delta = F.softplus(self.dt_proj(delta))  # (B, L, d_inner)

        # A 参数(负值确保稳定性)
        A = -torch.exp(self.A_log)  # (d_inner, d_state)

        # 选择性扫描
        y = self.selective_scan(x_branch, delta, A, B_param, C_param)

        # 门控 + 跳跃连接
        y = y * F.silu(z) + self.D * x_branch

        return self.out_proj(y)

    def selective_scan(self, u, delta, A, B, C):
        """硬件感知的选择性扫描"""
        batch, seq_len, d_inner = u.shape
        d_state = A.shape[1]

        # 离散化
        delta_A = torch.exp(delta.unsqueeze(-1) * A)  # (B, L, d_inner, d_state)
        delta_B = delta.unsqueeze(-1) * B.unsqueeze(2)  # (B, L, d_inner, d_state)

        # 递推扫描
        h = torch.zeros(batch, d_inner, d_state, device=u.device)
        ys = []

        for i in range(seq_len):
            h = delta_A[:, i] * h + delta_B[:, i] * u[:, i].unsqueeze(-1)
            y = (h * C[:, i].unsqueeze(1)).sum(dim=-1)  # (B, d_inner)
            ys.append(y)

        return torch.stack(ys, dim=1)  # (B, L, d_inner)

4. Mamba Block 完整实现

python 复制代码
class MambaBlock(nn.Module):
    """完整的 Mamba 块"""

    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.ssm = SelectiveSSM(d_model, d_state, d_conv, expand)

    def forward(self, x):
        return x + self.ssm(self.norm(x))  # 残差连接


class MambaModel(nn.Module):
    """Mamba 语言模型"""

    def __init__(self, vocab_size, d_model=768, n_layers=24, d_state=16):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            MambaBlock(d_model, d_state) for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, input_ids):
        x = self.embedding(input_ids)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return self.lm_head(x)

5. Mamba vs Transformer

特性 Transformer Mamba
计算复杂度 O(n²) O(n)
推理内存 O(n) KV Cache O(1) 固定状态
长序列 (64K+) 显存爆炸 线性增长
并行训练 完全并行 选择性扫描并行
上下文学习 较弱
语言建模 匹配(1.4B以下)

推理速度对比

序列长度 Transformer (7B) Mamba (7B)
1K 45 tokens/s 52 tokens/s
8K 22 tokens/s 48 tokens/s
32K 6 tokens/s 42 tokens/s
128K OOM 38 tokens/s

6. 训练 Mamba 模型

python 复制代码
from torch.optim import AdamW
from torch.cuda.amp import GradScaler, autocast

model = MambaModel(vocab_size=32000, d_model=768, n_layers=24).cuda()
optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
scaler = GradScaler()

for epoch in range(50):
    for batch in dataloader:
        input_ids = batch["input_ids"].cuda()
        labels = batch["labels"].cuda()

        with autocast():
            logits = model(input_ids)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                labels.view(-1)
            )

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

7. Mamba 2 与后续发展

7.1 Mamba 2 改进

复制代码
Mamba 2 核心改进:
1. 结构化状态空间对偶性 (SSD):将 SSM 与注意力统一
2. 多头结构:类似多头注意力的多头 SSM
3. 更高效的硬件实现:利用矩阵乘法而非逐元素操作
4. 性能提升 2-8x

7.2 混合架构

python 复制代码
class MambaTransformerBlock(nn.Module):
    """Mamba + Transformer 混合块"""

    def __init__(self, d_model, use_mamba=True):
        super().__init__()
        if use_mamba:
            self.block = MambaBlock(d_model)
        else:
            self.block = TransformerBlock(d_model)

    def forward(self, x):
        return self.block(x)

# 常见策略:底层用 Mamba(高效处理长序列),顶层用 Transformer(强上下文学习)
layers = [MambaTransformerBlock(d_model, use_mamba=(i < 16)) for i in range(24)]

8. 总结

Mamba 的核心贡献:

  1. 选择性机制:让 SSM 根据输入动态调整,解决了传统 SSM 的"内容无关"问题
  2. 线性复杂度:O(n) 计算和 O(1) 推理内存,突破 Transformer 的序列长度限制
  3. 硬件感知设计:充分利用 GPU 内存层级,实现高效训练和推理
  4. 混合架构:Mamba + Transformer 结合两者优势,是未来的重要方向

Mamba 特别适合超长序列场景(128K+ tokens),在这些场景下 Transformer 因显存限制无法工作。