Mamba 状态空间模型:Transformer 的高效替代架构
1. 引言
Transformer 的自注意力机制计算复杂度为 O(n²),在处理长序列时成为瓶颈。Mamba(2023)基于选择性状态空间模型(Selective SSM),实现了 O(n) 的线性复杂度,同时在语言建模任务上匹配甚至超越 Transformer。
核心创新:
- 选择性机制:让模型根据输入动态调整状态转移参数
- 硬件感知算法:利用 GPU 的 SRAM 和并行扫描实现高效推理
- 线性复杂度:O(n) 时间和 O(1) 推理内存
2. 状态空间模型基础
2.1 SSM 数学形式
连续时间 SSM:
h'(t) = A·h(t) + B·x(t) # 状态方程
y(t) = C·h(t) # 输出方程
离散化后:
h_k = Ā·h_{k-1} + B̄·x_k # 递推
y_k = C·h_k # 输出
其中:
Ā = exp(Δ·A) # 状态转移矩阵
B̄ = (Δ·A)^{-1}(Ā - I)·Δ·B # 输入矩阵
2.2 SSM 的两种计算模式
python
# 递推模式(Recurrent)- 推理时使用,O(1) 内存
def ssm_recurrent(x, A_bar, B_bar, C):
"""逐 token 递推"""
h = torch.zeros(batch, d_state) # 初始状态
outputs = []
for x_t in x:
h = A_bar * h + B_bar * x_t # 状态更新
y_t = C @ h # 输出
outputs.append(y_t)
return torch.stack(outputs)
# 卷积模式(Convolutional)- 训练时使用,可并行
def ssm_convolutional(x, A_bar, B_bar, C):
"""全局卷积计算"""
kernel = []
for i in range(seq_len):
kernel.append(C @ matrix_power(A_bar, i) @ B_bar)
return conv1d(x, kernel) # 一次卷积完成
3. Mamba 的选择性机制
3.1 核心思想
传统 SSM 的参数 A、B、Δ 是固定的,无法根据输入内容调整。Mamba 让这些参数依赖于输入:
python
class SelectiveSSM(nn.Module):
"""Mamba 的选择性 SSM 层"""
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_inner = d_model * expand
# 输入投影
self.in_proj = nn.Linear(d_model, self.d_inner * 2)
# 1D 卷积(局部特征提取)
self.conv1d = nn.Conv1d(
self.d_inner, self.d_inner, d_conv,
padding=d_conv-1, groups=self.d_inner
)
# SSM 参数投影(依赖输入!)
self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1) # B, C, Δ
self.dt_proj = nn.Linear(d_state, self.d_inner)
# A 参数(对数空间初始化)
A = torch.arange(1, d_state + 1).float().unsqueeze(0).expand(self.d_inner, -1)
self.A_log = nn.Parameter(torch.log(A))
# D 参数(跳跃连接)
self.D = nn.Parameter(torch.ones(self.d_inner))
# 输出投影
self.out_proj = nn.Linear(self.d_inner, d_model)
def forward(self, x):
"""x: (batch, seq_len, d_model)"""
B, L, _ = x.shape
# 双分支:x 分支和 z 分支(门控)
xz = self.in_proj(x) # (B, L, 2*d_inner)
x_branch, z = xz.chunk(2, dim=-1)
# 1D 卷积
x_branch = x_branch.transpose(1, 2) # (B, d_inner, L)
x_branch = self.conv1d(x_branch)[:, :, :L]
x_branch = x_branch.transpose(1, 2) # (B, L, d_inner)
x_branch = F.silu(x_branch)
# 选择性参数生成
x_proj = self.x_proj(x_branch) # (B, L, 2*d_state + 1)
B_param, C_param, delta = x_proj.split(
[self.d_state, self.d_state, 1], dim=-1
)
# Δ 通过 softplus 确保正值
delta = F.softplus(self.dt_proj(delta)) # (B, L, d_inner)
# A 参数(负值确保稳定性)
A = -torch.exp(self.A_log) # (d_inner, d_state)
# 选择性扫描
y = self.selective_scan(x_branch, delta, A, B_param, C_param)
# 门控 + 跳跃连接
y = y * F.silu(z) + self.D * x_branch
return self.out_proj(y)
def selective_scan(self, u, delta, A, B, C):
"""硬件感知的选择性扫描"""
batch, seq_len, d_inner = u.shape
d_state = A.shape[1]
# 离散化
delta_A = torch.exp(delta.unsqueeze(-1) * A) # (B, L, d_inner, d_state)
delta_B = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, d_inner, d_state)
# 递推扫描
h = torch.zeros(batch, d_inner, d_state, device=u.device)
ys = []
for i in range(seq_len):
h = delta_A[:, i] * h + delta_B[:, i] * u[:, i].unsqueeze(-1)
y = (h * C[:, i].unsqueeze(1)).sum(dim=-1) # (B, d_inner)
ys.append(y)
return torch.stack(ys, dim=1) # (B, L, d_inner)
4. Mamba Block 完整实现
python
class MambaBlock(nn.Module):
"""完整的 Mamba 块"""
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
super().__init__()
self.norm = nn.LayerNorm(d_model)
self.ssm = SelectiveSSM(d_model, d_state, d_conv, expand)
def forward(self, x):
return x + self.ssm(self.norm(x)) # 残差连接
class MambaModel(nn.Module):
"""Mamba 语言模型"""
def __init__(self, vocab_size, d_model=768, n_layers=24, d_state=16):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList([
MambaBlock(d_model, d_state) for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, input_ids):
x = self.embedding(input_ids)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
return self.lm_head(x)
5. Mamba vs Transformer
| 特性 | Transformer | Mamba |
|---|---|---|
| 计算复杂度 | O(n²) | O(n) |
| 推理内存 | O(n) KV Cache | O(1) 固定状态 |
| 长序列 (64K+) | 显存爆炸 | 线性增长 |
| 并行训练 | 完全并行 | 选择性扫描并行 |
| 上下文学习 | 强 | 较弱 |
| 语言建模 | 强 | 匹配(1.4B以下) |
推理速度对比
| 序列长度 | Transformer (7B) | Mamba (7B) |
|---|---|---|
| 1K | 45 tokens/s | 52 tokens/s |
| 8K | 22 tokens/s | 48 tokens/s |
| 32K | 6 tokens/s | 42 tokens/s |
| 128K | OOM | 38 tokens/s |
6. 训练 Mamba 模型
python
from torch.optim import AdamW
from torch.cuda.amp import GradScaler, autocast
model = MambaModel(vocab_size=32000, d_model=768, n_layers=24).cuda()
optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
scaler = GradScaler()
for epoch in range(50):
for batch in dataloader:
input_ids = batch["input_ids"].cuda()
labels = batch["labels"].cuda()
with autocast():
logits = model(input_ids)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1)
)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
7. Mamba 2 与后续发展
7.1 Mamba 2 改进
Mamba 2 核心改进:
1. 结构化状态空间对偶性 (SSD):将 SSM 与注意力统一
2. 多头结构:类似多头注意力的多头 SSM
3. 更高效的硬件实现:利用矩阵乘法而非逐元素操作
4. 性能提升 2-8x
7.2 混合架构
python
class MambaTransformerBlock(nn.Module):
"""Mamba + Transformer 混合块"""
def __init__(self, d_model, use_mamba=True):
super().__init__()
if use_mamba:
self.block = MambaBlock(d_model)
else:
self.block = TransformerBlock(d_model)
def forward(self, x):
return self.block(x)
# 常见策略:底层用 Mamba(高效处理长序列),顶层用 Transformer(强上下文学习)
layers = [MambaTransformerBlock(d_model, use_mamba=(i < 16)) for i in range(24)]
8. 总结
Mamba 的核心贡献:
- 选择性机制:让 SSM 根据输入动态调整,解决了传统 SSM 的"内容无关"问题
- 线性复杂度:O(n) 计算和 O(1) 推理内存,突破 Transformer 的序列长度限制
- 硬件感知设计:充分利用 GPU 内存层级,实现高效训练和推理
- 混合架构:Mamba + Transformer 结合两者优势,是未来的重要方向
Mamba 特别适合超长序列场景(128K+ tokens),在这些场景下 Transformer 因显存限制无法工作。