MoM (Mixture-of-Memories)新型线性序列建模架构

MoM (Mixture-of-Memories)新型线性序列建模架构

返回论文目录
返回资料目录

1. 论文讲了什么?

这篇论文旨在解决当前线性序列模型 (如 Linear Attention, Mamba/SSM, Linear RNNs)存在的一个核心缺陷:记忆干扰(Memory Interference)

  • 背景痛点

  • 传统的 Transformer 效果好,但推理成本高(KV Cache 显存占用大),复杂度是 N 2 N^2 N2

  • 现有的 线性模型 (如 Mamba, RWKV, RetNet)虽然训练快、推理省内存,但它们通常将整个输入序列压缩进 一个固定大小的记忆状态(Memory State) 中 。

  • 后果:当新信息进入时,旧信息容易被覆盖或干扰,导致模型在需要"回忆"大量细节的任务(Recall-Intensive Tasks)上表现不佳 。

  • MoM 的核心思想

  • 神经科学 (特别是海马体中 Theta-Gamma 振荡将不同记忆项在时间上分离)的启发 ,论文提出了"记忆混合"架构。

  • MoM 不再只用一个记忆状态,而是维护多个独立的记忆状态

  • 通过一个路由网络(Router),根据输入内容的重要性,将不同的 Token 导向不同的记忆单元进行存储 。

  • 这使得模型既保留了线性模型的高效率,又拥有了接近 Transformer 的记忆容量和召回能力 。

2. 核心创新点是什么?

MoM 的创新点可以总结为以下三个方面:

  1. 多记忆状态与路由机制 (Mixture-of-Memories Strategy)
  • 不同于传统线性模型只有一个 (记忆矩阵),MoM 设有 个记忆单元 。

  • 引入了 Top-k Router:对于每个输入 Token,路由网络会计算其权重,只激活最相关的 个记忆单元进行更新,其他单元保持不变 。

  • 这种机制有效地将不同类型的信息隔离开,极大减少了记忆干扰 。

  1. 共享记忆 (Shared Memory)
  • 除了动态路由的记忆单元外,MoM 还引入了一个始终激活的共享记忆(Shared Memory)

  • 作用:共享记忆负责捕捉全局上下文和长程依赖,确保即使某些 Token 被分流,模型依然能把握整体语义 。

  1. 受脑启发的抗干扰设计
  • 它模拟了大脑"E%-max"赢家通吃机制,只激活部分神经元来处理特定信息 。

  • 在推理时,通过加权求和(Weighted Sum)混合这些记忆产生的输出,使得推理复杂度保持在 (常数级内存),没有 Transformer 那样随着序列增长而无限膨胀的 KV Cache 。

3. Python Demo 代码辅助理解

这段代码模拟了 MoM 层在**推理(Recurrent Mode)**时的核心逻辑:路由 -> 选择记忆 -> 更新记忆 -> 混合输出

