MoM (Mixture-of-Memories)新型线性序列建模架构
1. 论文讲了什么?
这篇论文旨在解决当前线性序列模型 (如 Linear Attention, Mamba/SSM, Linear RNNs)存在的一个核心缺陷:记忆干扰(Memory Interference) 。
-
背景痛点:
-
传统的 Transformer 效果好,但推理成本高(KV Cache 显存占用大),复杂度是 N 2 N^2 N2
-
现有的 线性模型 (如 Mamba, RWKV, RetNet)虽然训练快、推理省内存,但它们通常将整个输入序列压缩进 一个固定大小的记忆状态(Memory State) 中 。
-
后果:当新信息进入时,旧信息容易被覆盖或干扰,导致模型在需要"回忆"大量细节的任务(Recall-Intensive Tasks)上表现不佳 。
-
MoM 的核心思想:
-
受神经科学 (特别是海马体中 Theta-Gamma 振荡将不同记忆项在时间上分离)的启发 ,论文提出了"记忆混合"架构。
-
MoM 不再只用一个记忆状态,而是维护多个独立的记忆状态 。
-
通过一个路由网络(Router),根据输入内容的重要性,将不同的 Token 导向不同的记忆单元进行存储 。
-
这使得模型既保留了线性模型的高效率,又拥有了接近 Transformer 的记忆容量和召回能力 。
2. 核心创新点是什么?
MoM 的创新点可以总结为以下三个方面:
- 多记忆状态与路由机制 (Mixture-of-Memories Strategy):
-
不同于传统线性模型只有一个 (记忆矩阵),MoM 设有 个记忆单元 。
-
引入了 Top-k Router:对于每个输入 Token,路由网络会计算其权重,只激活最相关的 个记忆单元进行更新,其他单元保持不变 。
-
这种机制有效地将不同类型的信息隔离开,极大减少了记忆干扰 。
- 共享记忆 (Shared Memory):
-
除了动态路由的记忆单元外,MoM 还引入了一个始终激活的共享记忆(Shared Memory) 。
-
作用:共享记忆负责捕捉全局上下文和长程依赖,确保即使某些 Token 被分流,模型依然能把握整体语义 。
- 受脑启发的抗干扰设计:
-
它模拟了大脑"E%-max"赢家通吃机制,只激活部分神经元来处理特定信息 。
-
在推理时,通过加权求和(Weighted Sum)混合这些记忆产生的输出,使得推理复杂度保持在 (常数级内存),没有 Transformer 那样随着序列增长而无限膨胀的 KV Cache 。
3. Python Demo 代码辅助理解
这段代码模拟了 MoM 层在**推理(Recurrent Mode)**时的核心逻辑:路由 -> 选择记忆 -> 更新记忆 -> 混合输出。
(注:为了易读性,这里使用了最基础的 Linear Attention 更新规则 ,论文中实际实验使用了 Gated DeltaNet 等更高级的更新规则,但核心逻辑一致。)
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class MixtureOfMemoriesLayer(nn.Module):
def __init__(self, d_model, num_memories=4, top_k=2):
super().__init__()
self.d_model = d_model
self.num_memories = num_memories
self.top_k = top_k
# 1. 投影层:生成 Query, Key, Value
# 注意:MoM中每个记忆单元可能有独立的投影,这里简化为共享投影但独立更新
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
# 2. 路由网络 (Router)
# 决定当前 token 应该存入哪几个记忆单元
self.router = nn.Linear(d_model, num_memories)
# 3. 共享记忆 (Shared Memory) 的投影
self.w_k_shared = nn.Linear(d_model, d_model)
self.w_v_shared = nn.Linear(d_model, d_model)
# 输出层
self.w_o = nn.Linear(d_model, d_model)
def forward_inference_step(self, x_t, memories, shared_memory):
"""
模拟单步推理 (Recurrent Step)
x_t: 当前输入 token [batch_size, d_model]
memories: 列表,包含 num_memories 个矩阵 [batch_size, d_model, d_model]
shared_memory: 共享记忆矩阵 [batch_size, d_model, d_model]
"""
batch_size = x_t.size(0)
# --- A. 基础投影 ---
q_t = self.w_q(x_t) # [B, D]
k_t = self.w_k(x_t) # [B, D]
v_t = self.w_v(x_t) # [B, D]
# --- B. 路由 (Routing) ---
# 计算路由分数
router_logits = self.router(x_t) # [B, num_memories]
router_probs = F.softmax(router_logits, dim=-1)
# 选出 Top-K 的记忆单元索引和权重
topk_weights, topk_indices = torch.topk(router_probs, self.top_k, dim=-1)
# 归一化权重
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
# --- C. 更新记忆 (Update Memories) ---
# 1. 更新共享记忆 (Shared Memory 总是更新)
# Rule: M_shared = M_shared + K_shared^T * V_shared
k_s = self.w_k_shared(x_t)
v_s = self.w_v_shared(x_t)
# 外积更新 (Outer Product)
update_shared = torch.bmm(k_s.unsqueeze(2), v_s.unsqueeze(1))
new_shared_memory = shared_memory + update_shared
# 2. 更新被激活的局部记忆 (Routed Memories)
new_memories = [m.clone() for m in memories] # 浅拷贝用于演示
# 简化的循环处理(实际中会用 scatter/gather 并行化)
kv_update = torch.bmm(k_t.unsqueeze(2), v_t.unsqueeze(1)) # [B, D, D]
for i in range(self.top_k):
idx = topk_indices[:, i] # 当前 batch 选中的第 i 个记忆单元索引
# 注意:这里简化处理,假设 batch=1 或者所有样本选的一样
# 真实实现需要处理 batch 中不同样本选不同记忆的情况
chosen_mem_idx = idx[0].item()
# 更新该记忆单元:M_new = M_old + K^T * V
new_memories[chosen_mem_idx] = new_memories[chosen_mem_idx] + kv_update
# --- D. 混合输出 (Mixture of Memories) ---
# 混合记忆状态 = 共享记忆 + 加权的局部记忆
mixed_memory = new_shared_memory.clone()
for i in range(self.top_k):
idx = topk_indices[:, i]
weight = topk_weights[:, i].unsqueeze(1).unsqueeze(2) # [B, 1, 1]
chosen_mem_idx = idx[0].item()
# Weighted Sum
mixed_memory += weight * new_memories[chosen_mem_idx]
# --- E. 计算输出 ---
# Output = Query * Mixed_Memory
# [B, 1, D] x [B, D, D] -> [B, 1, D]
o_t = torch.bmm(q_t.unsqueeze(1), mixed_memory).squeeze(1)
output = self.w_o(o_t)
return output, new_memories, new_shared_memory
# === 运行 Demo ===
# 初始化参数
D_MODEL = 16
NUM_MEM = 4
TOP_K = 2
layer = MixtureOfMemoriesLayer(D_MODEL, NUM_MEM, TOP_K)
# 模拟输入
x_input = torch.randn(1, D_MODEL) # Batch=1
# 初始化记忆状态 (全0)
memories = [torch.zeros(1, D_MODEL, D_MODEL) for _ in range(NUM_MEM)]
shared_mem = torch.zeros(1, D_MODEL, D_MODEL)
print("Input shape:", x_input.shape)
print(f"MoM Config: {NUM_MEM} memories, activating Top-{TOP_K}")
# 运行一步推理
out, new_mems, new_shared = layer.forward_inference_step(x_input, memories, shared_mem)
print("Output shape:", out.shape)
print("Shared Memory Updated?", not torch.allclose(shared_mem, new_shared))
print("Specific Memory 0 Updated?", not torch.allclose(memories[0], new_mems[0]))
代码解析
router: 决定了当前的信息x_t到底重要性如何,应该存到哪几个memories列表的槽位中。update: 这里的+ kv_update就是线性 Attention 的核心,把 KV 对写入矩阵。在 MoM 中,只有被 Router 选中的矩阵会被写入,其他的保持静默(从而保护了之前存入的旧信息不被覆盖)。mixed_memory: 最后查询时,不是去查某个单一的记忆,而是把相关的记忆加权融合起来查。这就像大脑提取记忆时,会同时调动多个相关的区域。
这篇论文的实验设计非常扎实,主要通过**对比实验(Comparative Experiments)和消融实验(Ablation Studies)**来验证 MoM 架构在解决"记忆干扰"和提升"长程召回能力"方面的有效性。
以下是具体的实验设置与验证方法:
1. 实验设置 (Experimental Setup)
为了公平比较,作者从零开始训练了所有模型,保持训练数据和计算资源的一致性。
-
模型配置:
-
核心架构 :MoM 使用 Gated DeltaNet 作为基础的记忆更新机制 。
-
具体参数 :设置了 4 个记忆状态(Memory States) ,每次路由激活其中的 2 个 ,并额外配备一个共享记忆(Shared Memory) 。
-
模型规模 :1. 340M 参数版 :在 150B (15B tokens) 数据上训练 。2. 1.3B 参数版:在 100B tokens 数据上训练 。
-
基线模型 (Baselines):
-
对比了主流的线性序列模型:RetNet , GLA (Gated Linear Attention), Gated DeltaNet , HGRN2 。
-
对比了标准 Transformer:Transformer++ (包含 Rotary Pos Emb 和 GLU) 。
-
训练数据 :使用 SlimPajama 数据集,这是 RedPajama 的去重清洗版 。
-
硬件环境:使用 32 张 NVIDIA A800 GPU 进行训练 。
2. 验证方法与任务 (Validation & Benchmarks)
论文通过三个维度的任务来验证 MoM 的性能,重点关注记忆召回能力。
A. 核心验证:召回密集型任务 (Recall-Intensive Tasks)
这是验证 MoM 是否解决了"记忆干扰"问题的关键实验。
-
任务内容:选取了 6 个需要长上下文检索的任务,包括 FDA, SWDE, SQUAD, NQ (Natural Questions), TriviaQA, Drop 。
-
验证逻辑:这些任务强依赖于从长文中精准提取信息。如果 MoM 能显著优于其他线性模型,说明其"多记忆槽位"设计成功减少了信息覆盖。
-
结果:MoM 在所有线性模型中表现最好,并且在 1.3B 参数下,性能已经非常接近 Transformer++,证明了其强大的记忆容量 。
B. 通用能力:常识推理与语言建模 (Commonsense Reasoning)
验证模型作为通用语言模型的基础能力。
-
任务内容:测试了 WikiText (Perplexity), LAMBADA, HellaSwag, PIQA, ARC, WinoGrande 等标准数据集 。
-
结果:MoM 的平均表现优于其他线性模型,甚至在部分指标上超过了 Transformer 。
C. 长上下文能力:LongBench
验证模型处理超长序列(Long Context)的能力。
-
任务内容:使用 LongBench 评测集,涵盖摘要、少样本学习、合成任务和代码补全 。
-
结果:MoM 在平均分上优于 RetNet 和 Gated DeltaNet,证明其能够有效处理长程依赖 。
3. 关键验证:消融实验 (Ablation Studies)
为了证明性能提升是因为"MoM 架构"而不是单纯增加了参数或状态大小,作者设计了非常巧妙的对比实验:
-
验证 1:混合记忆 vs. 单一大记忆 (Mixed Memory vs. Single Memory)
-
假设:有人可能会质疑,MoM 效果好只是因为你有多个记忆矩阵,存的东西变多了。
-
实验 :作者构建了一个基线模型,将其单个记忆矩阵的尺寸扩大,使其总容量等于 MoM 所有激活记忆的总和 。
-
结论 :实验数据显示,MoM (多个独立记忆) 的效果显著优于 Expanded Single Memory (单个大记忆) 。
-
意义 :这直接证明了"将记忆隔离"是减少干扰的关键,单纯增加容量(而不隔离)效果不如 MoM 。
-
验证 2:记忆数量与共享记忆的影响
-
测试了不同的记忆数量(2, 4, 8 个)和是否使用共享记忆。
-
结论 :增加记忆数量通常能提升性能;共享记忆(Shared Memory) 对于维持全局信息非常重要,移除它会导致性能下降 。
4. 效率验证 (Efficiency)
-
训练/推理速度:绘制了不同序列长度下的推理时间和显存占用图。
-
结果:随着序列长度增加,Transformer 的显存和时间呈二次方爆炸增长,而 MoM 保持线性(Linear)增长,推理显存保持恒定(Constant)。
-
训练 Loss 曲线:展示了训练过程中的 Loss 下降情况,MoM 的收敛曲线始终位于其他线性模型下方,说明学习效率更高 。
总结
论文通过控制变量法 (保持训练数据和基础架构一致),在特定的痛点任务 (召回任务)上证明了 MoM 的优越性,并通过与"单一大记忆模型"的对比,有力地回击了"仅仅是增加了参数"的质疑,验证了 隔离记忆(Separating Memories) 是解决线性模型遗忘问题的有效途径。