(注:为了易读性,这里使用了最基础的 Linear Attention 更新规则 ,论文中实际实验使用了 Gated DeltaNet 等更高级的更新规则,但核心逻辑一致。)

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class MixtureOfMemoriesLayer(nn.Module):
    def __init__(self, d_model, num_memories=4, top_k=2):
        super().__init__()
        self.d_model = d_model
        self.num_memories = num_memories
        self.top_k = top_k
        
        # 1. 投影层:生成 Query, Key, Value
        # 注意:MoM中每个记忆单元可能有独立的投影,这里简化为共享投影但独立更新
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        
        # 2. 路由网络 (Router)
        # 决定当前 token 应该存入哪几个记忆单元
        self.router = nn.Linear(d_model, num_memories)
        
        # 3. 共享记忆 (Shared Memory) 的投影
        self.w_k_shared = nn.Linear(d_model, d_model)
        self.w_v_shared = nn.Linear(d_model, d_model)

        # 输出层
        self.w_o = nn.Linear(d_model, d_model)

    def forward_inference_step(self, x_t, memories, shared_memory):
        """
        模拟单步推理 (Recurrent Step)
        x_t: 当前输入 token [batch_size, d_model]
        memories: 列表,包含 num_memories 个矩阵 [batch_size, d_model, d_model]
        shared_memory: 共享记忆矩阵 [batch_size, d_model, d_model]
        """
        batch_size = x_t.size(0)
        
        # --- A. 基础投影 ---
        q_t = self.w_q(x_t) # [B, D]
        k_t = self.w_k(x_t) # [B, D]
        v_t = self.w_v(x_t) # [B, D]
        
        # --- B. 路由 (Routing) ---
        # 计算路由分数
        router_logits = self.router(x_t) # [B, num_memories]
        router_probs = F.softmax(router_logits, dim=-1)
        
        # 选出 Top-K 的记忆单元索引和权重
        topk_weights, topk_indices = torch.topk(router_probs, self.top_k, dim=-1)
        # 归一化权重
        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
        
        # --- C. 更新记忆 (Update Memories) ---
        # 1. 更新共享记忆 (Shared Memory 总是更新)
        # Rule: M_shared = M_shared + K_shared^T * V_shared
        k_s = self.w_k_shared(x_t)
        v_s = self.w_v_shared(x_t)
        # 外积更新 (Outer Product)
        update_shared = torch.bmm(k_s.unsqueeze(2), v_s.unsqueeze(1)) 
        new_shared_memory = shared_memory + update_shared
        
        # 2. 更新被激活的局部记忆 (Routed Memories)
        new_memories = [m.clone() for m in memories] # 浅拷贝用于演示
        
        # 简化的循环处理(实际中会用 scatter/gather 并行化)
        kv_update = torch.bmm(k_t.unsqueeze(2), v_t.unsqueeze(1)) # [B, D, D]
        
        for i in range(self.top_k):
            idx = topk_indices[:, i] # 当前 batch 选中的第 i 个记忆单元索引
            # 注意:这里简化处理,假设 batch=1 或者所有样本选的一样
            # 真实实现需要处理 batch 中不同样本选不同记忆的情况
            chosen_mem_idx = idx[0].item() 
            
            # 更新该记忆单元:M_new = M_old + K^T * V
            new_memories[chosen_mem_idx] = new_memories[chosen_mem_idx] + kv_update

        # --- D. 混合输出 (Mixture of Memories) ---
        # 混合记忆状态 = 共享记忆 + 加权的局部记忆
        mixed_memory = new_shared_memory.clone()
        
        for i in range(self.top_k):
            idx = topk_indices[:, i]
            weight = topk_weights[:, i].unsqueeze(1).unsqueeze(2) # [B, 1, 1]
            chosen_mem_idx = idx[0].item()
            
            # Weighted Sum
            mixed_memory += weight * new_memories[chosen_mem_idx]
            
        # --- E. 计算输出 ---
        # Output = Query * Mixed_Memory
        # [B, 1, D] x [B, D, D] -> [B, 1, D]
        o_t = torch.bmm(q_t.unsqueeze(1), mixed_memory).squeeze(1)
        output = self.w_o(o_t)
        
        return output, new_memories, new_shared_memory

# === 运行 Demo ===
# 初始化参数
D_MODEL = 16
NUM_MEM = 4
TOP_K = 2
layer = MixtureOfMemoriesLayer(D_MODEL, NUM_MEM, TOP_K)

# 模拟输入
x_input = torch.randn(1, D_MODEL) # Batch=1

# 初始化记忆状态 (全0)
memories = [torch.zeros(1, D_MODEL, D_MODEL) for _ in range(NUM_MEM)]
shared_mem = torch.zeros(1, D_MODEL, D_MODEL)

print("Input shape:", x_input.shape)
print(f"MoM Config: {NUM_MEM} memories, activating Top-{TOP_K}")

# 运行一步推理
out, new_mems, new_shared = layer.forward_inference_step(x_input, memories, shared_mem)

print("Output shape:", out.shape)
print("Shared Memory Updated?", not torch.allclose(shared_mem, new_shared))
print("Specific Memory 0 Updated?", not torch.allclose(memories[0], new_mems[0]))

代码解析

  1. router : 决定了当前的信息 x_t 到底重要性如何,应该存到哪几个 memories 列表的槽位中。
  2. update : 这里的 + kv_update 就是线性 Attention 的核心,把 KV 对写入矩阵。在 MoM 中,只有被 Router 选中的矩阵会被写入,其他的保持静默(从而保护了之前存入的旧信息不被覆盖)。
  3. mixed_memory: 最后查询时,不是去查某个单一的记忆,而是把相关的记忆加权融合起来查。这就像大脑提取记忆时,会同时调动多个相关的区域。

这篇论文的实验设计非常扎实,主要通过**对比实验(Comparative Experiments)消融实验(Ablation Studies)**来验证 MoM 架构在解决"记忆干扰"和提升"长程召回能力"方面的有效性。

以下是具体的实验设置与验证方法:

1. 实验设置 (Experimental Setup)

为了公平比较,作者从零开始训练了所有模型,保持训练数据和计算资源的一致性。

  • 模型配置

  • 核心架构 :MoM 使用 Gated DeltaNet 作为基础的记忆更新机制 。

  • 具体参数 :设置了 4 个记忆状态(Memory States) ,每次路由激活其中的 2 个 ,并额外配备一个共享记忆(Shared Memory)

  • 模型规模 :1. 340M 参数版 :在 150B (15B tokens) 数据上训练 。2. 1.3B 参数版:在 100B tokens 数据上训练 。

  • 基线模型 (Baselines)

  • 对比了主流的线性序列模型:RetNet , GLA (Gated Linear Attention), Gated DeltaNet , HGRN2

  • 对比了标准 Transformer:Transformer++ (包含 Rotary Pos Emb 和 GLU) 。

  • 训练数据 :使用 SlimPajama 数据集,这是 RedPajama 的去重清洗版 。

  • 硬件环境:使用 32 张 NVIDIA A800 GPU 进行训练 。

2. 验证方法与任务 (Validation & Benchmarks)

论文通过三个维度的任务来验证 MoM 的性能,重点关注记忆召回能力

A. 核心验证:召回密集型任务 (Recall-Intensive Tasks)

这是验证 MoM 是否解决了"记忆干扰"问题的关键实验。

  • 任务内容:选取了 6 个需要长上下文检索的任务,包括 FDA, SWDE, SQUAD, NQ (Natural Questions), TriviaQA, Drop 。

  • 验证逻辑:这些任务强依赖于从长文中精准提取信息。如果 MoM 能显著优于其他线性模型,说明其"多记忆槽位"设计成功减少了信息覆盖。

  • 结果:MoM 在所有线性模型中表现最好,并且在 1.3B 参数下,性能已经非常接近 Transformer++,证明了其强大的记忆容量 。

B. 通用能力:常识推理与语言建模 (Commonsense Reasoning)

验证模型作为通用语言模型的基础能力。

  • 任务内容:测试了 WikiText (Perplexity), LAMBADA, HellaSwag, PIQA, ARC, WinoGrande 等标准数据集 。

  • 结果:MoM 的平均表现优于其他线性模型,甚至在部分指标上超过了 Transformer 。

C. 长上下文能力:LongBench

验证模型处理超长序列(Long Context)的能力。

  • 任务内容:使用 LongBench 评测集,涵盖摘要、少样本学习、合成任务和代码补全 。

  • 结果:MoM 在平均分上优于 RetNet 和 Gated DeltaNet,证明其能够有效处理长程依赖 。

3. 关键验证:消融实验 (Ablation Studies)

为了证明性能提升是因为"MoM 架构"而不是单纯增加了参数或状态大小,作者设计了非常巧妙的对比实验:

  • 验证 1:混合记忆 vs. 单一大记忆 (Mixed Memory vs. Single Memory)

  • 假设:有人可能会质疑,MoM 效果好只是因为你有多个记忆矩阵,存的东西变多了。

  • 实验 :作者构建了一个基线模型,将其单个记忆矩阵的尺寸扩大,使其总容量等于 MoM 所有激活记忆的总和 。

  • 结论 :实验数据显示,MoM (多个独立记忆) 的效果显著优于 Expanded Single Memory (单个大记忆)

  • 意义 :这直接证明了"将记忆隔离"是减少干扰的关键,单纯增加容量(而不隔离)效果不如 MoM 。

  • 验证 2:记忆数量与共享记忆的影响

  • 测试了不同的记忆数量(2, 4, 8 个)和是否使用共享记忆。

  • 结论 :增加记忆数量通常能提升性能;共享记忆(Shared Memory) 对于维持全局信息非常重要,移除它会导致性能下降 。

4. 效率验证 (Efficiency)

  • 训练/推理速度:绘制了不同序列长度下的推理时间和显存占用图。

  • 结果:随着序列长度增加,Transformer 的显存和时间呈二次方爆炸增长,而 MoM 保持线性(Linear)增长,推理显存保持恒定(Constant)。

  • 训练 Loss 曲线:展示了训练过程中的 Loss 下降情况,MoM 的收敛曲线始终位于其他线性模型下方,说明学习效率更高 。

总结

论文通过控制变量法 (保持训练数据和基础架构一致),在特定的痛点任务 (召回任务)上证明了 MoM 的优越性,并通过与"单一大记忆模型"的对比,有力地回击了"仅仅是增加了参数"的质疑,验证了 隔离记忆(Separating Memories) 是解决线性模型遗忘问题的有效途径。

相关推荐
视觉人机器视觉2 小时前
ROS2安装步骤总结
人工智能
非著名架构师2 小时前
空间计算的“环境校准器”:高精度AI气象如何为AR导航与自动驾驶提供厘米级实时大气修正?
人工智能·ar·空间计算
Java后端的Ai之路2 小时前
【机器学习】-超参数(模型“调音师”的魔法)
人工智能·机器学习
雨大王5122 小时前
汽车制造的智能化升级:工业AI平台如何重构生产线?
人工智能·汽车·制造
retrofit2 小时前
基于PyTorch的深度学习基础课程之十二:卷积神经网络
pytorch·深度学习·cnn·卷积神经网络
AKAMAI10 小时前
Akamai Cloud客户案例 | Avesha 在 Akamai 云上扩展 Kubernetes 解决方案
人工智能·云计算
小股虫10 小时前
数据一致性保障:从理论深度到架构实践的十年沉淀
架构·wpf
wasp52010 小时前
AgentScope Java 核心架构深度解析
java·开发语言·人工智能·架构·agentscope
智算菩萨10 小时前
高效多模态大语言模型:从统一框架到训练与推理效率的系统化理论梳理
大数据·人工智能·多模态